/*
 * Decompiled with CFR 0.152.
 */
package org.elasticsearch.xpack.inference.action;

import java.io.IOException;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.concurrent.Executor;
import java.util.concurrent.ExecutorService;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.elasticsearch.ElasticsearchStatusException;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.action.support.ActionFilters;
import org.elasticsearch.action.support.SubscribableListener;
import org.elasticsearch.action.support.master.TransportMasterNodeAction;
import org.elasticsearch.client.internal.Client;
import org.elasticsearch.cluster.ClusterState;
import org.elasticsearch.cluster.block.ClusterBlockException;
import org.elasticsearch.cluster.block.ClusterBlockLevel;
import org.elasticsearch.cluster.metadata.IndexNameExpressionResolver;
import org.elasticsearch.cluster.service.ClusterService;
import org.elasticsearch.common.VersionId;
import org.elasticsearch.common.bytes.BytesReference;
import org.elasticsearch.common.inject.Inject;
import org.elasticsearch.common.settings.ClusterSettings;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.common.util.concurrent.EsExecutors;
import org.elasticsearch.common.xcontent.XContentHelper;
import org.elasticsearch.core.Strings;
import org.elasticsearch.inference.InferenceService;
import org.elasticsearch.inference.InferenceServiceRegistry;
import org.elasticsearch.inference.Model;
import org.elasticsearch.inference.TaskType;
import org.elasticsearch.rest.RestStatus;
import org.elasticsearch.tasks.Task;
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.transport.TransportService;
import org.elasticsearch.xcontent.XContentParser;
import org.elasticsearch.xcontent.XContentParserConfiguration;
import org.elasticsearch.xcontent.XContentType;
import org.elasticsearch.xpack.core.inference.action.PutInferenceModelAction;
import org.elasticsearch.xpack.core.ml.MachineLearningField;
import org.elasticsearch.xpack.core.ml.inference.assignment.TrainedModelAssignmentUtils;
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
import org.elasticsearch.xpack.core.ml.utils.MlPlatformArchitecturesUtil;
import org.elasticsearch.xpack.inference.InferencePlugin;
import org.elasticsearch.xpack.inference.registry.ModelRegistry;

public class TransportPutInferenceModelAction
extends TransportMasterNodeAction<PutInferenceModelAction.Request, PutInferenceModelAction.Response> {
    private static final Logger logger = LogManager.getLogger(TransportPutInferenceModelAction.class);
    private final ModelRegistry modelRegistry;
    private final InferenceServiceRegistry serviceRegistry;
    private final Client client;
    private volatile boolean skipValidationAndStart;

    @Inject
    public TransportPutInferenceModelAction(TransportService transportService, ClusterService clusterService, ThreadPool threadPool, ActionFilters actionFilters, IndexNameExpressionResolver indexNameExpressionResolver, ModelRegistry modelRegistry, InferenceServiceRegistry serviceRegistry, Client client, Settings settings) {
        super("cluster:admin/xpack/inference/put", transportService, clusterService, threadPool, actionFilters, PutInferenceModelAction.Request::new, indexNameExpressionResolver, PutInferenceModelAction.Response::new, (Executor)EsExecutors.DIRECT_EXECUTOR_SERVICE);
        this.modelRegistry = modelRegistry;
        this.serviceRegistry = serviceRegistry;
        this.client = client;
        this.skipValidationAndStart = (Boolean)InferencePlugin.SKIP_VALIDATE_AND_START.get(settings);
        clusterService.getClusterSettings().addSettingsUpdateConsumer(InferencePlugin.SKIP_VALIDATE_AND_START, this::setSkipValidationAndStart);
    }

    protected void masterOperation(Task task, PutInferenceModelAction.Request request, ClusterState state, ActionListener<PutInferenceModelAction.Response> listener) throws Exception {
        Map<String, Object> requestAsMap = this.requestToMap(request);
        TaskType resolvedTaskType = TransportPutInferenceModelAction.resolveTaskType(request.getTaskType(), (String)requestAsMap.remove(TaskType.NAME));
        String serviceName = (String)requestAsMap.remove("service");
        if (serviceName == null) {
            listener.onFailure((Exception)new ElasticsearchStatusException("Inference endpoint configuration is missing the [service] setting", RestStatus.BAD_REQUEST, new Object[0]));
            return;
        }
        Optional service = this.serviceRegistry.getService(serviceName);
        if (service.isEmpty()) {
            listener.onFailure((Exception)new ElasticsearchStatusException("Unknown service [{}]", RestStatus.BAD_REQUEST, new Object[]{serviceName}));
            return;
        }
        if (((InferenceService)service.get()).getMinimalSupportedVersion().after((VersionId)state.getMinTransportVersion())) {
            logger.warn(Strings.format((String)"Service [%s] requires version [%s] but minimum cluster version is [%s]", (Object[])new Object[]{serviceName, ((InferenceService)service.get()).getMinimalSupportedVersion(), state.getMinTransportVersion()}));
            listener.onFailure((Exception)new ElasticsearchStatusException(Strings.format((String)"All nodes in the cluster are not aware of the service [%s].Wait for the cluster to finish upgrading and try again.", (Object[])new Object[]{serviceName}), RestStatus.BAD_REQUEST, new Object[0]));
            return;
        }
        List assignments = TrainedModelAssignmentUtils.modelAssignments((String)request.getInferenceEntityId(), (ClusterState)this.clusterService.state());
        if (!(assignments == null || assignments.isEmpty())) {
            listener.onFailure((Exception)ExceptionsHelper.badRequestException((String)"Model IDs must be unique. Requested model ID [{}] matches existing model IDs but must not.", (Object[])new Object[]{request.getInferenceEntityId()}));
            return;
        }
        if (((InferenceService)service.get()).isInClusterService()) {
            MlPlatformArchitecturesUtil.getMlNodesArchitecturesSet((ActionListener)listener.delegateFailureAndWrap((delegate, architectures) -> {
                if (architectures.isEmpty() && TransportPutInferenceModelAction.clusterIsInElasticCloud(this.clusterService.getClusterSettings())) {
                    this.parseAndStoreModel((InferenceService)service.get(), request.getInferenceEntityId(), resolvedTaskType, requestAsMap, Set.of("linux-x86_64"), (ActionListener<PutInferenceModelAction.Response>)delegate);
                } else {
                    this.parseAndStoreModel((InferenceService)service.get(), request.getInferenceEntityId(), resolvedTaskType, requestAsMap, (Set<String>)architectures, (ActionListener<PutInferenceModelAction.Response>)delegate);
                }
            }), (Client)this.client, (ExecutorService)this.threadPool.executor("inference_utility"));
        } else {
            this.parseAndStoreModel((InferenceService)service.get(), request.getInferenceEntityId(), resolvedTaskType, requestAsMap, Set.of(), listener);
        }
    }

    private void parseAndStoreModel(InferenceService service, String inferenceEntityId, TaskType taskType, Map<String, Object> config, Set<String> platformArchitectures, ActionListener<PutInferenceModelAction.Response> listener) {
        ActionListener storeModelListener = listener.delegateFailureAndWrap((delegate, verifiedModel) -> this.modelRegistry.storeModel((Model)verifiedModel, (ActionListener<Boolean>)delegate.delegateFailureAndWrap((l, r) -> this.putAndStartModel(service, (Model)verifiedModel, (ActionListener<PutInferenceModelAction.Response>)l))));
        ActionListener parsedModelListener = listener.delegateFailureAndWrap((delegate, model) -> {
            if (this.skipValidationAndStart) {
                storeModelListener.onResponse(model);
            } else {
                service.checkModelConfig(model, storeModelListener);
            }
        });
        service.parseRequestConfig(inferenceEntityId, taskType, config, platformArchitectures, parsedModelListener);
    }

    private void putAndStartModel(InferenceService service, Model model, ActionListener<PutInferenceModelAction.Response> finalListener) {
        SubscribableListener.newForked(listener -> {
            ActionListener errorCatchingListener = ActionListener.wrap(arg_0 -> ((ActionListener)listener).onResponse(arg_0), e -> listener.onResponse((Object)false));
            service.isModelDownloaded(model, errorCatchingListener);
        }).andThen((listener, isDownloaded) -> {
            if (!isDownloaded.booleanValue()) {
                service.putModel(model, listener);
            } else {
                listener.onResponse((Object)true);
            }
        }).andThen((listener, modelDidPut) -> {
            if (modelDidPut.booleanValue()) {
                if (this.skipValidationAndStart) {
                    listener.onResponse((Object)new PutInferenceModelAction.Response(model.getConfigurations()));
                } else {
                    service.start(model, listener.delegateFailureAndWrap((l3, ok) -> l3.onResponse((Object)new PutInferenceModelAction.Response(model.getConfigurations()))));
                }
            } else {
                logger.warn("Failed to put model [{}]", (Object)model.getInferenceEntityId());
            }
        }).addListener(finalListener);
    }

    private Map<String, Object> requestToMap(PutInferenceModelAction.Request request) throws IOException {
        try (XContentParser parser = XContentHelper.createParser((XContentParserConfiguration)XContentParserConfiguration.EMPTY, (BytesReference)request.getContent(), (XContentType)request.getContentType());){
            Map map = parser.map();
            return map;
        }
    }

    private void setSkipValidationAndStart(boolean skipValidationAndStart) {
        this.skipValidationAndStart = skipValidationAndStart;
    }

    protected ClusterBlockException checkBlock(PutInferenceModelAction.Request request, ClusterState state) {
        return state.blocks().globalBlockedException(ClusterBlockLevel.METADATA_WRITE);
    }

    static boolean clusterIsInElasticCloud(ClusterSettings settings) {
        return (Boolean)settings.get(MachineLearningField.USE_AUTO_MACHINE_MEMORY_PERCENT);
    }

    static TaskType resolveTaskType(TaskType urlTaskType, String bodyTaskType) {
        if (bodyTaskType == null) {
            if (urlTaskType == TaskType.ANY) {
                throw new ElasticsearchStatusException("model is missing required setting [task_type]", RestStatus.BAD_REQUEST, new Object[0]);
            }
            return urlTaskType;
        }
        TaskType parsedBodyTask = TaskType.fromStringOrStatusException((String)bodyTaskType);
        if (parsedBodyTask == TaskType.ANY) {
            throw new ElasticsearchStatusException("task_type [any] is not valid type for inference", RestStatus.BAD_REQUEST, new Object[0]);
        }
        if (!parsedBodyTask.isAnyOrSame(urlTaskType)) {
            throw new ElasticsearchStatusException("Cannot resolve conflicting task_type parameter in the request URL [{}] and the request body [{}]", RestStatus.BAD_REQUEST, new Object[]{urlTaskType.toString(), bodyTaskType});
        }
        return parsedBodyTask;
    }
}

