/*
 * Decompiled with CFR 0.152.
 */
package org.elasticsearch.xpack.ml.action;

import java.util.Collections;
import java.util.List;
import java.util.Map;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.action.ActionRequest;
import org.elasticsearch.action.ActionType;
import org.elasticsearch.action.support.ActionFilters;
import org.elasticsearch.action.support.HandledTransportAction;
import org.elasticsearch.client.Client;
import org.elasticsearch.cluster.service.ClusterService;
import org.elasticsearch.common.inject.Inject;
import org.elasticsearch.core.TimeValue;
import org.elasticsearch.license.LicenseUtils;
import org.elasticsearch.license.XPackLicenseState;
import org.elasticsearch.tasks.Task;
import org.elasticsearch.tasks.TaskId;
import org.elasticsearch.transport.TransportService;
import org.elasticsearch.xpack.core.ClientHelper;
import org.elasticsearch.xpack.core.ml.MachineLearningField;
import org.elasticsearch.xpack.core.ml.action.GetTrainedModelsAction;
import org.elasticsearch.xpack.core.ml.action.InferTrainedModelDeploymentAction;
import org.elasticsearch.xpack.core.ml.action.InternalInferModelAction;
import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig;
import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfigUpdate;
import org.elasticsearch.xpack.ml.inference.allocation.TrainedModelAllocationMetadata;
import org.elasticsearch.xpack.ml.inference.loadingservice.LocalModel;
import org.elasticsearch.xpack.ml.inference.loadingservice.ModelLoadingService;
import org.elasticsearch.xpack.ml.inference.persistence.TrainedModelProvider;
import org.elasticsearch.xpack.ml.utils.TypedChainTaskExecutor;

public class TransportInternalInferModelAction
extends HandledTransportAction<InternalInferModelAction.Request, InternalInferModelAction.Response> {
    private final ModelLoadingService modelLoadingService;
    private final Client client;
    private final ClusterService clusterService;
    private final XPackLicenseState licenseState;
    private final TrainedModelProvider trainedModelProvider;

    @Inject
    public TransportInternalInferModelAction(TransportService transportService, ActionFilters actionFilters, ModelLoadingService modelLoadingService, Client client, ClusterService clusterService, XPackLicenseState licenseState, TrainedModelProvider trainedModelProvider) {
        super("cluster:internal/xpack/ml/inference/infer", transportService, actionFilters, InternalInferModelAction.Request::new);
        this.modelLoadingService = modelLoadingService;
        this.client = client;
        this.clusterService = clusterService;
        this.licenseState = licenseState;
        this.trainedModelProvider = trainedModelProvider;
    }

    protected void doExecute(Task task, InternalInferModelAction.Request request, ActionListener<InternalInferModelAction.Response> listener) {
        InternalInferModelAction.Response.Builder responseBuilder = InternalInferModelAction.Response.builder();
        if (MachineLearningField.ML_API_FEATURE.check(this.licenseState)) {
            responseBuilder.setLicensed(true);
            this.doInfer(task, request, responseBuilder, listener);
        } else {
            this.trainedModelProvider.getTrainedModel(request.getModelId(), GetTrainedModelsAction.Includes.empty(), (ActionListener<TrainedModelConfig>)ActionListener.wrap(trainedModelConfig -> {
                responseBuilder.setLicensed(this.licenseState.isAllowedByLicense(trainedModelConfig.getLicenseLevel()));
                if (this.licenseState.isAllowedByLicense(trainedModelConfig.getLicenseLevel()) || request.isPreviouslyLicensed()) {
                    this.doInfer(task, request, responseBuilder, listener);
                } else {
                    listener.onFailure((Exception)LicenseUtils.newComplianceException((String)"ml"));
                }
            }, arg_0 -> listener.onFailure(arg_0)));
        }
    }

    private void doInfer(Task task, InternalInferModelAction.Request request, InternalInferModelAction.Response.Builder responseBuilder, ActionListener<InternalInferModelAction.Response> listener) {
        if (this.isAllocatedModel(request.getModelId())) {
            this.inferAgainstAllocatedModel(task, request, responseBuilder, listener);
        } else {
            this.getModelAndInfer(request, responseBuilder, listener);
        }
    }

    private boolean isAllocatedModel(String modelId) {
        TrainedModelAllocationMetadata trainedModelAllocationMetadata = TrainedModelAllocationMetadata.fromState(this.clusterService.state());
        return trainedModelAllocationMetadata.isAllocated(modelId);
    }

    private void getModelAndInfer(InternalInferModelAction.Request request, InternalInferModelAction.Response.Builder responseBuilder, ActionListener<InternalInferModelAction.Response> listener) {
        ActionListener getModelListener = ActionListener.wrap(model -> {
            TypedChainTaskExecutor<InferenceResults> typedChainTaskExecutor = new TypedChainTaskExecutor<InferenceResults>(this.client.threadPool().executor("same"), r -> true, ex -> true);
            request.getObjectsToInfer().forEach(stringObjectMap -> typedChainTaskExecutor.add(chainedTask -> model.infer((Map<String, Object>)stringObjectMap, request.getUpdate(), (ActionListener<InferenceResults>)chainedTask)));
            typedChainTaskExecutor.execute((ActionListener<List<InferenceResults>>)ActionListener.wrap(inferenceResultsInterfaces -> {
                model.release();
                listener.onResponse((Object)responseBuilder.setInferenceResults(inferenceResultsInterfaces).setModelId(model.getModelId()).build());
            }, e -> {
                model.release();
                listener.onFailure(e);
            }));
        }, arg_0 -> listener.onFailure(arg_0));
        this.modelLoadingService.getModelForPipeline(request.getModelId(), (ActionListener<LocalModel>)getModelListener);
    }

    private void inferAgainstAllocatedModel(Task task, InternalInferModelAction.Request request, InternalInferModelAction.Response.Builder responseBuilder, ActionListener<InternalInferModelAction.Response> listener) {
        TypedChainTaskExecutor<InferenceResults> typedChainTaskExecutor = new TypedChainTaskExecutor<InferenceResults>(this.client.threadPool().executor("same"), r -> true, ex -> true);
        request.getObjectsToInfer().forEach(stringObjectMap -> typedChainTaskExecutor.add(chainedTask -> this.inferSingleDocAgainstAllocatedModel(task, request.getModelId(), request.getUpdate(), (Map<String, Object>)stringObjectMap, (ActionListener<InferenceResults>)chainedTask)));
        typedChainTaskExecutor.execute((ActionListener<List<InferenceResults>>)ActionListener.wrap(inferenceResults -> listener.onResponse((Object)responseBuilder.setInferenceResults(inferenceResults).setModelId(request.getModelId()).build()), arg_0 -> listener.onFailure(arg_0)));
    }

    private void inferSingleDocAgainstAllocatedModel(Task task, String modelId, InferenceConfigUpdate inferenceConfigUpdate, Map<String, Object> doc, ActionListener<InferenceResults> listener) {
        TaskId taskId = new TaskId(this.clusterService.localNode().getId(), task.getId());
        InferTrainedModelDeploymentAction.Request request = new InferTrainedModelDeploymentAction.Request(modelId, inferenceConfigUpdate, Collections.singletonList(doc), TimeValue.MAX_VALUE);
        request.setParentTask(taskId);
        ClientHelper.executeAsyncWithOrigin((Client)this.client, (String)"ml", (ActionType)InferTrainedModelDeploymentAction.INSTANCE, (ActionRequest)request, (ActionListener)ActionListener.wrap(r -> listener.onResponse((Object)r.getResults()), arg_0 -> listener.onFailure(arg_0)));
    }
}

