/*
 * Decompiled with CFR 0.152.
 */
package org.elasticsearch.xpack.inference.services.custom;

import java.util.EnumSet;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import org.elasticsearch.ElasticsearchStatusException;
import org.elasticsearch.TransportVersion;
import org.elasticsearch.TransportVersions;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.common.ValidationException;
import org.elasticsearch.common.util.LazyInitializable;
import org.elasticsearch.core.Nullable;
import org.elasticsearch.core.Strings;
import org.elasticsearch.core.TimeValue;
import org.elasticsearch.inference.ChunkedInference;
import org.elasticsearch.inference.ChunkingSettings;
import org.elasticsearch.inference.InferenceServiceConfiguration;
import org.elasticsearch.inference.InferenceServiceResults;
import org.elasticsearch.inference.InputType;
import org.elasticsearch.inference.Model;
import org.elasticsearch.inference.SimilarityMeasure;
import org.elasticsearch.inference.TaskType;
import org.elasticsearch.inference.validation.ServiceIntegrationValidator;
import org.elasticsearch.rest.RestStatus;
import org.elasticsearch.xpack.inference.chunking.ChunkingSettingsBuilder;
import org.elasticsearch.xpack.inference.chunking.EmbeddingRequestChunker;
import org.elasticsearch.xpack.inference.external.action.ActionUtils;
import org.elasticsearch.xpack.inference.external.action.SenderExecutableAction;
import org.elasticsearch.xpack.inference.external.http.sender.ChatCompletionInput;
import org.elasticsearch.xpack.inference.external.http.sender.EmbeddingsInput;
import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender;
import org.elasticsearch.xpack.inference.external.http.sender.InferenceInputs;
import org.elasticsearch.xpack.inference.external.http.sender.QueryAndDocsInputs;
import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput;
import org.elasticsearch.xpack.inference.services.ConfigurationParseContext;
import org.elasticsearch.xpack.inference.services.SenderService;
import org.elasticsearch.xpack.inference.services.ServiceComponents;
import org.elasticsearch.xpack.inference.services.ServiceUtils;
import org.elasticsearch.xpack.inference.services.custom.CustomModel;
import org.elasticsearch.xpack.inference.services.custom.CustomRequestManager;
import org.elasticsearch.xpack.inference.services.custom.CustomServiceSettings;
import org.elasticsearch.xpack.inference.services.custom.request.CompletionParameters;
import org.elasticsearch.xpack.inference.services.custom.request.CustomRequest;
import org.elasticsearch.xpack.inference.services.custom.request.EmbeddingParameters;
import org.elasticsearch.xpack.inference.services.custom.request.RequestParameters;
import org.elasticsearch.xpack.inference.services.custom.request.RerankParameters;
import org.elasticsearch.xpack.inference.services.validation.CustomServiceIntegrationValidator;

public class CustomService
extends SenderService {
    public static final String NAME = "custom";
    private static final String SERVICE_NAME = "Custom";
    private static final EnumSet<TaskType> supportedTaskTypes = EnumSet.of(TaskType.TEXT_EMBEDDING, TaskType.SPARSE_EMBEDDING, TaskType.RERANK, TaskType.COMPLETION);

    public CustomService(HttpRequestSender.Factory factory, ServiceComponents serviceComponents) {
        super(factory, serviceComponents);
    }

    public String name() {
        return NAME;
    }

    public void parseRequestConfig(String inferenceEntityId, TaskType taskType, Map<String, Object> config, ActionListener<Model> parsedModelListener) {
        try {
            Map<String, Object> serviceSettingsMap = ServiceUtils.removeFromMapOrThrowIfNull(config, "service_settings");
            Map<String, Object> taskSettingsMap = ServiceUtils.removeFromMapOrDefaultEmpty(config, "task_settings");
            ChunkingSettings chunkingSettings = CustomService.extractChunkingSettings(config, taskType);
            CustomModel model = CustomService.createModel(inferenceEntityId, taskType, serviceSettingsMap, taskSettingsMap, serviceSettingsMap, chunkingSettings, ConfigurationParseContext.REQUEST);
            ServiceUtils.throwIfNotEmptyMap(config, NAME);
            ServiceUtils.throwIfNotEmptyMap(serviceSettingsMap, NAME);
            ServiceUtils.throwIfNotEmptyMap(taskSettingsMap, NAME);
            CustomService.validateConfiguration(model);
            parsedModelListener.onResponse((Object)model);
        }
        catch (Exception e) {
            parsedModelListener.onFailure(e);
        }
    }

    private static void validateConfiguration(CustomModel model) {
        try {
            new CustomRequest(CustomService.createParameters(model), model).createHttpRequest();
        }
        catch (IllegalStateException e) {
            ValidationException validationException = new ValidationException();
            validationException.addValidationError(Strings.format((String)"Failed to validate model configuration: %s", (Object[])new Object[]{e.getMessage()}));
            throw validationException;
        }
    }

    private static RequestParameters createParameters(CustomModel model) {
        return switch (model.getTaskType()) {
            case TaskType.RERANK -> RerankParameters.of(new QueryAndDocsInputs("test query", List.of("test input")));
            case TaskType.COMPLETION -> CompletionParameters.of(new ChatCompletionInput(List.of("test input")));
            case TaskType.TEXT_EMBEDDING, TaskType.SPARSE_EMBEDDING -> EmbeddingParameters.of(new EmbeddingsInput(List.of("test input"), null, null), model.getServiceSettings().getInputTypeTranslator());
            default -> throw new IllegalStateException(Strings.format((String)"Unsupported task type [%s] for custom service", (Object[])new Object[]{model.getTaskType()}));
        };
    }

    private static ChunkingSettings extractChunkingSettings(Map<String, Object> config, TaskType taskType) {
        if (TaskType.TEXT_EMBEDDING.equals((Object)taskType)) {
            return ChunkingSettingsBuilder.fromMap(ServiceUtils.removeFromMap(config, "chunking_settings"));
        }
        return null;
    }

    public InferenceServiceConfiguration getConfiguration() {
        return Configuration.get();
    }

    public EnumSet<TaskType> supportedTaskTypes() {
        return supportedTaskTypes;
    }

    @Override
    protected void doUnifiedCompletionInfer(Model model, UnifiedChatInput inputs, TimeValue timeout, ActionListener<InferenceServiceResults> listener) {
        ServiceUtils.throwUnsupportedUnifiedCompletionOperation(NAME);
    }

    private static CustomModel createModelWithoutLoggingDeprecations(String inferenceEntityId, TaskType taskType, Map<String, Object> serviceSettings, Map<String, Object> taskSettings, @Nullable Map<String, Object> secretSettings, @Nullable ChunkingSettings chunkingSettings) {
        return CustomService.createModel(inferenceEntityId, taskType, serviceSettings, taskSettings, secretSettings, chunkingSettings, ConfigurationParseContext.PERSISTENT);
    }

    private static CustomModel createModel(String inferenceEntityId, TaskType taskType, Map<String, Object> serviceSettings, Map<String, Object> taskSettings, @Nullable Map<String, Object> secretSettings, @Nullable ChunkingSettings chunkingSettings, ConfigurationParseContext context) {
        if (!supportedTaskTypes.contains(taskType)) {
            throw new ElasticsearchStatusException(TaskType.unsupportedTaskTypeErrorMsg((TaskType)taskType, (String)NAME), RestStatus.BAD_REQUEST, new Object[0]);
        }
        return new CustomModel(inferenceEntityId, taskType, NAME, serviceSettings, taskSettings, secretSettings, chunkingSettings, context);
    }

    public CustomModel parsePersistedConfigWithSecrets(String inferenceEntityId, TaskType taskType, Map<String, Object> config, Map<String, Object> secrets) {
        Map<String, Object> serviceSettingsMap = ServiceUtils.removeFromMapOrThrowIfNull(config, "service_settings");
        Map<String, Object> taskSettingsMap = ServiceUtils.removeFromMapOrThrowIfNull(config, "task_settings");
        Map<String, Object> secretSettingsMap = ServiceUtils.removeFromMapOrThrowIfNull(secrets, "secret_settings");
        ChunkingSettings chunkingSettings = CustomService.extractChunkingSettings(config, taskType);
        return CustomService.createModelWithoutLoggingDeprecations(inferenceEntityId, taskType, serviceSettingsMap, taskSettingsMap, secretSettingsMap, chunkingSettings);
    }

    public CustomModel parsePersistedConfig(String inferenceEntityId, TaskType taskType, Map<String, Object> config) {
        Map<String, Object> serviceSettingsMap = ServiceUtils.removeFromMapOrThrowIfNull(config, "service_settings");
        Map<String, Object> taskSettingsMap = ServiceUtils.removeFromMapOrThrowIfNull(config, "task_settings");
        ChunkingSettings chunkingSettings = CustomService.extractChunkingSettings(config, taskType);
        return CustomService.createModelWithoutLoggingDeprecations(inferenceEntityId, taskType, serviceSettingsMap, taskSettingsMap, null, chunkingSettings);
    }

    @Override
    public void doInfer(Model model, InferenceInputs inputs, Map<String, Object> taskSettings, TimeValue timeout, ActionListener<InferenceServiceResults> listener) {
        if (!(model instanceof CustomModel)) {
            listener.onFailure((Exception)ServiceUtils.createInvalidModelException(model));
            return;
        }
        CustomModel customModel = (CustomModel)model;
        CustomModel overriddenModel = CustomModel.of(customModel, taskSettings);
        String failedToSendRequestErrorMessage = ActionUtils.constructFailedToSendRequestMessage(SERVICE_NAME);
        CustomRequestManager manager = CustomRequestManager.of(overriddenModel, this.getServiceComponents().threadPool());
        SenderExecutableAction action = new SenderExecutableAction(this.getSender(), manager, failedToSendRequestErrorMessage);
        action.execute(inputs, timeout, listener);
    }

    @Override
    protected void validateInputType(InputType inputType, Model model, ValidationException validationException) {
    }

    @Override
    protected void doChunkedInfer(Model model, EmbeddingsInput inputs, Map<String, Object> taskSettings, InputType inputType, TimeValue timeout, ActionListener<List<ChunkedInference>> listener) {
        if (!(model instanceof CustomModel)) {
            listener.onFailure((Exception)ServiceUtils.createInvalidModelException(model));
            return;
        }
        CustomModel customModel = (CustomModel)model;
        CustomModel overriddenModel = CustomModel.of(customModel, taskSettings);
        String failedToSendRequestErrorMessage = ActionUtils.constructFailedToSendRequestMessage(SERVICE_NAME);
        CustomRequestManager manager = CustomRequestManager.of(overriddenModel, this.getServiceComponents().threadPool());
        List<EmbeddingRequestChunker.BatchRequestAndListener> batchedRequests = new EmbeddingRequestChunker(inputs.getInputs(), customModel.getServiceSettings().getBatchSize(), customModel.getConfigurations().getChunkingSettings()).batchRequestsWithListeners(listener);
        for (EmbeddingRequestChunker.BatchRequestAndListener request : batchedRequests) {
            SenderExecutableAction action = new SenderExecutableAction(this.getSender(), manager, failedToSendRequestErrorMessage);
            action.execute(EmbeddingsInput.fromStrings(request.batch().inputs().get(), inputType), timeout, request.listener());
        }
    }

    public Model updateModelWithEmbeddingDetails(Model model, int embeddingSize) {
        CustomModel customModel;
        if (model instanceof CustomModel && (customModel = (CustomModel)model).getTaskType() == TaskType.TEXT_EMBEDDING) {
            CustomServiceSettings newServiceSettings = CustomService.getCustomServiceSettings(customModel, embeddingSize);
            return new CustomModel(customModel, newServiceSettings);
        }
        throw new ElasticsearchStatusException(Strings.format((String)"Can't update embedding details for model of type: [%s], task type: [%s]", (Object[])new Object[]{model.getClass().getSimpleName(), model.getTaskType()}), RestStatus.BAD_REQUEST, new Object[0]);
    }

    private static CustomServiceSettings getCustomServiceSettings(CustomModel customModel, int embeddingSize) {
        CustomServiceSettings serviceSettings = customModel.getServiceSettings();
        SimilarityMeasure similarityFromModel = serviceSettings.similarity();
        SimilarityMeasure similarityToUse = similarityFromModel == null ? SimilarityMeasure.DOT_PRODUCT : similarityFromModel;
        return new CustomServiceSettings(new CustomServiceSettings.TextEmbeddingSettings(similarityToUse, embeddingSize, serviceSettings.getMaxInputTokens(), serviceSettings.elementType()), serviceSettings.getUrl(), serviceSettings.getHeaders(), serviceSettings.getQueryParameters(), serviceSettings.getRequestContentString(), serviceSettings.getResponseJsonParser(), serviceSettings.rateLimitSettings(), serviceSettings.getBatchSize(), serviceSettings.getInputTypeTranslator());
    }

    public TransportVersion getMinimalSupportedVersion() {
        return TransportVersions.INFERENCE_CUSTOM_SERVICE_ADDED;
    }

    public boolean hideFromConfigurationApi() {
        return true;
    }

    public ServiceIntegrationValidator getServiceIntegrationValidator(TaskType taskType) {
        if (taskType == TaskType.RERANK) {
            return new CustomServiceIntegrationValidator();
        }
        return null;
    }

    public static class Configuration {
        private static final LazyInitializable<InferenceServiceConfiguration, RuntimeException> configuration = new LazyInitializable(() -> {
            HashMap configurationMap = new HashMap();
            return new InferenceServiceConfiguration.Builder().setService(CustomService.NAME).setName(CustomService.SERVICE_NAME).setTaskTypes(supportedTaskTypes).setConfigurations(configurationMap).build();
        });

        public static InferenceServiceConfiguration get() {
            return (InferenceServiceConfiguration)configuration.getOrCompute();
        }
    }
}

