/*
 * Decompiled with CFR 0.152.
 */
package org.elasticsearch.xpack.ml.inference.deployment;

import java.io.IOException;
import java.io.InputStream;
import java.time.Instant;
import java.util.Objects;
import java.util.Optional;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentMap;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicLong;
import java.util.function.Consumer;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.apache.lucene.util.SetOnce;
import org.elasticsearch.ResourceNotFoundException;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.action.ActionRequest;
import org.elasticsearch.action.ActionType;
import org.elasticsearch.action.search.SearchAction;
import org.elasticsearch.action.search.SearchRequest;
import org.elasticsearch.action.support.master.AcknowledgedResponse;
import org.elasticsearch.client.internal.Client;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.util.concurrent.EsRejectedExecutionException;
import org.elasticsearch.common.xcontent.LoggingDeprecationHandler;
import org.elasticsearch.core.Strings;
import org.elasticsearch.core.TimeValue;
import org.elasticsearch.index.query.IdsQueryBuilder;
import org.elasticsearch.index.query.QueryBuilder;
import org.elasticsearch.search.SearchHit;
import org.elasticsearch.tasks.CancellableTask;
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.xcontent.DeprecationHandler;
import org.elasticsearch.xcontent.NamedXContentRegistry;
import org.elasticsearch.xcontent.XContentFactory;
import org.elasticsearch.xcontent.XContentParser;
import org.elasticsearch.xcontent.XContentParserConfiguration;
import org.elasticsearch.xcontent.XContentType;
import org.elasticsearch.xpack.core.ClientHelper;
import org.elasticsearch.xpack.core.ml.action.GetTrainedModelsAction;
import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig;
import org.elasticsearch.xpack.core.ml.inference.TrainedModelInput;
import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.IndexLocation;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.NlpConfig;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TrainedModelLocation;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.VocabularyConfig;
import org.elasticsearch.xpack.core.ml.job.messages.Messages;
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
import org.elasticsearch.xpack.ml.inference.deployment.AbstractPyTorchAction;
import org.elasticsearch.xpack.ml.inference.deployment.ClearCacheControlMessagePytorchAction;
import org.elasticsearch.xpack.ml.inference.deployment.InferencePyTorchAction;
import org.elasticsearch.xpack.ml.inference.deployment.ModelStats;
import org.elasticsearch.xpack.ml.inference.deployment.NlpInferenceInput;
import org.elasticsearch.xpack.ml.inference.deployment.ThreadSettingsControlMessagePytorchAction;
import org.elasticsearch.xpack.ml.inference.deployment.TrainedModelDeploymentTask;
import org.elasticsearch.xpack.ml.inference.nlp.NlpTask;
import org.elasticsearch.xpack.ml.inference.nlp.Vocabulary;
import org.elasticsearch.xpack.ml.inference.pytorch.PriorityProcessWorkerExecutorService;
import org.elasticsearch.xpack.ml.inference.pytorch.process.PyTorchProcess;
import org.elasticsearch.xpack.ml.inference.pytorch.process.PyTorchProcessFactory;
import org.elasticsearch.xpack.ml.inference.pytorch.process.PyTorchResultProcessor;
import org.elasticsearch.xpack.ml.inference.pytorch.process.PyTorchStateStreamer;
import org.elasticsearch.xpack.ml.inference.pytorch.results.ThreadSettings;

public class DeploymentManager {
    private static final Logger logger = LogManager.getLogger(DeploymentManager.class);
    private static final AtomicLong requestIdCounter = new AtomicLong(1L);
    private final Client client;
    private final NamedXContentRegistry xContentRegistry;
    private final PyTorchProcessFactory pyTorchProcessFactory;
    private final ExecutorService executorServiceForDeployment;
    private final ExecutorService executorServiceForProcess;
    private final ThreadPool threadPool;
    private final ConcurrentMap<Long, ProcessContext> processContextByAllocation = new ConcurrentHashMap<Long, ProcessContext>();
    private final int maxProcesses;

    public DeploymentManager(Client client, NamedXContentRegistry xContentRegistry, ThreadPool threadPool, PyTorchProcessFactory pyTorchProcessFactory, int maxProcesses) {
        this.client = Objects.requireNonNull(client);
        this.xContentRegistry = Objects.requireNonNull(xContentRegistry);
        this.pyTorchProcessFactory = Objects.requireNonNull(pyTorchProcessFactory);
        this.threadPool = Objects.requireNonNull(threadPool);
        this.executorServiceForDeployment = threadPool.executor("ml_utility");
        this.executorServiceForProcess = threadPool.executor("ml_native_inference_comms");
        this.maxProcesses = maxProcesses;
    }

    public Optional<ModelStats> getStats(TrainedModelDeploymentTask task) {
        return Optional.ofNullable((ProcessContext)this.processContextByAllocation.get(task.getId())).map(processContext -> {
            PyTorchResultProcessor.ResultStats stats = processContext.getResultProcessor().getResultStats();
            PyTorchResultProcessor.RecentStats recentStats = stats.recentStats();
            return new ModelStats(processContext.startTime, stats.timingStats().getCount(), stats.timingStats().getAverage(), stats.timingStatsExcludingCacheHits().getAverage(), stats.lastUsed(), processContext.priorityProcessWorker.queueSize() + stats.numberOfPendingResults(), stats.errorCount(), stats.cacheHitCount(), processContext.rejectedExecutionCount.intValue(), processContext.timeoutCount.intValue(), processContext.numThreadsPerAllocation, processContext.numAllocations, stats.peakThroughput(), recentStats.requestsProcessed(), recentStats.avgInferenceTime(), recentStats.cacheHitCount());
        });
    }

    ProcessContext addProcessContext(Long id, ProcessContext processContext) {
        return this.processContextByAllocation.putIfAbsent(id, processContext);
    }

    public void startDeployment(TrainedModelDeploymentTask task, ActionListener<TrainedModelDeploymentTask> finalListener) {
        logger.info("[{}] Starting model deployment", (Object)task.getDeploymentId());
        if (this.processContextByAllocation.size() >= this.maxProcesses) {
            finalListener.onFailure((Exception)((Object)ExceptionsHelper.serverError((String)"[{}] Could not start inference process as the node reached the max number [{}] of processes", (Object[])new Object[]{task.getDeploymentId(), this.maxProcesses})));
            return;
        }
        ProcessContext processContext = new ProcessContext(task);
        if (this.addProcessContext(task.getId(), processContext) != null) {
            finalListener.onFailure((Exception)((Object)ExceptionsHelper.serverError((String)"[{}] Could not create inference process as one already exists", (Object[])new Object[]{task.getDeploymentId()})));
            return;
        }
        ActionListener failedDeploymentListener = ActionListener.wrap(arg_0 -> finalListener.onResponse(arg_0), failure -> {
            ProcessContext failedContext = (ProcessContext)this.processContextByAllocation.remove(task.getId());
            if (failedContext != null) {
                failedContext.stopProcess();
            }
            finalListener.onFailure(failure);
        });
        ActionListener modelLoadedListener = ActionListener.wrap(success -> {
            this.executorServiceForProcess.execute(() -> processContext.getResultProcessor().process((PyTorchProcess)processContext.process.get()));
            finalListener.onResponse((Object)task);
        }, arg_0 -> ((ActionListener)failedDeploymentListener).onFailure(arg_0));
        ActionListener getModelListener = ActionListener.wrap(getModelResponse -> {
            assert (getModelResponse.getResources().results().size() == 1);
            TrainedModelConfig modelConfig = (TrainedModelConfig)getModelResponse.getResources().results().get(0);
            processContext.modelInput.set((Object)modelConfig.getInput());
            InferenceConfig patt8434$temp = modelConfig.getInferenceConfig();
            if (patt8434$temp instanceof NlpConfig) {
                NlpConfig nlpConfig = (NlpConfig)patt8434$temp;
                task.init((InferenceConfig)nlpConfig);
                SearchRequest searchRequest = this.vocabSearchRequest(nlpConfig.getVocabularyConfig(), modelConfig.getModelId());
                ClientHelper.executeAsyncWithOrigin((Client)this.client, (String)"ml", (ActionType)SearchAction.INSTANCE, (ActionRequest)searchRequest, (ActionListener)ActionListener.wrap(searchVocabResponse -> {
                    if (searchVocabResponse.getHits().getHits().length == 0) {
                        failedDeploymentListener.onFailure((Exception)((Object)new ResourceNotFoundException(Messages.getMessage((String)"Could not find vocabulary document [{1}] for trained model [{0}]", (Object[])new Object[]{modelConfig.getModelId(), VocabularyConfig.docId((String)modelConfig.getModelId())}), new Object[0])));
                        return;
                    }
                    Vocabulary vocabulary = this.parseVocabularyDocLeniently(searchVocabResponse.getHits().getAt(0));
                    NlpTask nlpTask = new NlpTask(nlpConfig, vocabulary);
                    NlpTask.Processor processor = nlpTask.createProcessor();
                    processContext.nlpTaskProcessor.set((Object)processor);
                    this.executorServiceForDeployment.execute(() -> processContext.startAndLoad(modelConfig.getLocation(), (ActionListener<Boolean>)modelLoadedListener));
                }, arg_0 -> ((ActionListener)failedDeploymentListener).onFailure(arg_0)));
            } else {
                failedDeploymentListener.onFailure((Exception)new IllegalArgumentException(Strings.format((String)"[%s] must be a pytorch model; found inference config of kind [%s]", (Object[])new Object[]{modelConfig.getModelId(), modelConfig.getInferenceConfig().getWriteableName()})));
            }
        }, arg_0 -> ((ActionListener)failedDeploymentListener).onFailure(arg_0));
        ClientHelper.executeAsyncWithOrigin((Client)this.client, (String)"ml", (ActionType)GetTrainedModelsAction.INSTANCE, (ActionRequest)new GetTrainedModelsAction.Request(task.getParams().getModelId()), (ActionListener)getModelListener);
    }

    private SearchRequest vocabSearchRequest(VocabularyConfig vocabularyConfig, String modelId) {
        return (SearchRequest)this.client.prepareSearch(new String[]{vocabularyConfig.getIndex()}).setQuery((QueryBuilder)new IdsQueryBuilder().addIds(new String[]{VocabularyConfig.docId((String)modelId)})).setSize(1).setTrackTotalHits(false).request();
    }

    /*
     * Enabled aggressive exception aggregation
     */
    Vocabulary parseVocabularyDocLeniently(SearchHit hit) throws IOException {
        try (StreamInput stream = hit.getSourceRef().streamInput();){
            Vocabulary vocabulary;
            block14: {
                XContentParser parser = XContentFactory.xContent((XContentType)XContentType.JSON).createParser(XContentParserConfiguration.EMPTY.withRegistry(this.xContentRegistry).withDeprecationHandler((DeprecationHandler)LoggingDeprecationHandler.INSTANCE), (InputStream)stream);
                try {
                    vocabulary = (Vocabulary)Vocabulary.PARSER.apply(parser, null);
                    if (parser == null) break block14;
                }
                catch (Throwable throwable) {
                    if (parser != null) {
                        try {
                            parser.close();
                        }
                        catch (Throwable throwable2) {
                            throwable.addSuppressed(throwable2);
                        }
                    }
                    throw throwable;
                }
                parser.close();
            }
            return vocabulary;
        }
        catch (IOException e) {
            logger.error(() -> "failed to parse trained model vocabulary [" + hit.getId() + "]", (Throwable)e);
            throw e;
        }
    }

    public void stopDeployment(TrainedModelDeploymentTask task) {
        ProcessContext processContext = (ProcessContext)this.processContextByAllocation.remove(task.getId());
        if (processContext != null) {
            logger.info("[{}] Stopping deployment, reason [{}]", (Object)task.getDeploymentId(), (Object)task.stoppedReason().orElse("unknown"));
            processContext.stopProcess();
        } else {
            logger.warn("[{}] No process context to stop", (Object)task.getDeploymentId());
        }
    }

    public void infer(TrainedModelDeploymentTask task, InferenceConfig config, NlpInferenceInput input, boolean skipQueue, TimeValue timeout, CancellableTask parentActionTask, ActionListener<InferenceResults> listener) {
        ProcessContext processContext = this.getProcessContext(task, arg_0 -> listener.onFailure(arg_0));
        if (processContext == null) {
            return;
        }
        long requestId = requestIdCounter.getAndIncrement();
        InferencePyTorchAction inferenceAction = new InferencePyTorchAction(task.getDeploymentId(), requestId, timeout, processContext, config, input, this.threadPool, parentActionTask, listener);
        PriorityProcessWorkerExecutorService.RequestPriority priority = skipQueue ? PriorityProcessWorkerExecutorService.RequestPriority.HIGH : PriorityProcessWorkerExecutorService.RequestPriority.NORMAL;
        this.executePyTorchAction(processContext, priority, inferenceAction);
    }

    public void updateNumAllocations(TrainedModelDeploymentTask task, int numAllocationThreads, TimeValue timeout, ActionListener<ThreadSettings> listener) {
        ProcessContext processContext = this.getProcessContext(task, arg_0 -> listener.onFailure(arg_0));
        if (processContext == null) {
            return;
        }
        long requestId = requestIdCounter.getAndIncrement();
        ThreadSettingsControlMessagePytorchAction controlMessageAction = new ThreadSettingsControlMessagePytorchAction(task.getDeploymentId(), requestId, numAllocationThreads, timeout, processContext, this.threadPool, listener);
        this.executePyTorchAction(processContext, PriorityProcessWorkerExecutorService.RequestPriority.HIGHEST, controlMessageAction);
    }

    public void clearCache(TrainedModelDeploymentTask task, TimeValue timeout, ActionListener<AcknowledgedResponse> listener) {
        ProcessContext processContext = this.getProcessContext(task, arg_0 -> listener.onFailure(arg_0));
        if (processContext == null) {
            return;
        }
        long requestId = requestIdCounter.getAndIncrement();
        ClearCacheControlMessagePytorchAction controlMessageAction = new ClearCacheControlMessagePytorchAction(task.getDeploymentId(), requestId, timeout, processContext, this.threadPool, (ActionListener<Boolean>)ActionListener.wrap(b -> listener.onResponse((Object)AcknowledgedResponse.TRUE), arg_0 -> listener.onFailure(arg_0)));
        this.executePyTorchAction(processContext, PriorityProcessWorkerExecutorService.RequestPriority.HIGHEST, controlMessageAction);
    }

    void executePyTorchAction(ProcessContext processContext, PriorityProcessWorkerExecutorService.RequestPriority priority, AbstractPyTorchAction<?> action) {
        try {
            processContext.getPriorityProcessWorker().executeWithPriority(action, priority, action.getRequestId());
        }
        catch (EsRejectedExecutionException e) {
            processContext.getRejectedExecutionCount().incrementAndGet();
            action.onFailure((Exception)((Object)e));
        }
        catch (Exception e) {
            action.onFailure(e);
        }
    }

    private ProcessContext getProcessContext(TrainedModelDeploymentTask task, Consumer<Exception> errorConsumer) {
        if (task.isStopped()) {
            errorConsumer.accept((Exception)ExceptionsHelper.conflictStatusException((String)"[{}] is stopping or stopped due to [{}]", (Object[])new Object[]{task.getDeploymentId(), task.stoppedReason().orElse("")}));
            return null;
        }
        ProcessContext processContext = (ProcessContext)this.processContextByAllocation.get(task.getId());
        if (processContext == null) {
            errorConsumer.accept((Exception)ExceptionsHelper.conflictStatusException((String)"[{}] process context missing", (Object[])new Object[]{task.getDeploymentId()}));
            return null;
        }
        return processContext;
    }

    class ProcessContext {
        private final TrainedModelDeploymentTask task;
        private final SetOnce<PyTorchProcess> process = new SetOnce();
        private final SetOnce<NlpTask.Processor> nlpTaskProcessor = new SetOnce();
        private final SetOnce<TrainedModelInput> modelInput = new SetOnce();
        private final PyTorchResultProcessor resultProcessor;
        private final PyTorchStateStreamer stateStreamer;
        private final PriorityProcessWorkerExecutorService priorityProcessWorker;
        private volatile Instant startTime;
        private volatile Integer numThreadsPerAllocation;
        private volatile Integer numAllocations;
        private final AtomicInteger rejectedExecutionCount = new AtomicInteger();
        private final AtomicInteger timeoutCount = new AtomicInteger();
        private volatile boolean isStopped;

        ProcessContext(TrainedModelDeploymentTask task) {
            this.task = Objects.requireNonNull(task);
            this.resultProcessor = new PyTorchResultProcessor(task.getDeploymentId(), threadSettings -> {
                this.numThreadsPerAllocation = threadSettings.numThreadsPerAllocation();
                this.numAllocations = threadSettings.numAllocations();
            });
            this.stateStreamer = new PyTorchStateStreamer(DeploymentManager.this.client, DeploymentManager.this.executorServiceForProcess, DeploymentManager.this.xContentRegistry);
            this.priorityProcessWorker = new PriorityProcessWorkerExecutorService(DeploymentManager.this.threadPool.getThreadContext(), "inference process", task.getParams().getQueueCapacity());
        }

        PyTorchResultProcessor getResultProcessor() {
            return this.resultProcessor;
        }

        synchronized void startAndLoad(TrainedModelLocation modelLocation, ActionListener<Boolean> loadedListener) {
            assert (Thread.currentThread().getName().contains("ml_utility")) : Strings.format((String)"Must execute from [%s] but thread is [%s]", (Object[])new Object[]{"ml_utility", Thread.currentThread().getName()});
            if (this.isStopped) {
                logger.debug("[{}] model stopped before it is started", (Object)this.task.getDeploymentId());
                loadedListener.onFailure((Exception)new IllegalArgumentException("model stopped before it is started"));
                return;
            }
            logger.debug("[{}] start and load", (Object)this.task.getDeploymentId());
            this.process.set((Object)DeploymentManager.this.pyTorchProcessFactory.createProcess(this.task, DeploymentManager.this.executorServiceForProcess, this::onProcessCrash));
            this.startTime = Instant.now();
            logger.debug("[{}] process started", (Object)this.task.getDeploymentId());
            try {
                this.loadModel(modelLocation, (ActionListener<Boolean>)ActionListener.wrap(success -> {
                    if (this.isStopped) {
                        logger.debug("[{}] model loaded but process is stopped", (Object)this.task.getDeploymentId());
                        this.killProcessIfPresent();
                        loadedListener.onFailure((Exception)new IllegalStateException("model loaded but process is stopped"));
                        return;
                    }
                    logger.debug("[{}] model loaded, starting priority process worker thread", (Object)this.task.getDeploymentId());
                    this.startPriorityProcessWorker();
                    loadedListener.onResponse(success);
                }, arg_0 -> loadedListener.onFailure(arg_0)));
            }
            catch (Exception e) {
                loadedListener.onFailure(e);
            }
        }

        void startPriorityProcessWorker() {
            DeploymentManager.this.executorServiceForProcess.submit(this.priorityProcessWorker::start);
        }

        synchronized void stopProcess() {
            this.isStopped = true;
            this.resultProcessor.stop();
            this.stateStreamer.cancel();
            if (this.priorityProcessWorker.isShutdown()) {
                this.priorityProcessWorker.notifyQueueRunnables();
            } else {
                this.priorityProcessWorker.shutdown();
            }
            this.killProcessIfPresent();
            if (this.nlpTaskProcessor.get() != null) {
                ((NlpTask.Processor)this.nlpTaskProcessor.get()).close();
            }
        }

        private void killProcessIfPresent() {
            try {
                if (this.process.get() == null) {
                    return;
                }
                ((PyTorchProcess)this.process.get()).kill(true);
            }
            catch (IOException e) {
                logger.error(() -> "[" + this.task.getDeploymentId() + "] Failed to kill process", (Throwable)e);
            }
        }

        private void onProcessCrash(String reason) {
            logger.error("[{}] inference process crashed due to reason [{}]", (Object)this.task.getDeploymentId(), (Object)reason);
            DeploymentManager.this.processContextByAllocation.remove(this.task.getId());
            this.isStopped = true;
            this.resultProcessor.stop();
            this.stateStreamer.cancel();
            this.priorityProcessWorker.shutdownWithError(new IllegalStateException(reason));
            if (this.nlpTaskProcessor.get() != null) {
                ((NlpTask.Processor)this.nlpTaskProcessor.get()).close();
            }
            this.task.setFailed("inference process crashed due to reason [" + reason + "]");
        }

        void loadModel(TrainedModelLocation modelLocation, ActionListener<Boolean> listener) {
            if (this.isStopped) {
                listener.onFailure((Exception)new IllegalArgumentException("Process has stopped, model loading canceled"));
                return;
            }
            if (modelLocation instanceof IndexLocation) {
                IndexLocation indexLocation = (IndexLocation)modelLocation;
                ((PyTorchProcess)this.process.get()).loadModel(this.task.getParams().getModelId(), indexLocation.getIndexName(), this.stateStreamer, (ActionListener<Boolean>)ActionListener.wrap(r -> DeploymentManager.this.executorServiceForDeployment.submit(() -> listener.onResponse(r)), e -> DeploymentManager.this.executorServiceForDeployment.submit(() -> listener.onFailure(e))));
            } else {
                listener.onFailure((Exception)new IllegalStateException("unsupported trained model location [" + modelLocation.getClass().getSimpleName() + "]"));
            }
        }

        AtomicInteger getTimeoutCount() {
            return this.timeoutCount;
        }

        PriorityProcessWorkerExecutorService getPriorityProcessWorker() {
            return this.priorityProcessWorker;
        }

        AtomicInteger getRejectedExecutionCount() {
            return this.rejectedExecutionCount;
        }

        SetOnce<TrainedModelInput> getModelInput() {
            return this.modelInput;
        }

        SetOnce<PyTorchProcess> getProcess() {
            return this.process;
        }

        SetOnce<NlpTask.Processor> getNlpTaskProcessor() {
            return this.nlpTaskProcessor;
        }
    }
}

