/*
 * Decompiled with CFR 0.152.
 */
package org.elasticsearch.xpack.ml.inference.assignment;

import java.util.ArrayList;
import java.util.Collections;
import java.util.Deque;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentLinkedDeque;
import java.util.concurrent.Executor;
import java.util.concurrent.ExecutorService;
import java.util.function.Consumer;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.elasticsearch.ElasticsearchException;
import org.elasticsearch.ResourceNotFoundException;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.action.search.SearchPhaseExecutionException;
import org.elasticsearch.action.support.SubscribableListener;
import org.elasticsearch.action.support.master.AcknowledgedResponse;
import org.elasticsearch.cluster.ClusterChangedEvent;
import org.elasticsearch.cluster.ClusterState;
import org.elasticsearch.cluster.ClusterStateListener;
import org.elasticsearch.cluster.metadata.IndexNameExpressionResolver;
import org.elasticsearch.cluster.service.ClusterService;
import org.elasticsearch.common.component.LifecycleListener;
import org.elasticsearch.common.unit.ByteSizeValue;
import org.elasticsearch.common.util.concurrent.ThreadContext;
import org.elasticsearch.common.util.set.Sets;
import org.elasticsearch.core.Strings;
import org.elasticsearch.core.TimeValue;
import org.elasticsearch.inference.InferenceResults;
import org.elasticsearch.license.XPackLicenseState;
import org.elasticsearch.tasks.CancellableTask;
import org.elasticsearch.tasks.Task;
import org.elasticsearch.tasks.TaskAwareRequest;
import org.elasticsearch.tasks.TaskId;
import org.elasticsearch.tasks.TaskManager;
import org.elasticsearch.threadpool.Scheduler;
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.xpack.core.ml.MlMetadata;
import org.elasticsearch.xpack.core.ml.action.StartTrainedModelDeploymentAction;
import org.elasticsearch.xpack.core.ml.action.UpdateTrainedModelAssignmentRoutingInfoAction;
import org.elasticsearch.xpack.core.ml.inference.TrainedModelPrefixStrings;
import org.elasticsearch.xpack.core.ml.inference.assignment.AssignmentState;
import org.elasticsearch.xpack.core.ml.inference.assignment.RoutingInfo;
import org.elasticsearch.xpack.core.ml.inference.assignment.RoutingInfoUpdate;
import org.elasticsearch.xpack.core.ml.inference.assignment.RoutingState;
import org.elasticsearch.xpack.core.ml.inference.assignment.RoutingStateAndReason;
import org.elasticsearch.xpack.core.ml.inference.assignment.TrainedModelAssignment;
import org.elasticsearch.xpack.core.ml.inference.assignment.TrainedModelAssignmentMetadata;
import org.elasticsearch.xpack.core.ml.inference.persistence.InferenceIndexConstants;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig;
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
import org.elasticsearch.xpack.ml.MachineLearning;
import org.elasticsearch.xpack.ml.inference.assignment.TrainedModelAssignmentService;
import org.elasticsearch.xpack.ml.inference.deployment.DeploymentManager;
import org.elasticsearch.xpack.ml.inference.deployment.ModelStats;
import org.elasticsearch.xpack.ml.inference.deployment.NlpInferenceInput;
import org.elasticsearch.xpack.ml.inference.deployment.TrainedModelDeploymentTask;
import org.elasticsearch.xpack.ml.inference.pytorch.results.ThreadSettings;
import org.elasticsearch.xpack.ml.task.AbstractJobPersistentTasksExecutor;

public class TrainedModelAssignmentNodeService
implements ClusterStateListener {
    private static final String NODE_NO_LONGER_REFERENCED = "node no longer referenced in model routing table";
    private static final String ASSIGNMENT_NO_LONGER_EXISTS = "deployment assignment no longer exists";
    private static final TimeValue MODEL_LOADING_CHECK_INTERVAL = TimeValue.timeValueSeconds((long)1L);
    private static final TimeValue CONTROL_MESSAGE_TIMEOUT = TimeValue.timeValueSeconds((long)60L);
    private static final Logger logger = LogManager.getLogger(TrainedModelAssignmentNodeService.class);
    private final TrainedModelAssignmentService trainedModelAssignmentService;
    private final DeploymentManager deploymentManager;
    private final TaskManager taskManager;
    private final Map<String, TrainedModelDeploymentTask> deploymentIdToTask;
    private final ThreadPool threadPool;
    private final Deque<TrainedModelDeploymentTask> loadingModels;
    private final XPackLicenseState licenseState;
    private final IndexNameExpressionResolver expressionResolver;
    private volatile Scheduler.Cancellable scheduledFuture;
    private volatile ClusterState latestState;
    private volatile boolean stopped;
    private volatile String nodeId;

    public TrainedModelAssignmentNodeService(TrainedModelAssignmentService trainedModelAssignmentService, final ClusterService clusterService, DeploymentManager deploymentManager, IndexNameExpressionResolver expressionResolver, TaskManager taskManager, ThreadPool threadPool, XPackLicenseState licenseState) {
        this.trainedModelAssignmentService = trainedModelAssignmentService;
        this.deploymentManager = deploymentManager;
        this.taskManager = taskManager;
        this.deploymentIdToTask = new ConcurrentHashMap<String, TrainedModelDeploymentTask>();
        this.loadingModels = new ConcurrentLinkedDeque<TrainedModelDeploymentTask>();
        this.threadPool = threadPool;
        this.licenseState = licenseState;
        clusterService.addLifecycleListener(new LifecycleListener(){

            public void afterStart() {
                TrainedModelAssignmentNodeService.this.nodeId = clusterService.localNode().getId();
                TrainedModelAssignmentNodeService.this.start();
            }

            public void beforeStop() {
                TrainedModelAssignmentNodeService.this.stop();
            }
        });
        this.expressionResolver = expressionResolver;
    }

    TrainedModelAssignmentNodeService(TrainedModelAssignmentService trainedModelAssignmentService, ClusterService clusterService, DeploymentManager deploymentManager, IndexNameExpressionResolver expressionResolver, TaskManager taskManager, ThreadPool threadPool, String nodeId, XPackLicenseState licenseState) {
        this.trainedModelAssignmentService = trainedModelAssignmentService;
        this.deploymentManager = deploymentManager;
        this.taskManager = taskManager;
        this.deploymentIdToTask = new ConcurrentHashMap<String, TrainedModelDeploymentTask>();
        this.loadingModels = new ConcurrentLinkedDeque<TrainedModelDeploymentTask>();
        this.threadPool = threadPool;
        this.nodeId = nodeId;
        this.licenseState = licenseState;
        clusterService.addLifecycleListener(new LifecycleListener(){

            public void afterStart() {
                TrainedModelAssignmentNodeService.this.start();
            }

            public void beforeStop() {
                TrainedModelAssignmentNodeService.this.stop();
            }
        });
        this.expressionResolver = expressionResolver;
    }

    void start() {
        this.stopped = false;
        this.schedule(false);
    }

    private void schedule(boolean runImmediately) {
        if (this.stopped) {
            return;
        }
        ActionListener rescheduleListener = ActionListener.wrap(this::schedule, e -> this.schedule(false));
        Runnable loadQueuedModels = () -> this.loadQueuedModels((ActionListener<Boolean>)rescheduleListener);
        ExecutorService executor = this.threadPool.executor("ml_utility");
        if (runImmediately) {
            executor.execute(loadQueuedModels);
        } else {
            this.scheduledFuture = this.threadPool.schedule(loadQueuedModels, MODEL_LOADING_CHECK_INTERVAL, (Executor)executor);
        }
    }

    void stop() {
        this.stopped = true;
        Scheduler.Cancellable cancellable = this.scheduledFuture;
        if (cancellable != null) {
            cancellable.cancel();
        }
    }

    void loadQueuedModels(ActionListener<Boolean> rescheduleImmediately) {
        List<String> unassignedIndices;
        if (this.stopped) {
            rescheduleImmediately.onResponse((Object)false);
            return;
        }
        if (this.latestState != null && (unassignedIndices = AbstractJobPersistentTasksExecutor.verifyIndicesPrimaryShardsAreActive(this.latestState, this.expressionResolver, true, ".ml-inference-*", InferenceIndexConstants.nativeDefinitionStore())).size() > 0) {
            logger.trace("not loading models as indices {} primary shards are unassigned", unassignedIndices);
            rescheduleImmediately.onResponse((Object)false);
            return;
        }
        TrainedModelDeploymentTask loadingTask = this.loadingModels.poll();
        if (loadingTask == null) {
            rescheduleImmediately.onResponse((Object)false);
            return;
        }
        this.loadModel(loadingTask, (ActionListener<Boolean>)ActionListener.wrap(retry -> {
            if (retry != null && retry.booleanValue()) {
                this.loadingModels.offer(loadingTask);
                rescheduleImmediately.onResponse((Object)(this.loadingModels.peek() != loadingTask ? 1 : 0));
            } else {
                rescheduleImmediately.onResponse((Object)(!this.loadingModels.isEmpty() ? 1 : 0));
            }
        }, e -> rescheduleImmediately.onResponse((Object)(!this.loadingModels.isEmpty() ? 1 : 0))));
    }

    void loadModel(TrainedModelDeploymentTask loadingTask, ActionListener<Boolean> retryListener) {
        if (loadingTask.isStopped()) {
            if (logger.isTraceEnabled()) {
                logger.trace("[{}] attempted to load stopped task with reason [{}]", (Object)loadingTask.getDeploymentId(), (Object)loadingTask.stoppedReason().orElse("_unknown_"));
            }
            retryListener.onResponse((Object)false);
            return;
        }
        SubscribableListener.newForked(l -> this.deploymentManager.startDeployment(loadingTask, (ActionListener<TrainedModelDeploymentTask>)l)).andThen((Executor)this.threadPool.executor("ml_utility"), this.threadPool.getThreadContext(), this::handleLoadSuccess).addListener(retryListener.delegateResponse((retryL, ex) -> {
            String deploymentId = loadingTask.getDeploymentId();
            logger.warn(() -> "[" + deploymentId + "] Start deployment failed", (Throwable)ex);
            if (ExceptionsHelper.unwrapCause((Throwable)ex) instanceof ResourceNotFoundException) {
                String modelId = loadingTask.getParams().getModelId();
                logger.debug(() -> "[" + deploymentId + "] Start deployment failed as model [" + modelId + "] was not found", (Throwable)ex);
                this.handleLoadFailure(loadingTask, (Exception)((Object)ExceptionsHelper.missingTrainedModel((String)modelId, (Exception)ex)), (ActionListener<Boolean>)retryL);
            } else if (ExceptionsHelper.unwrapCause((Throwable)ex) instanceof SearchPhaseExecutionException) {
                logger.debug(() -> "[" + deploymentId + "] Start deployment failed, will retry", (Throwable)ex);
                retryL.onResponse((Object)true);
            } else {
                this.handleLoadFailure(loadingTask, (Exception)ex, (ActionListener<Boolean>)retryL);
            }
        }), (Executor)this.threadPool.executor("ml_utility"), this.threadPool.getThreadContext());
    }

    public void gracefullyStopDeploymentAndNotify(TrainedModelDeploymentTask task, String reason, ActionListener<AcknowledgedResponse> listener) {
        logger.debug(() -> Strings.format((String)"[%s] Gracefully stopping deployment due to reason %s", (Object[])new Object[]{task.getDeploymentId(), reason}));
        this.stopAndNotifyHelper(task, reason, listener, this.deploymentManager::stopAfterCompletingPendingWork);
    }

    public void stopDeploymentAndNotify(TrainedModelDeploymentTask task, String reason, ActionListener<AcknowledgedResponse> listener) {
        logger.debug(() -> Strings.format((String)"[%s] Forcefully stopping deployment due to reason %s", (Object[])new Object[]{task.getDeploymentId(), reason}));
        this.stopAndNotifyHelper(task, reason, listener, this.deploymentManager::stopDeployment);
    }

    private void stopAndNotifyHelper(TrainedModelDeploymentTask task, String reason, ActionListener<AcknowledgedResponse> listener, Consumer<TrainedModelDeploymentTask> stopDeploymentFunc) {
        this.deploymentIdToTask.remove(task.getDeploymentId());
        ActionListener<Void> notifyDeploymentOfStopped = this.updateRoutingStateToStoppedListener(task.getDeploymentId(), reason, listener);
        this.updateStoredState(task.getDeploymentId(), RoutingInfoUpdate.updateStateAndReason((RoutingStateAndReason)new RoutingStateAndReason(RoutingState.STOPPING, reason)), (ActionListener<AcknowledgedResponse>)ActionListener.wrap(success -> this.stopDeploymentHelper(task, reason, stopDeploymentFunc, notifyDeploymentOfStopped), e -> {
            if (ExceptionsHelper.unwrapCause((Throwable)e) instanceof ResourceNotFoundException) {
                logger.debug(() -> Strings.format((String)"[%s] failed to set routing state to stopping as assignment already removed", (Object[])new Object[]{task.getDeploymentId()}), (Throwable)e);
            } else {
                logger.warn(() -> "[" + task.getDeploymentId() + "] failed to set routing state to stopping due to error", (Throwable)e);
            }
            this.stopDeploymentHelper(task, reason, stopDeploymentFunc, notifyDeploymentOfStopped);
        }));
    }

    public void infer(TrainedModelDeploymentTask task, InferenceConfig config, NlpInferenceInput input, boolean skipQueue, TimeValue timeout, TrainedModelPrefixStrings.PrefixType prefixType, CancellableTask parentActionTask, boolean chunkResponse, ActionListener<InferenceResults> listener) {
        this.deploymentManager.infer(task, config, input, skipQueue, timeout, prefixType, parentActionTask, chunkResponse, listener);
    }

    public Optional<ModelStats> modelStats(TrainedModelDeploymentTask task) {
        return this.deploymentManager.getStats(task);
    }

    public void clearCache(TrainedModelDeploymentTask task, ActionListener<AcknowledgedResponse> listener) {
        this.deploymentManager.clearCache(task, CONTROL_MESSAGE_TIMEOUT, listener);
    }

    private TaskAwareRequest taskAwareRequest(final StartTrainedModelDeploymentAction.TaskParams params) {
        final TrainedModelAssignmentNodeService trainedModelAssignmentNodeService = this;
        return new TaskAwareRequest(){

            public void setParentTask(TaskId taskId) {
                throw new UnsupportedOperationException("parent task id for model deployment tasks shouldn't change");
            }

            public void setRequestId(long requestId) {
                throw new UnsupportedOperationException("does not have request ID");
            }

            public TaskId getParentTask() {
                return TaskId.EMPTY_TASK_ID;
            }

            public Task createTask(long id, String type, String action, TaskId parentTaskId, Map<String, String> headers) {
                return new TrainedModelDeploymentTask(id, type, action, parentTaskId, headers, params, trainedModelAssignmentNodeService, TrainedModelAssignmentNodeService.this.licenseState, MachineLearning.ML_PYTORCH_MODEL_INFERENCE_FEATURE);
            }
        };
    }

    public void clusterChanged(ClusterChangedEvent event) {
        this.latestState = event.state();
        if (!event.metadataChanged()) {
            return;
        }
        boolean isResetMode = MlMetadata.getMlMetadata((ClusterState)event.state()).isResetMode();
        TrainedModelAssignmentMetadata modelAssignmentMetadata = TrainedModelAssignmentMetadata.fromState((ClusterState)event.state());
        String currentNode = event.state().nodes().getLocalNodeId();
        Set<String> shuttingDownNodes = Collections.unmodifiableSet(event.state().metadata().nodeShutdowns().getAllNodeIds());
        if (!isResetMode) {
            this.updateNumberOfAllocations(modelAssignmentMetadata);
        }
        for (TrainedModelAssignment trainedModelAssignment : modelAssignmentMetadata.allAssignments().values()) {
            RoutingInfo routingInfo = (RoutingInfo)trainedModelAssignment.getNodeRoutingTable().get(currentNode);
            if (routingInfo != null) {
                if (trainedModelAssignment.getAssignmentState() != AssignmentState.STOPPING) {
                    if (this.shouldAssignmentBeRestarted(routingInfo, trainedModelAssignment.getDeploymentId())) {
                        this.prepareAssignmentForRestart(trainedModelAssignment);
                    }
                    if (this.shouldLoadModel(routingInfo, trainedModelAssignment.getDeploymentId(), isResetMode)) {
                        StartTrainedModelDeploymentAction.TaskParams params = TrainedModelAssignmentNodeService.createStartTrainedModelDeploymentTaskParams(trainedModelAssignment, routingInfo.getCurrentAllocations());
                        try (ThreadContext.StoredContext ignored = this.threadPool.getThreadContext().newTraceContext();){
                            this.prepareModelToLoad(params);
                        }
                    }
                }
                if (!this.shouldGracefullyShutdownDeployment(trainedModelAssignment, shuttingDownNodes, currentNode)) continue;
                this.gracefullyStopDeployment(trainedModelAssignment.getDeploymentId(), currentNode);
                continue;
            }
            this.stopUnreferencedDeployment(trainedModelAssignment.getDeploymentId(), currentNode);
        }
        ArrayList<TrainedModelDeploymentTask> toCancel = new ArrayList<TrainedModelDeploymentTask>();
        for (String deploymentIds : Sets.difference(this.deploymentIdToTask.keySet(), modelAssignmentMetadata.allAssignments().keySet())) {
            toCancel.add(this.deploymentIdToTask.remove(deploymentIds));
        }
        for (TrainedModelDeploymentTask t : toCancel) {
            this.stopDeploymentAsync(t, ASSIGNMENT_NO_LONGER_EXISTS, (ActionListener<Void>)ActionListener.wrap(r -> logger.trace(() -> "[" + t.getDeploymentId() + "] stopped deployment"), e -> logger.warn(() -> "[" + t.getDeploymentId() + "] failed to fully stop deployment", (Throwable)e)));
        }
    }

    private boolean shouldAssignmentBeRestarted(RoutingInfo routingInfo, String deploymentId) {
        return routingInfo.getState() == RoutingState.STARTING && this.deploymentIdToTask.containsKey(deploymentId) && this.deploymentIdToTask.get(deploymentId).isFailed();
    }

    private void prepareAssignmentForRestart(TrainedModelAssignment trainedModelAssignment) {
        this.taskManager.unregister((Task)this.deploymentIdToTask.get(trainedModelAssignment.getDeploymentId()));
        this.deploymentIdToTask.remove(trainedModelAssignment.getDeploymentId());
    }

    private boolean shouldLoadModel(RoutingInfo routingInfo, String deploymentId, boolean isResetMode) {
        return routingInfo.getState().isAnyOf(new RoutingState[]{RoutingState.STARTING, RoutingState.STARTED}) && !this.deploymentIdToTask.containsKey(deploymentId) && !isResetMode;
    }

    private static StartTrainedModelDeploymentAction.TaskParams createStartTrainedModelDeploymentTaskParams(TrainedModelAssignment trainedModelAssignment, int currentAllocations) {
        return new StartTrainedModelDeploymentAction.TaskParams(trainedModelAssignment.getTaskParams().getModelId(), trainedModelAssignment.getDeploymentId(), trainedModelAssignment.getTaskParams().getModelBytes(), currentAllocations, trainedModelAssignment.getTaskParams().getThreadsPerAllocation(), trainedModelAssignment.getTaskParams().getQueueCapacity(), (ByteSizeValue)trainedModelAssignment.getTaskParams().getCacheSize().orElse(null), trainedModelAssignment.getTaskParams().getPriority(), trainedModelAssignment.getTaskParams().getPerDeploymentMemoryBytes(), trainedModelAssignment.getTaskParams().getPerAllocationMemoryBytes());
    }

    private boolean shouldGracefullyShutdownDeployment(TrainedModelAssignment trainedModelAssignment, Set<String> shuttingDownNodes, String currentNode) {
        boolean assignmentIsRoutedToOneOrFewerNodes;
        RoutingInfo routingInfo = (RoutingInfo)trainedModelAssignment.getNodeRoutingTable().get(currentNode);
        if (routingInfo == null) {
            return true;
        }
        boolean isCurrentNodeShuttingDown = shuttingDownNodes.contains(currentNode);
        boolean isRouteStopping = routingInfo.getState() == RoutingState.STOPPING;
        boolean hasDeploymentTask = this.deploymentIdToTask.containsKey(trainedModelAssignment.getDeploymentId());
        boolean hasStartedRoutes = trainedModelAssignment.hasStartedRoutes();
        boolean bl = assignmentIsRoutedToOneOrFewerNodes = trainedModelAssignment.getNodeRoutingTable().size() <= 1;
        if (isCurrentNodeShuttingDown && isRouteStopping && hasDeploymentTask) {
            logger.debug(() -> Strings.format((String)"[%s] Checking if deployment can be gracefully shutdown on node %s, has other started routes: %s, single or no routed nodes: %s", (Object[])new Object[]{trainedModelAssignment.getDeploymentId(), currentNode, hasStartedRoutes, assignmentIsRoutedToOneOrFewerNodes}));
        }
        return isCurrentNodeShuttingDown && isRouteStopping && hasDeploymentTask && (hasStartedRoutes || assignmentIsRoutedToOneOrFewerNodes);
    }

    private void gracefullyStopDeployment(String deploymentId, String currentNode) {
        logger.debug(() -> Strings.format((String)"[%s] Gracefully stopping deployment for shutting down node %s", (Object[])new Object[]{deploymentId, currentNode}));
        TrainedModelDeploymentTask task = this.deploymentIdToTask.remove(deploymentId);
        if (task == null) {
            logger.debug(() -> Strings.format((String)"[%s] Unable to gracefully stop deployment for shutting down node %s because task does not exit", (Object[])new Object[]{deploymentId, currentNode}));
            return;
        }
        ActionListener routingStateListener = ActionListener.wrap(r -> logger.debug(() -> Strings.format((String)"[%s] Gracefully stopped deployment for shutting down node %s", (Object[])new Object[]{task.getDeploymentId(), currentNode})), e -> logger.error(() -> Strings.format((String)"[%s] Failed to gracefully stop deployment for shutting down node %s", (Object[])new Object[]{task.getDeploymentId(), currentNode}), (Throwable)e));
        ActionListener<Void> notifyDeploymentOfStopped = this.updateRoutingStateToStoppedListener(task.getDeploymentId(), "node is shutting down", (ActionListener<AcknowledgedResponse>)routingStateListener);
        this.stopDeploymentAfterCompletingPendingWorkAsync(task, notifyDeploymentOfStopped);
    }

    private ActionListener<Void> updateRoutingStateToStoppedListener(String deploymentId, String reason, ActionListener<AcknowledgedResponse> listener) {
        RoutingInfoUpdate updateToStopped = RoutingInfoUpdate.updateStateAndReason((RoutingStateAndReason)new RoutingStateAndReason(RoutingState.STOPPED, reason));
        return ActionListener.wrap(_void -> {
            logger.debug(() -> Strings.format((String)"[%s] Updating routing state to stopped", (Object[])new Object[]{deploymentId}));
            this.updateStoredState(deploymentId, updateToStopped, listener);
        }, e -> {
            logger.warn(() -> Strings.format((String)"[%s] Failed to stop deployment due to error", (Object[])new Object[]{deploymentId}), (Throwable)e);
            this.updateStoredState(deploymentId, updateToStopped, listener);
        });
    }

    private void stopUnreferencedDeployment(String deploymentId, String currentNode) {
        TrainedModelDeploymentTask task = this.deploymentIdToTask.remove(deploymentId);
        if (task == null) {
            return;
        }
        logger.debug(() -> Strings.format((String)"[%s] Stopping unreferenced deployment for node %s", (Object[])new Object[]{deploymentId, currentNode}));
        this.stopDeploymentAsync(task, NODE_NO_LONGER_REFERENCED, (ActionListener<Void>)ActionListener.wrap(r -> logger.trace(() -> "[" + task.getDeploymentId() + "] stopped deployment"), e -> logger.warn(() -> "[" + task.getDeploymentId() + "] failed to fully stop deployment", (Throwable)e)));
    }

    private void stopDeploymentAsync(TrainedModelDeploymentTask task, String reason, ActionListener<Void> listener) {
        this.stopDeploymentHelper(task, reason, this.deploymentManager::stopDeployment, listener);
    }

    private void stopDeploymentHelper(TrainedModelDeploymentTask task, String reason, Consumer<TrainedModelDeploymentTask> stopDeploymentFunc, ActionListener<Void> listener) {
        if (this.stopped) {
            return;
        }
        task.markAsStopped(reason);
        this.threadPool.executor("ml_utility").execute(() -> {
            try {
                stopDeploymentFunc.accept(task);
                this.taskManager.unregister((Task)task);
                this.deploymentIdToTask.remove(task.getDeploymentId());
                listener.onResponse(null);
            }
            catch (Exception e) {
                listener.onFailure(e);
            }
        });
    }

    private void stopDeploymentAfterCompletingPendingWorkAsync(TrainedModelDeploymentTask task, ActionListener<Void> listener) {
        this.stopDeploymentHelper(task, "node is shutting down", this.deploymentManager::stopAfterCompletingPendingWork, listener);
    }

    private void updateNumberOfAllocations(TrainedModelAssignmentMetadata assignments) {
        List<TrainedModelAssignment> assignmentsToUpdate = assignments.allAssignments().values().stream().filter(a -> !TrainedModelAssignmentNodeService.hasStartingAssignments(a)).filter(a -> a.isRoutedToNode(this.nodeId)).filter(a -> {
            RoutingInfo routingInfo = (RoutingInfo)a.getNodeRoutingTable().get(this.nodeId);
            return routingInfo.getState() == RoutingState.STARTED && routingInfo.getCurrentAllocations() != routingInfo.getTargetAllocations();
        }).toList();
        for (TrainedModelAssignment assignment : assignmentsToUpdate) {
            TrainedModelDeploymentTask task = this.deploymentIdToTask.get(assignment.getDeploymentId());
            if (task == null) {
                logger.debug(() -> Strings.format((String)"[%s] task was removed whilst updating number of allocations", (Object[])new Object[]{assignment.getDeploymentId()}));
                continue;
            }
            RoutingInfo routingInfo = (RoutingInfo)assignment.getNodeRoutingTable().get(this.nodeId);
            this.deploymentManager.updateNumAllocations(task, ((RoutingInfo)assignment.getNodeRoutingTable().get(this.nodeId)).getTargetAllocations(), CONTROL_MESSAGE_TIMEOUT, (ActionListener<ThreadSettings>)ActionListener.wrap(threadSettings -> {
                logger.debug("[{}] Updated number of allocations to [{}]", (Object)assignment.getDeploymentId(), (Object)threadSettings.numAllocations());
                task.updateNumberOfAllocations(threadSettings.numAllocations());
                this.updateStoredState(assignment.getDeploymentId(), RoutingInfoUpdate.updateNumberOfAllocations((int)threadSettings.numAllocations()), (ActionListener<AcknowledgedResponse>)ActionListener.noop());
            }, e -> logger.error(Strings.format((String)"[%s] Could not update number of allocations to [%s]", (Object[])new Object[]{assignment.getDeploymentId(), routingInfo.getTargetAllocations()}), (Throwable)e)));
        }
    }

    private static boolean hasStartingAssignments(TrainedModelAssignment assignment) {
        return assignment.getNodeRoutingTable().values().stream().anyMatch(routingInfo -> routingInfo.getState().isAnyOf(new RoutingState[]{RoutingState.STARTING}));
    }

    TrainedModelDeploymentTask getTask(String deploymentId) {
        return this.deploymentIdToTask.get(deploymentId);
    }

    void prepareModelToLoad(StartTrainedModelDeploymentAction.TaskParams taskParams) {
        logger.debug(() -> Strings.format((String)"[%s] preparing to load model [%s] with task params: %s", (Object[])new Object[]{taskParams.getDeploymentId(), taskParams.getModelId(), taskParams}));
        TrainedModelDeploymentTask task = (TrainedModelDeploymentTask)this.taskManager.register("trained_model_assignment", "xpack/ml/trained_model_assignment[n]", this.taskAwareRequest(taskParams), false);
        if (this.deploymentIdToTask.putIfAbsent(taskParams.getDeploymentId(), task) == null) {
            this.loadingModels.offer(task);
        } else {
            this.taskManager.unregister((Task)task);
        }
    }

    private void handleLoadSuccess(ActionListener<Boolean> retryListener, TrainedModelDeploymentTask task) {
        logger.debug(() -> "[" + task.getParams().getDeploymentId() + "] model [" + task.getParams().getModelId() + "] successfully loaded and ready for inference. Notifying master node");
        if (task.isStopped()) {
            logger.debug(() -> Strings.format((String)"[%s] model [%s] loaded successfully, but stopped before routing table was updated; reason [%s]", (Object[])new Object[]{task.getDeploymentId(), task.getParams().getModelId(), task.stoppedReason().orElse("_unknown_")}));
            retryListener.onResponse((Object)false);
            return;
        }
        this.updateStoredState(task.getDeploymentId(), RoutingInfoUpdate.updateStateAndReason((RoutingStateAndReason)new RoutingStateAndReason(RoutingState.STARTED, "")), (ActionListener<AcknowledgedResponse>)ActionListener.runAfter((ActionListener)ActionListener.wrap(r -> logger.debug(() -> "[" + task.getDeploymentId() + "] model loaded and accepting routes"), e -> {
            if (ExceptionsHelper.unwrapCause((Throwable)e) instanceof ResourceNotFoundException) {
                logger.debug(() -> Strings.format((String)"[%s] model [%s] loaded but failed to start accepting routes as assignment to this node was removed", (Object[])new Object[]{task.getDeploymentId(), task.getParams().getModelId()}), (Throwable)e);
            } else {
                logger.warn(() -> "[" + task.getDeploymentId() + "] model [" + task.getParams().getModelId() + "] loaded but failed to start accepting routes", (Throwable)e);
            }
        }), () -> retryListener.onResponse((Object)false)));
    }

    private void updateStoredState(String deploymentId, RoutingInfoUpdate update, ActionListener<AcknowledgedResponse> listener) {
        if (this.stopped) {
            return;
        }
        this.trainedModelAssignmentService.updateModelAssignmentState(new UpdateTrainedModelAssignmentRoutingInfoAction.Request(this.nodeId, deploymentId, update), (ActionListener<AcknowledgedResponse>)ActionListener.wrap(success -> {
            logger.debug(() -> Strings.format((String)"[%s] deployment routing info was updated with [%s] and master notified", (Object[])new Object[]{deploymentId, update}));
            listener.onResponse((Object)AcknowledgedResponse.TRUE);
        }, error -> {
            logger.warn(() -> Strings.format((String)"[%s] failed to update deployment routing info with [%s]", (Object[])new Object[]{deploymentId, update}), (Throwable)error);
            listener.onFailure(error);
        }));
    }

    private void handleLoadFailure(TrainedModelDeploymentTask task, Exception ex, ActionListener<Boolean> retryListener) {
        ElasticsearchException esEx;
        if (ex instanceof ElasticsearchException && (esEx = (ElasticsearchException)((Object)ex)).status().getStatus() < 500) {
            logger.warn(() -> "[" + task.getDeploymentId() + "] model [" + task.getParams().getModelId() + "] failed to load", (Throwable)ex);
        } else {
            logger.error(() -> "[" + task.getDeploymentId() + "] model [" + task.getParams().getModelId() + "] failed to load", (Throwable)ex);
        }
        if (task.isStopped()) {
            logger.debug(() -> Strings.format((String)("[%s] model [" + task.getParams().getModelId() + "] failed to load, but is now stopped; reason [%s]"), (Object[])new Object[]{task.getDeploymentId(), task.getParams().getModelId(), task.stoppedReason().orElse("_unknown_")}));
        }
        Runnable stopTask = () -> this.stopDeploymentAsync(task, "model failed to load; reason [" + ex.getMessage() + "]", (ActionListener<Void>)ActionListener.running(() -> retryListener.onResponse((Object)false)));
        this.updateStoredState(task.getDeploymentId(), RoutingInfoUpdate.updateStateAndReason((RoutingStateAndReason)new RoutingStateAndReason(RoutingState.FAILED, ExceptionsHelper.unwrapCause((Throwable)ex).getMessage())), (ActionListener<AcknowledgedResponse>)ActionListener.running((Runnable)stopTask));
    }

    public void failAssignment(TrainedModelDeploymentTask task, String reason) {
        this.updateStoredState(task.getDeploymentId(), RoutingInfoUpdate.updateStateAndReason((RoutingStateAndReason)new RoutingStateAndReason(RoutingState.FAILED, reason)), (ActionListener<AcknowledgedResponse>)ActionListener.wrap(r -> logger.debug(() -> Strings.format((String)"[%s] Successfully updating assignment state to [%s] with reason [%s]", (Object[])new Object[]{task.getDeploymentId(), RoutingState.FAILED, reason})), e -> logger.error(() -> Strings.format((String)"[%s] Error while updating assignment state to [%s] with reason [%s]", (Object[])new Object[]{task.getDeploymentId(), RoutingState.FAILED, reason}), (Throwable)e)));
    }
}

