/*
 * Decompiled with CFR 0.152.
 */
package org.elasticsearch.xpack.ml.action;

import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Random;
import java.util.concurrent.Executor;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicReference;
import java.util.stream.Collectors;
import org.elasticsearch.ResourceNotFoundException;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.action.ActionRequest;
import org.elasticsearch.action.ActionType;
import org.elasticsearch.action.support.ActionFilters;
import org.elasticsearch.action.support.HandledTransportAction;
import org.elasticsearch.client.internal.Client;
import org.elasticsearch.cluster.service.ClusterService;
import org.elasticsearch.common.Randomness;
import org.elasticsearch.common.inject.Inject;
import org.elasticsearch.common.util.concurrent.AtomicArray;
import org.elasticsearch.common.util.concurrent.EsExecutors;
import org.elasticsearch.core.Strings;
import org.elasticsearch.core.TimeValue;
import org.elasticsearch.core.Tuple;
import org.elasticsearch.inference.InferenceResults;
import org.elasticsearch.license.License;
import org.elasticsearch.license.LicenseUtils;
import org.elasticsearch.license.XPackLicenseState;
import org.elasticsearch.tasks.CancellableTask;
import org.elasticsearch.tasks.Task;
import org.elasticsearch.tasks.TaskCancelledException;
import org.elasticsearch.tasks.TaskId;
import org.elasticsearch.transport.TransportService;
import org.elasticsearch.xpack.core.ClientHelper;
import org.elasticsearch.xpack.core.ml.action.GetTrainedModelsAction;
import org.elasticsearch.xpack.core.ml.action.InferModelAction;
import org.elasticsearch.xpack.core.ml.action.InferTrainedModelDeploymentAction;
import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig;
import org.elasticsearch.xpack.core.ml.inference.TrainedModelType;
import org.elasticsearch.xpack.core.ml.inference.assignment.AssignmentState;
import org.elasticsearch.xpack.core.ml.inference.assignment.RoutingState;
import org.elasticsearch.xpack.core.ml.inference.assignment.TrainedModelAssignment;
import org.elasticsearch.xpack.core.ml.inference.results.ErrorInferenceResults;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfigUpdate;
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
import org.elasticsearch.xpack.ml.MachineLearning;
import org.elasticsearch.xpack.ml.inference.ModelAliasMetadata;
import org.elasticsearch.xpack.ml.inference.assignment.TrainedModelAssignmentMetadata;
import org.elasticsearch.xpack.ml.inference.loadingservice.LocalModel;
import org.elasticsearch.xpack.ml.inference.loadingservice.ModelLoadingService;
import org.elasticsearch.xpack.ml.inference.persistence.TrainedModelProvider;
import org.elasticsearch.xpack.ml.utils.TypedChainTaskExecutor;

public class TransportInternalInferModelAction
extends HandledTransportAction<InferModelAction.Request, InferModelAction.Response> {
    private final ModelLoadingService modelLoadingService;
    private final Client client;
    private final ClusterService clusterService;
    private final XPackLicenseState licenseState;
    private final TrainedModelProvider trainedModelProvider;

    TransportInternalInferModelAction(String actionName, TransportService transportService, ActionFilters actionFilters, ModelLoadingService modelLoadingService, Client client, ClusterService clusterService, XPackLicenseState licenseState, TrainedModelProvider trainedModelProvider) {
        super(actionName, transportService, actionFilters, InferModelAction.Request::new, (Executor)EsExecutors.DIRECT_EXECUTOR_SERVICE);
        this.modelLoadingService = modelLoadingService;
        this.client = client;
        this.clusterService = clusterService;
        this.licenseState = licenseState;
        this.trainedModelProvider = trainedModelProvider;
    }

    @Inject
    public TransportInternalInferModelAction(TransportService transportService, ActionFilters actionFilters, ModelLoadingService modelLoadingService, Client client, ClusterService clusterService, XPackLicenseState licenseState, TrainedModelProvider trainedModelProvider) {
        this("cluster:internal/xpack/ml/inference/infer", transportService, actionFilters, modelLoadingService, client, clusterService, licenseState, trainedModelProvider);
    }

    protected void doExecute(Task task, InferModelAction.Request request, ActionListener<InferModelAction.Response> listener) {
        InferModelAction.Response.Builder responseBuilder = InferModelAction.Response.builder();
        TaskId parentTaskId = new TaskId(this.clusterService.localNode().getId(), task.getId());
        if (MachineLearning.INFERENCE_AGG_FEATURE.check(this.licenseState)) {
            responseBuilder.setLicensed(true);
            this.doInfer(task, request, responseBuilder, parentTaskId, listener);
        } else {
            this.trainedModelProvider.getTrainedModel(request.getId(), GetTrainedModelsAction.Includes.empty(), parentTaskId, (ActionListener<TrainedModelConfig>)ActionListener.wrap(trainedModelConfig -> {
                boolean allowed = trainedModelConfig.getLicenseLevel() == License.OperationMode.BASIC;
                responseBuilder.setLicensed(allowed);
                if (allowed || request.isPreviouslyLicensed()) {
                    this.doInfer(task, request, responseBuilder, parentTaskId, listener);
                } else {
                    listener.onFailure((Exception)LicenseUtils.newComplianceException((String)"ml"));
                }
            }, arg_0 -> listener.onFailure(arg_0)));
        }
    }

    private void doInfer(Task task, InferModelAction.Request request, InferModelAction.Response.Builder responseBuilder, TaskId parentTaskId, ActionListener<InferModelAction.Response> listener) {
        String concreteModelId = Optional.ofNullable(ModelAliasMetadata.fromState(this.clusterService.state()).getModelId(request.getId())).orElse(request.getId());
        responseBuilder.setId(concreteModelId);
        TrainedModelAssignmentMetadata trainedModelAssignmentMetadata = TrainedModelAssignmentMetadata.fromState(this.clusterService.state());
        TrainedModelAssignment assignment = trainedModelAssignmentMetadata.getDeploymentAssignment(concreteModelId);
        List<TrainedModelAssignment> assignments = assignment == null ? trainedModelAssignmentMetadata.getDeploymentsUsingModel(concreteModelId) : List.of(assignment);
        if (assignments.isEmpty()) {
            this.getModelAndInfer(request, responseBuilder, parentTaskId, (CancellableTask)task, listener);
        } else {
            this.inferAgainstAllocatedModel(assignments, request, responseBuilder, parentTaskId, listener);
        }
    }

    private void getModelAndInfer(InferModelAction.Request request, InferModelAction.Response.Builder responseBuilder, TaskId parentTaskId, CancellableTask task, ActionListener<InferModelAction.Response> listener) {
        ActionListener getModelListener = ActionListener.wrap(model -> {
            TypedChainTaskExecutor<InferenceResults> typedChainTaskExecutor = new TypedChainTaskExecutor<InferenceResults>(EsExecutors.DIRECT_EXECUTOR_SERVICE, r -> true, ex -> true);
            request.getObjectsToInfer().forEach(stringObjectMap -> typedChainTaskExecutor.add(chainedTask -> {
                if (task.isCancelled()) {
                    throw new TaskCancelledException(Strings.format((String)"Inference task cancelled with reason [%s]", (Object[])new Object[]{task.getReasonCancelled()}));
                }
                model.infer((Map<String, Object>)stringObjectMap, request.getUpdate(), (ActionListener<InferenceResults>)chainedTask);
            }));
            typedChainTaskExecutor.execute((ActionListener<List<InferenceResults>>)ActionListener.wrap(inferenceResultsInterfaces -> {
                model.release();
                listener.onResponse((Object)responseBuilder.addInferenceResults(inferenceResultsInterfaces).build());
            }, e -> {
                model.release();
                listener.onFailure(e);
            }));
        }, e -> {
            if (ExceptionsHelper.unwrapCause((Throwable)e) instanceof ResourceNotFoundException) {
                listener.onFailure(e);
                return;
            }
            this.trainedModelProvider.getTrainedModel(request.getId(), GetTrainedModelsAction.Includes.empty(), parentTaskId, (ActionListener<TrainedModelConfig>)ActionListener.wrap(trainedModelConfig -> {
                if (trainedModelConfig.getModelType() == TrainedModelType.PYTORCH) {
                    listener.onFailure((Exception)((Object)ExceptionsHelper.conflictStatusException((String)("Model [" + request.getId() + "] must be deployed to use. Please deploy with the start trained model deployment API."), (Object[])new Object[]{request.getId()})));
                } else {
                    listener.onFailure(e);
                }
            }, arg_0 -> ((ActionListener)listener).onFailure(arg_0)));
        });
        this.modelLoadingService.getModelForPipeline(request.getId(), parentTaskId, (ActionListener<LocalModel>)getModelListener);
    }

    private void inferAgainstAllocatedModel(List<TrainedModelAssignment> assignments, InferModelAction.Request request, InferModelAction.Response.Builder responseBuilder, TaskId parentTaskId, ActionListener<InferModelAction.Response> listener) {
        TrainedModelAssignment assignment = TransportInternalInferModelAction.pickAssignment(assignments);
        if (assignment.getAssignmentState() == AssignmentState.STOPPING || assignment.getAssignmentState() == AssignmentState.FAILED) {
            String message = "Trained model [" + assignment.getDeploymentId() + "] is [" + assignment.getAssignmentState() + "]";
            listener.onFailure((Exception)((Object)ExceptionsHelper.conflictStatusException((String)message, (Object[])new Object[0])));
            return;
        }
        List nodes = assignment.selectRandomStartedNodesWeighedOnAllocationsForNRequests(request.numberOfDocuments(), RoutingState.STARTED);
        if (nodes.isEmpty()) {
            nodes = assignment.selectRandomStartedNodesWeighedOnAllocationsForNRequests(request.numberOfDocuments(), RoutingState.STOPPING);
        }
        if (nodes.isEmpty()) {
            this.logger.trace(() -> Strings.format((String)"[%s] model deployment not allocated to any node", (Object[])new Object[]{assignment.getDeploymentId()}));
            listener.onFailure((Exception)((Object)ExceptionsHelper.conflictStatusException((String)("Trained model deployment [" + request.getId() + "] is not allocated to any nodes"), (Object[])new Object[0])));
            return;
        }
        assert (nodes.stream().mapToInt(Tuple::v2).sum() == request.numberOfDocuments()) : "mismatch; sum of node requests does not match number of documents in request";
        AtomicInteger count = new AtomicInteger();
        AtomicArray results = new AtomicArray(nodes.size());
        AtomicReference<Exception> failure = new AtomicReference<Exception>();
        int startPos = 0;
        int slot = 0;
        for (Tuple node : nodes) {
            InferTrainedModelDeploymentAction.Request deploymentRequest = request.getTextInput() == null ? InferTrainedModelDeploymentAction.Request.forDocs((String)assignment.getDeploymentId(), (InferenceConfigUpdate)request.getUpdate(), request.getObjectsToInfer().subList(startPos, startPos + (Integer)node.v2()), (TimeValue)request.getInferenceTimeout()) : InferTrainedModelDeploymentAction.Request.forTextInput((String)assignment.getDeploymentId(), (InferenceConfigUpdate)request.getUpdate(), request.getTextInput().subList(startPos, startPos + (Integer)node.v2()), (TimeValue)request.getInferenceTimeout());
            deploymentRequest.setHighPriority(request.isHighPriority());
            deploymentRequest.setPrefixType(request.getPrefixType());
            deploymentRequest.setNodes(new String[]{(String)node.v1()});
            deploymentRequest.setParentTask(parentTaskId);
            startPos += ((Integer)node.v2()).intValue();
            ClientHelper.executeAsyncWithOrigin((Client)this.client, (String)"ml", (ActionType)InferTrainedModelDeploymentAction.INSTANCE, (ActionRequest)deploymentRequest, TransportInternalInferModelAction.collectingListener(count, (AtomicArray<List<InferenceResults>>)results, failure, slot, nodes.size(), responseBuilder, listener));
            ++slot;
        }
    }

    static TrainedModelAssignment pickAssignment(List<TrainedModelAssignment> assignments) {
        assert (!assignments.isEmpty());
        if (assignments.size() == 1) {
            return assignments.get(0);
        }
        Map<AssignmentState, List<TrainedModelAssignment>> map = assignments.stream().collect(Collectors.groupingBy(TrainedModelAssignment::getAssignmentState));
        Random rng = Randomness.get();
        for (AssignmentState assignmentStat : new AssignmentState[]{AssignmentState.STARTED, AssignmentState.STARTING, AssignmentState.STOPPING, AssignmentState.FAILED}) {
            List<TrainedModelAssignment> bestPick = map.get(assignmentStat);
            if (bestPick == null) continue;
            Collections.shuffle(bestPick, rng);
            return bestPick.get(0);
        }
        throw new IllegalStateException();
    }

    private static ActionListener<InferTrainedModelDeploymentAction.Response> collectingListener(final AtomicInteger count, final AtomicArray<List<InferenceResults>> results, final AtomicReference<Exception> failure, final int slot, final int totalNumberOfResponses, final InferModelAction.Response.Builder responseBuilder, final ActionListener<InferModelAction.Response> finalListener) {
        return new ActionListener<InferTrainedModelDeploymentAction.Response>(){

            public void onResponse(InferTrainedModelDeploymentAction.Response response) {
                results.setOnce(slot, (Object)response.getResults());
                if (count.incrementAndGet() == totalNumberOfResponses) {
                    this.sendResponse();
                }
            }

            public void onFailure(Exception e) {
                failure.set(e);
                if (count.incrementAndGet() == totalNumberOfResponses) {
                    this.sendResponse();
                }
            }

            private void sendResponse() {
                if (failure.get() != null) {
                    finalListener.onFailure((Exception)failure.get());
                } else {
                    for (int i = 0; i < results.length(); ++i) {
                        List resultList = (List)results.get(i);
                        if (resultList == null) continue;
                        for (InferenceResults result : resultList) {
                            if (!(result instanceof ErrorInferenceResults)) continue;
                            ErrorInferenceResults errorResult = (ErrorInferenceResults)result;
                            finalListener.onFailure(errorResult.getException());
                            return;
                        }
                        responseBuilder.addInferenceResults(resultList);
                    }
                    finalListener.onResponse((Object)responseBuilder.build());
                }
            }
        };
    }
}

