/*
 * Decompiled with CFR 0.152.
 */
package org.elasticsearch.xpack.ml.action;

import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Objects;
import java.util.Optional;
import java.util.concurrent.Executor;
import java.util.function.Predicate;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.action.support.ActionFilters;
import org.elasticsearch.action.support.master.AcknowledgedResponse;
import org.elasticsearch.action.support.master.AcknowledgedTransportMasterNodeAction;
import org.elasticsearch.cluster.AckedClusterStateUpdateTask;
import org.elasticsearch.cluster.ClusterState;
import org.elasticsearch.cluster.ClusterStateUpdateTask;
import org.elasticsearch.cluster.ack.AckedRequest;
import org.elasticsearch.cluster.block.ClusterBlockException;
import org.elasticsearch.cluster.block.ClusterBlockLevel;
import org.elasticsearch.cluster.metadata.IndexNameExpressionResolver;
import org.elasticsearch.cluster.metadata.Metadata;
import org.elasticsearch.cluster.service.ClusterService;
import org.elasticsearch.common.inject.Inject;
import org.elasticsearch.common.logging.HeaderWarning;
import org.elasticsearch.common.util.concurrent.EsExecutors;
import org.elasticsearch.common.util.set.Sets;
import org.elasticsearch.core.SuppressForbidden;
import org.elasticsearch.license.License;
import org.elasticsearch.license.LicenseUtils;
import org.elasticsearch.license.XPackLicenseState;
import org.elasticsearch.tasks.Task;
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.transport.TransportService;
import org.elasticsearch.xpack.core.ml.MachineLearningField;
import org.elasticsearch.xpack.core.ml.action.GetTrainedModelsAction;
import org.elasticsearch.xpack.core.ml.action.PutTrainedModelAliasAction;
import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig;
import org.elasticsearch.xpack.core.ml.inference.TrainedModelType;
import org.elasticsearch.xpack.core.ml.inference.assignment.AllocationStatus;
import org.elasticsearch.xpack.core.ml.inference.assignment.TrainedModelAssignment;
import org.elasticsearch.xpack.core.ml.job.messages.Messages;
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
import org.elasticsearch.xpack.ml.inference.ModelAliasMetadata;
import org.elasticsearch.xpack.ml.inference.assignment.TrainedModelAssignmentMetadata;
import org.elasticsearch.xpack.ml.inference.persistence.TrainedModelProvider;
import org.elasticsearch.xpack.ml.notifications.InferenceAuditor;

public class TransportPutTrainedModelAliasAction
extends AcknowledgedTransportMasterNodeAction<PutTrainedModelAliasAction.Request> {
    private static final Logger logger = LogManager.getLogger(TransportPutTrainedModelAliasAction.class);
    private final XPackLicenseState licenseState;
    private final TrainedModelProvider trainedModelProvider;
    private final InferenceAuditor auditor;

    @Inject
    public TransportPutTrainedModelAliasAction(TransportService transportService, TrainedModelProvider trainedModelProvider, ClusterService clusterService, ThreadPool threadPool, XPackLicenseState licenseState, ActionFilters actionFilters, InferenceAuditor auditor, IndexNameExpressionResolver indexNameExpressionResolver) {
        super("cluster:admin/xpack/ml/inference/model_aliases/put", transportService, clusterService, threadPool, actionFilters, PutTrainedModelAliasAction.Request::new, indexNameExpressionResolver, (Executor)EsExecutors.DIRECT_EXECUTOR_SERVICE);
        this.licenseState = licenseState;
        this.trainedModelProvider = trainedModelProvider;
        this.auditor = auditor;
    }

    protected void masterOperation(Task task, final PutTrainedModelAliasAction.Request request, ClusterState state, ActionListener<AcknowledgedResponse> listener) throws Exception {
        boolean mlSupported = MachineLearningField.ML_API_FEATURE.check(this.licenseState);
        Predicate<TrainedModelConfig> isLicensed = model -> mlSupported || model.getLicenseLevel() == License.OperationMode.BASIC;
        String oldModelId = ModelAliasMetadata.fromState(state).getModelId(request.getModelAlias());
        if (oldModelId != null && !request.isReassign()) {
            listener.onFailure((Exception)ExceptionsHelper.badRequestException((String)"cannot assign model_alias [{}] to model_id [{}] as model_alias already refers to [{}]. Set parameter [reassign] to [true] if model_alias should be reassigned.", (Object[])new Object[]{request.getModelAlias(), request.getModelId(), oldModelId}));
            return;
        }
        HashSet<String> modelIds = new HashSet<String>();
        modelIds.add(request.getModelAlias());
        modelIds.add(request.getModelId());
        if (oldModelId != null) {
            modelIds.add(oldModelId);
        }
        this.trainedModelProvider.getTrainedModels(modelIds, GetTrainedModelsAction.Includes.empty(), true, null, (ActionListener<List<TrainedModelConfig>>)ActionListener.wrap(models -> {
            TrainedModelConfig newModel = null;
            TrainedModelConfig oldModel = null;
            for (TrainedModelConfig config : models) {
                if (config.getModelId().equals(request.getModelId())) {
                    newModel = config;
                }
                if (config.getModelId().equals(oldModelId)) {
                    oldModel = config;
                }
                if (!config.getModelId().equals(request.getModelAlias())) continue;
                listener.onFailure((Exception)ExceptionsHelper.badRequestException((String)"model_alias cannot be the same as an existing trained model_id", (Object[])new Object[0]));
                return;
            }
            if (newModel == null) {
                listener.onFailure((Exception)((Object)ExceptionsHelper.missingTrainedModel((String)request.getModelId())));
                return;
            }
            if (!isLicensed.test(newModel)) {
                listener.onFailure((Exception)LicenseUtils.newComplianceException((String)"ml"));
                return;
            }
            if (oldModel != null) {
                HashSet newInputFields;
                HashSet oldInputFields;
                if (newModel.getInferenceConfig() != null && oldModel.getInferenceConfig() != null && !newModel.getInferenceConfig().getName().equals(oldModel.getInferenceConfig().getName())) {
                    listener.onFailure((Exception)ExceptionsHelper.badRequestException((String)"cannot reassign model_alias [{}] to model [{}] with inference config type [{}] from model [{}] with type [{}]", (Object[])new Object[]{request.getModelAlias(), newModel.getModelId(), newModel.getInferenceConfig().getName(), oldModel.getModelId(), oldModel.getInferenceConfig().getName()}));
                    return;
                }
                if (!Objects.equals(newModel.getModelType(), oldModel.getModelType())) {
                    listener.onFailure((Exception)ExceptionsHelper.badRequestException((String)"cannot reassign model_alias [{}] to model [{}] with type [{}] from model [{}] with type [{}]", (Object[])new Object[]{request.getModelAlias(), newModel.getModelId(), Optional.ofNullable(newModel.getModelType()).orElse(TrainedModelType.TREE_ENSEMBLE).toString(), oldModel.getModelId(), Optional.ofNullable(oldModel.getModelType()).orElse(TrainedModelType.TREE_ENSEMBLE).toString()}));
                    return;
                }
                if (newModel.getModelType() == TrainedModelType.PYTORCH) {
                    List<TrainedModelAssignment> oldAssignments = TrainedModelAssignmentMetadata.assignmentsForModelId(state, oldModelId);
                    List<TrainedModelAssignment> newAssignments = TrainedModelAssignmentMetadata.assignmentsForModelId(state, newModel.getModelId());
                    if (!oldAssignments.isEmpty()) {
                        if (newAssignments.isEmpty()) {
                            listener.onFailure((Exception)ExceptionsHelper.badRequestException((String)"cannot reassign model_alias [{}] to model [{}] from model [{}] as it is not yet deployed", (Object[])new Object[]{request.getModelAlias(), newModel.getModelId(), oldModel.getModelId()}));
                            return;
                        }
                        for (TrainedModelAssignment oldAssignment : oldAssignments) {
                            Optional oldAllocationStatus = oldAssignment.calculateAllocationStatus();
                            if (!oldAllocationStatus.isPresent() || !((AllocationStatus)oldAllocationStatus.get()).calculateState().isAnyOf(new AllocationStatus.State[]{AllocationStatus.State.FULLY_ALLOCATED, AllocationStatus.State.STARTED})) continue;
                            for (TrainedModelAssignment newAssignment : newAssignments) {
                                Optional newAllocationStatus = newAssignment.calculateAllocationStatus();
                                if (!newAllocationStatus.isEmpty() && !((AllocationStatus)newAllocationStatus.get()).calculateState().equals((Object)AllocationStatus.State.STARTING)) continue;
                                listener.onFailure((Exception)ExceptionsHelper.badRequestException((String)"cannot reassign model_alias [{}] to model [{}]  from model [{}] as it is not yet allocated to any nodes", (Object[])new Object[]{request.getModelAlias(), newModel.getModelId(), oldModel.getModelId()}));
                                return;
                            }
                        }
                    }
                }
                if (Sets.difference(oldInputFields = new HashSet(oldModel.getInput().getFieldNames()), newInputFields = new HashSet(newModel.getInput().getFieldNames())).size() > oldInputFields.size() / 2 || Sets.intersection(newInputFields, oldInputFields).size() < oldInputFields.size() / 2) {
                    String warning = Messages.getMessage((String)"The input fields for new model [{0}] and for old model [{1}] differ significantly, model results may change drastically.", (Object[])new Object[]{request.getModelId(), oldModelId});
                    this.auditor.warning(oldModelId, warning);
                    logger.warn("[{}] {}", (Object)oldModelId, (Object)warning);
                    HeaderWarning.addWarning((String)warning, (Object[])new Object[0]);
                }
            }
            this.submitUnbatchedTask("update-model-alias", (ClusterStateUpdateTask)new AckedClusterStateUpdateTask((AckedRequest)request, listener){

                public ClusterState execute(ClusterState currentState) {
                    return TransportPutTrainedModelAliasAction.updateModelAlias(currentState, request);
                }
            });
        }, arg_0 -> listener.onFailure(arg_0)));
    }

    @SuppressForbidden(reason="legacy usage of unbatched task")
    private void submitUnbatchedTask(String source, ClusterStateUpdateTask task) {
        this.clusterService.submitUnbatchedStateUpdateTask(source, task);
    }

    static ClusterState updateModelAlias(ClusterState currentState, PutTrainedModelAliasAction.Request request) {
        ClusterState.Builder builder = ClusterState.builder((ClusterState)currentState);
        ModelAliasMetadata currentMetadata = ModelAliasMetadata.fromState(currentState);
        String currentModelId = currentMetadata.getModelId(request.getModelAlias());
        HashMap<String, ModelAliasMetadata.ModelAliasEntry> newMetadata = new HashMap<String, ModelAliasMetadata.ModelAliasEntry>(currentMetadata.modelAliases());
        if (currentModelId == null) {
            logger.info("creating new model_alias [{}] for model [{}]", (Object)request.getModelAlias(), (Object)request.getModelId());
        } else {
            logger.info("updating model_alias [{}] to refer to model [{}] from model [{}]", (Object)request.getModelAlias(), (Object)request.getModelId(), (Object)currentModelId);
        }
        newMetadata.put(request.getModelAlias(), new ModelAliasMetadata.ModelAliasEntry(request.getModelId()));
        ModelAliasMetadata modelAliasMetadata = new ModelAliasMetadata(newMetadata);
        builder.metadata(Metadata.builder((Metadata)currentState.getMetadata()).putCustom("trained_model_alias", (Metadata.Custom)modelAliasMetadata).build());
        return builder.build();
    }

    protected ClusterBlockException checkBlock(PutTrainedModelAliasAction.Request request, ClusterState state) {
        return state.blocks().globalBlockedException(ClusterBlockLevel.METADATA_WRITE);
    }
}

