/*
 * Decompiled with CFR 0.152.
 */
package org.elasticsearch.xpack.ml.action;

import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.concurrent.Executor;
import java.util.concurrent.atomic.AtomicInteger;
import org.elasticsearch.ElasticsearchStatusException;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.action.FailedNodeException;
import org.elasticsearch.action.TaskOperationFailure;
import org.elasticsearch.action.support.ActionFilters;
import org.elasticsearch.action.support.tasks.TransportTasksAction;
import org.elasticsearch.cluster.service.ClusterService;
import org.elasticsearch.common.inject.Inject;
import org.elasticsearch.common.util.concurrent.AtomicArray;
import org.elasticsearch.common.util.concurrent.EsExecutors;
import org.elasticsearch.inference.InferenceResults;
import org.elasticsearch.rest.RestStatus;
import org.elasticsearch.tasks.CancellableTask;
import org.elasticsearch.transport.TransportService;
import org.elasticsearch.xpack.core.ml.action.InferTrainedModelDeploymentAction;
import org.elasticsearch.xpack.core.ml.inference.results.ErrorInferenceResults;
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
import org.elasticsearch.xpack.ml.inference.deployment.NlpInferenceInput;
import org.elasticsearch.xpack.ml.inference.deployment.TrainedModelDeploymentTask;

public class TransportInferTrainedModelDeploymentAction
extends TransportTasksAction<TrainedModelDeploymentTask, InferTrainedModelDeploymentAction.Request, InferTrainedModelDeploymentAction.Response, InferTrainedModelDeploymentAction.Response> {
    @Inject
    public TransportInferTrainedModelDeploymentAction(ClusterService clusterService, TransportService transportService, ActionFilters actionFilters) {
        super("cluster:monitor/xpack/ml/trained_models/deployment/infer", clusterService, transportService, actionFilters, InferTrainedModelDeploymentAction.Request::new, InferTrainedModelDeploymentAction.Response::new, InferTrainedModelDeploymentAction.Response::new, (Executor)EsExecutors.DIRECT_EXECUTOR_SERVICE);
    }

    protected InferTrainedModelDeploymentAction.Response newResponse(InferTrainedModelDeploymentAction.Request request, List<InferTrainedModelDeploymentAction.Response> tasks, List<TaskOperationFailure> taskOperationFailures, List<FailedNodeException> failedNodeExceptions) {
        if (!taskOperationFailures.isEmpty()) {
            throw ExceptionsHelper.taskOperationFailureToStatusException((TaskOperationFailure)taskOperationFailures.get(0));
        }
        if (!failedNodeExceptions.isEmpty()) {
            throw failedNodeExceptions.get(0);
        }
        if (tasks.isEmpty()) {
            throw new ElasticsearchStatusException("Unable to find model deployment task [{}] please stop and start the deployment or try again momentarily", RestStatus.NOT_FOUND, new Object[]{request.getId()});
        }
        assert (tasks.size() == 1);
        return tasks.get(0);
    }

    protected void taskOperation(CancellableTask actionTask, InferTrainedModelDeploymentAction.Request request, TrainedModelDeploymentTask task, ActionListener<InferTrainedModelDeploymentAction.Response> listener) {
        ArrayList<NlpInferenceInput> nlpInputs = new ArrayList<NlpInferenceInput>();
        if (request.getTextInput() != null) {
            for (String text : request.getTextInput()) {
                nlpInputs.add(NlpInferenceInput.fromText(text));
            }
        } else {
            for (Map doc : request.getDocs()) {
                nlpInputs.add(NlpInferenceInput.fromDoc(doc));
            }
        }
        AtomicInteger count = new AtomicInteger();
        AtomicArray results = new AtomicArray(nlpInputs.size());
        int slot = 0;
        for (NlpInferenceInput input : nlpInputs) {
            task.infer(input, request.getUpdate(), request.isHighPriority(), request.getInferenceTimeout(), actionTask, TransportInferTrainedModelDeploymentAction.orderedListener(count, (AtomicArray<InferenceResults>)results, slot++, nlpInputs.size(), listener));
        }
    }

    static ActionListener<InferenceResults> orderedListener(final AtomicInteger count, final AtomicArray<InferenceResults> results, final int slot, final int totalNumberOfResponses, final ActionListener<InferTrainedModelDeploymentAction.Response> finalListener) {
        return new ActionListener<InferenceResults>(){

            public void onResponse(InferenceResults response) {
                results.setOnce(slot, (Object)response);
                if (count.incrementAndGet() == totalNumberOfResponses) {
                    this.sendResponse();
                }
            }

            public void onFailure(Exception e) {
                results.setOnce(slot, (Object)new ErrorInferenceResults(e));
                if (count.incrementAndGet() == totalNumberOfResponses) {
                    this.sendResponse();
                }
            }

            private void sendResponse() {
                finalListener.onResponse((Object)new InferTrainedModelDeploymentAction.Response(results.asList()));
            }
        };
    }
}

