/*
 * Decompiled with CFR 0.152.
 */
package org.elasticsearch.xpack.inference.services.mistral;

import java.util.List;
import java.util.Map;
import java.util.Set;
import org.elasticsearch.ElasticsearchStatusException;
import org.elasticsearch.TransportVersion;
import org.elasticsearch.TransportVersions;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.core.Nullable;
import org.elasticsearch.core.TimeValue;
import org.elasticsearch.inference.ChunkedInferenceServiceResults;
import org.elasticsearch.inference.ChunkingOptions;
import org.elasticsearch.inference.InferenceServiceResults;
import org.elasticsearch.inference.InputType;
import org.elasticsearch.inference.Model;
import org.elasticsearch.inference.SimilarityMeasure;
import org.elasticsearch.inference.TaskType;
import org.elasticsearch.rest.RestStatus;
import org.elasticsearch.xpack.inference.chunking.EmbeddingRequestChunker;
import org.elasticsearch.xpack.inference.external.action.ExecutableAction;
import org.elasticsearch.xpack.inference.external.action.mistral.MistralActionCreator;
import org.elasticsearch.xpack.inference.external.http.sender.DocumentsOnlyInput;
import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender;
import org.elasticsearch.xpack.inference.services.ConfigurationParseContext;
import org.elasticsearch.xpack.inference.services.SenderService;
import org.elasticsearch.xpack.inference.services.ServiceComponents;
import org.elasticsearch.xpack.inference.services.ServiceUtils;
import org.elasticsearch.xpack.inference.services.mistral.embeddings.MistralEmbeddingsModel;
import org.elasticsearch.xpack.inference.services.mistral.embeddings.MistralEmbeddingsServiceSettings;

public class MistralService
extends SenderService {
    public static final String NAME = "mistral";

    public MistralService(HttpRequestSender.Factory factory, ServiceComponents serviceComponents) {
        super(factory, serviceComponents);
    }

    @Override
    protected void doInfer(Model model, List<String> input, Map<String, Object> taskSettings, InputType inputType, TimeValue timeout, ActionListener<InferenceServiceResults> listener) {
        MistralActionCreator actionCreator = new MistralActionCreator(this.getSender(), this.getServiceComponents());
        if (model instanceof MistralEmbeddingsModel) {
            MistralEmbeddingsModel mistralEmbeddingsModel = (MistralEmbeddingsModel)model;
            ExecutableAction action = mistralEmbeddingsModel.accept(actionCreator, taskSettings);
            action.execute(new DocumentsOnlyInput(input), timeout, listener);
        } else {
            listener.onFailure((Exception)ServiceUtils.createInvalidModelException(model));
        }
    }

    @Override
    protected void doInfer(Model model, String query, List<String> input, Map<String, Object> taskSettings, InputType inputType, TimeValue timeout, ActionListener<InferenceServiceResults> listener) {
        throw new UnsupportedOperationException("Mistral service does not support inference with query input");
    }

    @Override
    protected void doChunkedInfer(Model model, String query, List<String> input, Map<String, Object> taskSettings, InputType inputType, ChunkingOptions chunkingOptions, TimeValue timeout, ActionListener<List<ChunkedInferenceServiceResults>> listener) {
        MistralActionCreator actionCreator = new MistralActionCreator(this.getSender(), this.getServiceComponents());
        if (model instanceof MistralEmbeddingsModel) {
            MistralEmbeddingsModel mistralEmbeddingsModel = (MistralEmbeddingsModel)model;
            List<EmbeddingRequestChunker.BatchRequestAndListener> batchedRequests = new EmbeddingRequestChunker(input, 96, EmbeddingRequestChunker.EmbeddingType.FLOAT).batchRequestsWithListeners(listener);
            for (EmbeddingRequestChunker.BatchRequestAndListener request : batchedRequests) {
                ExecutableAction action = mistralEmbeddingsModel.accept(actionCreator, taskSettings);
                action.execute(new DocumentsOnlyInput(request.batch().inputs()), timeout, request.listener());
            }
        } else {
            listener.onFailure((Exception)ServiceUtils.createInvalidModelException(model));
        }
    }

    public String name() {
        return NAME;
    }

    public void parseRequestConfig(String modelId, TaskType taskType, Map<String, Object> config, Set<String> platfromArchitectures, ActionListener<Model> parsedModelListener) {
        try {
            Map<String, Object> serviceSettingsMap = ServiceUtils.removeFromMapOrThrowIfNull(config, "service_settings");
            Map<String, Object> taskSettingsMap = ServiceUtils.removeFromMapOrDefaultEmpty(config, "task_settings");
            MistralEmbeddingsModel model = MistralService.createModel(modelId, taskType, serviceSettingsMap, taskSettingsMap, serviceSettingsMap, TaskType.unsupportedTaskTypeErrorMsg((TaskType)taskType, (String)NAME), ConfigurationParseContext.REQUEST);
            ServiceUtils.throwIfNotEmptyMap(config, NAME);
            ServiceUtils.throwIfNotEmptyMap(serviceSettingsMap, NAME);
            ServiceUtils.throwIfNotEmptyMap(taskSettingsMap, NAME);
            parsedModelListener.onResponse((Object)model);
        }
        catch (Exception e) {
            parsedModelListener.onFailure(e);
        }
    }

    public Model parsePersistedConfigWithSecrets(String modelId, TaskType taskType, Map<String, Object> config, Map<String, Object> secrets) {
        Map<String, Object> serviceSettingsMap = ServiceUtils.removeFromMapOrThrowIfNull(config, "service_settings");
        Map<String, Object> taskSettingsMap = ServiceUtils.removeFromMapOrDefaultEmpty(config, "task_settings");
        Map<String, Object> secretSettingsMap = ServiceUtils.removeFromMapOrDefaultEmpty(secrets, "secret_settings");
        return this.createModelFromPersistent(modelId, taskType, serviceSettingsMap, taskSettingsMap, secretSettingsMap, ServiceUtils.parsePersistedConfigErrorMsg(modelId, NAME));
    }

    public Model parsePersistedConfig(String modelId, TaskType taskType, Map<String, Object> config) {
        Map<String, Object> serviceSettingsMap = ServiceUtils.removeFromMapOrThrowIfNull(config, "service_settings");
        Map<String, Object> taskSettingsMap = ServiceUtils.removeFromMapOrDefaultEmpty(config, "task_settings");
        return this.createModelFromPersistent(modelId, taskType, serviceSettingsMap, taskSettingsMap, null, ServiceUtils.parsePersistedConfigErrorMsg(modelId, NAME));
    }

    public TransportVersion getMinimalSupportedVersion() {
        return TransportVersions.ADD_MISTRAL_EMBEDDINGS_INFERENCE;
    }

    private static MistralEmbeddingsModel createModel(String modelId, TaskType taskType, Map<String, Object> serviceSettings, Map<String, Object> taskSettings, @Nullable Map<String, Object> secretSettings, String failureMessage, ConfigurationParseContext context) {
        if (taskType == TaskType.TEXT_EMBEDDING) {
            return new MistralEmbeddingsModel(modelId, taskType, NAME, serviceSettings, taskSettings, secretSettings, context);
        }
        throw new ElasticsearchStatusException(failureMessage, RestStatus.BAD_REQUEST, new Object[0]);
    }

    private MistralEmbeddingsModel createModelFromPersistent(String inferenceEntityId, TaskType taskType, Map<String, Object> serviceSettings, Map<String, Object> taskSettings, Map<String, Object> secretSettings, String failureMessage) {
        return MistralService.createModel(inferenceEntityId, taskType, serviceSettings, taskSettings, secretSettings, failureMessage, ConfigurationParseContext.PERSISTENT);
    }

    public void checkModelConfig(Model model, ActionListener<Model> listener) {
        if (model instanceof MistralEmbeddingsModel) {
            MistralEmbeddingsModel embeddingsModel = (MistralEmbeddingsModel)model;
            ServiceUtils.getEmbeddingSize(model, this, (ActionListener<Integer>)listener.delegateFailureAndWrap((l, size) -> l.onResponse((Object)this.updateEmbeddingModelConfig(embeddingsModel, (int)size))));
        } else {
            listener.onResponse((Object)model);
        }
    }

    private MistralEmbeddingsModel updateEmbeddingModelConfig(MistralEmbeddingsModel embeddingsModel, int embeddingsSize) {
        MistralEmbeddingsServiceSettings embeddingServiceSettings = embeddingsModel.getServiceSettings();
        SimilarityMeasure similarityFromModel = embeddingsModel.getServiceSettings().similarity();
        SimilarityMeasure similarityToUse = similarityFromModel == null ? SimilarityMeasure.DOT_PRODUCT : similarityFromModel;
        MistralEmbeddingsServiceSettings serviceSettings = new MistralEmbeddingsServiceSettings(embeddingServiceSettings.model(), embeddingsSize, embeddingServiceSettings.maxInputTokens(), similarityToUse, embeddingServiceSettings.rateLimitSettings());
        return new MistralEmbeddingsModel(embeddingsModel, serviceSettings);
    }
}

