/*
 * Decompiled with CFR 0.152.
 */
package org.elasticsearch.xpack.ml.inference.loadingservice;

import java.util.ArrayDeque;
import java.util.Arrays;
import java.util.Collections;
import java.util.EnumSet;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Queue;
import java.util.Set;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.TimeUnit;
import java.util.function.Supplier;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.apache.lucene.util.SetOnce;
import org.elasticsearch.ElasticsearchStatusException;
import org.elasticsearch.ResourceNotFoundException;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.cluster.ClusterChangedEvent;
import org.elasticsearch.cluster.ClusterState;
import org.elasticsearch.cluster.ClusterStateListener;
import org.elasticsearch.cluster.service.ClusterService;
import org.elasticsearch.common.breaker.CircuitBreaker;
import org.elasticsearch.common.breaker.CircuitBreakingException;
import org.elasticsearch.common.cache.Cache;
import org.elasticsearch.common.cache.CacheBuilder;
import org.elasticsearch.common.cache.CacheLoader;
import org.elasticsearch.common.cache.RemovalNotification;
import org.elasticsearch.common.settings.Setting;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.common.unit.ByteSizeValue;
import org.elasticsearch.common.util.set.Sets;
import org.elasticsearch.core.Nullable;
import org.elasticsearch.core.Strings;
import org.elasticsearch.core.TimeValue;
import org.elasticsearch.ingest.IngestMetadata;
import org.elasticsearch.license.License;
import org.elasticsearch.license.XPackLicenseState;
import org.elasticsearch.rest.RestStatus;
import org.elasticsearch.tasks.TaskId;
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.xpack.core.ml.action.GetTrainedModelsAction;
import org.elasticsearch.xpack.core.ml.inference.ModelAliasMetadata;
import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig;
import org.elasticsearch.xpack.core.ml.inference.TrainedModelType;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfig;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.RegressionConfig;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TargetType;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.inference.InferenceDefinition;
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
import org.elasticsearch.xpack.core.ml.utils.NamedXContentObject;
import org.elasticsearch.xpack.ml.MachineLearning;
import org.elasticsearch.xpack.ml.inference.TrainedModelStatsService;
import org.elasticsearch.xpack.ml.inference.loadingservice.LocalModel;
import org.elasticsearch.xpack.ml.inference.persistence.TrainedModelProvider;
import org.elasticsearch.xpack.ml.notifications.InferenceAuditor;

public class ModelLoadingService
implements ClusterStateListener {
    public static final Setting<ByteSizeValue> INFERENCE_MODEL_CACHE_SIZE = Setting.memorySizeSetting((String)"xpack.ml.inference_model.cache_size", (String)"40%", (Setting.Property[])new Setting.Property[]{Setting.Property.NodeScope});
    public static final Setting<TimeValue> INFERENCE_MODEL_CACHE_TTL = Setting.timeSetting((String)"xpack.ml.inference_model.time_to_live", (TimeValue)new TimeValue(5L, TimeUnit.MINUTES), (TimeValue)new TimeValue(1L, TimeUnit.MILLISECONDS), (Setting.Property[])new Setting.Property[]{Setting.Property.NodeScope});
    private static final Logger logger = LogManager.getLogger(ModelLoadingService.class);
    private final TrainedModelStatsService modelStatsService;
    private final Cache<String, ModelAndConsumer> localModelCache;
    private final Set<String> referencedModels = new HashSet<String>();
    private final Map<String, String> modelAliasToId = new HashMap<String, String>();
    private final Map<String, Set<String>> modelIdToModelAliases = new HashMap<String, Set<String>>();
    private final Map<String, Set<String>> modelIdToUpdatedModelAliases = new HashMap<String, Set<String>>();
    private final Map<String, Queue<ActionListener<LocalModel>>> loadingListeners = new HashMap<String, Queue<ActionListener<LocalModel>>>();
    private final TrainedModelProvider provider;
    private final Set<String> shouldNotAudit;
    private final ThreadPool threadPool;
    private final InferenceAuditor auditor;
    private final ByteSizeValue maxCacheSize;
    private final String localNode;
    private final CircuitBreaker trainedModelCircuitBreaker;
    private final XPackLicenseState licenseState;

    public ModelLoadingService(TrainedModelProvider trainedModelProvider, InferenceAuditor auditor, ThreadPool threadPool, ClusterService clusterService, TrainedModelStatsService modelStatsService, Settings settings, String localNode, CircuitBreaker trainedModelCircuitBreaker, XPackLicenseState licenseState) {
        this.provider = trainedModelProvider;
        this.threadPool = threadPool;
        this.maxCacheSize = (ByteSizeValue)INFERENCE_MODEL_CACHE_SIZE.get(settings);
        this.auditor = auditor;
        this.modelStatsService = modelStatsService;
        this.shouldNotAudit = new HashSet<String>();
        this.localModelCache = CacheBuilder.builder().setMaximumWeight(this.maxCacheSize.getBytes()).weigher((id, modelAndConsumer) -> modelAndConsumer.model.ramBytesUsed()).removalListener(this::cacheEvictionListener).setExpireAfterAccess((TimeValue)INFERENCE_MODEL_CACHE_TTL.get(settings)).build();
        clusterService.addListener((ClusterStateListener)this);
        this.localNode = localNode;
        this.trainedModelCircuitBreaker = (CircuitBreaker)ExceptionsHelper.requireNonNull((Object)trainedModelCircuitBreaker, (String)"trainedModelCircuitBreaker");
        this.licenseState = licenseState;
    }

    public String getModelId(String modelIdOrAlias) {
        return this.modelAliasToId.getOrDefault(modelIdOrAlias, modelIdOrAlias);
    }

    boolean isModelCached(String modelId) {
        return this.localModelCache.get((Object)this.modelAliasToId.getOrDefault(modelId, modelId)) != null;
    }

    public ByteSizeValue getMaxCacheSize() {
        return this.maxCacheSize;
    }

    public ByteSizeValue getCurrentCacheSize() {
        return ByteSizeValue.ofBytes((long)this.localModelCache.weight());
    }

    public void getModelForPipeline(String modelId, TaskId parentTaskId, ActionListener<LocalModel> modelActionListener) {
        this.getModel(modelId, Consumer.PIPELINE, parentTaskId, modelActionListener);
    }

    public void getModelForInternalInference(String modelId, ActionListener<LocalModel> modelActionListener) {
        this.getModel(modelId, Consumer.INTERNAL, null, modelActionListener);
    }

    public void getModelForAggregation(String modelId, ActionListener<LocalModel> modelActionListener) {
        this.getModel(modelId, Consumer.SEARCH_AGGS, null, modelActionListener);
    }

    public void getModelForLearningToRank(String modelId, ActionListener<LocalModel> modelActionListener) {
        this.getModel(modelId, Consumer.SEARCH_RESCORER, null, modelActionListener);
    }

    private void getModel(String modelIdOrAlias, Consumer consumer, TaskId parentTaskId, ActionListener<LocalModel> modelActionListener) {
        String modelId = this.modelAliasToId.getOrDefault(modelIdOrAlias, modelIdOrAlias);
        ModelAndConsumer cachedModel = (ModelAndConsumer)this.localModelCache.get((Object)modelId);
        if (cachedModel != null) {
            if (!consumer.inferenceConfigSupported(cachedModel.model.getInferenceConfig())) {
                modelActionListener.onFailure((Exception)((Object)ModelLoadingService.modelUnsupportedInUsageContext(modelId, cachedModel.model.getTrainedModelType(), cachedModel.model.getInferenceConfig(), consumer)));
                return;
            }
            cachedModel.consumers.add(consumer);
            try {
                cachedModel.model.acquire();
            }
            catch (CircuitBreakingException e) {
                modelActionListener.onFailure((Exception)((Object)e));
                return;
            }
            modelActionListener.onResponse((Object)cachedModel.model);
            logger.trace(() -> Strings.format((String)"[%s] (model_alias [%s]) loaded from cache", (Object[])new Object[]{modelId, modelIdOrAlias}));
            return;
        }
        if (this.loadModelIfNecessary(modelIdOrAlias, consumer, parentTaskId, modelActionListener)) {
            logger.trace(() -> Strings.format((String)"[%s] (model_alias [%s]) is loading or loaded, added new listener to queue", (Object[])new Object[]{modelId, modelIdOrAlias}));
        }
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     * Enabled aggressive block sorting
     * Enabled unnecessary exception pruning
     * Enabled aggressive exception aggregation
     * Converted monitor instructions to comments
     * Lifted jumps to return sites
     */
    private boolean loadModelIfNecessary(String modelIdOrAlias, Consumer consumer, TaskId parentTaskId, ActionListener<LocalModel> modelActionListener) {
        String modelId;
        SetOnce modelLoadingRunnable;
        SetOnce localModelToNotifyListener;
        SetOnce exceptionToNotifyListener;
        block38: {
            ModelAndConsumer cachedModel;
            exceptionToNotifyListener = new SetOnce();
            localModelToNotifyListener = new SetOnce();
            modelLoadingRunnable = new SetOnce();
            try {
                Map<String, Queue<ActionListener<LocalModel>>> map = this.loadingListeners;
                // MONITORENTER : map
                modelId = this.modelAliasToId.getOrDefault(modelIdOrAlias, modelIdOrAlias);
                cachedModel = (ModelAndConsumer)this.localModelCache.get((Object)modelId);
                if (cachedModel == null) break block38;
                cachedModel.consumers.add(consumer);
                try {
                    cachedModel.model.acquire();
                }
                catch (CircuitBreakingException e) {
                    exceptionToNotifyListener.set((Object)e);
                    boolean bl = true;
                    // MONITOREXIT : map
                    assert (exceptionToNotifyListener.get() == null || localModelToNotifyListener.get() == null) : "both exception and local model set";
                    if (exceptionToNotifyListener.get() != null) {
                        assert (modelLoadingRunnable.get() == null) : "Exception encountered, model loading runnable should be null";
                        modelActionListener.onFailure((Exception)exceptionToNotifyListener.get());
                        return bl;
                    }
                    if (localModelToNotifyListener.get() == null) {
                        if (modelLoadingRunnable.get() == null) return bl;
                        ((Runnable)modelLoadingRunnable.get()).run();
                        return bl;
                    }
                    assert (modelLoadingRunnable.get() == null) : "Model was cached, model loading runnable should be null";
                    modelActionListener.onResponse((Object)((LocalModel)localModelToNotifyListener.get()));
                    return bl;
                }
            }
            catch (Throwable throwable) {
                assert (exceptionToNotifyListener.get() == null || localModelToNotifyListener.get() == null) : "both exception and local model set";
                if (exceptionToNotifyListener.get() != null) {
                    assert (modelLoadingRunnable.get() == null) : "Exception encountered, model loading runnable should be null";
                    modelActionListener.onFailure((Exception)exceptionToNotifyListener.get());
                    throw throwable;
                }
                if (localModelToNotifyListener.get() == null) {
                    if (modelLoadingRunnable.get() == null) throw throwable;
                    ((Runnable)modelLoadingRunnable.get()).run();
                    throw throwable;
                }
                assert (modelLoadingRunnable.get() == null) : "Model was cached, model loading runnable should be null";
                modelActionListener.onResponse((Object)((LocalModel)localModelToNotifyListener.get()));
                throw throwable;
            }
            localModelToNotifyListener.set((Object)cachedModel.model);
            boolean e = true;
            // MONITOREXIT : map
            assert (exceptionToNotifyListener.get() == null || localModelToNotifyListener.get() == null) : "both exception and local model set";
            if (exceptionToNotifyListener.get() != null) {
                assert (modelLoadingRunnable.get() == null) : "Exception encountered, model loading runnable should be null";
                modelActionListener.onFailure((Exception)exceptionToNotifyListener.get());
                return e;
            }
            if (localModelToNotifyListener.get() == null) {
                if (modelLoadingRunnable.get() == null) return e;
                ((Runnable)modelLoadingRunnable.get()).run();
                return e;
            }
            assert (modelLoadingRunnable.get() == null) : "Model was cached, model loading runnable should be null";
            modelActionListener.onResponse((Object)((LocalModel)localModelToNotifyListener.get()));
            return e;
        }
        Queue listeners = this.loadingListeners.computeIfPresent(modelId, (storedModelKey, listenerQueue) -> ModelLoadingService.addFluently(listenerQueue, modelActionListener));
        if (listeners != null) {
            boolean bl = true;
            // MONITOREXIT : map
            assert (exceptionToNotifyListener.get() == null || localModelToNotifyListener.get() == null) : "both exception and local model set";
            if (exceptionToNotifyListener.get() != null) {
                assert (modelLoadingRunnable.get() == null) : "Exception encountered, model loading runnable should be null";
                modelActionListener.onFailure((Exception)exceptionToNotifyListener.get());
                return bl;
            }
            if (localModelToNotifyListener.get() == null) {
                if (modelLoadingRunnable.get() == null) return bl;
                ((Runnable)modelLoadingRunnable.get()).run();
                return bl;
            }
            assert (modelLoadingRunnable.get() == null) : "Model was cached, model loading runnable should be null";
            modelActionListener.onResponse((Object)((LocalModel)localModelToNotifyListener.get()));
            return bl;
        }
        if (!consumer.isAnyOf(Consumer.SEARCH_AGGS, Consumer.SEARCH_RESCORER) && !this.referencedModels.contains(modelId)) {
            logger.trace(() -> Strings.format((String)"[%s] (model_alias [%s]) not actively loading, eager loading without cache", (Object[])new Object[]{modelId, modelIdOrAlias}));
            modelLoadingRunnable.set(() -> this.loadWithoutCaching(modelId, consumer, parentTaskId, modelActionListener));
        } else {
            logger.trace(() -> Strings.format((String)"[%s] (model_alias [%s]) attempting to load and cache", (Object[])new Object[]{modelId, modelIdOrAlias}));
            this.loadingListeners.put(modelId, ModelLoadingService.addFluently(new ArrayDeque(), modelActionListener));
            modelLoadingRunnable.set(() -> this.loadModel(modelId, consumer));
        }
        boolean bl = false;
        // MONITOREXIT : map
        assert (exceptionToNotifyListener.get() == null || localModelToNotifyListener.get() == null) : "both exception and local model set";
        if (exceptionToNotifyListener.get() != null) {
            assert (modelLoadingRunnable.get() == null) : "Exception encountered, model loading runnable should be null";
            modelActionListener.onFailure((Exception)exceptionToNotifyListener.get());
            return bl;
        }
        if (localModelToNotifyListener.get() == null) {
            if (modelLoadingRunnable.get() == null) return bl;
            ((Runnable)modelLoadingRunnable.get()).run();
            return bl;
        }
        assert (modelLoadingRunnable.get() == null) : "Model was cached, model loading runnable should be null";
        modelActionListener.onResponse((Object)((LocalModel)localModelToNotifyListener.get()));
        return bl;
    }

    private void loadModel(String modelId, Consumer consumer) {
        this.provider.getTrainedModel(modelId, GetTrainedModelsAction.Includes.empty(), null, (ActionListener<TrainedModelConfig>)ActionListener.wrap(trainedModelConfig -> {
            if (!consumer.inferenceConfigSupported(trainedModelConfig.getInferenceConfig())) {
                this.handleLoadFailure(modelId, (Exception)((Object)ModelLoadingService.modelUnsupportedInUsageContext(modelId, trainedModelConfig.getModelType(), trainedModelConfig.getInferenceConfig(), consumer)));
                return;
            }
            if (trainedModelConfig.isAllocateOnly()) {
                this.handleLoadFailure(modelId, (Exception)((Object)ModelLoadingService.modelMustBeDeployedError(modelId)));
                return;
            }
            this.auditNewReferencedModel(modelId);
            this.trainedModelCircuitBreaker.addEstimateBytesAndMaybeBreak(trainedModelConfig.getModelSize(), modelId);
            this.provider.getTrainedModelForInference(modelId, consumer == Consumer.INTERNAL, (ActionListener<InferenceDefinition>)ActionListener.wrap(inferenceDefinition -> {
                try {
                    this.updateCircuitBreakerEstimate(modelId, (InferenceDefinition)inferenceDefinition, (TrainedModelConfig)trainedModelConfig);
                }
                catch (CircuitBreakingException ex) {
                    this.handleLoadFailure(modelId, (Exception)((Object)ex));
                    return;
                }
                this.handleLoadSuccess(modelId, consumer, (TrainedModelConfig)trainedModelConfig, (InferenceDefinition)inferenceDefinition);
            }, failure -> {
                this.trainedModelCircuitBreaker.addWithoutBreaking(-trainedModelConfig.getModelSize());
                logger.warn(() -> "[" + modelId + "] failed to load model definition", (Throwable)failure);
                this.handleLoadFailure(modelId, (Exception)failure);
            }));
        }, failure -> {
            if (consumer != Consumer.PIPELINE) {
                logger.warn(() -> "[" + modelId + "] failed to load model configuration ", (Throwable)failure);
            }
            this.handleLoadFailure(modelId, (Exception)failure);
        }));
    }

    private void loadWithoutCaching(String modelId, Consumer consumer, TaskId parentTaskId, ActionListener<LocalModel> modelActionListener) {
        this.provider.getTrainedModel(modelId, GetTrainedModelsAction.Includes.empty(), parentTaskId, (ActionListener<TrainedModelConfig>)ActionListener.wrap(trainedModelConfig -> {
            if (!consumer.inferenceConfigSupported(trainedModelConfig.getInferenceConfig())) {
                this.handleLoadFailure(modelId, (Exception)((Object)ModelLoadingService.modelUnsupportedInUsageContext(modelId, trainedModelConfig.getModelType(), trainedModelConfig.getInferenceConfig(), consumer)));
                return;
            }
            if (trainedModelConfig.isAllocateOnly()) {
                modelActionListener.onFailure((Exception)((Object)ModelLoadingService.modelMustBeDeployedError(modelId)));
                return;
            }
            this.trainedModelCircuitBreaker.addEstimateBytesAndMaybeBreak(trainedModelConfig.getModelSize(), modelId);
            this.provider.getTrainedModelForInference(modelId, consumer == Consumer.INTERNAL, (ActionListener<InferenceDefinition>)ActionListener.wrap(inferenceDefinition -> {
                InferenceConfig inferenceConfig = trainedModelConfig.getInferenceConfig() == null ? ModelLoadingService.inferenceConfigFromTargetType(inferenceDefinition.getTargetType()) : trainedModelConfig.getInferenceConfig();
                try {
                    this.updateCircuitBreakerEstimate(modelId, (InferenceDefinition)inferenceDefinition, (TrainedModelConfig)trainedModelConfig);
                }
                catch (CircuitBreakingException ex) {
                    modelActionListener.onFailure((Exception)((Object)ex));
                    return;
                }
                modelActionListener.onResponse((Object)new LocalModel(trainedModelConfig.getModelId(), this.localNode, (InferenceDefinition)inferenceDefinition, trainedModelConfig.getInput(), trainedModelConfig.getDefaultFieldMap(), inferenceConfig, trainedModelConfig.getLicenseLevel(), trainedModelConfig.getModelType(), this.modelStatsService, this.trainedModelCircuitBreaker));
            }, e -> {
                this.trainedModelCircuitBreaker.addWithoutBreaking(-trainedModelConfig.getModelSize());
                if (ExceptionsHelper.unwrapCause((Throwable)e) instanceof ResourceNotFoundException) {
                    modelActionListener.onFailure(e);
                } else {
                    modelActionListener.onFailure((Exception)((Object)new ElasticsearchStatusException("failed to load model [{}] definition", RestStatus.INTERNAL_SERVER_ERROR, new Object[]{modelId, e})));
                }
            }));
        }, arg_0 -> modelActionListener.onFailure(arg_0)));
    }

    private void updateCircuitBreakerEstimate(String modelId, InferenceDefinition inferenceDefinition, TrainedModelConfig trainedModelConfig) throws CircuitBreakingException {
        long estimateDiff = inferenceDefinition.ramBytesUsed() - trainedModelConfig.getModelSize();
        if (estimateDiff < 0L) {
            this.trainedModelCircuitBreaker.addWithoutBreaking(estimateDiff);
        } else if (estimateDiff > 0L) {
            try {
                this.trainedModelCircuitBreaker.addEstimateBytesAndMaybeBreak(estimateDiff, modelId);
            }
            catch (CircuitBreakingException ex) {
                this.trainedModelCircuitBreaker.addWithoutBreaking(-trainedModelConfig.getModelSize());
                throw ex;
            }
        }
    }

    private static ElasticsearchStatusException modelMustBeDeployedError(String modelId) {
        return new ElasticsearchStatusException("Model [{}] must be deployed to use. Please deploy with the start trained model deployment API.", RestStatus.BAD_REQUEST, new Object[]{modelId});
    }

    private static ElasticsearchStatusException modelUnsupportedInUsageContext(String modelId, TrainedModelType modelType, InferenceConfig inferenceConfig, Consumer consumer) {
        return new ElasticsearchStatusException("Trained model [{}] with type [{}] and task [{}] is currently not usable in [{}].", RestStatus.BAD_REQUEST, new Object[]{modelId, modelType, Optional.ofNullable(inferenceConfig).map(NamedXContentObject::getName).orElse("_unknown_"), consumer.exceptionName()});
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    private void handleLoadSuccess(String modelId, Consumer consumer, TrainedModelConfig trainedModelConfig, InferenceDefinition inferenceDefinition) {
        Queue<ActionListener<LocalModel>> listeners;
        InferenceConfig inferenceConfig = trainedModelConfig.getInferenceConfig() == null ? ModelLoadingService.inferenceConfigFromTargetType(inferenceDefinition.getTargetType()) : trainedModelConfig.getInferenceConfig();
        LocalModel loadedModel = new LocalModel(trainedModelConfig.getModelId(), this.localNode, inferenceDefinition, trainedModelConfig.getInput(), trainedModelConfig.getDefaultFieldMap(), inferenceConfig, trainedModelConfig.getLicenseLevel(), Optional.ofNullable(trainedModelConfig.getModelType()).orElse(TrainedModelType.TREE_ENSEMBLE), this.modelStatsService, this.trainedModelCircuitBreaker);
        ModelAndConsumerLoader modelAndConsumerLoader = new ModelAndConsumerLoader(new ModelAndConsumer(loadedModel, consumer));
        Map<String, Queue<ActionListener<LocalModel>>> map = this.loadingListeners;
        synchronized (map) {
            this.populateNewModelAlias(modelId);
            if (this.referencedModels.contains(modelId) || Sets.haveNonEmptyIntersection((Set)this.modelIdToModelAliases.getOrDefault(modelId, new HashSet()), this.referencedModels) || consumer.isAnyOf(Consumer.SEARCH_AGGS, Consumer.SEARCH_RESCORER)) {
                try {
                    this.localModelCache.computeIfAbsent((Object)modelId, (CacheLoader)modelAndConsumerLoader);
                    if (!License.OperationMode.BASIC.equals((Object)trainedModelConfig.getLicenseLevel())) {
                        MachineLearning.ML_MODEL_INFERENCE_FEATURE.startTracking(this.licenseState, modelId);
                    }
                }
                catch (ExecutionException ee) {
                    logger.warn(() -> "[" + modelId + "] threw when attempting add to cache", (Throwable)ee);
                }
                this.shouldNotAudit.remove(modelId);
            }
            if ((listeners = this.loadingListeners.remove(modelId)) == null) {
                if (modelAndConsumerLoader.isLoaded()) {
                    loadedModel.release();
                }
                return;
            }
        }
        ActionListener<LocalModel> listener = listeners.poll();
        while (listener != null) {
            loadedModel.acquire();
            listener.onResponse((Object)loadedModel);
            listener = listeners.poll();
        }
        if (modelAndConsumerLoader.isLoaded()) {
            loadedModel.release();
        }
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    private void handleLoadFailure(String modelId, Exception failure) {
        Queue<ActionListener<LocalModel>> listeners;
        Map<String, Queue<ActionListener<LocalModel>>> map = this.loadingListeners;
        synchronized (map) {
            listeners = this.loadingListeners.remove(modelId);
            this.populateNewModelAlias(modelId);
            if (listeners == null) {
                return;
            }
        }
        ActionListener<LocalModel> listener = listeners.poll();
        while (listener != null) {
            listener.onFailure(failure);
            listener = listeners.poll();
        }
    }

    private void populateNewModelAlias(String modelId) {
        Set<String> newModelAliases = this.modelIdToUpdatedModelAliases.remove(modelId);
        if (newModelAliases != null && !newModelAliases.isEmpty()) {
            logger.trace(() -> Strings.format((String)"[%s] model is now loaded, setting new model_aliases %s", (Object[])new Object[]{modelId, newModelAliases}));
            for (String modelAlias : newModelAliases) {
                this.modelAliasToId.put(modelAlias, modelId);
            }
        }
    }

    private void cacheEvictionListener(RemovalNotification<String, ModelAndConsumer> notification) {
        try {
            if (notification.getRemovalReason() == RemovalNotification.RemovalReason.EVICTED) {
                Supplier<String> msg = () -> Strings.format((String)"model cache entry evicted.current cache [%s] current max [%s] model size [%s]. If this is undesired, consider updating setting [%s] or [%s].", (Object[])new Object[]{ByteSizeValue.ofBytes((long)this.localModelCache.weight()).getStringRep(), this.maxCacheSize.getStringRep(), ByteSizeValue.ofBytes((long)((ModelAndConsumer)notification.getValue()).model.ramBytesUsed()).getStringRep(), INFERENCE_MODEL_CACHE_SIZE.getKey(), INFERENCE_MODEL_CACHE_TTL.getKey()});
                this.auditIfNecessary((String)notification.getKey(), msg);
            }
            String modelId = this.modelAliasToId.getOrDefault(notification.getKey(), (String)notification.getKey());
            logger.trace(() -> Strings.format((String)"Persisting stats for evicted model [%s] (model_aliases %s)", (Object[])new Object[]{modelId, this.modelIdToModelAliases.getOrDefault(modelId, new HashSet())}));
            if (!this.referencedModels.contains(modelId)) {
                MachineLearning.ML_MODEL_INFERENCE_FEATURE.stopTracking(this.licenseState, modelId);
            }
            ((ModelAndConsumer)notification.getValue()).model.persistStats(!this.referencedModels.contains(modelId));
        }
        finally {
            ((ModelAndConsumer)notification.getValue()).model.release();
        }
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    public void clusterChanged(ClusterChangedEvent event) {
        Set removedModels;
        HashSet<String> referencedModelsBeforeClusterState;
        HashMap<String, Set<String>> oldIdToAliases;
        boolean prefetchModels;
        if (event.changedCustomMetadataSet().contains("trained_model_cache_metadata")) {
            logger.trace("Trained model cache invalidated on node [{}]", new org.apache.logging.log4j.util.Supplier[]{() -> event.state().nodes().getLocalNodeId()});
            this.localModelCache.invalidateAll();
        }
        if (!((prefetchModels = event.state().nodes().getLocalNode().isIngestNode()) || event.changedCustomMetadataSet().contains("ingest") || event.changedCustomMetadataSet().contains("trained_model_alias"))) {
            return;
        }
        ClusterState state = event.state();
        IngestMetadata currentIngestMetadata = (IngestMetadata)state.metadata().custom("ingest");
        HashSet<String> allReferencedModelKeys = event.changedCustomMetadataSet().contains("ingest") ? ModelLoadingService.countInferenceProcessors(currentIngestMetadata) : new HashSet<String>(this.referencedModels);
        HashSet<String> loadingModelBeforeClusterState = null;
        HashMap<String, Set> addedModelViaAliases = new HashMap<String, Set>();
        Map<String, Queue<ActionListener<LocalModel>>> map = this.loadingListeners;
        synchronized (map) {
            String modelId;
            oldIdToAliases = new HashMap<String, Set<String>>(this.modelIdToModelAliases);
            Map<String, String> changedAliases = this.gatherLazyChangedAliasesAndUpdateModelAliases(event, prefetchModels, allReferencedModelKeys);
            if (!prefetchModels) {
                return;
            }
            referencedModelsBeforeClusterState = new HashSet<String>(this.referencedModels);
            if (logger.isTraceEnabled()) {
                loadingModelBeforeClusterState = new HashSet<String>(this.loadingListeners.keySet());
            }
            removedModels = Sets.difference(referencedModelsBeforeClusterState, allReferencedModelKeys);
            this.referencedModels.removeAll(removedModels);
            this.shouldNotAudit.removeAll(removedModels);
            for (String string : removedModels) {
                modelId = changedAliases.getOrDefault(string, this.modelAliasToId.getOrDefault(string, string));
                boolean oldModelAliasesNotReferenced = Sets.haveEmptyIntersection(this.referencedModels, oldIdToAliases.getOrDefault(modelId, Collections.emptySet()));
                boolean modelIsNotReferenced = !this.referencedModels.contains(modelId);
                boolean newModelAliasesNotReferenced = Sets.haveEmptyIntersection(this.referencedModels, this.modelIdToModelAliases.getOrDefault(modelId, Collections.emptySet()));
                if (!oldModelAliasesNotReferenced || !newModelAliasesNotReferenced || !modelIsNotReferenced) continue;
                ModelAndConsumer modelAndConsumer = (ModelAndConsumer)this.localModelCache.get((Object)modelId);
                if (modelAndConsumer != null && modelAndConsumer.consumers.stream().noneMatch(c -> c.isAnyOf(Consumer.SEARCH_AGGS, Consumer.SEARCH_RESCORER))) {
                    logger.trace("[{} ({})] invalidated from cache", (Object)modelId, (Object)string);
                    this.localModelCache.invalidate((Object)modelId);
                }
                if (modelAndConsumer != null) continue;
                MachineLearning.ML_MODEL_INFERENCE_FEATURE.stopTracking(this.licenseState, modelId);
            }
            allReferencedModelKeys.removeAll(this.referencedModels);
            for (String string : allReferencedModelKeys) {
                modelId = changedAliases.getOrDefault(string, this.modelAliasToId.getOrDefault(string, string));
                if (this.referencedModels.contains(modelId)) continue;
                addedModelViaAliases.computeIfAbsent(modelId, k -> new HashSet()).add(string);
            }
            for (Map.Entry entry : changedAliases.entrySet()) {
                String modelAlias = (String)entry.getKey();
                String modelId2 = (String)entry.getValue();
                if (this.referencedModels.contains(modelAlias)) {
                    addedModelViaAliases.computeIfAbsent(modelId2, k -> new HashSet()).add(modelAlias);
                    String oldModelId = this.modelAliasToId.get(modelAlias);
                    if (oldModelId != null && this.localModelCache.get((Object)oldModelId) != null) {
                        this.modelIdToUpdatedModelAliases.computeIfAbsent(modelId2, k -> new HashSet()).add(modelAlias);
                        continue;
                    }
                    this.modelAliasToId.put(modelAlias, modelId2);
                    continue;
                }
                this.modelAliasToId.put(modelAlias, modelId2);
            }
            this.referencedModels.addAll(allReferencedModelKeys);
            for (String string : addedModelViaAliases.keySet()) {
                this.loadingListeners.computeIfAbsent(string, s -> new ArrayDeque());
            }
        }
        if (logger.isTraceEnabled()) {
            if (!this.loadingListeners.keySet().equals(loadingModelBeforeClusterState)) {
                logger.trace("cluster state event changed loading models: before {} after {}", loadingModelBeforeClusterState, this.loadingListeners.keySet());
            }
            if (!this.referencedModels.equals(referencedModelsBeforeClusterState)) {
                logger.trace("cluster state event changed referenced models: before {} after {}", referencedModelsBeforeClusterState, this.referencedModels);
            }
            if (!oldIdToAliases.equals(this.modelIdToModelAliases)) {
                logger.trace("model id to alias mappings changed. before {} after {}. Model alias to IDs {}", oldIdToAliases, this.modelIdToModelAliases, this.modelAliasToId);
            }
            if (!addedModelViaAliases.isEmpty()) {
                logger.trace("adding new models via model_aliases and ids: {}", addedModelViaAliases);
            }
            if (!this.modelIdToUpdatedModelAliases.isEmpty()) {
                logger.trace("delayed model aliases to update {}", this.modelIdToModelAliases);
            }
        }
        removedModels.forEach(this::auditUnreferencedModel);
        this.loadModelsForPipeline(addedModelViaAliases.keySet());
    }

    private Map<String, String> gatherLazyChangedAliasesAndUpdateModelAliases(ClusterChangedEvent event, boolean prefetchModels, Set<String> allReferencedModelKeys) {
        HashMap<String, String> changedAliases = new HashMap<String, String>();
        if (event.changedCustomMetadataSet().contains("trained_model_alias")) {
            HashMap modelAliasesToIds = new HashMap(ModelAliasMetadata.fromState((ClusterState)event.state()).modelAliases());
            this.modelIdToModelAliases.clear();
            for (Map.Entry aliasToId : modelAliasesToIds.entrySet()) {
                this.modelIdToModelAliases.computeIfAbsent(((ModelAliasMetadata.ModelAliasEntry)aliasToId.getValue()).getModelId(), k -> new HashSet()).add((String)aliasToId.getKey());
                String modelId = this.modelAliasToId.get(aliasToId.getKey());
                if (modelId != null && !modelId.equals(((ModelAliasMetadata.ModelAliasEntry)aliasToId.getValue()).getModelId())) {
                    if (prefetchModels && allReferencedModelKeys.contains(aliasToId.getKey())) {
                        changedAliases.put((String)aliasToId.getKey(), ((ModelAliasMetadata.ModelAliasEntry)aliasToId.getValue()).getModelId());
                    } else {
                        this.modelAliasToId.put((String)aliasToId.getKey(), ((ModelAliasMetadata.ModelAliasEntry)aliasToId.getValue()).getModelId());
                    }
                }
                if (modelId != null) continue;
                this.modelAliasToId.put((String)aliasToId.getKey(), ((ModelAliasMetadata.ModelAliasEntry)aliasToId.getValue()).getModelId());
            }
            Set removedAliases = Sets.difference(this.modelAliasToId.keySet(), modelAliasesToIds.keySet());
            this.modelAliasToId.keySet().removeAll(removedAliases);
        }
        return changedAliases;
    }

    private void auditIfNecessary(String modelId, Supplier<String> msg) {
        if (this.shouldNotAudit.contains(modelId)) {
            logger.trace(() -> Strings.format((String)"[%s] %s", (Object[])new Object[]{modelId, msg.get()}));
            return;
        }
        this.auditor.info(modelId, msg.get());
        this.shouldNotAudit.add(modelId);
        logger.info("[{}] {}", (Object)modelId, (Object)msg.get());
    }

    private void loadModelsForPipeline(Set<String> modelIds) {
        if (modelIds.isEmpty()) {
            return;
        }
        this.threadPool.executor("ml_utility").execute(() -> {
            for (String modelId : modelIds) {
                this.loadModel(modelId, Consumer.PIPELINE);
            }
        });
    }

    private void auditNewReferencedModel(String modelId) {
        this.auditor.info(modelId, "referenced by ingest processors. Attempting to load model into cache");
    }

    private void auditUnreferencedModel(String modelId) {
        this.auditor.info(modelId, "no longer referenced by any processors");
    }

    private static <T> Queue<T> addFluently(Queue<T> queue, T object) {
        queue.add(object);
        return queue;
    }

    private static Set<String> countInferenceProcessors(IngestMetadata ingestMetadata) {
        HashSet<String> allReferencedModelKeys = new HashSet<String>();
        if (ingestMetadata == null) {
            return allReferencedModelKeys;
        }
        ingestMetadata.getPipelines().forEach((pipelineId, pipelineConfiguration) -> {
            Object processors = pipelineConfiguration.getConfig().get("processors");
            if (processors instanceof List) {
                for (Object processor : (List)processors) {
                    Object modelId;
                    Object processorConfig;
                    if (!(processor instanceof Map) || !((processorConfig = ((Map)processor).get("inference")) instanceof Map) || (modelId = ((Map)processorConfig).get("model_id")) == null) continue;
                    assert (modelId instanceof String);
                    allReferencedModelKeys.add(modelId.toString());
                }
            }
        });
        return allReferencedModelKeys;
    }

    private static InferenceConfig inferenceConfigFromTargetType(TargetType targetType) {
        return switch (targetType) {
            default -> throw new IncompatibleClassChangeError();
            case TargetType.REGRESSION -> RegressionConfig.EMPTY_PARAMS;
            case TargetType.CLASSIFICATION -> ClassificationConfig.EMPTY_PARAMS;
        };
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    void addModelLoadedListener(String modelId, ActionListener<LocalModel> modelLoadedListener) {
        Map<String, Queue<ActionListener<LocalModel>>> map = this.loadingListeners;
        synchronized (map) {
            this.loadingListeners.compute(modelId, (modelKey, listenerQueue) -> {
                if (listenerQueue == null) {
                    return ModelLoadingService.addFluently(new ArrayDeque(), modelLoadedListener);
                }
                return ModelLoadingService.addFluently(listenerQueue, modelLoadedListener);
            });
        }
    }

    /*
     * Uses 'sealed' constructs - enablewith --sealed true
     */
    public static enum Consumer {
        PIPELINE{

            @Override
            public boolean inferenceConfigSupported(InferenceConfig config) {
                return config == null || config.supportsIngestPipeline();
            }

            @Override
            public String exceptionName() {
                return "ingest";
            }
        }
        ,
        SEARCH_AGGS{

            @Override
            public boolean inferenceConfigSupported(InferenceConfig config) {
                return config == null || config.supportsPipelineAggregation();
            }

            @Override
            public String exceptionName() {
                return "search(aggregation)";
            }
        }
        ,
        SEARCH_RESCORER{

            @Override
            public boolean inferenceConfigSupported(InferenceConfig config) {
                return config != null && config.supportsSearchRescorer();
            }

            @Override
            public String exceptionName() {
                return "search(rescorer)";
            }
        }
        ,
        INTERNAL{

            @Override
            public boolean inferenceConfigSupported(InferenceConfig config) {
                return true;
            }

            @Override
            public String exceptionName() {
                return "internal";
            }
        };


        public abstract boolean inferenceConfigSupported(@Nullable InferenceConfig var1);

        public abstract String exceptionName();

        public boolean isAnyOf(Consumer ... consumers) {
            return Arrays.stream(consumers).anyMatch(c -> this == c);
        }
    }

    private static class ModelAndConsumer {
        private final LocalModel model;
        private final EnumSet<Consumer> consumers;

        private ModelAndConsumer(LocalModel model, Consumer consumer) {
            this.model = model;
            this.consumers = EnumSet.of(consumer);
        }
    }

    private static class ModelAndConsumerLoader
    implements CacheLoader<String, ModelAndConsumer> {
        private boolean loaded;
        private final ModelAndConsumer modelAndConsumer;

        ModelAndConsumerLoader(ModelAndConsumer modelAndConsumer) {
            this.modelAndConsumer = modelAndConsumer;
        }

        boolean isLoaded() {
            return this.loaded;
        }

        public ModelAndConsumer load(String key) throws Exception {
            this.loaded = true;
            this.modelAndConsumer.model.acquire();
            return this.modelAndConsumer;
        }
    }
}

