/*
 * Decompiled with CFR 0.152.
 */
package org.elasticsearch.xpack.inference.services.elasticsearch;

import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.Set;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.elasticsearch.ElasticsearchStatusException;
import org.elasticsearch.TransportVersion;
import org.elasticsearch.TransportVersions;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.action.ActionRequest;
import org.elasticsearch.action.ActionType;
import org.elasticsearch.client.internal.Client;
import org.elasticsearch.client.internal.OriginSettingClient;
import org.elasticsearch.core.Nullable;
import org.elasticsearch.core.TimeValue;
import org.elasticsearch.inference.ChunkedInferenceServiceResults;
import org.elasticsearch.inference.ChunkingOptions;
import org.elasticsearch.inference.InferenceResults;
import org.elasticsearch.inference.InferenceService;
import org.elasticsearch.inference.InferenceServiceExtension;
import org.elasticsearch.inference.InferenceServiceResults;
import org.elasticsearch.inference.InputType;
import org.elasticsearch.inference.Model;
import org.elasticsearch.inference.ServiceSettings;
import org.elasticsearch.inference.TaskType;
import org.elasticsearch.rest.RestStatus;
import org.elasticsearch.xpack.core.ClientHelper;
import org.elasticsearch.xpack.core.inference.results.ErrorChunkedInferenceResults;
import org.elasticsearch.xpack.core.inference.results.ResultUtils;
import org.elasticsearch.xpack.core.inference.results.TextEmbeddingResults;
import org.elasticsearch.xpack.core.ml.action.CreateTrainedModelAssignmentAction;
import org.elasticsearch.xpack.core.ml.action.GetTrainedModelsAction;
import org.elasticsearch.xpack.core.ml.action.InferTrainedModelDeploymentAction;
import org.elasticsearch.xpack.core.ml.action.PutTrainedModelAction;
import org.elasticsearch.xpack.core.ml.action.StartTrainedModelDeploymentAction;
import org.elasticsearch.xpack.core.ml.action.StopTrainedModelDeploymentAction;
import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig;
import org.elasticsearch.xpack.core.ml.inference.TrainedModelInput;
import org.elasticsearch.xpack.core.ml.inference.results.ChunkedTextEmbeddingResults;
import org.elasticsearch.xpack.core.ml.inference.results.ErrorInferenceResults;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfigUpdate;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TextEmbeddingConfigUpdate;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TokenizationConfigUpdate;
import org.elasticsearch.xpack.inference.services.ServiceUtils;
import org.elasticsearch.xpack.inference.services.elasticsearch.CustomElandInternalServiceSettings;
import org.elasticsearch.xpack.inference.services.elasticsearch.CustomElandModel;
import org.elasticsearch.xpack.inference.services.elasticsearch.ElasticsearchModel;
import org.elasticsearch.xpack.inference.services.elasticsearch.MultilingualE5SmallInternalServiceSettings;
import org.elasticsearch.xpack.inference.services.elasticsearch.MultilingualE5SmallModel;
import org.elasticsearch.xpack.inference.services.settings.InternalServiceSettings;

public class ElasticsearchInternalService
implements InferenceService {
    public static final String NAME = "elasticsearch";
    static final String MULTILINGUAL_E5_SMALL_MODEL_ID = ".multilingual-e5-small";
    static final String MULTILINGUAL_E5_SMALL_MODEL_ID_LINUX_X86 = ".multilingual-e5-small_linux-x86_64";
    public static final Set<String> MULTILINGUAL_E5_SMALL_VALID_IDS = Set.of(".multilingual-e5-small", ".multilingual-e5-small_linux-x86_64");
    private final OriginSettingClient client;
    private static final Logger logger = LogManager.getLogger(ElasticsearchInternalService.class);

    public ElasticsearchInternalService(InferenceServiceExtension.InferenceServiceFactoryContext context) {
        this.client = new OriginSettingClient(context.client(), "inference");
    }

    public void parseRequestConfig(String inferenceEntityId, TaskType taskType, Map<String, Object> config, Set<String> platformArchitectures, ActionListener<Model> modelListener) {
        try {
            Map<String, Object> serviceSettingsMap = ServiceUtils.removeFromMapOrThrowIfNull(config, "service_settings");
            String modelId = (String)serviceSettingsMap.get("model_id");
            if (modelId == null) {
                throw new IllegalArgumentException("Error parsing request config, model id is missing");
            }
            if (MULTILINGUAL_E5_SMALL_VALID_IDS.contains(modelId)) {
                this.e5Case(inferenceEntityId, taskType, config, platformArchitectures, serviceSettingsMap, modelListener);
            } else {
                ServiceUtils.throwIfNotEmptyMap(config, this.name());
                this.customElandCase(inferenceEntityId, taskType, serviceSettingsMap, modelListener);
            }
        }
        catch (Exception e) {
            modelListener.onFailure(e);
        }
    }

    private void customElandCase(String inferenceEntityId, TaskType taskType, Map<String, Object> serviceSettingsMap, ActionListener<Model> modelListener) {
        String modelId = (String)serviceSettingsMap.get("model_id");
        GetTrainedModelsAction.Request request = new GetTrainedModelsAction.Request(modelId);
        ActionListener getModelsListener = modelListener.delegateFailureAndWrap((delegate, response) -> {
            if (response.getResources().count() < 1L) {
                throw new IllegalArgumentException("Error parsing request config, model id does not match any models available on this platform. Was [" + modelId + "]. You may need to load it into the cluster using eland.");
            }
            CustomElandInternalServiceSettings customElandInternalServiceSettings = (CustomElandInternalServiceSettings)CustomElandInternalServiceSettings.fromMap(serviceSettingsMap).build();
            ServiceUtils.throwIfNotEmptyMap(serviceSettingsMap, this.name());
            delegate.onResponse((Object)new CustomElandModel(inferenceEntityId, taskType, this.name(), customElandInternalServiceSettings));
        });
        this.client.execute((ActionType)GetTrainedModelsAction.INSTANCE, (ActionRequest)request, getModelsListener);
    }

    private void e5Case(String inferenceEntityId, TaskType taskType, Map<String, Object> config, Set<String> platformArchitectures, Map<String, Object> serviceSettingsMap, ActionListener<Model> modelListener) {
        InternalServiceSettings.Builder e5ServiceSettings = MultilingualE5SmallInternalServiceSettings.fromMap(serviceSettingsMap);
        if (e5ServiceSettings.getModelId() == null) {
            e5ServiceSettings.setModelId(ElasticsearchInternalService.selectDefaultModelVariantBasedOnClusterArchitecture(platformArchitectures));
        }
        if (ElasticsearchInternalService.modelVariantDoesNotMatchArchitecturesAndIsNotPlatformAgnostic(platformArchitectures, e5ServiceSettings)) {
            throw new IllegalArgumentException("Error parsing request config, model id does not match any models available on this platform. Was [" + e5ServiceSettings.getModelId() + "]");
        }
        ServiceUtils.throwIfNotEmptyMap(config, this.name());
        ServiceUtils.throwIfNotEmptyMap(serviceSettingsMap, this.name());
        modelListener.onResponse((Object)new MultilingualE5SmallModel(inferenceEntityId, taskType, NAME, (MultilingualE5SmallInternalServiceSettings)e5ServiceSettings.build()));
    }

    private static boolean modelVariantDoesNotMatchArchitecturesAndIsNotPlatformAgnostic(Set<String> platformArchitectures, InternalServiceSettings.Builder e5ServiceSettings) {
        return !e5ServiceSettings.getModelId().equals(ElasticsearchInternalService.selectDefaultModelVariantBasedOnClusterArchitecture(platformArchitectures)) && !e5ServiceSettings.getModelId().equals(MULTILINGUAL_E5_SMALL_MODEL_ID);
    }

    public ElasticsearchModel parsePersistedConfigWithSecrets(String inferenceEntityId, TaskType taskType, Map<String, Object> config, Map<String, Object> secrets) {
        return this.parsePersistedConfig(inferenceEntityId, taskType, (Map)config);
    }

    public ElasticsearchModel parsePersistedConfig(String inferenceEntityId, TaskType taskType, Map<String, Object> config) {
        Map<String, Object> serviceSettingsMap = ServiceUtils.removeFromMapOrThrowIfNull(config, "service_settings");
        String modelId = (String)serviceSettingsMap.get("model_id");
        if (modelId == null) {
            throw new IllegalArgumentException("Error parsing request config, model id is missing");
        }
        if (MULTILINGUAL_E5_SMALL_VALID_IDS.contains(modelId)) {
            return new MultilingualE5SmallModel(inferenceEntityId, taskType, NAME, (MultilingualE5SmallInternalServiceSettings)MultilingualE5SmallInternalServiceSettings.fromMap(serviceSettingsMap).build());
        }
        return new CustomElandModel(inferenceEntityId, taskType, this.name(), (CustomElandInternalServiceSettings)CustomElandInternalServiceSettings.fromMap(serviceSettingsMap).build());
    }

    public void infer(Model model, @Nullable String query, List<String> input, Map<String, Object> taskSettings, InputType inputType, TimeValue timeout, ActionListener<InferenceServiceResults> listener) {
        try {
            this.checkCompatibleTaskType(model.getConfigurations().getTaskType());
        }
        catch (Exception e) {
            listener.onFailure(e);
            return;
        }
        InferTrainedModelDeploymentAction.Request request = InferTrainedModelDeploymentAction.Request.forTextInput((String)model.getConfigurations().getInferenceEntityId(), (InferenceConfigUpdate)TextEmbeddingConfigUpdate.EMPTY_INSTANCE, input, (TimeValue)timeout);
        this.client.execute((ActionType)InferTrainedModelDeploymentAction.INSTANCE, (ActionRequest)request, listener.delegateFailureAndWrap((l, inferenceResult) -> l.onResponse((Object)TextEmbeddingResults.of((List)inferenceResult.getResults()))));
    }

    public void chunkedInfer(Model model, List<String> input, Map<String, Object> taskSettings, InputType inputType, ChunkingOptions chunkingOptions, TimeValue timeout, ActionListener<List<ChunkedInferenceServiceResults>> listener) {
        this.chunkedInfer(model, null, input, taskSettings, inputType, chunkingOptions, timeout, listener);
    }

    public void chunkedInfer(Model model, @Nullable String query, List<String> input, Map<String, Object> taskSettings, InputType inputType, ChunkingOptions chunkingOptions, TimeValue timeout, ActionListener<List<ChunkedInferenceServiceResults>> listener) {
        try {
            this.checkCompatibleTaskType(model.getConfigurations().getTaskType());
        }
        catch (Exception e) {
            listener.onFailure(e);
            return;
        }
        TokenizationConfigUpdate configUpdate = chunkingOptions != null ? new TokenizationConfigUpdate(chunkingOptions.windowSize(), chunkingOptions.span()) : new TokenizationConfigUpdate(null, null);
        InferTrainedModelDeploymentAction.Request request = InferTrainedModelDeploymentAction.Request.forTextInput((String)model.getConfigurations().getInferenceEntityId(), (InferenceConfigUpdate)configUpdate, input, (TimeValue)timeout);
        request.setChunkResults(true);
        this.client.execute((ActionType)InferTrainedModelDeploymentAction.INSTANCE, (ActionRequest)request, listener.delegateFailureAndWrap((l, inferenceResult) -> l.onResponse(ElasticsearchInternalService.translateToChunkedResults(inferenceResult.getResults()))));
    }

    private static List<ChunkedInferenceServiceResults> translateToChunkedResults(List<InferenceResults> inferenceResults) {
        ArrayList<ChunkedInferenceServiceResults> translated = new ArrayList<ChunkedInferenceServiceResults>();
        for (InferenceResults inferenceResult : inferenceResults) {
            translated.add(ElasticsearchInternalService.translateToChunkedResult(inferenceResult));
        }
        return translated;
    }

    private static ChunkedInferenceServiceResults translateToChunkedResult(InferenceResults inferenceResult) {
        if (inferenceResult instanceof ChunkedTextEmbeddingResults) {
            ChunkedTextEmbeddingResults mlChunkedResult = (ChunkedTextEmbeddingResults)inferenceResult;
            return org.elasticsearch.xpack.core.inference.results.ChunkedTextEmbeddingResults.ofMlResult((ChunkedTextEmbeddingResults)mlChunkedResult);
        }
        if (inferenceResult instanceof ErrorInferenceResults) {
            ErrorInferenceResults error = (ErrorInferenceResults)inferenceResult;
            return new ErrorChunkedInferenceResults(error.getException());
        }
        throw ResultUtils.createInvalidChunkedResultException((String)inferenceResult.getWriteableName());
    }

    public void start(Model model, ActionListener<Boolean> listener) {
        if (!(model instanceof ElasticsearchModel)) {
            listener.onFailure((Exception)ElasticsearchInternalService.notTextEmbeddingModelException(model));
            return;
        }
        if (model.getConfigurations().getTaskType() != TaskType.TEXT_EMBEDDING) {
            listener.onFailure((Exception)new IllegalStateException(TaskType.unsupportedTaskTypeErrorMsg((TaskType)model.getConfigurations().getTaskType(), (String)NAME)));
            return;
        }
        StartTrainedModelDeploymentAction.Request startRequest = ((ElasticsearchModel)model).getStartTrainedModelDeploymentActionRequest();
        ActionListener<CreateTrainedModelAssignmentAction.Response> responseListener = ((ElasticsearchModel)model).getCreateTrainedModelAssignmentActionListener(model, listener);
        this.client.execute((ActionType)StartTrainedModelDeploymentAction.INSTANCE, (ActionRequest)startRequest, responseListener);
    }

    public void stop(String inferenceEntityId, ActionListener<Boolean> listener) {
        this.client.execute((ActionType)StopTrainedModelDeploymentAction.INSTANCE, (ActionRequest)new StopTrainedModelDeploymentAction.Request(inferenceEntityId), listener.delegateFailureAndWrap((delegatedResponseListener, response) -> delegatedResponseListener.onResponse((Object)Boolean.TRUE)));
    }

    public void putModel(Model model, ActionListener<Boolean> listener) {
        if (!(model instanceof ElasticsearchModel)) {
            listener.onFailure((Exception)ElasticsearchInternalService.notTextEmbeddingModelException(model));
            return;
        }
        if (model instanceof MultilingualE5SmallModel) {
            MultilingualE5SmallModel e5Model = (MultilingualE5SmallModel)model;
            String modelId = e5Model.getServiceSettings().getModelId();
            List fieldNames = List.of();
            TrainedModelInput input = new TrainedModelInput(fieldNames);
            TrainedModelConfig config = TrainedModelConfig.builder().setInput(input).setModelId(modelId).build();
            PutTrainedModelAction.Request putRequest = new PutTrainedModelAction.Request(config, false, true);
            ClientHelper.executeAsyncWithOrigin((Client)this.client, (String)"inference", (ActionType)PutTrainedModelAction.INSTANCE, (ActionRequest)putRequest, (ActionListener)ActionListener.wrap(response -> listener.onResponse((Object)Boolean.TRUE), e -> {
                ElasticsearchStatusException esException;
                if (e instanceof ElasticsearchStatusException && (esException = (ElasticsearchStatusException)e).getMessage().contains("the model id is the same as the deployment id of a current model deployment")) {
                    listener.onResponse((Object)Boolean.TRUE);
                } else {
                    listener.onFailure(e);
                }
            }));
        } else if (model instanceof CustomElandModel) {
            CustomElandModel elandModel = (CustomElandModel)model;
            logger.info("Custom eland model detected, model must have been already loaded into the cluster with eland.");
            listener.onResponse((Object)Boolean.TRUE);
        } else {
            listener.onFailure((Exception)new IllegalArgumentException("Can not download model automatically for [" + model.getConfigurations().getInferenceEntityId() + "] you may need to download it through the trained models API or with eland."));
            return;
        }
    }

    public void isModelDownloaded(Model model, ActionListener<Boolean> listener) {
        ActionListener getModelsResponseListener = listener.delegateFailure((delegate, response) -> {
            if (response.getResources().count() < 1L) {
                delegate.onResponse((Object)Boolean.FALSE);
            } else {
                delegate.onResponse((Object)Boolean.TRUE);
            }
        });
        if (!(model instanceof ElasticsearchModel)) {
            listener.onFailure((Exception)ElasticsearchInternalService.notTextEmbeddingModelException(model));
        } else {
            ServiceSettings serviceSettings = model.getServiceSettings();
            if (serviceSettings instanceof InternalServiceSettings) {
                InternalServiceSettings internalServiceSettings = (InternalServiceSettings)serviceSettings;
                String modelId = internalServiceSettings.getModelId();
                GetTrainedModelsAction.Request getRequest = new GetTrainedModelsAction.Request(modelId);
                ClientHelper.executeAsyncWithOrigin((Client)this.client, (String)"inference", (ActionType)GetTrainedModelsAction.INSTANCE, (ActionRequest)getRequest, (ActionListener)getModelsResponseListener);
            } else {
                listener.onFailure((Exception)new IllegalArgumentException("Unable to determine supported model for [" + model.getConfigurations().getInferenceEntityId() + "] please verify the request and submit a bug report if necessary."));
            }
        }
    }

    private static IllegalStateException notTextEmbeddingModelException(Model model) {
        return new IllegalStateException("Error starting model, [" + model.getConfigurations().getInferenceEntityId() + "] is not a text embedding model");
    }

    private void checkCompatibleTaskType(TaskType taskType) {
        if (!TaskType.TEXT_EMBEDDING.isAnyOrSame(taskType)) {
            throw new ElasticsearchStatusException(TaskType.unsupportedTaskTypeErrorMsg((TaskType)taskType, (String)NAME), RestStatus.BAD_REQUEST, new Object[0]);
        }
    }

    public boolean isInClusterService() {
        return true;
    }

    public TransportVersion getMinimalSupportedVersion() {
        return TransportVersions.ML_INFERENCE_L2_NORM_SIMILARITY_ADDED;
    }

    public void close() throws IOException {
    }

    public String name() {
        return NAME;
    }

    private static String selectDefaultModelVariantBasedOnClusterArchitecture(Set<String> modelArchitectures) {
        boolean homogenous;
        boolean bl = homogenous = modelArchitectures.size() == 1;
        if (homogenous && modelArchitectures.iterator().next().equals("linux-x86_64")) {
            return MULTILINGUAL_E5_SMALL_MODEL_ID_LINUX_X86;
        }
        return MULTILINGUAL_E5_SMALL_MODEL_ID;
    }
}

