/*
 * Decompiled with CFR 0.152.
 */
package org.elasticsearch.xpack.ml.action;

import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.concurrent.Executor;
import java.util.function.Supplier;
import org.elasticsearch.ElasticsearchStatusException;
import org.elasticsearch.ExceptionsHelper;
import org.elasticsearch.ResourceNotFoundException;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.action.ActionRequest;
import org.elasticsearch.action.ActionType;
import org.elasticsearch.action.support.ActionFilters;
import org.elasticsearch.action.support.HandledTransportAction;
import org.elasticsearch.client.internal.Client;
import org.elasticsearch.cluster.ClusterState;
import org.elasticsearch.cluster.service.ClusterService;
import org.elasticsearch.common.util.concurrent.EsExecutors;
import org.elasticsearch.core.TimeValue;
import org.elasticsearch.inference.InferenceServiceResults;
import org.elasticsearch.inference.InputType;
import org.elasticsearch.inference.TaskType;
import org.elasticsearch.injection.guice.Inject;
import org.elasticsearch.rest.RestStatus;
import org.elasticsearch.tasks.Task;
import org.elasticsearch.transport.TransportService;
import org.elasticsearch.xpack.core.ClientHelper;
import org.elasticsearch.xpack.core.inference.action.GetInferenceModelAction;
import org.elasticsearch.xpack.core.inference.action.InferenceAction;
import org.elasticsearch.xpack.core.ml.action.CoordinatedInferenceAction;
import org.elasticsearch.xpack.core.ml.action.InferModelAction;
import org.elasticsearch.xpack.core.ml.inference.TrainedModelPrefixStrings;
import org.elasticsearch.xpack.core.ml.inference.assignment.TrainedModelAssignmentUtils;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.EmptyConfigUpdate;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfigUpdate;

public class TransportCoordinatedInferenceAction
extends HandledTransportAction<CoordinatedInferenceAction.Request, InferModelAction.Response> {
    private static final Map<TrainedModelPrefixStrings.PrefixType, InputType> PREFIX_TYPE_INPUT_TYPE_MAP = Map.of(TrainedModelPrefixStrings.PrefixType.INGEST, InputType.INTERNAL_INGEST, TrainedModelPrefixStrings.PrefixType.SEARCH, InputType.INTERNAL_SEARCH);
    private final Client client;
    private final ClusterService clusterService;

    @Inject
    public TransportCoordinatedInferenceAction(TransportService transportService, ActionFilters actionFilters, Client client, ClusterService clusterService) {
        super("cluster:internal/xpack/ml/coordinatedinference", transportService, actionFilters, CoordinatedInferenceAction.Request::new, (Executor)EsExecutors.DIRECT_EXECUTOR_SERVICE);
        this.client = client;
        this.clusterService = clusterService;
    }

    protected void doExecute(Task task, CoordinatedInferenceAction.Request request, ActionListener<InferModelAction.Response> listener) {
        if (request.getRequestModelType() == CoordinatedInferenceAction.Request.RequestModelType.NLP_MODEL) {
            this.forNlp(request, listener);
        } else if (request.hasObjects()) {
            this.doInClusterModel(request, this.wrapCheckForServiceModelOnMissing(request.getModelId(), listener));
        } else {
            this.forNlp(request, listener);
        }
    }

    /*
     * Enabled aggressive block sorting
     */
    private void forNlp(CoordinatedInferenceAction.Request request, ActionListener<InferModelAction.Response> listener) {
        ClusterState clusterState = this.clusterService.state();
        List assignments = TrainedModelAssignmentUtils.modelAssignments((String)request.getModelId(), (ClusterState)clusterState);
        if (assignments != null && !assignments.isEmpty()) {
            this.doInClusterModel(request, listener);
            return;
        }
        this.doInferenceServiceModel(request, (ActionListener<InferModelAction.Response>)ActionListener.wrap(arg_0 -> listener.onResponse(arg_0), e -> this.replaceErrorOnMissing((Exception)e, () -> new ElasticsearchStatusException("[" + request.getModelId() + "] is not an inference service model or a deployed ml model", RestStatus.NOT_FOUND, new Object[0]), listener)));
    }

    private void doInferenceServiceModel(CoordinatedInferenceAction.Request request, ActionListener<InferModelAction.Response> listener) {
        InputType inputType = TransportCoordinatedInferenceAction.convertPrefixToInputType(request.getPrefixType());
        ClientHelper.executeAsyncWithOrigin((Client)this.client, (String)"inference", (ActionType)InferenceAction.INSTANCE, (ActionRequest)new InferenceAction.Request(TaskType.ANY, request.getModelId(), null, null, null, request.getInputs(), request.getTaskSettings(), inputType, request.getInferenceTimeout(), false), (ActionListener)listener.delegateFailureAndWrap((l, r) -> l.onResponse((Object)TransportCoordinatedInferenceAction.translateInferenceServiceResponse(r.getResults()))));
    }

    static InputType convertPrefixToInputType(TrainedModelPrefixStrings.PrefixType prefixType) {
        InputType inputType = PREFIX_TYPE_INPUT_TYPE_MAP.get(prefixType);
        if (inputType == null) {
            return InputType.INTERNAL_INGEST;
        }
        return inputType;
    }

    private void doInClusterModel(CoordinatedInferenceAction.Request request, ActionListener<InferModelAction.Response> listener) {
        InferModelAction.Request inferModelRequest = TransportCoordinatedInferenceAction.translateRequest(request);
        ClientHelper.executeAsyncWithOrigin((Client)this.client, (String)"ml", (ActionType)InferModelAction.INSTANCE, (ActionRequest)inferModelRequest, listener);
    }

    static InferModelAction.Request translateRequest(CoordinatedInferenceAction.Request request) {
        EmptyConfigUpdate inferenceConfigUpdate = request.getInferenceConfigUpdate() == null ? EmptyConfigUpdate.INSTANCE : request.getInferenceConfigUpdate();
        InferModelAction.Request inferModelRequest = request.hasObjects() ? InferModelAction.Request.forIngestDocs((String)request.getModelId(), (List)request.getObjectsToInfer(), (InferenceConfigUpdate)inferenceConfigUpdate, (boolean)request.getPreviouslyLicensed(), (TimeValue)request.getInferenceTimeout()) : InferModelAction.Request.forTextInput((String)request.getModelId(), (InferenceConfigUpdate)inferenceConfigUpdate, (List)request.getInputs(), (boolean)request.getPreviouslyLicensed(), (TimeValue)request.getInferenceTimeout());
        inferModelRequest.setPrefixType(request.getPrefixType());
        inferModelRequest.setHighPriority(request.getHighPriority());
        return inferModelRequest;
    }

    private ActionListener<InferModelAction.Response> wrapCheckForServiceModelOnMissing(String modelId, ActionListener<InferModelAction.Response> listener) {
        return ActionListener.wrap(arg_0 -> listener.onResponse(arg_0), originalError -> {
            if (ExceptionsHelper.unwrapCause((Throwable)originalError) instanceof ResourceNotFoundException) {
                ClientHelper.executeAsyncWithOrigin((Client)this.client, (String)"inference", (ActionType)GetInferenceModelAction.INSTANCE, (ActionRequest)new GetInferenceModelAction.Request(modelId, TaskType.ANY), (ActionListener)ActionListener.wrap(model -> listener.onFailure((Exception)new ElasticsearchStatusException("[" + modelId + "] is configured for the _inference API and does not accept documents as input. If using an inference ingest processor configure it with the [input_output] option instead of [field_map].", RestStatus.BAD_REQUEST, new Object[0])), e -> listener.onFailure(originalError)));
            } else {
                listener.onFailure(originalError);
            }
        });
    }

    private void replaceErrorOnMissing(Exception originalError, Supplier<ElasticsearchStatusException> replaceOnMissing, ActionListener<InferModelAction.Response> listener) {
        if (ExceptionsHelper.unwrapCause((Throwable)originalError) instanceof ResourceNotFoundException) {
            listener.onFailure((Exception)replaceOnMissing.get());
        } else {
            listener.onFailure(originalError);
        }
    }

    static InferModelAction.Response translateInferenceServiceResponse(InferenceServiceResults inferenceResults) {
        ArrayList legacyResults = new ArrayList(inferenceResults.transformToCoordinationFormat());
        return new InferModelAction.Response(legacyResults, null, false);
    }
}

