/*
 * Decompiled with CFR 0.152.
 */
package org.elasticsearch.xpack.ml.inference.loadingservice;

import java.util.ArrayDeque;
import java.util.Collections;
import java.util.EnumSet;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Queue;
import java.util.Set;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.TimeUnit;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.apache.logging.log4j.message.Message;
import org.apache.logging.log4j.message.ParameterizedMessage;
import org.apache.logging.log4j.util.MessageSupplier;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.cluster.ClusterChangedEvent;
import org.elasticsearch.cluster.ClusterState;
import org.elasticsearch.cluster.ClusterStateListener;
import org.elasticsearch.cluster.service.ClusterService;
import org.elasticsearch.common.breaker.CircuitBreaker;
import org.elasticsearch.common.breaker.CircuitBreakingException;
import org.elasticsearch.common.cache.Cache;
import org.elasticsearch.common.cache.CacheBuilder;
import org.elasticsearch.common.cache.CacheLoader;
import org.elasticsearch.common.cache.RemovalNotification;
import org.elasticsearch.common.settings.Setting;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.common.unit.ByteSizeValue;
import org.elasticsearch.common.util.set.Sets;
import org.elasticsearch.core.TimeValue;
import org.elasticsearch.ingest.IngestMetadata;
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.xpack.core.ml.action.GetTrainedModelsAction;
import org.elasticsearch.xpack.core.ml.inference.ModelAliasMetadata;
import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfig;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.RegressionConfig;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TargetType;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.inference.InferenceDefinition;
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
import org.elasticsearch.xpack.ml.inference.TrainedModelStatsService;
import org.elasticsearch.xpack.ml.inference.loadingservice.LocalModel;
import org.elasticsearch.xpack.ml.inference.persistence.TrainedModelProvider;
import org.elasticsearch.xpack.ml.notifications.InferenceAuditor;

public class ModelLoadingService
implements ClusterStateListener {
    public static final Setting<ByteSizeValue> INFERENCE_MODEL_CACHE_SIZE = Setting.memorySizeSetting((String)"xpack.ml.inference_model.cache_size", (String)"40%", (Setting.Property[])new Setting.Property[]{Setting.Property.NodeScope});
    public static final Setting<TimeValue> INFERENCE_MODEL_CACHE_TTL = Setting.timeSetting((String)"xpack.ml.inference_model.time_to_live", (TimeValue)new TimeValue(5L, TimeUnit.MINUTES), (TimeValue)new TimeValue(1L, TimeUnit.MILLISECONDS), (Setting.Property[])new Setting.Property[]{Setting.Property.NodeScope});
    private static final Logger logger = LogManager.getLogger(ModelLoadingService.class);
    private final TrainedModelStatsService modelStatsService;
    private final Cache<String, ModelAndConsumer> localModelCache;
    private final Set<String> referencedModels = new HashSet<String>();
    private final Map<String, String> modelAliasToId = new HashMap<String, String>();
    private final Map<String, Set<String>> modelIdToModelAliases = new HashMap<String, Set<String>>();
    private final Map<String, Set<String>> modelIdToUpdatedModelAliases = new HashMap<String, Set<String>>();
    private final Map<String, Queue<ActionListener<LocalModel>>> loadingListeners = new HashMap<String, Queue<ActionListener<LocalModel>>>();
    private final TrainedModelProvider provider;
    private final Set<String> shouldNotAudit;
    private final ThreadPool threadPool;
    private final InferenceAuditor auditor;
    private final ByteSizeValue maxCacheSize;
    private final String localNode;
    private final CircuitBreaker trainedModelCircuitBreaker;

    public ModelLoadingService(TrainedModelProvider trainedModelProvider, InferenceAuditor auditor, ThreadPool threadPool, ClusterService clusterService, TrainedModelStatsService modelStatsService, Settings settings, String localNode, CircuitBreaker trainedModelCircuitBreaker) {
        this.provider = trainedModelProvider;
        this.threadPool = threadPool;
        this.maxCacheSize = (ByteSizeValue)INFERENCE_MODEL_CACHE_SIZE.get(settings);
        this.auditor = auditor;
        this.modelStatsService = modelStatsService;
        this.shouldNotAudit = new HashSet<String>();
        this.localModelCache = CacheBuilder.builder().setMaximumWeight(this.maxCacheSize.getBytes()).weigher((id, modelAndConsumer) -> ((ModelAndConsumer)modelAndConsumer).model.ramBytesUsed()).removalListener(notification -> this.cacheEvictionListener((RemovalNotification<String, ModelAndConsumer>)notification)).setExpireAfterAccess((TimeValue)INFERENCE_MODEL_CACHE_TTL.get(settings)).build();
        clusterService.addListener((ClusterStateListener)this);
        this.localNode = localNode;
        this.trainedModelCircuitBreaker = (CircuitBreaker)ExceptionsHelper.requireNonNull((Object)trainedModelCircuitBreaker, (String)"trainedModelCircuitBreaker");
    }

    String getModelId(String modelIdOrAlias) {
        return this.modelAliasToId.getOrDefault(modelIdOrAlias, modelIdOrAlias);
    }

    boolean isModelCached(String modelId) {
        return this.localModelCache.get((Object)this.modelAliasToId.getOrDefault(modelId, modelId)) != null;
    }

    public void getModelForPipeline(String modelId, ActionListener<LocalModel> modelActionListener) {
        this.getModel(modelId, Consumer.PIPELINE, modelActionListener);
    }

    public void getModelForSearch(String modelId, ActionListener<LocalModel> modelActionListener) {
        this.getModel(modelId, Consumer.SEARCH, modelActionListener);
    }

    private void getModel(String modelIdOrAlias, Consumer consumer, ActionListener<LocalModel> modelActionListener) {
        String modelId = this.modelAliasToId.getOrDefault(modelIdOrAlias, modelIdOrAlias);
        ModelAndConsumer cachedModel = (ModelAndConsumer)this.localModelCache.get((Object)modelId);
        if (cachedModel != null) {
            cachedModel.consumers.add(consumer);
            try {
                cachedModel.model.acquire();
            }
            catch (CircuitBreakingException e) {
                modelActionListener.onFailure((Exception)((Object)e));
                return;
            }
            modelActionListener.onResponse((Object)cachedModel.model);
            logger.trace(() -> new ParameterizedMessage("[{}] (model_alias [{}]) loaded from cache", (Object)modelId, (Object)modelIdOrAlias));
            return;
        }
        if (this.loadModelIfNecessary(modelIdOrAlias, consumer, modelActionListener)) {
            logger.trace(() -> new ParameterizedMessage("[{}] (model_alias [{}]) is loading or loaded, added new listener to queue", (Object)modelId, (Object)modelIdOrAlias));
        }
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    private boolean loadModelIfNecessary(String modelIdOrAlias, Consumer consumer, ActionListener<LocalModel> modelActionListener) {
        Map<String, Queue<ActionListener<LocalModel>>> map = this.loadingListeners;
        synchronized (map) {
            String modelId = this.modelAliasToId.getOrDefault(modelIdOrAlias, modelIdOrAlias);
            ModelAndConsumer cachedModel = (ModelAndConsumer)this.localModelCache.get((Object)modelId);
            if (cachedModel != null) {
                cachedModel.consumers.add(consumer);
                try {
                    cachedModel.model.acquire();
                }
                catch (CircuitBreakingException e) {
                    modelActionListener.onFailure((Exception)((Object)e));
                    return true;
                }
                modelActionListener.onResponse((Object)cachedModel.model);
                return true;
            }
            Queue listeners = this.loadingListeners.computeIfPresent(modelId, (storedModelKey, listenerQueue) -> ModelLoadingService.addFluently(listenerQueue, modelActionListener));
            if (listeners != null) {
                return true;
            }
            if (Consumer.PIPELINE == consumer && !this.referencedModels.contains(modelId)) {
                logger.trace(() -> new ParameterizedMessage("[{}] (model_alias [{}]) not actively loading, eager loading without cache", (Object)modelId, (Object)modelIdOrAlias));
                this.loadWithoutCaching(modelId, modelActionListener);
            } else {
                logger.trace(() -> new ParameterizedMessage("[{}] (model_alias [{}]) attempting to load and cache", (Object)modelId, (Object)modelIdOrAlias));
                this.loadingListeners.put(modelId, ModelLoadingService.addFluently(new ArrayDeque(), modelActionListener));
                this.loadModel(modelId, consumer);
            }
            return false;
        }
    }

    private void loadModel(String modelId, Consumer consumer) {
        this.provider.getTrainedModel(modelId, GetTrainedModelsAction.Includes.empty(), (ActionListener<TrainedModelConfig>)ActionListener.wrap(trainedModelConfig -> {
            this.trainedModelCircuitBreaker.addEstimateBytesAndMaybeBreak(trainedModelConfig.getEstimatedHeapMemory(), modelId);
            this.provider.getTrainedModelForInference(modelId, (ActionListener<InferenceDefinition>)ActionListener.wrap(inferenceDefinition -> {
                try {
                    this.updateCircuitBreakerEstimate(modelId, (InferenceDefinition)inferenceDefinition, (TrainedModelConfig)trainedModelConfig);
                }
                catch (CircuitBreakingException ex) {
                    this.handleLoadFailure(modelId, (Exception)((Object)ex));
                    return;
                }
                this.handleLoadSuccess(modelId, consumer, (TrainedModelConfig)trainedModelConfig, (InferenceDefinition)inferenceDefinition);
            }, failure -> {
                this.trainedModelCircuitBreaker.addWithoutBreaking(-trainedModelConfig.getEstimatedHeapMemory());
                logger.warn((Message)new ParameterizedMessage("[{}] failed to load model definition", (Object)modelId), (Throwable)failure);
                this.handleLoadFailure(modelId, (Exception)failure);
            }));
        }, failure -> {
            logger.warn((Message)new ParameterizedMessage("[{}] failed to load model configuration", (Object)modelId), (Throwable)failure);
            this.handleLoadFailure(modelId, (Exception)failure);
        }));
    }

    private void loadWithoutCaching(String modelId, ActionListener<LocalModel> modelActionListener) {
        this.provider.getTrainedModel(modelId, GetTrainedModelsAction.Includes.empty(), (ActionListener<TrainedModelConfig>)ActionListener.wrap(trainedModelConfig -> {
            this.trainedModelCircuitBreaker.addEstimateBytesAndMaybeBreak(trainedModelConfig.getEstimatedHeapMemory(), modelId);
            this.provider.getTrainedModelForInference(modelId, (ActionListener<InferenceDefinition>)ActionListener.wrap(inferenceDefinition -> {
                InferenceConfig inferenceConfig = trainedModelConfig.getInferenceConfig() == null ? ModelLoadingService.inferenceConfigFromTargetType(inferenceDefinition.getTargetType()) : trainedModelConfig.getInferenceConfig();
                try {
                    this.updateCircuitBreakerEstimate(modelId, (InferenceDefinition)inferenceDefinition, (TrainedModelConfig)trainedModelConfig);
                }
                catch (CircuitBreakingException ex) {
                    modelActionListener.onFailure((Exception)((Object)ex));
                    return;
                }
                modelActionListener.onResponse((Object)new LocalModel(trainedModelConfig.getModelId(), this.localNode, (InferenceDefinition)inferenceDefinition, trainedModelConfig.getInput(), trainedModelConfig.getDefaultFieldMap(), inferenceConfig, trainedModelConfig.getLicenseLevel(), this.modelStatsService, this.trainedModelCircuitBreaker));
            }, e -> {
                this.trainedModelCircuitBreaker.addWithoutBreaking(-trainedModelConfig.getEstimatedHeapMemory());
                modelActionListener.onFailure(e);
            }));
        }, arg_0 -> modelActionListener.onFailure(arg_0)));
    }

    private void updateCircuitBreakerEstimate(String modelId, InferenceDefinition inferenceDefinition, TrainedModelConfig trainedModelConfig) throws CircuitBreakingException {
        long estimateDiff = inferenceDefinition.ramBytesUsed() - trainedModelConfig.getEstimatedHeapMemory();
        if (estimateDiff < 0L) {
            this.trainedModelCircuitBreaker.addWithoutBreaking(estimateDiff);
        } else if (estimateDiff > 0L) {
            try {
                this.trainedModelCircuitBreaker.addEstimateBytesAndMaybeBreak(estimateDiff, modelId);
            }
            catch (CircuitBreakingException ex) {
                this.trainedModelCircuitBreaker.addWithoutBreaking(-trainedModelConfig.getEstimatedHeapMemory());
                throw ex;
            }
        }
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    private void handleLoadSuccess(String modelId, Consumer consumer, TrainedModelConfig trainedModelConfig, InferenceDefinition inferenceDefinition) {
        Queue<ActionListener<LocalModel>> listeners;
        InferenceConfig inferenceConfig = trainedModelConfig.getInferenceConfig() == null ? ModelLoadingService.inferenceConfigFromTargetType(inferenceDefinition.getTargetType()) : trainedModelConfig.getInferenceConfig();
        LocalModel loadedModel = new LocalModel(trainedModelConfig.getModelId(), this.localNode, inferenceDefinition, trainedModelConfig.getInput(), trainedModelConfig.getDefaultFieldMap(), inferenceConfig, trainedModelConfig.getLicenseLevel(), this.modelStatsService, this.trainedModelCircuitBreaker);
        ModelAndConsumerLoader modelAndConsumerLoader = new ModelAndConsumerLoader(new ModelAndConsumer(loadedModel, consumer));
        Map<String, Queue<ActionListener<LocalModel>>> map = this.loadingListeners;
        synchronized (map) {
            this.populateNewModelAlias(modelId);
            if (this.referencedModels.contains(modelId) || Sets.haveNonEmptyIntersection((Set)this.modelIdToModelAliases.getOrDefault(modelId, new HashSet()), this.referencedModels) || consumer.equals((Object)Consumer.SEARCH)) {
                try {
                    this.localModelCache.computeIfAbsent((Object)modelId, (CacheLoader)modelAndConsumerLoader);
                }
                catch (ExecutionException ee) {
                    logger.warn(() -> new ParameterizedMessage("[{}] threw when attempting add to cache", (Object)modelId), (Throwable)ee);
                }
                this.shouldNotAudit.remove(modelId);
            }
            if ((listeners = this.loadingListeners.remove(modelId)) == null) {
                if (modelAndConsumerLoader.isLoaded()) {
                    loadedModel.release();
                }
                return;
            }
        }
        ActionListener<LocalModel> listener = listeners.poll();
        while (listener != null) {
            loadedModel.acquire();
            listener.onResponse((Object)loadedModel);
            listener = listeners.poll();
        }
        if (modelAndConsumerLoader.isLoaded()) {
            loadedModel.release();
        }
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    private void handleLoadFailure(String modelId, Exception failure) {
        Queue<ActionListener<LocalModel>> listeners;
        Map<String, Queue<ActionListener<LocalModel>>> map = this.loadingListeners;
        synchronized (map) {
            listeners = this.loadingListeners.remove(modelId);
            this.populateNewModelAlias(modelId);
            if (listeners == null) {
                return;
            }
        }
        ActionListener<LocalModel> listener = listeners.poll();
        while (listener != null) {
            listener.onFailure(failure);
            listener = listeners.poll();
        }
    }

    private void populateNewModelAlias(String modelId) {
        Set<String> newModelAliases = this.modelIdToUpdatedModelAliases.remove(modelId);
        if (newModelAliases != null && !newModelAliases.isEmpty()) {
            logger.trace(() -> new ParameterizedMessage("[{}] model is now loaded, setting new model_aliases {}", (Object)modelId, (Object)newModelAliases));
            for (String modelAlias : newModelAliases) {
                this.modelAliasToId.put(modelAlias, modelId);
            }
        }
    }

    private void cacheEvictionListener(RemovalNotification<String, ModelAndConsumer> notification) {
        try {
            if (notification.getRemovalReason() == RemovalNotification.RemovalReason.EVICTED) {
                MessageSupplier msg = () -> new ParameterizedMessage("model cache entry evicted.current cache [{}] current max [{}] model size [{}]. If this is undesired, consider updating setting [{}] or [{}].", new Object[]{ByteSizeValue.ofBytes((long)this.localModelCache.weight()).getStringRep(), this.maxCacheSize.getStringRep(), ByteSizeValue.ofBytes((long)((ModelAndConsumer)notification.getValue()).model.ramBytesUsed()).getStringRep(), INFERENCE_MODEL_CACHE_SIZE.getKey(), INFERENCE_MODEL_CACHE_TTL.getKey()});
                this.auditIfNecessary((String)notification.getKey(), msg);
            }
            String modelId = this.modelAliasToId.getOrDefault(notification.getKey(), (String)notification.getKey());
            logger.trace(() -> new ParameterizedMessage("Persisting stats for evicted model [{}] (model_aliases {})", (Object)modelId, this.modelIdToModelAliases.getOrDefault(modelId, new HashSet())));
            ((ModelAndConsumer)notification.getValue()).model.persistStats(!this.referencedModels.contains(modelId));
        }
        finally {
            ((ModelAndConsumer)notification.getValue()).model.release();
        }
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    public void clusterChanged(ClusterChangedEvent event) {
        Set removedModels;
        HashSet<String> referencedModelsBeforeClusterState;
        HashMap<String, Set<String>> oldIdToAliases;
        boolean prefetchModels = event.state().nodes().getLocalNode().isIngestNode();
        if (!(prefetchModels || event.changedCustomMetadataSet().contains("ingest") || event.changedCustomMetadataSet().contains("trained_model_alias"))) {
            return;
        }
        ClusterState state = event.state();
        IngestMetadata currentIngestMetadata = (IngestMetadata)state.metadata().custom("ingest");
        HashSet<String> allReferencedModelKeys = event.changedCustomMetadataSet().contains("ingest") ? ModelLoadingService.getReferencedModelKeys(currentIngestMetadata) : new HashSet<String>(this.referencedModels);
        HashSet<String> loadingModelBeforeClusterState = null;
        HashMap<String, Set> addedModelViaAliases = new HashMap<String, Set>();
        Map<String, Queue<ActionListener<LocalModel>>> map = this.loadingListeners;
        synchronized (map) {
            String modelId;
            oldIdToAliases = new HashMap<String, Set<String>>(this.modelIdToModelAliases);
            Map<String, String> changedAliases = this.gatherLazyChangedAliasesAndUpdateModelAliases(event, prefetchModels, allReferencedModelKeys);
            if (!prefetchModels) {
                return;
            }
            referencedModelsBeforeClusterState = new HashSet<String>(this.referencedModels);
            if (logger.isTraceEnabled()) {
                loadingModelBeforeClusterState = new HashSet<String>(this.loadingListeners.keySet());
            }
            removedModels = Sets.difference(referencedModelsBeforeClusterState, allReferencedModelKeys);
            this.referencedModels.removeAll(removedModels);
            this.shouldNotAudit.removeAll(removedModels);
            for (String string : removedModels) {
                ModelAndConsumer modelAndConsumer;
                modelId = changedAliases.getOrDefault(string, this.modelAliasToId.getOrDefault(string, string));
                boolean oldModelAliasesNotReferenced = Sets.haveEmptyIntersection(this.referencedModels, oldIdToAliases.getOrDefault(modelId, Collections.emptySet()));
                boolean modelIsNotReferenced = !this.referencedModels.contains(modelId);
                boolean newModelAliasesNotReferenced = Sets.haveEmptyIntersection(this.referencedModels, this.modelIdToModelAliases.getOrDefault(modelId, Collections.emptySet()));
                if (!oldModelAliasesNotReferenced || !newModelAliasesNotReferenced || !modelIsNotReferenced || (modelAndConsumer = (ModelAndConsumer)this.localModelCache.get((Object)modelId)) == null || modelAndConsumer.consumers.contains((Object)Consumer.SEARCH)) continue;
                logger.trace("[{} ({})] invalidated from cache", (Object)modelId, (Object)string);
                this.localModelCache.invalidate((Object)modelId);
            }
            allReferencedModelKeys.removeAll(this.referencedModels);
            for (String string : allReferencedModelKeys) {
                modelId = changedAliases.getOrDefault(string, this.modelAliasToId.getOrDefault(string, string));
                if (this.referencedModels.contains(modelId)) continue;
                addedModelViaAliases.computeIfAbsent(modelId, k -> new HashSet()).add(string);
            }
            for (Map.Entry entry : changedAliases.entrySet()) {
                String modelAlias = (String)entry.getKey();
                String modelId2 = (String)entry.getValue();
                if (this.referencedModels.contains(modelAlias)) {
                    addedModelViaAliases.computeIfAbsent(modelId2, k -> new HashSet()).add(modelAlias);
                    String oldModelId = this.modelAliasToId.get(modelAlias);
                    if (oldModelId != null && this.localModelCache.get((Object)oldModelId) != null) {
                        this.modelIdToUpdatedModelAliases.computeIfAbsent(modelId2, k -> new HashSet()).add(modelAlias);
                        continue;
                    }
                    this.modelAliasToId.put(modelAlias, modelId2);
                    continue;
                }
                this.modelAliasToId.put(modelAlias, modelId2);
            }
            this.referencedModels.addAll(allReferencedModelKeys);
            for (String string : addedModelViaAliases.keySet()) {
                this.loadingListeners.computeIfAbsent(string, s -> new ArrayDeque());
            }
        }
        if (logger.isTraceEnabled()) {
            if (!this.loadingListeners.keySet().equals(loadingModelBeforeClusterState)) {
                logger.trace("cluster state event changed loading models: before {} after {}", loadingModelBeforeClusterState, this.loadingListeners.keySet());
            }
            if (!this.referencedModels.equals(referencedModelsBeforeClusterState)) {
                logger.trace("cluster state event changed referenced models: before {} after {}", referencedModelsBeforeClusterState, this.referencedModels);
            }
            if (!oldIdToAliases.equals(this.modelIdToModelAliases)) {
                logger.trace("model id to alias mappings changed. before {} after {}. Model alias to IDs {}", oldIdToAliases, this.modelIdToModelAliases, this.modelAliasToId);
            }
            if (!addedModelViaAliases.isEmpty()) {
                logger.trace("adding new models via model_aliases and ids: {}", addedModelViaAliases);
            }
            if (!this.modelIdToUpdatedModelAliases.isEmpty()) {
                logger.trace("delayed model aliases to update {}", this.modelIdToModelAliases);
            }
        }
        removedModels.forEach(this::auditUnreferencedModel);
        this.loadModelsForPipeline(addedModelViaAliases.keySet());
    }

    private Map<String, String> gatherLazyChangedAliasesAndUpdateModelAliases(ClusterChangedEvent event, boolean prefetchModels, Set<String> allReferencedModelKeys) {
        HashMap<String, String> changedAliases = new HashMap<String, String>();
        if (event.changedCustomMetadataSet().contains("trained_model_alias")) {
            HashMap modelAliasesToIds = new HashMap(ModelAliasMetadata.fromState((ClusterState)event.state()).modelAliases());
            this.modelIdToModelAliases.clear();
            for (Map.Entry aliasToId : modelAliasesToIds.entrySet()) {
                this.modelIdToModelAliases.computeIfAbsent(((ModelAliasMetadata.ModelAliasEntry)aliasToId.getValue()).getModelId(), k -> new HashSet()).add((String)aliasToId.getKey());
                String modelId = this.modelAliasToId.get(aliasToId.getKey());
                if (modelId != null && !modelId.equals(((ModelAliasMetadata.ModelAliasEntry)aliasToId.getValue()).getModelId())) {
                    if (prefetchModels && allReferencedModelKeys.contains(aliasToId.getKey())) {
                        changedAliases.put((String)aliasToId.getKey(), ((ModelAliasMetadata.ModelAliasEntry)aliasToId.getValue()).getModelId());
                    } else {
                        this.modelAliasToId.put((String)aliasToId.getKey(), ((ModelAliasMetadata.ModelAliasEntry)aliasToId.getValue()).getModelId());
                    }
                }
                if (modelId != null) continue;
                this.modelAliasToId.put((String)aliasToId.getKey(), ((ModelAliasMetadata.ModelAliasEntry)aliasToId.getValue()).getModelId());
            }
            Set removedAliases = Sets.difference(this.modelAliasToId.keySet(), modelAliasesToIds.keySet());
            this.modelAliasToId.keySet().removeAll(removedAliases);
        }
        return changedAliases;
    }

    private void auditIfNecessary(String modelId, MessageSupplier msg) {
        if (this.shouldNotAudit.contains(modelId)) {
            logger.trace(() -> new ParameterizedMessage("[{}] {}", (Object)modelId, (Object)msg.get().getFormattedMessage()));
            return;
        }
        this.auditor.info(modelId, msg.get().getFormattedMessage());
        this.shouldNotAudit.add(modelId);
        logger.info("[{}] {}", (Object)modelId, (Object)msg.get().getFormattedMessage());
    }

    private void loadModelsForPipeline(Set<String> modelIds) {
        if (modelIds.isEmpty()) {
            return;
        }
        this.threadPool.executor("ml_utility").execute(() -> {
            for (String modelId : modelIds) {
                this.auditNewReferencedModel(modelId);
                this.loadModel(modelId, Consumer.PIPELINE);
            }
        });
    }

    private void auditNewReferencedModel(String modelId) {
        this.auditor.info(modelId, "referenced by ingest processors. Attempting to load model into cache");
    }

    private void auditUnreferencedModel(String modelId) {
        this.auditor.info(modelId, "no longer referenced by any processors");
    }

    private static <T> Queue<T> addFluently(Queue<T> queue, T object) {
        queue.add(object);
        return queue;
    }

    private static Set<String> getReferencedModelKeys(IngestMetadata ingestMetadata) {
        HashSet<String> allReferencedModelKeys = new HashSet<String>();
        if (ingestMetadata == null) {
            return allReferencedModelKeys;
        }
        ingestMetadata.getPipelines().forEach((pipelineId, pipelineConfiguration) -> {
            Object processors = pipelineConfiguration.getConfigAsMap().get("processors");
            if (processors instanceof List) {
                for (Object processor : (List)processors) {
                    Object modelId;
                    Object processorConfig;
                    if (!(processor instanceof Map) || !((processorConfig = ((Map)processor).get("inference")) instanceof Map) || (modelId = ((Map)processorConfig).get("model_id")) == null) continue;
                    assert (modelId instanceof String);
                    allReferencedModelKeys.add(modelId.toString());
                }
            }
        });
        return allReferencedModelKeys;
    }

    private static InferenceConfig inferenceConfigFromTargetType(TargetType targetType) {
        switch (targetType) {
            case REGRESSION: {
                return RegressionConfig.EMPTY_PARAMS;
            }
            case CLASSIFICATION: {
                return ClassificationConfig.EMPTY_PARAMS;
            }
        }
        throw ExceptionsHelper.badRequestException((String)"unsupported target type [{}]", (Object[])new Object[]{targetType});
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    void addModelLoadedListener(String modelId, ActionListener<LocalModel> modelLoadedListener) {
        Map<String, Queue<ActionListener<LocalModel>>> map = this.loadingListeners;
        synchronized (map) {
            this.loadingListeners.compute(modelId, (modelKey, listenerQueue) -> {
                if (listenerQueue == null) {
                    return ModelLoadingService.addFluently(new ArrayDeque(), modelLoadedListener);
                }
                return ModelLoadingService.addFluently(listenerQueue, modelLoadedListener);
            });
        }
    }

    public static enum Consumer {
        PIPELINE,
        SEARCH;

    }

    private static class ModelAndConsumer {
        private final LocalModel model;
        private final EnumSet<Consumer> consumers;

        private ModelAndConsumer(LocalModel model, Consumer consumer) {
            this.model = model;
            this.consumers = EnumSet.of(consumer);
        }
    }

    private static class ModelAndConsumerLoader
    implements CacheLoader<String, ModelAndConsumer> {
        private boolean loaded;
        private final ModelAndConsumer modelAndConsumer;

        ModelAndConsumerLoader(ModelAndConsumer modelAndConsumer) {
            this.modelAndConsumer = modelAndConsumer;
        }

        boolean isLoaded() {
            return this.loaded;
        }

        public ModelAndConsumer load(String key) throws Exception {
            this.loaded = true;
            this.modelAndConsumer.model.acquire();
            return this.modelAndConsumer;
        }
    }
}

