/*
 * Decompiled with CFR 0.152.
 */
package org.elasticsearch.xpack.ml.action;

import java.io.IOException;
import java.time.Instant;
import java.util.List;
import org.elasticsearch.ElasticsearchException;
import org.elasticsearch.Version;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.action.ActionRequest;
import org.elasticsearch.action.search.SearchRequest;
import org.elasticsearch.action.support.ActionFilters;
import org.elasticsearch.action.support.master.MasterNodeRequest;
import org.elasticsearch.action.support.master.TransportMasterNodeAction;
import org.elasticsearch.client.Client;
import org.elasticsearch.cluster.ClusterState;
import org.elasticsearch.cluster.block.ClusterBlockException;
import org.elasticsearch.cluster.block.ClusterBlockLevel;
import org.elasticsearch.cluster.metadata.IndexNameExpressionResolver;
import org.elasticsearch.cluster.service.ClusterService;
import org.elasticsearch.common.inject.Inject;
import org.elasticsearch.common.util.concurrent.ThreadContext;
import org.elasticsearch.common.xcontent.NamedXContentRegistry;
import org.elasticsearch.index.query.ConstantScoreQueryBuilder;
import org.elasticsearch.index.query.QueryBuilder;
import org.elasticsearch.index.query.QueryBuilders;
import org.elasticsearch.license.License;
import org.elasticsearch.license.LicenseUtils;
import org.elasticsearch.license.XPackLicenseState;
import org.elasticsearch.search.builder.SearchSourceBuilder;
import org.elasticsearch.tasks.Task;
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.transport.TransportService;
import org.elasticsearch.xpack.core.ClientHelper;
import org.elasticsearch.xpack.core.ml.action.PutTrainedModelAction;
import org.elasticsearch.xpack.core.ml.inference.ModelAliasMetadata;
import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig;
import org.elasticsearch.xpack.core.ml.job.messages.Messages;
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
import org.elasticsearch.xpack.ml.inference.persistence.TrainedModelProvider;

public class TransportPutTrainedModelAction
extends TransportMasterNodeAction<PutTrainedModelAction.Request, PutTrainedModelAction.Response> {
    private final TrainedModelProvider trainedModelProvider;
    private final XPackLicenseState licenseState;
    private final NamedXContentRegistry xContentRegistry;
    private final Client client;

    @Inject
    public TransportPutTrainedModelAction(TransportService transportService, ClusterService clusterService, ThreadPool threadPool, XPackLicenseState licenseState, ActionFilters actionFilters, IndexNameExpressionResolver indexNameExpressionResolver, Client client, TrainedModelProvider trainedModelProvider, NamedXContentRegistry xContentRegistry) {
        super("cluster:admin/xpack/ml/inference/put", transportService, clusterService, threadPool, actionFilters, PutTrainedModelAction.Request::new, indexNameExpressionResolver, PutTrainedModelAction.Response::new, "same");
        this.licenseState = licenseState;
        this.trainedModelProvider = trainedModelProvider;
        this.xContentRegistry = xContentRegistry;
        this.client = client;
    }

    protected void masterOperation(PutTrainedModelAction.Request request, ClusterState state, ActionListener<PutTrainedModelAction.Response> listener) {
        if (state.nodes().getMinNodeVersion().before(Version.V_7_8_0)) {
            listener.onFailure((Exception)((Object)ExceptionsHelper.badRequestException((String)"Creating a new model requires that all nodes are at least version [{}]", (Object[])new Object[]{request.getTrainedModelConfig().getModelId(), Version.V_7_8_0.toString()})));
            return;
        }
        try {
            request.getTrainedModelConfig().ensureParsedDefinition(this.xContentRegistry);
            request.getTrainedModelConfig().getModelDefinition().getTrainedModel().validate();
        }
        catch (IOException ex) {
            listener.onFailure((Exception)((Object)ExceptionsHelper.badRequestException((String)"Failed to parse definition for [{}]", (Throwable)ex, (Object[])new Object[]{request.getTrainedModelConfig().getModelId()})));
            return;
        }
        catch (ElasticsearchException ex) {
            listener.onFailure((Exception)((Object)ExceptionsHelper.badRequestException((String)"Definition for [{}] has validation failures.", (Throwable)ex, (Object[])new Object[]{request.getTrainedModelConfig().getModelId()})));
            return;
        }
        if (!request.getTrainedModelConfig().getInferenceConfig().isTargetTypeSupported(request.getTrainedModelConfig().getModelDefinition().getTrainedModel().targetType())) {
            listener.onFailure((Exception)((Object)ExceptionsHelper.badRequestException((String)"Model [{}] inference config type [{}] does not support definition target type [{}]", (Object[])new Object[]{request.getTrainedModelConfig().getModelId(), request.getTrainedModelConfig().getInferenceConfig().getName(), request.getTrainedModelConfig().getModelDefinition().getTrainedModel().targetType()})));
            return;
        }
        Version minCompatibilityVersion = request.getTrainedModelConfig().getModelDefinition().getTrainedModel().getMinimalCompatibilityVersion();
        if (state.nodes().getMinNodeVersion().before(minCompatibilityVersion)) {
            listener.onFailure((Exception)((Object)ExceptionsHelper.badRequestException((String)"Definition for [{}] requires that all nodes are at least version [{}]", (Object[])new Object[]{request.getTrainedModelConfig().getModelId(), minCompatibilityVersion.toString()})));
            return;
        }
        TrainedModelConfig trainedModelConfig = new TrainedModelConfig.Builder(request.getTrainedModelConfig()).setVersion(Version.CURRENT).setCreateTime(Instant.now()).setCreatedBy("api_user").setLicenseLevel(License.OperationMode.PLATINUM.description()).setEstimatedHeapMemory(request.getTrainedModelConfig().getModelDefinition().ramBytesUsed()).setEstimatedOperations(request.getTrainedModelConfig().getModelDefinition().getTrainedModel().estimatedNumOperations()).build();
        if (ModelAliasMetadata.fromState((ClusterState)state).getModelId(trainedModelConfig.getModelId()) != null) {
            listener.onFailure((Exception)((Object)ExceptionsHelper.badRequestException((String)"requested model_id [{}] is the same as an existing model_alias. Model model_aliases and ids must be unique", (Object[])new Object[]{request.getTrainedModelConfig().getModelId()})));
            return;
        }
        ActionListener tagsModelIdCheckListener = ActionListener.wrap(r -> this.trainedModelProvider.storeTrainedModel(trainedModelConfig, (ActionListener<Boolean>)ActionListener.wrap(bool -> {
            TrainedModelConfig configToReturn = new TrainedModelConfig.Builder(trainedModelConfig).clearDefinition().build();
            listener.onResponse((Object)new PutTrainedModelAction.Response(configToReturn));
        }, arg_0 -> ((ActionListener)listener).onFailure(arg_0))), arg_0 -> listener.onFailure(arg_0));
        ActionListener modelIdTagCheckListener = ActionListener.wrap(r -> this.checkTagsAgainstModelIds(request.getTrainedModelConfig().getTags(), (ActionListener<Void>)tagsModelIdCheckListener), arg_0 -> listener.onFailure(arg_0));
        this.checkModelIdAgainstTags(request.getTrainedModelConfig().getModelId(), (ActionListener<Void>)modelIdTagCheckListener);
    }

    private void checkModelIdAgainstTags(String modelId, ActionListener<Void> listener) {
        ConstantScoreQueryBuilder builder = QueryBuilders.constantScoreQuery((QueryBuilder)QueryBuilders.boolQuery().filter((QueryBuilder)QueryBuilders.termQuery((String)TrainedModelConfig.TAGS.getPreferredName(), (String)modelId)));
        SearchSourceBuilder sourceBuilder = new SearchSourceBuilder().query((QueryBuilder)builder).size(0).trackTotalHitsUpTo(1);
        SearchRequest searchRequest = new SearchRequest(new String[]{".ml-inference-*"}).source(sourceBuilder);
        ClientHelper.executeAsyncWithOrigin((ThreadContext)this.client.threadPool().getThreadContext(), (String)"ml", (ActionRequest)searchRequest, (ActionListener)ActionListener.wrap(response -> {
            if (response.getHits().getTotalHits().value > 0L) {
                listener.onFailure((Exception)((Object)ExceptionsHelper.badRequestException((String)Messages.getMessage((String)"The provided model_id {0} must not match existing tags.", (Object[])new Object[]{modelId}), (Object[])new Object[0])));
                return;
            }
            listener.onResponse(null);
        }, arg_0 -> listener.onFailure(arg_0)), (arg_0, arg_1) -> ((Client)this.client).search(arg_0, arg_1));
    }

    private void checkTagsAgainstModelIds(List<String> tags, ActionListener<Void> listener) {
        if (tags.isEmpty()) {
            listener.onResponse(null);
            return;
        }
        ConstantScoreQueryBuilder builder = QueryBuilders.constantScoreQuery((QueryBuilder)QueryBuilders.boolQuery().filter((QueryBuilder)QueryBuilders.termsQuery((String)TrainedModelConfig.MODEL_ID.getPreferredName(), tags)));
        SearchSourceBuilder sourceBuilder = new SearchSourceBuilder().query((QueryBuilder)builder).size(0).trackTotalHitsUpTo(1);
        SearchRequest searchRequest = new SearchRequest(new String[]{".ml-inference-*"}).source(sourceBuilder);
        ClientHelper.executeAsyncWithOrigin((ThreadContext)this.client.threadPool().getThreadContext(), (String)"ml", (ActionRequest)searchRequest, (ActionListener)ActionListener.wrap(response -> {
            if (response.getHits().getTotalHits().value > 0L) {
                listener.onFailure((Exception)((Object)ExceptionsHelper.badRequestException((String)Messages.getMessage((String)"The provided tags {0} must not match existing model_ids.", (Object[])new Object[]{tags}), (Object[])new Object[0])));
                return;
            }
            listener.onResponse(null);
        }, arg_0 -> listener.onFailure(arg_0)), (arg_0, arg_1) -> ((Client)this.client).search(arg_0, arg_1));
    }

    protected ClusterBlockException checkBlock(PutTrainedModelAction.Request request, ClusterState state) {
        return state.blocks().globalBlockedException(ClusterBlockLevel.METADATA_WRITE);
    }

    protected void doExecute(Task task, PutTrainedModelAction.Request request, ActionListener<PutTrainedModelAction.Response> listener) {
        if (this.licenseState.checkFeature(XPackLicenseState.Feature.MACHINE_LEARNING)) {
            super.doExecute(task, (MasterNodeRequest)request, listener);
        } else {
            listener.onFailure((Exception)LicenseUtils.newComplianceException((String)"ml"));
        }
    }
}

