/*
 * Decompiled with CFR 0.152.
 */
package org.elasticsearch.xpack.inference.action;

import java.util.Optional;
import java.util.Set;
import java.util.concurrent.Executor;
import java.util.concurrent.Flow;
import java.util.stream.Collectors;
import org.elasticsearch.ElasticsearchStatusException;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.action.support.ActionFilters;
import org.elasticsearch.action.support.HandledTransportAction;
import org.elasticsearch.common.logging.DeprecationLogger;
import org.elasticsearch.common.util.concurrent.EsExecutors;
import org.elasticsearch.core.Strings;
import org.elasticsearch.inference.InferenceService;
import org.elasticsearch.inference.InferenceServiceRegistry;
import org.elasticsearch.inference.InferenceServiceResults;
import org.elasticsearch.inference.Model;
import org.elasticsearch.inference.TaskType;
import org.elasticsearch.inference.UnparsedModel;
import org.elasticsearch.injection.guice.Inject;
import org.elasticsearch.rest.RestStatus;
import org.elasticsearch.tasks.Task;
import org.elasticsearch.transport.TransportService;
import org.elasticsearch.xpack.core.inference.action.InferenceAction;
import org.elasticsearch.xpack.inference.action.task.StreamingTaskManager;
import org.elasticsearch.xpack.inference.registry.ModelRegistry;
import org.elasticsearch.xpack.inference.telemetry.InferenceStats;

public class TransportInferenceAction
extends HandledTransportAction<InferenceAction.Request, InferenceAction.Response> {
    private static final String STREAMING_INFERENCE_TASK_TYPE = "streaming_inference";
    private static final String STREAMING_TASK_ACTION = "xpack/inference/streaming_inference[n]";
    private static final Set<Class<? extends InferenceService>> supportsStreaming = Set.of();
    private final ModelRegistry modelRegistry;
    private final InferenceServiceRegistry serviceRegistry;
    private final InferenceStats inferenceStats;
    private final StreamingTaskManager streamingTaskManager;
    private static final DeprecationLogger DEPRECATION_LOGGER = DeprecationLogger.getLogger(TransportInferenceAction.class);

    @Inject
    public TransportInferenceAction(TransportService transportService, ActionFilters actionFilters, ModelRegistry modelRegistry, InferenceServiceRegistry serviceRegistry, InferenceStats inferenceStats, StreamingTaskManager streamingTaskManager) {
        super("cluster:monitor/xpack/inference", transportService, actionFilters, InferenceAction.Request::new, (Executor)EsExecutors.DIRECT_EXECUTOR_SERVICE);
        this.modelRegistry = modelRegistry;
        this.serviceRegistry = serviceRegistry;
        this.inferenceStats = inferenceStats;
        this.streamingTaskManager = streamingTaskManager;
    }

    protected void doExecute(Task task, InferenceAction.Request request, ActionListener<InferenceAction.Response> listener) {
        ActionListener getModelListener = listener.delegateFailureAndWrap((delegate, unparsedModel) -> {
            Optional service = this.serviceRegistry.getService(unparsedModel.service());
            if (service.isEmpty()) {
                listener.onFailure((Exception)TransportInferenceAction.unknownServiceException(unparsedModel.service(), request.getInferenceEntityId()));
                return;
            }
            if (!request.getTaskType().isAnyOrSame(unparsedModel.taskType())) {
                listener.onFailure((Exception)TransportInferenceAction.incompatibleTaskTypeException(request.getTaskType(), unparsedModel.taskType()));
                return;
            }
            Model model = ((InferenceService)service.get()).parsePersistedConfigWithSecrets(unparsedModel.inferenceEntityId(), unparsedModel.taskType(), unparsedModel.settings(), unparsedModel.secrets());
            this.inferOnService(model, request, (InferenceService)service.get(), (ActionListener<InferenceAction.Response>)delegate);
        });
        this.modelRegistry.getModelWithSecrets(request.getInferenceEntityId(), (ActionListener<UnparsedModel>)getModelListener);
    }

    private void inferOnService(Model model, InferenceAction.Request request, InferenceService service, ActionListener<InferenceAction.Response> listener) {
        if (!request.isStreaming() || service.canStream(request.getTaskType())) {
            this.inferenceStats.incrementRequestCount(model);
            service.infer(model, request.getQuery(), request.getInput(), request.isStreaming(), request.getTaskSettings(), request.getInputType(), request.getInferenceTimeout(), this.createListener(request, listener));
        } else {
            listener.onFailure((Exception)this.unsupportedStreamingTaskException(request, service));
        }
    }

    private ElasticsearchStatusException unsupportedStreamingTaskException(InferenceAction.Request request, InferenceService service) {
        Set supportedTasks = service.supportedStreamingTasks();
        if (supportedTasks.isEmpty()) {
            return new ElasticsearchStatusException(Strings.format((String)"Streaming is not allowed for service [%s].", (Object[])new Object[]{service.name()}), RestStatus.METHOD_NOT_ALLOWED, new Object[0]);
        }
        String validTasks = supportedTasks.stream().map(TaskType::toString).collect(Collectors.joining(","));
        return new ElasticsearchStatusException(Strings.format((String)"Streaming is not allowed for service [%s] and task [%s]. Supported tasks: [%s]", (Object[])new Object[]{service.name(), request.getTaskType(), validTasks}), RestStatus.METHOD_NOT_ALLOWED, new Object[0]);
    }

    private ActionListener<InferenceServiceResults> createListener(InferenceAction.Request request, ActionListener<InferenceAction.Response> listener) {
        if (request.isStreaming()) {
            return listener.delegateFailureAndWrap((l, inferenceResults) -> {
                Flow.Processor taskProcessor = this.streamingTaskManager.create(STREAMING_INFERENCE_TASK_TYPE, STREAMING_TASK_ACTION);
                inferenceResults.publisher().subscribe(taskProcessor);
                l.onResponse((Object)new InferenceAction.Response(inferenceResults, taskProcessor));
            });
        }
        return listener.delegateFailureAndWrap((l, inferenceResults) -> l.onResponse((Object)new InferenceAction.Response(inferenceResults)));
    }

    private static ElasticsearchStatusException unknownServiceException(String service, String inferenceId) {
        return new ElasticsearchStatusException("Unknown service [{}] for model [{}]. ", RestStatus.BAD_REQUEST, new Object[]{service, inferenceId});
    }

    private static ElasticsearchStatusException incompatibleTaskTypeException(TaskType requested, TaskType expected) {
        return new ElasticsearchStatusException("Incompatible task_type, the requested type [{}] does not match the model type [{}]", RestStatus.BAD_REQUEST, new Object[]{requested, expected});
    }
}

