/*
 * Decompiled with CFR 0.152.
 */
package org.elasticsearch.xpack.ml.inference.deployment;

import java.io.IOException;
import java.io.InputStream;
import java.time.Instant;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentMap;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicLong;
import java.util.function.Consumer;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.apache.logging.log4j.message.Message;
import org.apache.logging.log4j.message.ParameterizedMessage;
import org.apache.lucene.util.SetOnce;
import org.elasticsearch.ElasticsearchStatusException;
import org.elasticsearch.ResourceNotFoundException;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.action.ActionRequest;
import org.elasticsearch.action.ActionType;
import org.elasticsearch.action.search.SearchAction;
import org.elasticsearch.action.search.SearchRequest;
import org.elasticsearch.client.internal.Client;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.util.concurrent.AbstractRunnable;
import org.elasticsearch.common.util.concurrent.EsRejectedExecutionException;
import org.elasticsearch.common.xcontent.LoggingDeprecationHandler;
import org.elasticsearch.core.TimeValue;
import org.elasticsearch.index.query.IdsQueryBuilder;
import org.elasticsearch.index.query.QueryBuilder;
import org.elasticsearch.rest.RestStatus;
import org.elasticsearch.search.SearchHit;
import org.elasticsearch.threadpool.Scheduler;
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.xcontent.DeprecationHandler;
import org.elasticsearch.xcontent.NamedXContentRegistry;
import org.elasticsearch.xcontent.XContentFactory;
import org.elasticsearch.xcontent.XContentParser;
import org.elasticsearch.xcontent.XContentType;
import org.elasticsearch.xpack.core.ClientHelper;
import org.elasticsearch.xpack.core.ml.action.GetTrainedModelsAction;
import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig;
import org.elasticsearch.xpack.core.ml.inference.TrainedModelInput;
import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.IndexLocation;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.NlpConfig;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TrainedModelLocation;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.VocabularyConfig;
import org.elasticsearch.xpack.core.ml.job.messages.Messages;
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
import org.elasticsearch.xpack.ml.inference.deployment.ModelStats;
import org.elasticsearch.xpack.ml.inference.deployment.TrainedModelDeploymentTask;
import org.elasticsearch.xpack.ml.inference.nlp.NlpTask;
import org.elasticsearch.xpack.ml.inference.nlp.Vocabulary;
import org.elasticsearch.xpack.ml.inference.nlp.tokenizers.TokenizationResult;
import org.elasticsearch.xpack.ml.inference.pytorch.process.NativePyTorchProcess;
import org.elasticsearch.xpack.ml.inference.pytorch.process.PyTorchProcessFactory;
import org.elasticsearch.xpack.ml.inference.pytorch.process.PyTorchResultProcessor;
import org.elasticsearch.xpack.ml.inference.pytorch.process.PyTorchStateStreamer;
import org.elasticsearch.xpack.ml.inference.pytorch.results.PyTorchInferenceResult;
import org.elasticsearch.xpack.ml.job.process.ProcessWorkerExecutorService;

public class DeploymentManager {
    private static final Logger logger = LogManager.getLogger(DeploymentManager.class);
    private static final AtomicLong requestIdCounter = new AtomicLong(1L);
    private final Client client;
    private final NamedXContentRegistry xContentRegistry;
    private final PyTorchProcessFactory pyTorchProcessFactory;
    private final ExecutorService executorServiceForDeployment;
    private final ExecutorService executorServiceForProcess;
    private final ThreadPool threadPool;
    private final ConcurrentMap<Long, ProcessContext> processContextByAllocation = new ConcurrentHashMap<Long, ProcessContext>();

    public DeploymentManager(Client client, NamedXContentRegistry xContentRegistry, ThreadPool threadPool, PyTorchProcessFactory pyTorchProcessFactory) {
        this.client = Objects.requireNonNull(client);
        this.xContentRegistry = Objects.requireNonNull(xContentRegistry);
        this.pyTorchProcessFactory = Objects.requireNonNull(pyTorchProcessFactory);
        this.threadPool = Objects.requireNonNull(threadPool);
        this.executorServiceForDeployment = threadPool.executor("ml_utility");
        this.executorServiceForProcess = threadPool.executor("ml_job_comms");
    }

    public void startDeployment(TrainedModelDeploymentTask task, ActionListener<TrainedModelDeploymentTask> listener) {
        this.doStartDeployment(task, listener);
    }

    public Optional<ModelStats> getStats(TrainedModelDeploymentTask task) {
        return Optional.ofNullable((ProcessContext)this.processContextByAllocation.get(task.getId())).map(processContext -> {
            PyTorchResultProcessor.ResultStats stats = processContext.getResultProcessor().getResultStats();
            return new ModelStats(processContext.startTime, stats.timingStats(), stats.lastUsed(), processContext.executorService.queueSize() + stats.numberOfPendingResults(), stats.errorCount(), processContext.rejectedExecutionCount.intValue(), processContext.timeoutCount.intValue(), processContext.inferenceThreads, processContext.modelThreads);
        });
    }

    ProcessContext addProcessContext(Long id, ProcessContext processContext) {
        return this.processContextByAllocation.putIfAbsent(id, processContext);
    }

    private void doStartDeployment(TrainedModelDeploymentTask task, ActionListener<TrainedModelDeploymentTask> finalListener) {
        logger.info("[{}] Starting model deployment", (Object)task.getModelId());
        ProcessContext processContext = new ProcessContext(task, this.executorServiceForProcess);
        if (this.addProcessContext(task.getId(), processContext) != null) {
            finalListener.onFailure((Exception)((Object)ExceptionsHelper.serverError((String)"[{}] Could not create inference process as one already exists", (Object[])new Object[]{task.getModelId()})));
            return;
        }
        ActionListener listener = ActionListener.wrap(arg_0 -> finalListener.onResponse(arg_0), failure -> {
            this.processContextByAllocation.remove(task.getId());
            finalListener.onFailure(failure);
        });
        ActionListener modelLoadedListener = ActionListener.wrap(success -> {
            this.executorServiceForProcess.execute(() -> processContext.getResultProcessor().process((NativePyTorchProcess)processContext.process.get()));
            listener.onResponse((Object)task);
        }, arg_0 -> ((ActionListener)listener).onFailure(arg_0));
        ActionListener getModelListener = ActionListener.wrap(getModelResponse -> {
            assert (getModelResponse.getResources().results().size() == 1);
            TrainedModelConfig modelConfig = (TrainedModelConfig)getModelResponse.getResources().results().get(0);
            processContext.modelInput.set((Object)modelConfig.getInput());
            assert (modelConfig.getInferenceConfig() instanceof NlpConfig);
            NlpConfig nlpConfig = (NlpConfig)modelConfig.getInferenceConfig();
            task.init((InferenceConfig)nlpConfig);
            SearchRequest searchRequest = this.vocabSearchRequest(nlpConfig.getVocabularyConfig(), modelConfig.getModelId());
            ClientHelper.executeAsyncWithOrigin((Client)this.client, (String)"ml", (ActionType)SearchAction.INSTANCE, (ActionRequest)searchRequest, (ActionListener)ActionListener.wrap(searchVocabResponse -> {
                if (searchVocabResponse.getHits().getHits().length == 0) {
                    listener.onFailure((Exception)new ResourceNotFoundException(Messages.getMessage((String)"Could not find vocabulary document [{1}] for trained model [{0}]", (Object[])new Object[]{task.getModelId(), VocabularyConfig.docId((String)modelConfig.getModelId())}), new Object[0]));
                    return;
                }
                Vocabulary vocabulary = this.parseVocabularyDocLeniently(searchVocabResponse.getHits().getAt(0));
                NlpTask nlpTask = new NlpTask(nlpConfig, vocabulary);
                NlpTask.Processor processor = nlpTask.createProcessor();
                processContext.nlpTaskProcessor.set((Object)processor);
                this.executorServiceForDeployment.execute(() -> this.startAndLoad(processContext, modelConfig.getLocation(), (ActionListener<Boolean>)modelLoadedListener));
            }, arg_0 -> ((ActionListener)listener).onFailure(arg_0)));
        }, arg_0 -> ((ActionListener)listener).onFailure(arg_0));
        ClientHelper.executeAsyncWithOrigin((Client)this.client, (String)"ml", (ActionType)GetTrainedModelsAction.INSTANCE, (ActionRequest)new GetTrainedModelsAction.Request(task.getModelId()), (ActionListener)getModelListener);
    }

    private SearchRequest vocabSearchRequest(VocabularyConfig vocabularyConfig, String modelId) {
        return (SearchRequest)this.client.prepareSearch(new String[]{vocabularyConfig.getIndex()}).setQuery((QueryBuilder)new IdsQueryBuilder().addIds(new String[]{VocabularyConfig.docId((String)modelId)})).setSize(1).setTrackTotalHits(false).request();
    }

    /*
     * Enabled aggressive exception aggregation
     */
    Vocabulary parseVocabularyDocLeniently(SearchHit hit) throws IOException {
        try (StreamInput stream = hit.getSourceRef().streamInput();){
            Vocabulary vocabulary;
            block14: {
                XContentParser parser = XContentFactory.xContent((XContentType)XContentType.JSON).createParser(this.xContentRegistry, (DeprecationHandler)LoggingDeprecationHandler.INSTANCE, (InputStream)stream);
                try {
                    vocabulary = (Vocabulary)Vocabulary.createParser(true).apply(parser, null);
                    if (parser == null) break block14;
                }
                catch (Throwable throwable) {
                    if (parser != null) {
                        try {
                            parser.close();
                        }
                        catch (Throwable throwable2) {
                            throwable.addSuppressed(throwable2);
                        }
                    }
                    throw throwable;
                }
                parser.close();
            }
            return vocabulary;
        }
        catch (IOException e) {
            logger.error((Message)new ParameterizedMessage("failed to parse trained model vocabulary [{}]", (Object)hit.getId()), (Throwable)e);
            throw e;
        }
    }

    private void startAndLoad(ProcessContext processContext, TrainedModelLocation modelLocation, ActionListener<Boolean> loadedListener) {
        try {
            processContext.startProcess();
            processContext.loadModel(modelLocation, loadedListener);
        }
        catch (Exception e) {
            loadedListener.onFailure(e);
        }
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    public void stopDeployment(TrainedModelDeploymentTask task) {
        ProcessContext processContext;
        ConcurrentMap<Long, ProcessContext> concurrentMap = this.processContextByAllocation;
        synchronized (concurrentMap) {
            processContext = (ProcessContext)this.processContextByAllocation.get(task.getId());
        }
        if (processContext != null) {
            logger.info("[{}] Stopping deployment", (Object)task.getModelId());
            processContext.stopProcess();
        } else {
            logger.warn("[{}] No process context to stop", (Object)task.getModelId());
        }
    }

    public void infer(TrainedModelDeploymentTask task, InferenceConfig config, Map<String, Object> doc, TimeValue timeout, ActionListener<InferenceResults> listener) {
        if (task.isStopped()) {
            listener.onFailure((Exception)((Object)ExceptionsHelper.conflictStatusException((String)"[{}] is stopping or stopped due to [{}]", (Object[])new Object[]{task.getModelId(), task.stoppedReason().orElse("")})));
            return;
        }
        ProcessContext processContext = (ProcessContext)this.processContextByAllocation.get(task.getId());
        if (processContext == null) {
            listener.onFailure((Exception)((Object)ExceptionsHelper.conflictStatusException((String)"[{}] process context missing", (Object[])new Object[]{task.getModelId()})));
            return;
        }
        long requestId = requestIdCounter.getAndIncrement();
        InferenceAction inferenceAction = new InferenceAction(task.getModelId(), requestId, timeout, processContext, config, doc, this.threadPool, listener);
        try {
            processContext.getExecutorService().execute((Runnable)((Object)inferenceAction));
        }
        catch (EsRejectedExecutionException e) {
            processContext.getRejectedExecutionCount().incrementAndGet();
            inferenceAction.onFailure((Exception)((Object)e));
        }
        catch (Exception e) {
            inferenceAction.onFailure(e);
        }
    }

    class ProcessContext {
        private final TrainedModelDeploymentTask task;
        private final SetOnce<NativePyTorchProcess> process = new SetOnce();
        private final SetOnce<NlpTask.Processor> nlpTaskProcessor = new SetOnce();
        private final SetOnce<TrainedModelInput> modelInput = new SetOnce();
        private final PyTorchResultProcessor resultProcessor;
        private final PyTorchStateStreamer stateStreamer;
        private final ProcessWorkerExecutorService executorService;
        private volatile Instant startTime;
        private volatile Integer inferenceThreads;
        private volatile Integer modelThreads;
        private AtomicInteger rejectedExecutionCount = new AtomicInteger();
        private AtomicInteger timeoutCount = new AtomicInteger();

        ProcessContext(TrainedModelDeploymentTask task, ExecutorService executorService) {
            this.task = Objects.requireNonNull(task);
            this.resultProcessor = new PyTorchResultProcessor(task.getModelId(), threadSettings -> {
                this.inferenceThreads = threadSettings.inferenceThreads();
                this.modelThreads = threadSettings.modelThreads();
            });
            this.stateStreamer = new PyTorchStateStreamer(DeploymentManager.this.client, executorService, DeploymentManager.this.xContentRegistry);
            this.executorService = new ProcessWorkerExecutorService(DeploymentManager.this.threadPool.getThreadContext(), "inference process", task.getParams().getQueueCapacity());
        }

        PyTorchResultProcessor getResultProcessor() {
            return this.resultProcessor;
        }

        synchronized void startProcess() {
            this.process.set((Object)DeploymentManager.this.pyTorchProcessFactory.createProcess(this.task, DeploymentManager.this.executorServiceForProcess, this.onProcessCrash()));
            this.startTime = Instant.now();
            DeploymentManager.this.executorServiceForProcess.submit(this.executorService::start);
        }

        synchronized void stopProcess() {
            this.resultProcessor.stop();
            this.executorService.shutdown();
            if (this.process.get() == null) {
                return;
            }
            try {
                this.stateStreamer.cancel();
                ((NativePyTorchProcess)this.process.get()).kill(true);
                DeploymentManager.this.processContextByAllocation.remove(this.task.getId());
            }
            catch (IOException e) {
                logger.error((Message)new ParameterizedMessage("[{}] Failed to kill process", (Object)this.task.getModelId()), (Throwable)e);
            }
        }

        private Consumer<String> onProcessCrash() {
            return reason -> {
                logger.error("[{}] inference process crashed due to reason [{}]", (Object)this.task.getModelId(), reason);
                this.resultProcessor.stop();
                this.executorService.shutdownWithError(new IllegalStateException((String)reason));
                DeploymentManager.this.processContextByAllocation.remove(this.task.getId());
                this.task.setFailed("inference process crashed due to reason [" + reason + "]");
            };
        }

        void loadModel(TrainedModelLocation modelLocation, ActionListener<Boolean> listener) {
            if (!(modelLocation instanceof IndexLocation)) {
                throw new IllegalStateException("unsupported trained model location [" + modelLocation.getClass().getSimpleName() + "]");
            }
            ((NativePyTorchProcess)this.process.get()).loadModel(this.task.getModelId(), ((IndexLocation)modelLocation).getIndexName(), this.stateStreamer, listener);
        }

        AtomicInteger getTimeoutCount() {
            return this.timeoutCount;
        }

        ExecutorService getExecutorService() {
            return this.executorService;
        }

        AtomicInteger getRejectedExecutionCount() {
            return this.rejectedExecutionCount;
        }
    }

    static class InferenceAction
    extends AbstractRunnable
    implements ActionListener<InferenceResults> {
        private final String modelId;
        private final long requestId;
        private final TimeValue timeout;
        private final Scheduler.Cancellable timeoutHandler;
        private final ProcessContext processContext;
        private final InferenceConfig config;
        private final Map<String, Object> doc;
        private final ActionListener<InferenceResults> listener;
        private final AtomicBoolean notified = new AtomicBoolean();

        InferenceAction(String modelId, long requestId, TimeValue timeout, ProcessContext processContext, InferenceConfig config, Map<String, Object> doc, ThreadPool threadPool, ActionListener<InferenceResults> listener) {
            this.modelId = modelId;
            this.requestId = requestId;
            this.timeout = timeout;
            this.processContext = processContext;
            this.config = config;
            this.doc = doc;
            this.listener = listener;
            this.timeoutHandler = threadPool.schedule(this::onTimeout, (TimeValue)ExceptionsHelper.requireNonNull((Object)timeout, (String)"timeout"), "ml_utility");
        }

        void onTimeout() {
            if (this.notified.compareAndSet(false, true)) {
                this.processContext.getTimeoutCount().incrementAndGet();
                this.processContext.getResultProcessor().ignoreResponseWithoutNotifying(String.valueOf(this.requestId));
                this.listener.onFailure((Exception)((Object)new ElasticsearchStatusException("timeout [{}] waiting for inference result", RestStatus.REQUEST_TIMEOUT, new Object[]{this.timeout})));
                return;
            }
            logger.debug("[{}] request [{}] received timeout after [{}] but listener already alerted", (Object)this.modelId, (Object)this.requestId, (Object)this.timeout);
        }

        public void onResponse(InferenceResults inferenceResults) {
            this.onSuccess(inferenceResults);
        }

        void onSuccess(InferenceResults inferenceResults) {
            this.timeoutHandler.cancel();
            if (this.notified.compareAndSet(false, true)) {
                this.listener.onResponse((Object)inferenceResults);
                return;
            }
            logger.debug("[{}] request [{}] received inference response but listener already notified", (Object)this.modelId, (Object)this.requestId);
        }

        public void onFailure(Exception e) {
            this.timeoutHandler.cancel();
            if (this.notified.compareAndSet(false, true)) {
                this.processContext.getResultProcessor().ignoreResponseWithoutNotifying(String.valueOf(this.requestId));
                this.listener.onFailure(e);
                return;
            }
            logger.debug(() -> new ParameterizedMessage("[{}] request [{}] received failure but listener already notified", (Object)this.modelId, (Object)this.requestId), (Throwable)e);
        }

        protected void doRun() throws Exception {
            if (this.notified.get()) {
                logger.debug(() -> new ParameterizedMessage("[{}] skipping inference on request [{}] as it has timed out", (Object)this.modelId, (Object)this.requestId));
                return;
            }
            String requestIdStr = String.valueOf(this.requestId);
            try {
                List<String> text = Collections.singletonList(NlpTask.extractInput((TrainedModelInput)this.processContext.modelInput.get(), this.doc));
                NlpTask.Processor processor = (NlpTask.Processor)this.processContext.nlpTaskProcessor.get();
                processor.validateInputs(text);
                assert (this.config instanceof NlpConfig);
                NlpConfig nlpConfig = (NlpConfig)this.config;
                NlpTask.Request request = processor.getRequestBuilder(nlpConfig).buildRequest(text, requestIdStr, nlpConfig.getTokenization().getTruncate());
                logger.debug(() -> "Inference Request " + request.processInput.utf8ToString());
                if (request.tokenization.anyTruncated()) {
                    logger.debug("[{}] [{}] input truncated", (Object)this.modelId, (Object)this.requestId);
                }
                this.processContext.getResultProcessor().registerRequest(requestIdStr, (ActionListener<PyTorchInferenceResult>)ActionListener.wrap(inferenceResult -> this.processResult((PyTorchInferenceResult)inferenceResult, this.processContext, request.tokenization, processor.getResultProcessor((NlpConfig)this.config), this), this::onFailure));
                ((NativePyTorchProcess)this.processContext.process.get()).writeInferenceRequest(request.processInput);
            }
            catch (IOException e) {
                logger.error((Message)new ParameterizedMessage("[{}] error writing to inference process", (Object)this.processContext.task.getModelId()), (Throwable)e);
                this.onFailure((Exception)((Object)ExceptionsHelper.serverError((String)"Error writing to inference process", (Throwable)e)));
            }
            catch (Exception e) {
                this.onFailure(e);
            }
        }

        private void processResult(PyTorchInferenceResult inferenceResult, ProcessContext context, TokenizationResult tokenization, NlpTask.ResultProcessor inferenceResultsProcessor, ActionListener<InferenceResults> resultsListener) {
            if (inferenceResult.isError()) {
                resultsListener.onFailure((Exception)((Object)new ElasticsearchStatusException("Error in inference process: [" + inferenceResult.getError() + "]", RestStatus.INTERNAL_SERVER_ERROR, new Object[0])));
                return;
            }
            logger.debug(() -> new ParameterizedMessage("[{}] retrieved result for request [{}]", (Object)context.task.getModelId(), (Object)this.requestId));
            if (this.notified.get()) {
                logger.debug(() -> new ParameterizedMessage("[{}] skipping result processing for request [{}] as the request has timed out", (Object)context.task.getModelId(), (Object)this.requestId));
                return;
            }
            InferenceResults results = inferenceResultsProcessor.processResult(tokenization, inferenceResult);
            logger.debug(() -> new ParameterizedMessage("[{}] processed result for request [{}]", (Object)context.task.getModelId(), (Object)this.requestId));
            resultsListener.onResponse((Object)results);
        }
    }
}

