/*
 * Decompiled with CFR 0.152.
 */
package org.elasticsearch.xpack.inference.services.azureaistudio;

import java.util.EnumSet;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.stream.Stream;
import org.elasticsearch.ElasticsearchStatusException;
import org.elasticsearch.TransportVersion;
import org.elasticsearch.TransportVersions;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.common.Strings;
import org.elasticsearch.common.util.LazyInitializable;
import org.elasticsearch.core.Nullable;
import org.elasticsearch.core.TimeValue;
import org.elasticsearch.inference.ChunkedInferenceServiceResults;
import org.elasticsearch.inference.ChunkingOptions;
import org.elasticsearch.inference.ChunkingSettings;
import org.elasticsearch.inference.EmptySettingsConfiguration;
import org.elasticsearch.inference.InferenceServiceConfiguration;
import org.elasticsearch.inference.InferenceServiceResults;
import org.elasticsearch.inference.InputType;
import org.elasticsearch.inference.Model;
import org.elasticsearch.inference.SettingsConfiguration;
import org.elasticsearch.inference.SimilarityMeasure;
import org.elasticsearch.inference.TaskSettingsConfiguration;
import org.elasticsearch.inference.TaskType;
import org.elasticsearch.inference.configuration.SettingsConfigurationDisplayType;
import org.elasticsearch.inference.configuration.SettingsConfigurationFieldType;
import org.elasticsearch.inference.configuration.SettingsConfigurationSelectOption;
import org.elasticsearch.rest.RestStatus;
import org.elasticsearch.xpack.inference.chunking.ChunkingSettingsBuilder;
import org.elasticsearch.xpack.inference.chunking.EmbeddingRequestChunker;
import org.elasticsearch.xpack.inference.external.action.ExecutableAction;
import org.elasticsearch.xpack.inference.external.action.azureaistudio.AzureAiStudioActionCreator;
import org.elasticsearch.xpack.inference.external.http.sender.DocumentsOnlyInput;
import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender;
import org.elasticsearch.xpack.inference.external.http.sender.InferenceInputs;
import org.elasticsearch.xpack.inference.services.ConfigurationParseContext;
import org.elasticsearch.xpack.inference.services.SenderService;
import org.elasticsearch.xpack.inference.services.ServiceComponents;
import org.elasticsearch.xpack.inference.services.ServiceUtils;
import org.elasticsearch.xpack.inference.services.azureaistudio.AzureAiStudioEndpointType;
import org.elasticsearch.xpack.inference.services.azureaistudio.AzureAiStudioModel;
import org.elasticsearch.xpack.inference.services.azureaistudio.AzureAiStudioProvider;
import org.elasticsearch.xpack.inference.services.azureaistudio.AzureAiStudioProviderCapabilities;
import org.elasticsearch.xpack.inference.services.azureaistudio.completion.AzureAiStudioChatCompletionModel;
import org.elasticsearch.xpack.inference.services.azureaistudio.completion.AzureAiStudioChatCompletionTaskSettings;
import org.elasticsearch.xpack.inference.services.azureaistudio.embeddings.AzureAiStudioEmbeddingsModel;
import org.elasticsearch.xpack.inference.services.azureaistudio.embeddings.AzureAiStudioEmbeddingsServiceSettings;
import org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettings;
import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings;
import org.elasticsearch.xpack.inference.services.validation.ModelValidatorBuilder;

public class AzureAiStudioService
extends SenderService {
    static final String NAME = "azureaistudio";
    private static final EnumSet<TaskType> supportedTaskTypes = EnumSet.of(TaskType.TEXT_EMBEDDING, TaskType.COMPLETION);

    public AzureAiStudioService(HttpRequestSender.Factory factory, ServiceComponents serviceComponents) {
        super(factory, serviceComponents);
    }

    @Override
    protected void doInfer(Model model, InferenceInputs inputs, Map<String, Object> taskSettings, InputType inputType, TimeValue timeout, ActionListener<InferenceServiceResults> listener) {
        AzureAiStudioActionCreator actionCreator = new AzureAiStudioActionCreator(this.getSender(), this.getServiceComponents());
        if (model instanceof AzureAiStudioModel) {
            AzureAiStudioModel baseAzureAiStudioModel = (AzureAiStudioModel)model;
            ExecutableAction action = baseAzureAiStudioModel.accept(actionCreator, taskSettings);
            action.execute(inputs, timeout, listener);
        } else {
            listener.onFailure((Exception)ServiceUtils.createInvalidModelException(model));
        }
    }

    @Override
    protected void doChunkedInfer(Model model, DocumentsOnlyInput inputs, Map<String, Object> taskSettings, InputType inputType, ChunkingOptions chunkingOptions, TimeValue timeout, ActionListener<List<ChunkedInferenceServiceResults>> listener) {
        if (model instanceof AzureAiStudioModel) {
            AzureAiStudioModel baseAzureAiStudioModel = (AzureAiStudioModel)model;
            AzureAiStudioActionCreator actionCreator = new AzureAiStudioActionCreator(this.getSender(), this.getServiceComponents());
            List<EmbeddingRequestChunker.BatchRequestAndListener> batchedRequests = new EmbeddingRequestChunker(inputs.getInputs(), 2048, EmbeddingRequestChunker.EmbeddingType.FLOAT, baseAzureAiStudioModel.getConfigurations().getChunkingSettings()).batchRequestsWithListeners(listener);
            for (EmbeddingRequestChunker.BatchRequestAndListener request : batchedRequests) {
                ExecutableAction action = baseAzureAiStudioModel.accept(actionCreator, taskSettings);
                action.execute(new DocumentsOnlyInput(request.batch().inputs()), timeout, request.listener());
            }
        } else {
            listener.onFailure((Exception)ServiceUtils.createInvalidModelException(model));
        }
    }

    public void parseRequestConfig(String inferenceEntityId, TaskType taskType, Map<String, Object> config, ActionListener<Model> parsedModelListener) {
        try {
            Map<String, Object> serviceSettingsMap = ServiceUtils.removeFromMapOrThrowIfNull(config, "service_settings");
            Map<String, Object> taskSettingsMap = ServiceUtils.removeFromMapOrDefaultEmpty(config, "task_settings");
            ChunkingSettings chunkingSettings = null;
            if (TaskType.TEXT_EMBEDDING.equals((Object)taskType)) {
                chunkingSettings = ChunkingSettingsBuilder.fromMap(ServiceUtils.removeFromMapOrDefaultEmpty(config, "chunking_settings"));
            }
            AzureAiStudioModel model = AzureAiStudioService.createModel(inferenceEntityId, taskType, serviceSettingsMap, taskSettingsMap, chunkingSettings, serviceSettingsMap, TaskType.unsupportedTaskTypeErrorMsg((TaskType)taskType, (String)NAME), ConfigurationParseContext.REQUEST);
            ServiceUtils.throwIfNotEmptyMap(config, NAME);
            ServiceUtils.throwIfNotEmptyMap(serviceSettingsMap, NAME);
            ServiceUtils.throwIfNotEmptyMap(taskSettingsMap, NAME);
            parsedModelListener.onResponse((Object)model);
        }
        catch (Exception e) {
            parsedModelListener.onFailure(e);
        }
    }

    public AzureAiStudioModel parsePersistedConfigWithSecrets(String inferenceEntityId, TaskType taskType, Map<String, Object> config, Map<String, Object> secrets) {
        Map<String, Object> serviceSettingsMap = ServiceUtils.removeFromMapOrThrowIfNull(config, "service_settings");
        Map<String, Object> taskSettingsMap = ServiceUtils.removeFromMapOrDefaultEmpty(config, "task_settings");
        Map<String, Object> secretSettingsMap = ServiceUtils.removeFromMapOrDefaultEmpty(secrets, "secret_settings");
        ChunkingSettings chunkingSettings = null;
        if (TaskType.TEXT_EMBEDDING.equals((Object)taskType)) {
            chunkingSettings = ChunkingSettingsBuilder.fromMap(ServiceUtils.removeFromMap(config, "chunking_settings"));
        }
        return this.createModelFromPersistent(inferenceEntityId, taskType, serviceSettingsMap, taskSettingsMap, chunkingSettings, secretSettingsMap, ServiceUtils.parsePersistedConfigErrorMsg(inferenceEntityId, NAME));
    }

    public Model parsePersistedConfig(String inferenceEntityId, TaskType taskType, Map<String, Object> config) {
        Map<String, Object> serviceSettingsMap = ServiceUtils.removeFromMapOrThrowIfNull(config, "service_settings");
        Map<String, Object> taskSettingsMap = ServiceUtils.removeFromMapOrDefaultEmpty(config, "task_settings");
        ChunkingSettings chunkingSettings = null;
        if (TaskType.TEXT_EMBEDDING.equals((Object)taskType)) {
            chunkingSettings = ChunkingSettingsBuilder.fromMap(ServiceUtils.removeFromMap(config, "chunking_settings"));
        }
        return this.createModelFromPersistent(inferenceEntityId, taskType, serviceSettingsMap, taskSettingsMap, chunkingSettings, null, ServiceUtils.parsePersistedConfigErrorMsg(inferenceEntityId, NAME));
    }

    public InferenceServiceConfiguration getConfiguration() {
        return Configuration.get();
    }

    public EnumSet<TaskType> supportedTaskTypes() {
        return supportedTaskTypes;
    }

    public String name() {
        return NAME;
    }

    public TransportVersion getMinimalSupportedVersion() {
        return TransportVersions.V_8_15_0;
    }

    public Set<TaskType> supportedStreamingTasks() {
        return COMPLETION_ONLY;
    }

    private static AzureAiStudioModel createModel(String inferenceEntityId, TaskType taskType, Map<String, Object> serviceSettings, Map<String, Object> taskSettings, ChunkingSettings chunkingSettings, @Nullable Map<String, Object> secretSettings, String failureMessage, ConfigurationParseContext context) {
        if (taskType == TaskType.TEXT_EMBEDDING) {
            AzureAiStudioEmbeddingsModel embeddingsModel = new AzureAiStudioEmbeddingsModel(inferenceEntityId, taskType, NAME, serviceSettings, taskSettings, chunkingSettings, secretSettings, context);
            AzureAiStudioService.checkProviderAndEndpointTypeForTask(TaskType.TEXT_EMBEDDING, embeddingsModel.getServiceSettings().provider(), embeddingsModel.getServiceSettings().endpointType());
            return embeddingsModel;
        }
        if (taskType == TaskType.COMPLETION) {
            AzureAiStudioChatCompletionModel completionModel = new AzureAiStudioChatCompletionModel(inferenceEntityId, taskType, NAME, serviceSettings, taskSettings, secretSettings, context);
            AzureAiStudioService.checkProviderAndEndpointTypeForTask(TaskType.COMPLETION, completionModel.getServiceSettings().provider(), completionModel.getServiceSettings().endpointType());
            return completionModel;
        }
        throw new ElasticsearchStatusException(failureMessage, RestStatus.BAD_REQUEST, new Object[0]);
    }

    private AzureAiStudioModel createModelFromPersistent(String inferenceEntityId, TaskType taskType, Map<String, Object> serviceSettings, Map<String, Object> taskSettings, ChunkingSettings chunkingSettings, Map<String, Object> secretSettings, String failureMessage) {
        return AzureAiStudioService.createModel(inferenceEntityId, taskType, serviceSettings, taskSettings, chunkingSettings, secretSettings, failureMessage, ConfigurationParseContext.PERSISTENT);
    }

    public void checkModelConfig(Model model, ActionListener<Model> listener) {
        ModelValidatorBuilder.buildModelValidator(model.getTaskType()).validate(this, model, listener);
    }

    public Model updateModelWithEmbeddingDetails(Model model, int embeddingSize) {
        if (model instanceof AzureAiStudioEmbeddingsModel) {
            AzureAiStudioEmbeddingsModel embeddingsModel = (AzureAiStudioEmbeddingsModel)model;
            AzureAiStudioEmbeddingsServiceSettings serviceSettings = embeddingsModel.getServiceSettings();
            SimilarityMeasure similarityFromModel = serviceSettings.similarity();
            SimilarityMeasure similarityToUse = similarityFromModel == null ? SimilarityMeasure.DOT_PRODUCT : similarityFromModel;
            AzureAiStudioEmbeddingsServiceSettings updatedServiceSettings = new AzureAiStudioEmbeddingsServiceSettings(serviceSettings.target(), serviceSettings.provider(), serviceSettings.endpointType(), embeddingSize, serviceSettings.dimensionsSetByUser(), serviceSettings.maxInputTokens(), similarityToUse, serviceSettings.rateLimitSettings());
            return new AzureAiStudioEmbeddingsModel(embeddingsModel, updatedServiceSettings);
        }
        throw ServiceUtils.invalidModelTypeForUpdateModelWithEmbeddingDetails(model.getClass());
    }

    public Model updateModelWithChatCompletionDetails(Model model) {
        if (model instanceof AzureAiStudioChatCompletionModel) {
            AzureAiStudioChatCompletionModel chatCompletionModel = (AzureAiStudioChatCompletionModel)model;
            AzureAiStudioChatCompletionTaskSettings taskSettings = chatCompletionModel.getTaskSettings();
            Integer modelMaxNewTokens = taskSettings.maxNewTokens();
            Integer maxNewTokensToUse = modelMaxNewTokens == null ? AzureAiStudioChatCompletionTaskSettings.DEFAULT_MAX_NEW_TOKENS : modelMaxNewTokens;
            AzureAiStudioChatCompletionTaskSettings updatedTaskSettings = new AzureAiStudioChatCompletionTaskSettings(taskSettings.temperature(), taskSettings.topP(), taskSettings.doSample(), maxNewTokensToUse);
            return new AzureAiStudioChatCompletionModel(chatCompletionModel, updatedTaskSettings);
        }
        throw ServiceUtils.invalidModelTypeForUpdateModelWithChatCompletionDetails(model.getClass());
    }

    private static void checkProviderAndEndpointTypeForTask(TaskType taskType, AzureAiStudioProvider provider, AzureAiStudioEndpointType endpointType) {
        if (!AzureAiStudioProviderCapabilities.providerAllowsTaskType(provider, taskType)) {
            throw new ElasticsearchStatusException(Strings.format((String)"The [%s] task type for provider [%s] is not available", (Object[])new Object[]{taskType, provider}), RestStatus.BAD_REQUEST, new Object[0]);
        }
        if (!AzureAiStudioProviderCapabilities.providerAllowsEndpointTypeForTask(provider, taskType, endpointType)) {
            throw new ElasticsearchStatusException(Strings.format((String)"The [%s] endpoint type with [%s] task type for provider [%s] is not available", (Object[])new Object[]{endpointType, taskType, provider}), RestStatus.BAD_REQUEST, new Object[0]);
        }
    }

    public static class Configuration {
        private static final LazyInitializable<InferenceServiceConfiguration, RuntimeException> configuration = new LazyInitializable(() -> {
            HashMap<String, SettingsConfiguration> configurationMap = new HashMap<String, SettingsConfiguration>();
            configurationMap.put("target", new SettingsConfiguration.Builder().setDisplay(SettingsConfigurationDisplayType.TEXTBOX).setLabel("Target").setOrder(Integer.valueOf(2)).setRequired(Boolean.valueOf(true)).setSensitive(Boolean.valueOf(false)).setTooltip("The target URL of your Azure AI Studio model deployment.").setType(SettingsConfigurationFieldType.STRING).build());
            configurationMap.put("endpoint_type", new SettingsConfiguration.Builder().setDisplay(SettingsConfigurationDisplayType.DROPDOWN).setLabel("Endpoint Type").setOrder(Integer.valueOf(3)).setRequired(Boolean.valueOf(true)).setSensitive(Boolean.valueOf(false)).setTooltip("Specifies the type of endpoint that is used in your model deployment.").setType(SettingsConfigurationFieldType.STRING).setOptions(Stream.of("token", "realtime").map(v -> new SettingsConfigurationSelectOption.Builder().setLabelAndValue(v).build()).toList()).build());
            configurationMap.put("provider", new SettingsConfiguration.Builder().setDisplay(SettingsConfigurationDisplayType.DROPDOWN).setLabel("Provider").setOrder(Integer.valueOf(3)).setRequired(Boolean.valueOf(true)).setSensitive(Boolean.valueOf(false)).setTooltip("The model provider for your deployment.").setType(SettingsConfigurationFieldType.STRING).setOptions(Stream.of("cohere", "meta", "microsoft_phi", "mistral", "openai", "databricks").map(v -> new SettingsConfigurationSelectOption.Builder().setLabelAndValue(v).build()).toList()).build());
            configurationMap.putAll(DefaultSecretSettings.toSettingsConfiguration());
            configurationMap.putAll(RateLimitSettings.toSettingsConfiguration());
            return new InferenceServiceConfiguration.Builder().setProvider(AzureAiStudioService.NAME).setTaskTypes(supportedTaskTypes.stream().map(t -> new TaskSettingsConfiguration.Builder().setTaskType(t).setConfiguration(switch (t) {
                case TaskType.TEXT_EMBEDDING -> AzureAiStudioEmbeddingsModel.Configuration.get();
                case TaskType.COMPLETION -> AzureAiStudioChatCompletionModel.Configuration.get();
                default -> EmptySettingsConfiguration.get();
            }).build()).toList()).setConfiguration(configurationMap).build();
        });

        public static InferenceServiceConfiguration get() {
            return (InferenceServiceConfiguration)configuration.getOrCompute();
        }
    }
}

