/*
 * Decompiled with CFR 0.152.
 */
package org.elasticsearch.xpack.ml.action;

import java.io.IOException;
import java.util.Collections;
import java.util.HashMap;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.OptionalLong;
import java.util.Set;
import java.util.function.Predicate;
import java.util.stream.Collectors;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.apache.logging.log4j.message.Message;
import org.apache.logging.log4j.message.ParameterizedMessage;
import org.elasticsearch.ElasticsearchStatusException;
import org.elasticsearch.ResourceAlreadyExistsException;
import org.elasticsearch.ResourceNotFoundException;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.action.ActionRequest;
import org.elasticsearch.action.ActionType;
import org.elasticsearch.action.support.ActionFilters;
import org.elasticsearch.action.support.master.AcknowledgedResponse;
import org.elasticsearch.action.support.master.TransportMasterNodeAction;
import org.elasticsearch.client.internal.Client;
import org.elasticsearch.client.internal.OriginSettingClient;
import org.elasticsearch.cluster.ClusterState;
import org.elasticsearch.cluster.block.ClusterBlockException;
import org.elasticsearch.cluster.block.ClusterBlockLevel;
import org.elasticsearch.cluster.metadata.IndexNameExpressionResolver;
import org.elasticsearch.cluster.metadata.NodesShutdownMetadata;
import org.elasticsearch.cluster.service.ClusterService;
import org.elasticsearch.common.inject.Inject;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.common.unit.ByteSizeValue;
import org.elasticsearch.common.util.set.Sets;
import org.elasticsearch.core.CheckedFunction;
import org.elasticsearch.core.TimeValue;
import org.elasticsearch.index.query.QueryBuilder;
import org.elasticsearch.index.query.QueryBuilders;
import org.elasticsearch.license.LicenseUtils;
import org.elasticsearch.license.XPackLicenseState;
import org.elasticsearch.persistent.PersistentTasksCustomMetadata;
import org.elasticsearch.rest.RestStatus;
import org.elasticsearch.search.SearchHit;
import org.elasticsearch.search.sort.FieldSortBuilder;
import org.elasticsearch.search.sort.SortBuilder;
import org.elasticsearch.search.sort.SortBuilders;
import org.elasticsearch.search.sort.SortOrder;
import org.elasticsearch.tasks.Task;
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.transport.TransportService;
import org.elasticsearch.xcontent.NamedXContentRegistry;
import org.elasticsearch.xpack.core.ml.MachineLearningField;
import org.elasticsearch.xpack.core.ml.action.CreateTrainedModelAllocationAction;
import org.elasticsearch.xpack.core.ml.action.GetTrainedModelsAction;
import org.elasticsearch.xpack.core.ml.action.StartTrainedModelDeploymentAction;
import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig;
import org.elasticsearch.xpack.core.ml.inference.TrainedModelType;
import org.elasticsearch.xpack.core.ml.inference.allocation.AllocationStatus;
import org.elasticsearch.xpack.core.ml.inference.allocation.RoutingState;
import org.elasticsearch.xpack.core.ml.inference.allocation.RoutingStateAndReason;
import org.elasticsearch.xpack.core.ml.inference.allocation.TrainedModelAllocation;
import org.elasticsearch.xpack.core.ml.inference.persistence.InferenceIndexConstants;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.IndexLocation;
import org.elasticsearch.xpack.core.ml.job.messages.Messages;
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
import org.elasticsearch.xpack.ml.MachineLearning;
import org.elasticsearch.xpack.ml.inference.allocation.TrainedModelAllocationMetadata;
import org.elasticsearch.xpack.ml.inference.allocation.TrainedModelAllocationService;
import org.elasticsearch.xpack.ml.inference.persistence.ChunkedTrainedModelRestorer;
import org.elasticsearch.xpack.ml.inference.persistence.TrainedModelDefinitionDoc;
import org.elasticsearch.xpack.ml.job.NodeLoadDetector;
import org.elasticsearch.xpack.ml.process.MlMemoryTracker;

public class TransportStartTrainedModelDeploymentAction
extends TransportMasterNodeAction<StartTrainedModelDeploymentAction.Request, CreateTrainedModelAllocationAction.Response> {
    private static final Logger logger = LogManager.getLogger(TransportStartTrainedModelDeploymentAction.class);
    private final XPackLicenseState licenseState;
    private final Client client;
    private final TrainedModelAllocationService trainedModelAllocationService;
    private final NamedXContentRegistry xContentRegistry;
    private final MlMemoryTracker memoryTracker;
    protected volatile int maxLazyMLNodes;
    protected volatile long maxMLNodeSize;

    @Inject
    public TransportStartTrainedModelDeploymentAction(TransportService transportService, Client client, ClusterService clusterService, ThreadPool threadPool, ActionFilters actionFilters, XPackLicenseState licenseState, IndexNameExpressionResolver indexNameExpressionResolver, Settings settings, TrainedModelAllocationService trainedModelAllocationService, NamedXContentRegistry xContentRegistry, MlMemoryTracker memoryTracker) {
        super("cluster:admin/xpack/ml/trained_models/deployment/start", transportService, clusterService, threadPool, actionFilters, StartTrainedModelDeploymentAction.Request::new, indexNameExpressionResolver, CreateTrainedModelAllocationAction.Response::new, "same");
        this.licenseState = Objects.requireNonNull(licenseState);
        this.client = new OriginSettingClient(Objects.requireNonNull(client), "ml");
        this.xContentRegistry = Objects.requireNonNull(xContentRegistry);
        this.memoryTracker = Objects.requireNonNull(memoryTracker);
        this.trainedModelAllocationService = Objects.requireNonNull(trainedModelAllocationService);
        this.maxLazyMLNodes = (Integer)MachineLearning.MAX_LAZY_ML_NODES.get(settings);
        this.maxMLNodeSize = ((ByteSizeValue)MachineLearning.MAX_ML_NODE_SIZE.get(settings)).getBytes();
        clusterService.getClusterSettings().addSettingsUpdateConsumer(MachineLearning.MAX_LAZY_ML_NODES, this::setMaxLazyMLNodes);
        clusterService.getClusterSettings().addSettingsUpdateConsumer(MachineLearning.MAX_ML_NODE_SIZE, this::setMaxMLNodeSize);
    }

    private void setMaxLazyMLNodes(int value) {
        this.maxLazyMLNodes = value;
    }

    private void setMaxMLNodeSize(ByteSizeValue value) {
        this.maxMLNodeSize = value.getBytes();
    }

    protected void masterOperation(Task task, StartTrainedModelDeploymentAction.Request request, ClusterState state, ActionListener<CreateTrainedModelAllocationAction.Response> listener) throws Exception {
        logger.trace(() -> new ParameterizedMessage("[{}] received deploy request", (Object)request.getModelId()));
        if (!MachineLearningField.ML_API_FEATURE.check(this.licenseState)) {
            listener.onFailure((Exception)LicenseUtils.newComplianceException((String)"ml"));
            return;
        }
        ActionListener waitForDeploymentToStart = ActionListener.wrap(modelAllocation -> this.waitForDeploymentState(request.getModelId(), request.getTimeout(), request.getWaitForState(), listener), e -> {
            logger.warn(() -> new ParameterizedMessage("[{}] creating new allocation failed", (Object)request.getModelId()), (Throwable)e);
            if (ExceptionsHelper.unwrapCause((Throwable)e) instanceof ResourceAlreadyExistsException) {
                e = new ElasticsearchStatusException("Cannot start deployment [{}] because it has already been started", RestStatus.CONFLICT, (Throwable)e, new Object[]{request.getModelId()});
            }
            listener.onFailure(e);
        });
        ActionListener getModelListener = ActionListener.wrap(getModelResponse -> {
            if (getModelResponse.getResources().results().size() > 1) {
                listener.onFailure((Exception)ExceptionsHelper.badRequestException((String)"cannot deploy more than one models at the same time; [{}] matches [{}] models]", (Object[])new Object[]{request.getModelId(), getModelResponse.getResources().results().size()}));
                return;
            }
            TrainedModelConfig trainedModelConfig = (TrainedModelConfig)getModelResponse.getResources().results().get(0);
            if (trainedModelConfig.getModelType() != TrainedModelType.PYTORCH) {
                listener.onFailure((Exception)ExceptionsHelper.badRequestException((String)"model [{}] of type [{}] cannot be deployed. Only PyTorch models can be deployed", (Object[])new Object[]{trainedModelConfig.getModelId(), trainedModelConfig.getModelType()}));
                return;
            }
            if (trainedModelConfig.getLocation() == null) {
                listener.onFailure((Exception)((Object)ExceptionsHelper.serverError((String)"model [{}] does not have location", (Object[])new Object[]{trainedModelConfig.getModelId()})));
                return;
            }
            this.validateModelDefinition(trainedModelConfig, (ActionListener<Void>)ActionListener.wrap(validate -> this.getModelBytes(trainedModelConfig, (ActionListener<Long>)ActionListener.wrap(modelBytes -> {
                StartTrainedModelDeploymentAction.TaskParams taskParams = new StartTrainedModelDeploymentAction.TaskParams(trainedModelConfig.getModelId(), modelBytes.longValue(), request.getInferenceThreads(), request.getModelThreads(), request.getQueueCapacity());
                PersistentTasksCustomMetadata persistentTasks = (PersistentTasksCustomMetadata)this.clusterService.state().getMetadata().custom("persistent_tasks");
                this.memoryTracker.refresh(persistentTasks, (ActionListener<Void>)ActionListener.wrap(aVoid -> this.trainedModelAllocationService.createNewModelAllocation(taskParams, (ActionListener<CreateTrainedModelAllocationAction.Response>)waitForDeploymentToStart), arg_0 -> ((ActionListener)listener).onFailure(arg_0)));
            }, arg_0 -> ((ActionListener)listener).onFailure(arg_0))), arg_0 -> ((ActionListener)listener).onFailure(arg_0)));
        }, arg_0 -> listener.onFailure(arg_0));
        GetTrainedModelsAction.Request getModelRequest = new GetTrainedModelsAction.Request(request.getModelId());
        this.client.execute((ActionType)GetTrainedModelsAction.INSTANCE, (ActionRequest)getModelRequest, getModelListener);
    }

    private void getModelBytes(TrainedModelConfig trainedModelConfig, ActionListener<Long> listener) {
        ChunkedTrainedModelRestorer restorer = new ChunkedTrainedModelRestorer(trainedModelConfig.getModelId(), this.client, this.threadPool.executor("ml_utility"), this.xContentRegistry);
        restorer.setSearchIndex(trainedModelConfig.getLocation().getResourceName());
        restorer.setSearchSize(1);
        restorer.restoreModelDefinition((CheckedFunction<TrainedModelDefinitionDoc, Boolean, IOException>)((CheckedFunction)doc -> {
            listener.onResponse((Object)doc.getTotalDefinitionLength());
            return false;
        }), success -> {}, arg_0 -> listener.onFailure(arg_0));
    }

    private void waitForDeploymentState(final String modelId, TimeValue timeout, AllocationStatus.State state, final ActionListener<CreateTrainedModelAllocationAction.Response> listener) {
        final DeploymentStartedPredicate predicate = new DeploymentStartedPredicate(modelId, state, this.maxLazyMLNodes, this.maxMLNodeSize);
        this.trainedModelAllocationService.waitForAllocationCondition(modelId, predicate, timeout, new TrainedModelAllocationService.WaitForAllocationListener(){

            public void onResponse(TrainedModelAllocation allocation) {
                if (predicate.exception != null) {
                    TransportStartTrainedModelDeploymentAction.this.deleteFailedDeployment(modelId, predicate.exception, (ActionListener<CreateTrainedModelAllocationAction.Response>)listener);
                } else {
                    listener.onResponse((Object)new CreateTrainedModelAllocationAction.Response(allocation));
                }
            }

            public void onFailure(Exception e) {
                listener.onFailure(e);
            }
        });
    }

    private void deleteFailedDeployment(String modelId, Exception exception, ActionListener<CreateTrainedModelAllocationAction.Response> listener) {
        this.trainedModelAllocationService.deleteModelAllocation(modelId, (ActionListener<AcknowledgedResponse>)ActionListener.wrap(pTask -> listener.onFailure(exception), e -> {
            logger.error((Message)new ParameterizedMessage("[{}] Failed to delete model allocation that had failed with the reason [{}]", (Object)modelId, (Object)exception.getMessage()), (Throwable)e);
            listener.onFailure(exception);
        }));
    }

    private void validateModelDefinition(TrainedModelConfig config, ActionListener<Void> listener) {
        if (!(config.getLocation() instanceof IndexLocation)) {
            listener.onResponse(null);
            return;
        }
        String modelId = config.getModelId();
        String[] requiredSourceFields = new String[]{TrainedModelDefinitionDoc.DEFINITION_LENGTH.getPreferredName(), TrainedModelDefinitionDoc.DOC_NUM.getPreferredName(), TrainedModelDefinitionDoc.TOTAL_DEFINITION_LENGTH.getPreferredName(), TrainedModelDefinitionDoc.EOS.getPreferredName()};
        Set<String> requiredSet = Set.of(requiredSourceFields);
        String index = ((IndexLocation)config.getLocation()).getIndexName();
        this.client.prepareSearch(new String[]{index}).setQuery((QueryBuilder)QueryBuilders.constantScoreQuery((QueryBuilder)QueryBuilders.boolQuery().filter((QueryBuilder)QueryBuilders.termQuery((String)TrainedModelConfig.MODEL_ID.getPreferredName(), (String)modelId)).filter((QueryBuilder)QueryBuilders.termQuery((String)InferenceIndexConstants.DOC_TYPE.getPreferredName(), (String)"trained_model_definition_doc")))).setFetchSource(requiredSourceFields, new String[0]).setSize(10000).setTrackTotalHits(true).addSort((SortBuilder)((FieldSortBuilder)SortBuilders.fieldSort((String)TrainedModelDefinitionDoc.DOC_NUM.getPreferredName()).order(SortOrder.ASC)).unmappedType("long")).execute(ActionListener.wrap(response -> {
            SearchHit[] hits = response.getHits().getHits();
            if (hits.length == 0) {
                listener.onFailure((Exception)new ResourceNotFoundException(Messages.getMessage((String)"Could not find trained model definition [{0}]", (Object[])new Object[]{modelId}), new Object[0]));
                return;
            }
            long firstTotalLength = ((Number)hits[0].getSourceAsMap().get(TrainedModelDefinitionDoc.TOTAL_DEFINITION_LENGTH.getPreferredName())).longValue();
            long summedLengths = 0L;
            for (SearchHit hit : hits) {
                Map fields = hit.getSourceAsMap();
                if (fields == null) {
                    listener.onFailure((Exception)ExceptionsHelper.badRequestException((String)"[{}] model definition [{}] is missing required fields {}. {}", (Object[])new Object[]{modelId, TrainedModelDefinitionDoc.docNum(modelId, Objects.requireNonNull(hit.getId())), List.of(requiredSourceFields), "Unable to deploy model, please delete and recreate the model definition"}));
                    return;
                }
                Set diff = Sets.difference(fields.keySet(), (Set)requiredSet);
                if (!diff.isEmpty()) {
                    listener.onFailure((Exception)ExceptionsHelper.badRequestException((String)"[{}] model definition [{}] is missing required fields {}. {}", (Object[])new Object[]{modelId, TrainedModelDefinitionDoc.docNum(modelId, Objects.requireNonNull(hit.getId())), diff, "Unable to deploy model, please delete and recreate the model definition"}));
                    return;
                }
                summedLengths += ((Number)fields.get(TrainedModelDefinitionDoc.DEFINITION_LENGTH.getPreferredName())).longValue();
                long totalLength = ((Number)fields.get(TrainedModelDefinitionDoc.TOTAL_DEFINITION_LENGTH.getPreferredName())).longValue();
                if (totalLength == firstTotalLength) continue;
                listener.onFailure((Exception)ExceptionsHelper.badRequestException((String)"[{}] [total_definition_length] must be the same in all model definition parts. The value [{}] in model definition part [{}] does not match the value [{}] in part [{}]. Unable to deploy model, please delete and recreate the model definition", (Object[])new Object[]{modelId, totalLength, TrainedModelDefinitionDoc.docNum(modelId, Objects.requireNonNull(hit.getId())), firstTotalLength, TrainedModelDefinitionDoc.docNum(modelId, Objects.requireNonNull(hits[0].getId()))}));
                return;
            }
            Boolean eos = (Boolean)hits[hits.length - 1].getSourceAsMap().get(TrainedModelDefinitionDoc.EOS.getPreferredName());
            if (summedLengths != firstTotalLength || eos == null || !eos.booleanValue()) {
                listener.onFailure((Exception)ExceptionsHelper.badRequestException((String)Messages.getMessage((String)"Model definition truncated. Unable to deserialize trained model definition [{0}]", (Object[])new Object[]{modelId}), (Object[])new Object[0]));
                return;
            }
            listener.onResponse(null);
        }, e -> {
            if (ExceptionsHelper.unwrapCause((Throwable)e) instanceof ResourceNotFoundException) {
                ResourceNotFoundException ex = new ResourceNotFoundException(Messages.getMessage((String)"Could not find trained model definition [{0}]", (Object[])new Object[]{modelId}), new Object[0]);
                ex.addSuppressed((Throwable)e);
                listener.onFailure((Exception)ex);
                return;
            }
            listener.onFailure(e);
        }));
    }

    protected ClusterBlockException checkBlock(StartTrainedModelDeploymentAction.Request request, ClusterState state) {
        return state.blocks().globalBlockedException(ClusterBlockLevel.METADATA_WRITE);
    }

    static Set<String> nodesShuttingDown(ClusterState state) {
        return NodesShutdownMetadata.getShutdowns((ClusterState)state).map(NodesShutdownMetadata::getAllNodeMetadataMap).map(Map::keySet).orElse(Collections.emptySet());
    }

    private static class DeploymentStartedPredicate
    implements Predicate<ClusterState> {
        private volatile Exception exception;
        private final String modelId;
        private final AllocationStatus.State waitForState;
        private final int maxLazyMLNodes;
        private final long maxMLNodeSize;

        DeploymentStartedPredicate(String modelId, AllocationStatus.State waitForState, int maxLazyMLNodes, long maxMLNodeSize) {
            this.modelId = (String)ExceptionsHelper.requireNonNull((Object)modelId, (String)"model_id");
            this.waitForState = waitForState;
            this.maxLazyMLNodes = maxLazyMLNodes;
            this.maxMLNodeSize = maxMLNodeSize;
        }

        @Override
        public boolean test(ClusterState clusterState) {
            TrainedModelAllocation trainedModelAllocation = TrainedModelAllocationMetadata.allocationForModelId(clusterState, this.modelId).orElse(null);
            if (trainedModelAllocation == null) {
                return true;
            }
            Set nodesAndState = trainedModelAllocation.getNodeRoutingTable().entrySet();
            HashMap<String, String> nodeFailuresAndReasons = new HashMap<String, String>();
            LinkedHashSet<String> nodesStillInitializing = new LinkedHashSet<String>();
            for (Map.Entry nodeIdAndState : nodesAndState) {
                if (RoutingState.FAILED.equals((Object)((RoutingStateAndReason)nodeIdAndState.getValue()).getState())) {
                    nodeFailuresAndReasons.put((String)nodeIdAndState.getKey(), ((RoutingStateAndReason)nodeIdAndState.getValue()).getReason());
                }
                if (!RoutingState.STARTING.equals((Object)((RoutingStateAndReason)nodeIdAndState.getValue()).getState())) continue;
                nodesStillInitializing.add((String)nodeIdAndState.getKey());
            }
            if (!nodeFailuresAndReasons.isEmpty()) {
                this.exception = new ElasticsearchStatusException("Could not start trained model deployment, the following nodes failed with errors [{}]", RestStatus.INTERNAL_SERVER_ERROR, new Object[]{nodeFailuresAndReasons});
                return true;
            }
            Set<String> nodesShuttingDown = TransportStartTrainedModelDeploymentAction.nodesShuttingDown(clusterState);
            List nodes = clusterState.nodes().getAllNodes().stream().filter(d -> !nodesShuttingDown.contains(d.getId())).filter(StartTrainedModelDeploymentAction.TaskParams::mayAllocateToNode).collect(Collectors.toList());
            OptionalLong smallestMLNode = nodes.stream().map(NodeLoadDetector::getNodeSize).flatMapToLong(OptionalLong::stream).min();
            if (nodesAndState.isEmpty() && this.maxLazyMLNodes <= nodes.size() && (smallestMLNode.isEmpty() || smallestMLNode.getAsLong() >= this.maxMLNodeSize)) {
                String msg = "Could not start deployment because no suitable nodes were found, allocation explanation [" + trainedModelAllocation.getReason() + "]";
                logger.warn("[{}] {}", (Object)this.modelId, (Object)msg);
                IllegalStateException detail = new IllegalStateException(msg);
                this.exception = new ElasticsearchStatusException("Could not start deployment because no ML nodes with sufficient capacity were found", RestStatus.TOO_MANY_REQUESTS, (Throwable)detail, new Object[0]);
                return true;
            }
            AllocationStatus allocationStatus = trainedModelAllocation.calculateAllocationStatus(nodes).orElse(null);
            if (allocationStatus == null || allocationStatus.calculateState().compareTo((Enum)this.waitForState) >= 0) {
                return true;
            }
            if (nodesStillInitializing.isEmpty()) {
                return true;
            }
            logger.trace(() -> new ParameterizedMessage("[{}] tested with state [{}] and nodes {} still initializing", new Object[]{this.modelId, trainedModelAllocation.getAllocationState(), nodesStillInitializing}));
            return false;
        }
    }
}

