/*
 * Decompiled with CFR 0.152.
 */
package org.elasticsearch.xpack.inference.services.elasticsearch;

import java.io.IOException;
import java.util.List;
import java.util.concurrent.ExecutorService;
import java.util.function.Consumer;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.elasticsearch.ElasticsearchStatusException;
import org.elasticsearch.ExceptionsHelper;
import org.elasticsearch.ResourceNotFoundException;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.action.ActionRequest;
import org.elasticsearch.action.ActionType;
import org.elasticsearch.action.support.SubscribableListener;
import org.elasticsearch.client.internal.Client;
import org.elasticsearch.client.internal.OriginSettingClient;
import org.elasticsearch.cluster.service.ClusterService;
import org.elasticsearch.core.TimeValue;
import org.elasticsearch.inference.InferenceService;
import org.elasticsearch.inference.InferenceServiceExtension;
import org.elasticsearch.inference.InputType;
import org.elasticsearch.inference.Model;
import org.elasticsearch.inference.ServiceSettings;
import org.elasticsearch.inference.TaskType;
import org.elasticsearch.inference.UnparsedModel;
import org.elasticsearch.xpack.core.ClientHelper;
import org.elasticsearch.xpack.core.ml.MachineLearningField;
import org.elasticsearch.xpack.core.ml.action.CreateTrainedModelAssignmentAction;
import org.elasticsearch.xpack.core.ml.action.GetTrainedModelsAction;
import org.elasticsearch.xpack.core.ml.action.InferModelAction;
import org.elasticsearch.xpack.core.ml.action.PutTrainedModelAction;
import org.elasticsearch.xpack.core.ml.action.StartTrainedModelDeploymentAction;
import org.elasticsearch.xpack.core.ml.action.StopTrainedModelDeploymentAction;
import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig;
import org.elasticsearch.xpack.core.ml.inference.TrainedModelInput;
import org.elasticsearch.xpack.core.ml.inference.TrainedModelPrefixStrings;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfigUpdate;
import org.elasticsearch.xpack.core.ml.utils.MlPlatformArchitecturesUtil;
import org.elasticsearch.xpack.inference.services.elasticsearch.CustomElandModel;
import org.elasticsearch.xpack.inference.services.elasticsearch.ElasticRerankerModel;
import org.elasticsearch.xpack.inference.services.elasticsearch.ElasticsearchInternalModel;
import org.elasticsearch.xpack.inference.services.elasticsearch.ElasticsearchInternalServiceSettings;
import org.elasticsearch.xpack.inference.services.elasticsearch.ElserInternalModel;
import org.elasticsearch.xpack.inference.services.elasticsearch.MultilingualE5SmallModel;

public abstract class BaseElasticsearchInternalService
implements InferenceService {
    protected final OriginSettingClient client;
    protected final ExecutorService inferenceExecutor;
    protected final Consumer<ActionListener<PreferredModelVariant>> preferredModelVariantFn;
    private final ClusterService clusterService;
    private static final Logger logger = LogManager.getLogger(BaseElasticsearchInternalService.class);

    public BaseElasticsearchInternalService(InferenceServiceExtension.InferenceServiceFactoryContext context) {
        this.client = new OriginSettingClient(context.client(), "inference");
        this.inferenceExecutor = context.threadPool().executor("inference_utility");
        this.preferredModelVariantFn = this::preferredVariantFromPlatformArchitecture;
        this.clusterService = context.clusterService();
    }

    public BaseElasticsearchInternalService(InferenceServiceExtension.InferenceServiceFactoryContext context, Consumer<ActionListener<PreferredModelVariant>> preferredModelVariantFn) {
        this.client = new OriginSettingClient(context.client(), "inference");
        this.inferenceExecutor = context.threadPool().executor("inference_utility");
        this.preferredModelVariantFn = preferredModelVariantFn;
        this.clusterService = context.clusterService();
    }

    public void start(Model model, TimeValue timeout, ActionListener<Boolean> finalListener) {
        if (model instanceof ElasticsearchInternalModel) {
            ElasticsearchInternalModel esModel = (ElasticsearchInternalModel)model;
            if (!this.supportedTaskTypes().contains(model.getTaskType())) {
                finalListener.onFailure((Exception)new IllegalStateException(TaskType.unsupportedTaskTypeErrorMsg((TaskType)model.getConfigurations().getTaskType(), (String)this.name())));
                return;
            }
            if (esModel.usesExistingDeployment()) {
                finalListener.onResponse((Object)Boolean.TRUE);
                return;
            }
            SubscribableListener.newForked(forkedListener -> this.isBuiltinModelPut(model, (ActionListener<Boolean>)forkedListener)).andThen((l, modelConfigExists) -> {
                if (!modelConfigExists.booleanValue()) {
                    this.putModel(model, (ActionListener<Boolean>)l);
                } else {
                    l.onResponse((Object)true);
                }
            }).andThen((l2, modelDidPut) -> {
                StartTrainedModelDeploymentAction.Request startRequest = esModel.getStartTrainedModelDeploymentActionRequest(timeout);
                ActionListener<CreateTrainedModelAssignmentAction.Response> responseListener = esModel.getCreateTrainedModelAssignmentActionListener(model, (ActionListener<Boolean>)l2);
                this.client.execute((ActionType)StartTrainedModelDeploymentAction.INSTANCE, (ActionRequest)startRequest, responseListener);
            }).addListener(finalListener);
        } else {
            finalListener.onFailure((Exception)BaseElasticsearchInternalService.notElasticsearchModelException(model));
        }
    }

    public void stop(UnparsedModel unparsedModel, ActionListener<Boolean> listener) {
        Model model = this.parsePersistedConfig(unparsedModel.inferenceEntityId(), unparsedModel.taskType(), unparsedModel.settings());
        if (model instanceof ElasticsearchInternalModel) {
            ElasticsearchInternalModel esModel = (ElasticsearchInternalModel)model;
            ElasticsearchInternalServiceSettings serviceSettings = esModel.getServiceSettings();
            if (serviceSettings.getDeploymentId() != null) {
                listener.onResponse((Object)Boolean.TRUE);
                return;
            }
            StopTrainedModelDeploymentAction.Request request = new StopTrainedModelDeploymentAction.Request(esModel.mlNodeDeploymentId());
            request.setForce(true);
            this.client.execute((ActionType)StopTrainedModelDeploymentAction.INSTANCE, (ActionRequest)request, listener.delegateFailureAndWrap((delegatedResponseListener, response) -> delegatedResponseListener.onResponse((Object)Boolean.TRUE)));
        } else {
            listener.onFailure((Exception)BaseElasticsearchInternalService.notElasticsearchModelException(model));
        }
    }

    protected static IllegalStateException notElasticsearchModelException(Model model) {
        return new IllegalStateException("Error starting model, [" + model.getConfigurations().getInferenceEntityId() + "] is not an Elasticsearch service model");
    }

    protected void putModel(Model model, ActionListener<Boolean> listener) {
        if (!(model instanceof ElasticsearchInternalModel)) {
            listener.onFailure((Exception)BaseElasticsearchInternalService.notElasticsearchModelException(model));
            return;
        }
        if (model instanceof MultilingualE5SmallModel) {
            MultilingualE5SmallModel e5Model = (MultilingualE5SmallModel)model;
            this.putBuiltInModel(e5Model.getServiceSettings().modelId(), listener);
        } else if (model instanceof ElserInternalModel) {
            ElserInternalModel elserModel = (ElserInternalModel)model;
            this.putBuiltInModel(elserModel.getServiceSettings().modelId(), listener);
        } else if (model instanceof ElasticRerankerModel) {
            ElasticRerankerModel elasticRerankerModel = (ElasticRerankerModel)model;
            this.putBuiltInModel(elasticRerankerModel.getServiceSettings().modelId(), listener);
        } else if (model instanceof CustomElandModel) {
            logger.info("Custom eland model detected, model must have been already loaded into the cluster with eland.");
            listener.onResponse((Object)Boolean.TRUE);
        } else {
            listener.onFailure((Exception)new IllegalArgumentException("Can not download model automatically for [" + model.getConfigurations().getInferenceEntityId() + "] you may need to download it through the trained models API or with eland."));
            return;
        }
    }

    protected void putBuiltInModel(String modelId, ActionListener<Boolean> listener) {
        TrainedModelInput input = new TrainedModelInput(List.of("text_field"));
        TrainedModelConfig config = TrainedModelConfig.builder().setInput(input).setModelId(modelId).validate(true).build();
        PutTrainedModelAction.Request putRequest = new PutTrainedModelAction.Request(config, false, true);
        ClientHelper.executeAsyncWithOrigin((Client)this.client, (String)"inference", (ActionType)PutTrainedModelAction.INSTANCE, (ActionRequest)putRequest, (ActionListener)ActionListener.wrap(response -> listener.onResponse((Object)Boolean.TRUE), e -> {
            ElasticsearchStatusException esException;
            if (e instanceof ElasticsearchStatusException && (esException = (ElasticsearchStatusException)e).getMessage().contains("the model id is the same as the deployment id of a current model deployment")) {
                listener.onResponse((Object)Boolean.TRUE);
            } else {
                listener.onFailure(e);
            }
        }));
    }

    protected void isBuiltinModelPut(Model model, ActionListener<Boolean> listener) {
        ActionListener getModelsResponseListener = ActionListener.wrap(response -> {
            if (response.getResources().count() < 1L) {
                listener.onResponse((Object)Boolean.FALSE);
            } else {
                listener.onResponse((Object)Boolean.TRUE);
            }
        }, exception -> {
            if (exception instanceof ResourceNotFoundException) {
                listener.onResponse((Object)Boolean.FALSE);
            } else {
                listener.onFailure(exception);
            }
        });
        if (!(model instanceof ElasticsearchInternalModel)) {
            listener.onFailure((Exception)BaseElasticsearchInternalService.notElasticsearchModelException(model));
        } else {
            ServiceSettings serviceSettings = model.getServiceSettings();
            if (serviceSettings instanceof ElasticsearchInternalServiceSettings) {
                ElasticsearchInternalServiceSettings internalServiceSettings = (ElasticsearchInternalServiceSettings)serviceSettings;
                String modelId = internalServiceSettings.modelId();
                GetTrainedModelsAction.Request getRequest = new GetTrainedModelsAction.Request(modelId);
                ClientHelper.executeAsyncWithOrigin((Client)this.client, (String)"inference", (ActionType)GetTrainedModelsAction.INSTANCE, (ActionRequest)getRequest, (ActionListener)getModelsResponseListener);
            } else {
                listener.onFailure((Exception)new IllegalStateException("Can not check the download status of the model used by [" + model.getConfigurations().getInferenceEntityId() + "] as the model_id cannot be found."));
            }
        }
    }

    public void close() throws IOException {
    }

    public static String selectDefaultModelVariantBasedOnClusterArchitecture(PreferredModelVariant preferredModelVariant, String linuxX86OptimizedModel, String platformAgnosticModel) {
        if (PreferredModelVariant.LINUX_X86_OPTIMIZED.equals((Object)preferredModelVariant)) {
            return linuxX86OptimizedModel;
        }
        return platformAgnosticModel;
    }

    private void preferredVariantFromPlatformArchitecture(ActionListener<PreferredModelVariant> preferredVariantListener) {
        MlPlatformArchitecturesUtil.getMlNodesArchitecturesSet((ActionListener)preferredVariantListener.delegateFailureAndWrap((delegate, architectures) -> {
            if (architectures.isEmpty() && this.isClusterInElasticCloud()) {
                delegate.onResponse((Object)PreferredModelVariant.LINUX_X86_OPTIMIZED);
            } else {
                boolean homogenous;
                boolean bl = homogenous = architectures.size() == 1;
                if (homogenous && ((String)architectures.iterator().next()).equals("linux-x86_64")) {
                    delegate.onResponse((Object)PreferredModelVariant.LINUX_X86_OPTIMIZED);
                } else {
                    delegate.onResponse((Object)PreferredModelVariant.PLATFORM_AGNOSTIC);
                }
            }
        }), (Client)this.client, (ExecutorService)this.inferenceExecutor);
    }

    boolean isClusterInElasticCloud() {
        Integer maxMlLazyNodes = (Integer)this.clusterService.getClusterSettings().get(MachineLearningField.MAX_LAZY_ML_NODES);
        return maxMlLazyNodes > 0;
    }

    public static InferModelAction.Request buildInferenceRequest(String id, InferenceConfigUpdate update, List<String> inputs, InputType inputType, TimeValue timeout) {
        InferModelAction.Request request = InferModelAction.Request.forTextInput((String)id, (InferenceConfigUpdate)update, inputs, (boolean)true, (TimeValue)timeout);
        request.setPrefixType(InputType.SEARCH == inputType ? TrainedModelPrefixStrings.PrefixType.SEARCH : TrainedModelPrefixStrings.PrefixType.INGEST);
        request.setHighPriority(InputType.SEARCH == inputType);
        request.setChunked(false);
        return request;
    }

    abstract boolean isDefaultId(String var1);

    protected void maybeStartDeployment(ElasticsearchInternalModel model, Exception e, InferModelAction.Request request, ActionListener<InferModelAction.Response> listener) {
        if (this.isDefaultId(model.getInferenceEntityId()) && ExceptionsHelper.unwrapCause((Throwable)e) instanceof ResourceNotFoundException) {
            this.start(model, request.getInferenceTimeout(), (ActionListener<Boolean>)listener.delegateFailureAndWrap((l, started) -> this.client.execute((ActionType)InferModelAction.INSTANCE, (ActionRequest)request, listener)));
        } else {
            listener.onFailure(e);
        }
    }

    public static enum PreferredModelVariant {
        LINUX_X86_OPTIMIZED,
        PLATFORM_AGNOSTIC;

    }
}

