/*
 * Decompiled with CFR 0.152.
 */
package org.elasticsearch.xpack.inference.services.elser;

import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.Set;
import org.elasticsearch.ElasticsearchStatusException;
import org.elasticsearch.ResourceNotFoundException;
import org.elasticsearch.TransportVersion;
import org.elasticsearch.TransportVersions;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.action.ActionRequest;
import org.elasticsearch.action.ActionType;
import org.elasticsearch.client.internal.Client;
import org.elasticsearch.client.internal.OriginSettingClient;
import org.elasticsearch.core.Nullable;
import org.elasticsearch.core.TimeValue;
import org.elasticsearch.inference.ChunkedInferenceServiceResults;
import org.elasticsearch.inference.ChunkingOptions;
import org.elasticsearch.inference.InferenceResults;
import org.elasticsearch.inference.InferenceService;
import org.elasticsearch.inference.InferenceServiceExtension;
import org.elasticsearch.inference.InferenceServiceResults;
import org.elasticsearch.inference.InputType;
import org.elasticsearch.inference.Model;
import org.elasticsearch.inference.TaskType;
import org.elasticsearch.rest.RestStatus;
import org.elasticsearch.xpack.core.ClientHelper;
import org.elasticsearch.xpack.core.inference.results.ErrorChunkedInferenceResults;
import org.elasticsearch.xpack.core.inference.results.InferenceChunkedSparseEmbeddingResults;
import org.elasticsearch.xpack.core.inference.results.SparseEmbeddingResults;
import org.elasticsearch.xpack.core.ml.action.CreateTrainedModelAssignmentAction;
import org.elasticsearch.xpack.core.ml.action.GetTrainedModelsAction;
import org.elasticsearch.xpack.core.ml.action.InferModelAction;
import org.elasticsearch.xpack.core.ml.action.PutTrainedModelAction;
import org.elasticsearch.xpack.core.ml.action.StartTrainedModelDeploymentAction;
import org.elasticsearch.xpack.core.ml.action.StopTrainedModelDeploymentAction;
import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig;
import org.elasticsearch.xpack.core.ml.inference.TrainedModelInput;
import org.elasticsearch.xpack.core.ml.inference.assignment.AllocationStatus;
import org.elasticsearch.xpack.core.ml.inference.results.ErrorInferenceResults;
import org.elasticsearch.xpack.core.ml.inference.results.MlChunkedTextExpansionResults;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfigUpdate;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TextExpansionConfigUpdate;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TokenizationConfigUpdate;
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
import org.elasticsearch.xpack.inference.services.ServiceUtils;
import org.elasticsearch.xpack.inference.services.elasticsearch.ElasticsearchInternalService;
import org.elasticsearch.xpack.inference.services.elser.ElserInternalModel;
import org.elasticsearch.xpack.inference.services.elser.ElserInternalServiceSettings;
import org.elasticsearch.xpack.inference.services.elser.ElserMlNodeTaskSettings;
import org.elasticsearch.xpack.inference.services.settings.InternalServiceSettings;

public class ElserInternalService
implements InferenceService {
    public static final String NAME = "elser";
    static final String ELSER_V1_MODEL = ".elser_model_1";
    static final String ELSER_V2_MODEL = ".elser_model_2";
    static final String ELSER_V2_MODEL_LINUX_X86 = ".elser_model_2_linux-x86_64";
    public static Set<String> VALID_ELSER_MODEL_IDS = Set.of(".elser_model_1", ".elser_model_2", ".elser_model_2_linux-x86_64");
    private static final String OLD_MODEL_ID_FIELD_NAME = "model_version";
    private final OriginSettingClient client;

    public ElserInternalService(InferenceServiceExtension.InferenceServiceFactoryContext context) {
        this.client = new OriginSettingClient(context.client(), "inference");
    }

    public boolean isInClusterService() {
        return true;
    }

    public void parseRequestConfig(String inferenceEntityId, TaskType taskType, Map<String, Object> config, Set<String> modelArchitectures, ActionListener<Model> parsedModelListener) {
        try {
            Map<String, Object> serviceSettingsMap = ServiceUtils.removeFromMapOrThrowIfNull(config, "service_settings");
            InternalServiceSettings.Builder serviceSettingsBuilder = ElserInternalServiceSettings.fromMap(serviceSettingsMap);
            if (serviceSettingsBuilder.getModelId() == null) {
                serviceSettingsBuilder.setModelId(ElserInternalService.selectDefaultModelVersionBasedOnClusterArchitecture(modelArchitectures));
            }
            Map<String, Object> taskSettingsMap = config.containsKey("task_settings") ? ServiceUtils.removeFromMapOrThrowIfNull(config, "task_settings") : Map.of();
            ElserMlNodeTaskSettings taskSettings = ElserInternalService.taskSettingsFromMap(taskType, taskSettingsMap);
            ServiceUtils.throwIfNotEmptyMap(config, NAME);
            ServiceUtils.throwIfNotEmptyMap(serviceSettingsMap, NAME);
            ServiceUtils.throwIfNotEmptyMap(taskSettingsMap, NAME);
            parsedModelListener.onResponse((Object)new ElserInternalModel(inferenceEntityId, taskType, NAME, (ElserInternalServiceSettings)serviceSettingsBuilder.build(), taskSettings));
        }
        catch (Exception e) {
            parsedModelListener.onFailure(e);
        }
    }

    private static String selectDefaultModelVersionBasedOnClusterArchitecture(Set<String> modelArchitectures) {
        boolean homogenous;
        boolean bl = homogenous = modelArchitectures.size() == 1;
        if (homogenous && modelArchitectures.iterator().next().equals("linux-x86_64")) {
            return ELSER_V2_MODEL_LINUX_X86;
        }
        return ELSER_V2_MODEL;
    }

    public ElserInternalModel parsePersistedConfigWithSecrets(String inferenceEntityId, TaskType taskType, Map<String, Object> config, Map<String, Object> secrets) {
        return this.parsePersistedConfig(inferenceEntityId, taskType, (Map)config);
    }

    public ElserInternalModel parsePersistedConfig(String inferenceEntityId, TaskType taskType, Map<String, Object> config) {
        Map<String, Object> serviceSettingsMap = ServiceUtils.removeFromMapOrThrowIfNull(config, "service_settings");
        if (serviceSettingsMap.containsKey(OLD_MODEL_ID_FIELD_NAME)) {
            String modelId = ServiceUtils.removeAsType(serviceSettingsMap, OLD_MODEL_ID_FIELD_NAME, String.class);
            serviceSettingsMap.put("model_id", modelId);
        }
        InternalServiceSettings.Builder serviceSettingsBuilder = ElserInternalServiceSettings.fromMap(serviceSettingsMap);
        Map<String, Object> taskSettingsMap = config.containsKey("task_settings") ? ServiceUtils.removeFromMapOrThrowIfNull(config, "task_settings") : Map.of();
        ElserMlNodeTaskSettings taskSettings = ElserInternalService.taskSettingsFromMap(taskType, taskSettingsMap);
        return new ElserInternalModel(inferenceEntityId, taskType, NAME, (ElserInternalServiceSettings)serviceSettingsBuilder.build(), taskSettings);
    }

    public void start(Model model, ActionListener<Boolean> listener) {
        if (!(model instanceof ElserInternalModel)) {
            listener.onFailure((Exception)new IllegalStateException("Error starting model, [" + model.getConfigurations().getInferenceEntityId() + "] is not an ELSER model"));
            return;
        }
        if (model.getConfigurations().getTaskType() != TaskType.SPARSE_EMBEDDING) {
            listener.onFailure((Exception)new IllegalStateException(TaskType.unsupportedTaskTypeErrorMsg((TaskType)model.getConfigurations().getTaskType(), (String)NAME)));
            return;
        }
        this.client.execute((ActionType)StartTrainedModelDeploymentAction.INSTANCE, (ActionRequest)ElserInternalService.startDeploymentRequest(model), ElserInternalService.elserNotDownloadedListener(listener));
    }

    private static StartTrainedModelDeploymentAction.Request startDeploymentRequest(Model model) {
        ElserInternalModel elserModel = (ElserInternalModel)model;
        ElserInternalServiceSettings serviceSettings = elserModel.getServiceSettings();
        StartTrainedModelDeploymentAction.Request startRequest = new StartTrainedModelDeploymentAction.Request(serviceSettings.getModelId(), model.getConfigurations().getInferenceEntityId());
        startRequest.setNumberOfAllocations(serviceSettings.getNumAllocations());
        startRequest.setThreadsPerAllocation(serviceSettings.getNumThreads());
        startRequest.setWaitForState(AllocationStatus.State.STARTED);
        return startRequest;
    }

    private static ActionListener<CreateTrainedModelAssignmentAction.Response> elserNotDownloadedListener(final ActionListener<Boolean> listener) {
        return new ActionListener<CreateTrainedModelAssignmentAction.Response>(){

            public void onResponse(CreateTrainedModelAssignmentAction.Response response) {
                listener.onResponse((Object)Boolean.TRUE);
            }

            public void onFailure(Exception e) {
                if (ExceptionsHelper.unwrapCause((Throwable)e) instanceof ResourceNotFoundException) {
                    listener.onFailure((Exception)new ResourceNotFoundException("Could not start the ELSER service as the ELSER model for this platform cannot be found. ELSER needs to be downloaded before it can be started.", new Object[0]));
                    return;
                }
                listener.onFailure(e);
            }
        };
    }

    public void stop(String inferenceEntityId, ActionListener<Boolean> listener) {
        StopTrainedModelDeploymentAction.Request request = new StopTrainedModelDeploymentAction.Request(inferenceEntityId);
        request.setForce(true);
        this.client.execute((ActionType)StopTrainedModelDeploymentAction.INSTANCE, (ActionRequest)request, listener.delegateFailureAndWrap((delegatedResponseListener, response) -> delegatedResponseListener.onResponse((Object)Boolean.TRUE)));
    }

    public void infer(Model model, @Nullable String query, List<String> inputs, Map<String, Object> taskSettings, InputType inputType, TimeValue timeout, ActionListener<InferenceServiceResults> listener) {
        try {
            this.checkCompatibleTaskType(model.getConfigurations().getTaskType());
        }
        catch (Exception e) {
            listener.onFailure(e);
            return;
        }
        InferModelAction.Request request = ElasticsearchInternalService.buildInferenceRequest(model.getConfigurations().getInferenceEntityId(), (InferenceConfigUpdate)TextExpansionConfigUpdate.EMPTY_UPDATE, inputs, inputType, timeout, false);
        this.client.execute((ActionType)InferModelAction.INSTANCE, (ActionRequest)request, listener.delegateFailureAndWrap((l, inferenceResult) -> l.onResponse((Object)SparseEmbeddingResults.of((List)inferenceResult.getInferenceResults()))));
    }

    public void chunkedInfer(Model model, List<String> input, Map<String, Object> taskSettings, InputType inputType, @Nullable ChunkingOptions chunkingOptions, TimeValue timeout, ActionListener<List<ChunkedInferenceServiceResults>> listener) {
        this.chunkedInfer(model, null, input, taskSettings, inputType, chunkingOptions, timeout, listener);
    }

    public void chunkedInfer(Model model, @Nullable String query, List<String> inputs, Map<String, Object> taskSettings, InputType inputType, @Nullable ChunkingOptions chunkingOptions, TimeValue timeout, ActionListener<List<ChunkedInferenceServiceResults>> listener) {
        try {
            this.checkCompatibleTaskType(model.getConfigurations().getTaskType());
        }
        catch (Exception e) {
            listener.onFailure(e);
            return;
        }
        TokenizationConfigUpdate configUpdate = chunkingOptions != null ? new TokenizationConfigUpdate(chunkingOptions.windowSize(), chunkingOptions.span()) : new TokenizationConfigUpdate(null, null);
        InferModelAction.Request request = ElasticsearchInternalService.buildInferenceRequest(model.getConfigurations().getInferenceEntityId(), (InferenceConfigUpdate)configUpdate, inputs, inputType, timeout, true);
        this.client.execute((ActionType)InferModelAction.INSTANCE, (ActionRequest)request, listener.delegateFailureAndWrap((l, inferenceResult) -> l.onResponse(this.translateChunkedResults(inferenceResult.getInferenceResults()))));
    }

    private void checkCompatibleTaskType(TaskType taskType) {
        if (!TaskType.SPARSE_EMBEDDING.isAnyOrSame(taskType)) {
            throw new ElasticsearchStatusException(TaskType.unsupportedTaskTypeErrorMsg((TaskType)taskType, (String)NAME), RestStatus.BAD_REQUEST, new Object[0]);
        }
    }

    public void putModel(Model model, ActionListener<Boolean> listener) {
        if (!(model instanceof ElserInternalModel)) {
            listener.onFailure((Exception)new IllegalStateException("Error starting model, [" + model.getConfigurations().getInferenceEntityId() + "] is not an ELSER model"));
            return;
        }
        String modelId = ((ElserInternalModel)model).getServiceSettings().getModelId();
        TrainedModelInput input = new TrainedModelInput(List.of("text_field"));
        TrainedModelConfig config = TrainedModelConfig.builder().setInput(input).setModelId(modelId).validate(true).build();
        PutTrainedModelAction.Request putRequest = new PutTrainedModelAction.Request(config, false, true);
        ClientHelper.executeAsyncWithOrigin((Client)this.client, (String)"inference", (ActionType)PutTrainedModelAction.INSTANCE, (ActionRequest)putRequest, (ActionListener)listener.delegateFailure((l, r) -> l.onResponse((Object)Boolean.TRUE)));
    }

    public void isModelDownloaded(Model model, ActionListener<Boolean> listener) {
        ActionListener getModelsResponseListener = listener.delegateFailure((delegate, response) -> {
            if (response.getResources().count() < 1L) {
                delegate.onResponse((Object)Boolean.FALSE);
            } else {
                delegate.onResponse((Object)Boolean.TRUE);
            }
        });
        if (model instanceof ElserInternalModel) {
            ElserInternalModel elserModel = (ElserInternalModel)model;
            String modelId = elserModel.getServiceSettings().getModelId();
            GetTrainedModelsAction.Request getRequest = new GetTrainedModelsAction.Request(modelId);
            ClientHelper.executeAsyncWithOrigin((Client)this.client, (String)"inference", (ActionType)GetTrainedModelsAction.INSTANCE, (ActionRequest)getRequest, (ActionListener)getModelsResponseListener);
        } else {
            listener.onFailure((Exception)new IllegalArgumentException("Can not download model automatically for [" + model.getConfigurations().getInferenceEntityId() + "] you may need to download it through the trained models API or with eland."));
        }
    }

    private static ElserMlNodeTaskSettings taskSettingsFromMap(TaskType taskType, Map<String, Object> config) {
        if (taskType != TaskType.SPARSE_EMBEDDING) {
            throw new ElasticsearchStatusException(TaskType.unsupportedTaskTypeErrorMsg((TaskType)taskType, (String)NAME), RestStatus.BAD_REQUEST, new Object[0]);
        }
        return ElserMlNodeTaskSettings.DEFAULT;
    }

    private List<ChunkedInferenceServiceResults> translateChunkedResults(List<InferenceResults> inferenceResults) {
        ArrayList<ChunkedInferenceServiceResults> translated = new ArrayList<ChunkedInferenceServiceResults>();
        for (InferenceResults inferenceResult : inferenceResults) {
            if (inferenceResult instanceof MlChunkedTextExpansionResults) {
                MlChunkedTextExpansionResults mlChunkedResult = (MlChunkedTextExpansionResults)inferenceResult;
                translated.add((ChunkedInferenceServiceResults)InferenceChunkedSparseEmbeddingResults.ofMlResult((MlChunkedTextExpansionResults)mlChunkedResult));
                continue;
            }
            if (inferenceResult instanceof ErrorInferenceResults) {
                ErrorInferenceResults error = (ErrorInferenceResults)inferenceResult;
                translated.add((ChunkedInferenceServiceResults)new ErrorChunkedInferenceResults(error.getException()));
                continue;
            }
            throw new ElasticsearchStatusException("Expected a chunked inference [{}] received [{}]", RestStatus.INTERNAL_SERVER_ERROR, new Object[]{"chunked_text_expansion_result", inferenceResult.getWriteableName()});
        }
        return translated;
    }

    public String name() {
        return NAME;
    }

    public void close() throws IOException {
    }

    public TransportVersion getMinimalSupportedVersion() {
        return TransportVersions.V_8_12_0;
    }
}

