/*
 * Decompiled with CFR 0.152.
 */
package org.elasticsearch.xpack.inference.services.huggingface.elser;

import java.util.ArrayList;
import java.util.EnumSet;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import org.elasticsearch.ElasticsearchStatusException;
import org.elasticsearch.TransportVersion;
import org.elasticsearch.TransportVersions;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.common.Strings;
import org.elasticsearch.common.util.LazyInitializable;
import org.elasticsearch.core.TimeValue;
import org.elasticsearch.inference.ChunkInferenceInput;
import org.elasticsearch.inference.ChunkedInference;
import org.elasticsearch.inference.InferenceServiceConfiguration;
import org.elasticsearch.inference.InferenceServiceResults;
import org.elasticsearch.inference.InputType;
import org.elasticsearch.inference.Model;
import org.elasticsearch.inference.SettingsConfiguration;
import org.elasticsearch.inference.TaskType;
import org.elasticsearch.inference.configuration.SettingsConfigurationFieldType;
import org.elasticsearch.rest.RestStatus;
import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceEmbedding;
import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceError;
import org.elasticsearch.xpack.core.inference.results.EmbeddingResults;
import org.elasticsearch.xpack.core.inference.results.ResultUtils;
import org.elasticsearch.xpack.core.inference.results.SparseEmbeddingResults;
import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults;
import org.elasticsearch.xpack.core.inference.results.TextEmbeddingUtils;
import org.elasticsearch.xpack.core.ml.inference.results.ErrorInferenceResults;
import org.elasticsearch.xpack.inference.external.http.sender.EmbeddingsInput;
import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender;
import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput;
import org.elasticsearch.xpack.inference.services.ServiceComponents;
import org.elasticsearch.xpack.inference.services.ServiceUtils;
import org.elasticsearch.xpack.inference.services.huggingface.HuggingFaceBaseService;
import org.elasticsearch.xpack.inference.services.huggingface.HuggingFaceModel;
import org.elasticsearch.xpack.inference.services.huggingface.HuggingFaceModelParameters;
import org.elasticsearch.xpack.inference.services.huggingface.elser.HuggingFaceElserModel;
import org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettings;
import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings;

public class HuggingFaceElserService
extends HuggingFaceBaseService {
    public static final String NAME = "hugging_face_elser";
    private static final String SERVICE_NAME = "Hugging Face ELSER";
    private static final EnumSet<TaskType> supportedTaskTypes = EnumSet.of(TaskType.SPARSE_EMBEDDING);

    public HuggingFaceElserService(HttpRequestSender.Factory factory, ServiceComponents serviceComponents) {
        super(factory, serviceComponents);
    }

    public String name() {
        return NAME;
    }

    /*
     * Enabled force condition propagation
     * Lifted jumps to return sites
     */
    @Override
    protected HuggingFaceModel createModel(HuggingFaceModelParameters input) {
        switch (input.taskType()) {
            case SPARSE_EMBEDDING: {
                return new HuggingFaceElserModel(input.inferenceEntityId(), input.taskType(), NAME, input.serviceSettings(), input.secretSettings(), input.context());
            }
            default: {
                throw new ElasticsearchStatusException(input.failureMessage(), RestStatus.BAD_REQUEST, new Object[0]);
            }
        }
    }

    @Override
    protected void doUnifiedCompletionInfer(Model model, UnifiedChatInput inputs, TimeValue timeout, ActionListener<InferenceServiceResults> listener) {
        ServiceUtils.throwUnsupportedUnifiedCompletionOperation(NAME);
    }

    @Override
    protected void doChunkedInfer(Model model, EmbeddingsInput inputs, Map<String, Object> taskSettings, InputType inputType, TimeValue timeout, ActionListener<List<ChunkedInference>> listener) {
        ActionListener inferListener = listener.delegateFailureAndWrap((delegate, response) -> delegate.onResponse(HuggingFaceElserService.translateToChunkedResults(inputs, response)));
        this.doInfer(model, inputs, taskSettings, timeout, (ActionListener<InferenceServiceResults>)inferListener);
    }

    private static List<ChunkedInference> translateToChunkedResults(EmbeddingsInput inputs, InferenceServiceResults inferenceResults) {
        if (inferenceResults instanceof TextEmbeddingFloatResults) {
            TextEmbeddingFloatResults textEmbeddingResults = (TextEmbeddingFloatResults)inferenceResults;
            TextEmbeddingUtils.validateInputSizeAgainstEmbeddings((List)ChunkInferenceInput.inputs(inputs.getInputs()), (int)textEmbeddingResults.embeddings().size());
            ArrayList<ChunkedInference> results = new ArrayList<ChunkedInference>(inputs.getInputs().size());
            for (int i = 0; i < inputs.getInputs().size(); ++i) {
                results.add((ChunkedInference)new ChunkedInferenceEmbedding(List.of(new EmbeddingResults.Chunk((EmbeddingResults.Embedding)textEmbeddingResults.embeddings().get(i), new ChunkedInference.TextOffset(0, inputs.getInputs().get(i).input().length())))));
            }
            return results;
        }
        if (inferenceResults instanceof SparseEmbeddingResults) {
            SparseEmbeddingResults sparseEmbeddingResults = (SparseEmbeddingResults)inferenceResults;
            List inputsAsList = ChunkInferenceInput.inputs(EmbeddingsInput.of(inputs).getInputs());
            return ChunkedInferenceEmbedding.listOf((List)inputsAsList, (SparseEmbeddingResults)sparseEmbeddingResults);
        }
        if (inferenceResults instanceof ErrorInferenceResults) {
            ErrorInferenceResults error = (ErrorInferenceResults)inferenceResults;
            return List.of(new ChunkedInferenceError(error.getException()));
        }
        String expectedClasses = Strings.format((String)"One of [%s,%s]", (Object[])new Object[]{TextEmbeddingFloatResults.class.getSimpleName(), SparseEmbeddingResults.class.getSimpleName()});
        throw ResultUtils.createInvalidChunkedResultException((String)expectedClasses, (String)inferenceResults.getWriteableName());
    }

    public InferenceServiceConfiguration getConfiguration() {
        return Configuration.get();
    }

    public boolean hideFromConfigurationApi() {
        return true;
    }

    public EnumSet<TaskType> supportedTaskTypes() {
        return supportedTaskTypes;
    }

    public TransportVersion getMinimalSupportedVersion() {
        return TransportVersions.V_8_12_0;
    }

    public static class Configuration {
        private static final LazyInitializable<InferenceServiceConfiguration, RuntimeException> configuration = new LazyInitializable(() -> {
            HashMap<String, SettingsConfiguration> configurationMap = new HashMap<String, SettingsConfiguration>();
            configurationMap.put("url", new SettingsConfiguration.Builder(supportedTaskTypes).setDescription("The URL endpoint to use for the requests.").setLabel("URL").setRequired(Boolean.valueOf(true)).setSensitive(Boolean.valueOf(false)).setUpdatable(Boolean.valueOf(false)).setType(SettingsConfigurationFieldType.STRING).build());
            configurationMap.putAll(DefaultSecretSettings.toSettingsConfiguration(supportedTaskTypes));
            configurationMap.putAll(RateLimitSettings.toSettingsConfiguration(supportedTaskTypes));
            return new InferenceServiceConfiguration.Builder().setService(HuggingFaceElserService.NAME).setName(HuggingFaceElserService.SERVICE_NAME).setTaskTypes(supportedTaskTypes).setConfigurations(configurationMap).build();
        });

        public static InferenceServiceConfiguration get() {
            return (InferenceServiceConfiguration)configuration.getOrCompute();
        }
    }
}

