/*
 * Decompiled with CFR 0.152.
 */
package org.elasticsearch.xpack.inference.services.elser;

import java.io.IOException;
import java.util.List;
import java.util.Map;
import java.util.Set;
import org.elasticsearch.ElasticsearchStatusException;
import org.elasticsearch.TransportVersion;
import org.elasticsearch.TransportVersions;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.action.ActionRequest;
import org.elasticsearch.action.ActionType;
import org.elasticsearch.client.internal.OriginSettingClient;
import org.elasticsearch.core.TimeValue;
import org.elasticsearch.inference.InferenceService;
import org.elasticsearch.inference.InferenceServiceExtension;
import org.elasticsearch.inference.InferenceServiceResults;
import org.elasticsearch.inference.Model;
import org.elasticsearch.inference.TaskType;
import org.elasticsearch.rest.RestStatus;
import org.elasticsearch.xpack.core.inference.results.SparseEmbeddingResults;
import org.elasticsearch.xpack.core.ml.action.InferTrainedModelDeploymentAction;
import org.elasticsearch.xpack.core.ml.action.StartTrainedModelDeploymentAction;
import org.elasticsearch.xpack.core.ml.inference.assignment.AllocationStatus;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfigUpdate;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TextExpansionConfigUpdate;
import org.elasticsearch.xpack.inference.services.ServiceUtils;
import org.elasticsearch.xpack.inference.services.elser.ElserMlNodeModel;
import org.elasticsearch.xpack.inference.services.elser.ElserMlNodeServiceSettings;
import org.elasticsearch.xpack.inference.services.elser.ElserMlNodeTaskSettings;

public class ElserMlNodeService
implements InferenceService {
    public static final String NAME = "elser";
    static final String ELSER_V1_MODEL = ".elser_model_1";
    static final String ELSER_V2_MODEL = ".elser_model_2";
    static final String ELSER_V2_MODEL_LINUX_X86 = ".elser_model_2_linux-x86_64";
    public static Set<String> VALID_ELSER_MODELS = Set.of(".elser_model_1", ".elser_model_2", ".elser_model_2_linux-x86_64");
    private final OriginSettingClient client;

    public ElserMlNodeService(InferenceServiceExtension.InferenceServiceFactoryContext context) {
        this.client = new OriginSettingClient(context.client(), "inference");
    }

    public boolean isInClusterService() {
        return true;
    }

    public ElserMlNodeModel parseRequestConfig(String modelId, TaskType taskType, Map<String, Object> config, Set<String> modelArchitectures) {
        Map<String, Object> serviceSettingsMap = ServiceUtils.removeFromMapOrThrowIfNull(config, "service_settings");
        ElserMlNodeServiceSettings.Builder serviceSettingsBuilder = ElserMlNodeServiceSettings.fromMap(serviceSettingsMap);
        if (serviceSettingsBuilder.getModelVariant() == null) {
            boolean homogenous;
            boolean bl = homogenous = modelArchitectures.size() == 1;
            if (homogenous && modelArchitectures.iterator().next().equals("linux-x86_64")) {
                serviceSettingsBuilder.setModelVariant(ELSER_V2_MODEL_LINUX_X86);
            } else {
                serviceSettingsBuilder.setModelVariant(ELSER_V2_MODEL);
            }
        }
        Map<String, Object> taskSettingsMap = config.containsKey("task_settings") ? ServiceUtils.removeFromMapOrThrowIfNull(config, "task_settings") : Map.of();
        ElserMlNodeTaskSettings taskSettings = ElserMlNodeService.taskSettingsFromMap(taskType, taskSettingsMap);
        ServiceUtils.throwIfNotEmptyMap(config, NAME);
        ServiceUtils.throwIfNotEmptyMap(serviceSettingsMap, NAME);
        ServiceUtils.throwIfNotEmptyMap(taskSettingsMap, NAME);
        return new ElserMlNodeModel(modelId, taskType, NAME, serviceSettingsBuilder.build(), taskSettings);
    }

    public ElserMlNodeModel parsePersistedConfigWithSecrets(String modelId, TaskType taskType, Map<String, Object> config, Map<String, Object> secrets) {
        return this.parsePersistedConfig(modelId, taskType, (Map)config);
    }

    public ElserMlNodeModel parsePersistedConfig(String modelId, TaskType taskType, Map<String, Object> config) {
        Map<String, Object> serviceSettingsMap = ServiceUtils.removeFromMapOrThrowIfNull(config, "service_settings");
        ElserMlNodeServiceSettings.Builder serviceSettingsBuilder = ElserMlNodeServiceSettings.fromMap(serviceSettingsMap);
        Map<String, Object> taskSettingsMap = config.containsKey("task_settings") ? ServiceUtils.removeFromMapOrThrowIfNull(config, "task_settings") : Map.of();
        ElserMlNodeTaskSettings taskSettings = ElserMlNodeService.taskSettingsFromMap(taskType, taskSettingsMap);
        return new ElserMlNodeModel(modelId, taskType, NAME, serviceSettingsBuilder.build(), taskSettings);
    }

    public void start(Model model, ActionListener<Boolean> listener) {
        if (!(model instanceof ElserMlNodeModel)) {
            listener.onFailure((Exception)new IllegalStateException("Error starting model, [" + model.getConfigurations().getModelId() + "] is not an elser model"));
            return;
        }
        if (model.getConfigurations().getTaskType() != TaskType.SPARSE_EMBEDDING) {
            listener.onFailure((Exception)new IllegalStateException(TaskType.unsupportedTaskTypeErrorMsg((TaskType)model.getConfigurations().getTaskType(), (String)NAME)));
            return;
        }
        ElserMlNodeModel elserModel = (ElserMlNodeModel)model;
        ElserMlNodeServiceSettings serviceSettings = elserModel.getServiceSettings();
        StartTrainedModelDeploymentAction.Request startRequest = new StartTrainedModelDeploymentAction.Request(serviceSettings.getModelVariant(), model.getConfigurations().getModelId());
        startRequest.setNumberOfAllocations(serviceSettings.getNumAllocations());
        startRequest.setThreadsPerAllocation(serviceSettings.getNumThreads());
        startRequest.setWaitForState(AllocationStatus.State.STARTED);
        this.client.execute((ActionType)StartTrainedModelDeploymentAction.INSTANCE, (ActionRequest)startRequest, listener.delegateFailureAndWrap((l, r) -> l.onResponse((Object)Boolean.TRUE)));
    }

    public void infer(Model model, List<String> input, Map<String, Object> taskSettings, ActionListener<InferenceServiceResults> listener) {
        if (!TaskType.SPARSE_EMBEDDING.isAnyOrSame(model.getConfigurations().getTaskType())) {
            listener.onFailure((Exception)new ElasticsearchStatusException(TaskType.unsupportedTaskTypeErrorMsg((TaskType)model.getConfigurations().getTaskType(), (String)NAME), RestStatus.BAD_REQUEST, new Object[0]));
            return;
        }
        InferTrainedModelDeploymentAction.Request request = InferTrainedModelDeploymentAction.Request.forTextInput((String)model.getConfigurations().getModelId(), (InferenceConfigUpdate)TextExpansionConfigUpdate.EMPTY_UPDATE, input, (TimeValue)TimeValue.timeValueSeconds((long)10L));
        this.client.execute((ActionType)InferTrainedModelDeploymentAction.INSTANCE, (ActionRequest)request, listener.delegateFailureAndWrap((l, inferenceResult) -> l.onResponse((Object)SparseEmbeddingResults.of((List)inferenceResult.getResults()))));
    }

    private static ElserMlNodeTaskSettings taskSettingsFromMap(TaskType taskType, Map<String, Object> config) {
        if (taskType != TaskType.SPARSE_EMBEDDING) {
            throw new ElasticsearchStatusException(TaskType.unsupportedTaskTypeErrorMsg((TaskType)taskType, (String)NAME), RestStatus.BAD_REQUEST, new Object[0]);
        }
        return ElserMlNodeTaskSettings.DEFAULT;
    }

    public String name() {
        return NAME;
    }

    public void close() throws IOException {
    }

    public TransportVersion getMinimalSupportedVersion() {
        return TransportVersions.ELSER_SERVICE_MODEL_VERSION_ADDED;
    }
}

