/*
 * Decompiled with CFR 0.152.
 */
package org.elasticsearch.xpack.inference.services.sagemaker;

import java.io.IOException;
import java.util.EnumSet;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.concurrent.Executor;
import org.elasticsearch.ElasticsearchStatusException;
import org.elasticsearch.TransportVersion;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.action.support.SubscribableListener;
import org.elasticsearch.common.CheckedSupplier;
import org.elasticsearch.common.util.LazyInitializable;
import org.elasticsearch.core.Nullable;
import org.elasticsearch.core.Strings;
import org.elasticsearch.core.TimeValue;
import org.elasticsearch.inference.ChunkInferenceInput;
import org.elasticsearch.inference.ChunkedInference;
import org.elasticsearch.inference.InferenceService;
import org.elasticsearch.inference.InferenceServiceConfiguration;
import org.elasticsearch.inference.InferenceServiceResults;
import org.elasticsearch.inference.InputType;
import org.elasticsearch.inference.Model;
import org.elasticsearch.inference.SettingsConfiguration;
import org.elasticsearch.inference.TaskType;
import org.elasticsearch.inference.UnifiedCompletionRequest;
import org.elasticsearch.rest.RestStatus;
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.xpack.inference.chunking.EmbeddingRequestChunker;
import org.elasticsearch.xpack.inference.common.amazon.AwsSecretSettings;
import org.elasticsearch.xpack.inference.services.ServiceUtils;
import org.elasticsearch.xpack.inference.services.sagemaker.SageMakerClient;
import org.elasticsearch.xpack.inference.services.sagemaker.SageMakerInferenceRequest;
import org.elasticsearch.xpack.inference.services.sagemaker.model.SageMakerModel;
import org.elasticsearch.xpack.inference.services.sagemaker.model.SageMakerModelBuilder;
import org.elasticsearch.xpack.inference.services.sagemaker.schema.SageMakerSchema;
import org.elasticsearch.xpack.inference.services.sagemaker.schema.SageMakerSchemas;
import org.elasticsearch.xpack.inference.services.sagemaker.schema.SageMakerStreamSchema;
import software.amazon.awssdk.services.sagemakerruntime.model.InvokeEndpointRequest;
import software.amazon.awssdk.services.sagemakerruntime.model.InvokeEndpointResponse;
import software.amazon.awssdk.services.sagemakerruntime.model.InvokeEndpointWithResponseStreamRequest;

public class SageMakerService
implements InferenceService {
    public static final String NAME = "amazon_sagemaker";
    private static final String DISPLAY_NAME = "Amazon SageMaker";
    private static final List<String> ALIASES = List.of("sagemaker", "amazonsagemaker");
    private static final int DEFAULT_BATCH_SIZE = 256;
    private static final TimeValue DEFAULT_TIMEOUT = TimeValue.THIRTY_SECONDS;
    private static final TransportVersion ML_INFERENCE_SAGEMAKER = TransportVersion.fromName((String)"ml_inference_sagemaker");
    private final SageMakerModelBuilder modelBuilder;
    private final SageMakerClient client;
    private final SageMakerSchemas schemas;
    private final ThreadPool threadPool;
    private final LazyInitializable<InferenceServiceConfiguration, RuntimeException> configuration;

    public SageMakerService(SageMakerModelBuilder modelBuilder, SageMakerClient client, SageMakerSchemas schemas, ThreadPool threadPool, CheckedSupplier<Map<String, SettingsConfiguration>, RuntimeException> configurationMap) {
        this.modelBuilder = modelBuilder;
        this.client = client;
        this.schemas = schemas;
        this.threadPool = threadPool;
        this.configuration = new LazyInitializable(() -> new InferenceServiceConfiguration.Builder().setService(NAME).setName(DISPLAY_NAME).setTaskTypes(this.supportedTaskTypes()).setConfigurations((Map)configurationMap.get()).build());
    }

    public String name() {
        return NAME;
    }

    public List<String> aliases() {
        return ALIASES;
    }

    public void parseRequestConfig(String modelId, TaskType taskType, Map<String, Object> config, ActionListener<Model> parsedModelListener) {
        ActionListener.completeWith(parsedModelListener, () -> this.modelBuilder.fromRequest(modelId, taskType, NAME, config));
    }

    public Model parsePersistedConfigWithSecrets(String modelId, TaskType taskType, Map<String, Object> config, Map<String, Object> secrets) {
        return this.modelBuilder.fromStorage(modelId, taskType, NAME, config, secrets);
    }

    public Model parsePersistedConfig(String modelId, TaskType taskType, Map<String, Object> config) {
        return this.modelBuilder.fromStorage(modelId, taskType, NAME, config, null);
    }

    public InferenceServiceConfiguration getConfiguration() {
        return (InferenceServiceConfiguration)this.configuration.getOrCompute();
    }

    public EnumSet<TaskType> supportedTaskTypes() {
        return this.schemas.supportedTaskTypes();
    }

    public Set<TaskType> supportedStreamingTasks() {
        return this.schemas.supportedStreamingTasks();
    }

    public void infer(Model model, @Nullable String query, @Nullable Boolean returnDocuments, @Nullable Integer topN, List<String> input, boolean stream, Map<String, Object> taskSettings, InputType inputType, @Nullable TimeValue timeout, ActionListener<InferenceServiceResults> listener) {
        if (!(model instanceof SageMakerModel)) {
            listener.onFailure((Exception)((Object)ServiceUtils.createInvalidModelException(model)));
            return;
        }
        SageMakerInferenceRequest inferenceRequest = new SageMakerInferenceRequest(query, returnDocuments, topN, input, stream, inputType);
        try {
            SageMakerModel sageMakerModel = ((SageMakerModel)model).override(taskSettings);
            SageMakerClient.RegionAndSecrets regionAndSecrets = this.regionAndSecrets(sageMakerModel);
            if (stream) {
                SageMakerStreamSchema schema = this.schemas.streamSchemaFor(sageMakerModel);
                InvokeEndpointWithResponseStreamRequest request = schema.streamRequest(sageMakerModel, inferenceRequest);
                this.client.invokeStream(regionAndSecrets, request, timeout != null ? timeout : DEFAULT_TIMEOUT, (ActionListener<SageMakerClient.SageMakerStream>)ActionListener.wrap(response -> listener.onResponse((Object)schema.streamResponse(sageMakerModel, (SageMakerClient.SageMakerStream)response)), e -> listener.onFailure(schema.error(sageMakerModel, (Exception)e))));
            } else {
                SageMakerSchema schema = this.schemas.schemaFor(sageMakerModel);
                InvokeEndpointRequest request = schema.request(sageMakerModel, inferenceRequest);
                this.client.invoke(regionAndSecrets, request, timeout != null ? timeout : DEFAULT_TIMEOUT, (ActionListener<InvokeEndpointResponse>)ActionListener.wrap(response -> listener.onResponse((Object)schema.response(sageMakerModel, (InvokeEndpointResponse)response, this.threadPool.getThreadContext())), e -> listener.onFailure(schema.error(sageMakerModel, (Exception)e))));
            }
        }
        catch (Exception e2) {
            listener.onFailure((Exception)((Object)SageMakerService.internalFailure(model, e2)));
        }
    }

    private SageMakerClient.RegionAndSecrets regionAndSecrets(SageMakerModel model) {
        Optional<AwsSecretSettings> secrets = model.awsSecretSettings();
        if (secrets.isEmpty()) {
            assert (false) : "Cannot invoke a model without secrets";
            throw new ElasticsearchStatusException(Strings.format((String)"Attempting to infer using a model without API keys, inference id [%s]", (Object[])new Object[]{model.getInferenceEntityId()}), RestStatus.INTERNAL_SERVER_ERROR, new Object[0]);
        }
        return new SageMakerClient.RegionAndSecrets(model.region(), secrets.get());
    }

    private static ElasticsearchStatusException internalFailure(Model model, Exception cause) {
        if (cause instanceof ElasticsearchStatusException) {
            ElasticsearchStatusException ese = (ElasticsearchStatusException)((Object)cause);
            return ese;
        }
        return new ElasticsearchStatusException("Failed to call SageMaker for inference id [{}].", RestStatus.INTERNAL_SERVER_ERROR, (Throwable)cause, new Object[]{model.getInferenceEntityId()});
    }

    public void unifiedCompletionInfer(Model model, UnifiedCompletionRequest request, @Nullable TimeValue timeout, ActionListener<InferenceServiceResults> listener) {
        if (!(model instanceof SageMakerModel)) {
            listener.onFailure((Exception)((Object)ServiceUtils.createInvalidModelException(model)));
            return;
        }
        try {
            SageMakerModel sageMakerModel = (SageMakerModel)model;
            SageMakerClient.RegionAndSecrets regionAndSecrets = this.regionAndSecrets(sageMakerModel);
            SageMakerStreamSchema schema = this.schemas.streamSchemaFor(sageMakerModel);
            InvokeEndpointWithResponseStreamRequest sagemakerRequest = schema.chatCompletionStreamRequest(sageMakerModel, request);
            this.client.invokeStream(regionAndSecrets, sagemakerRequest, timeout != null ? timeout : DEFAULT_TIMEOUT, (ActionListener<SageMakerClient.SageMakerStream>)ActionListener.wrap(response -> listener.onResponse((Object)schema.chatCompletionStreamResponse(sageMakerModel, (SageMakerClient.SageMakerStream)response)), e -> listener.onFailure((Exception)schema.chatCompletionError(sageMakerModel, (Exception)e))));
        }
        catch (Exception e2) {
            listener.onFailure((Exception)((Object)SageMakerService.internalFailure(model, e2)));
        }
    }

    public void chunkedInfer(Model model, String query, List<ChunkInferenceInput> input, Map<String, Object> taskSettings, InputType inputType, @Nullable TimeValue timeout, ActionListener<List<ChunkedInference>> listener) {
        if (!(model instanceof SageMakerModel)) {
            listener.onFailure((Exception)((Object)ServiceUtils.createInvalidModelException(model)));
            return;
        }
        if (input.isEmpty()) {
            listener.onResponse(List.of());
        }
        try {
            SageMakerModel sageMakerModel = ((SageMakerModel)model).override(taskSettings);
            List<EmbeddingRequestChunker.BatchRequestAndListener> batchedRequests = new EmbeddingRequestChunker(input, sageMakerModel.batchSize().orElse(256), sageMakerModel.getConfigurations().getChunkingSettings()).batchRequestsWithListeners(listener);
            SubscribableListener subscribableListener = SubscribableListener.newSucceeded(null);
            for (EmbeddingRequestChunker.BatchRequestAndListener request : batchedRequests) {
                subscribableListener = subscribableListener.andThen((Executor)this.threadPool.executor("inference_utility"), this.threadPool.getThreadContext(), (l, ignored) -> this.infer(sageMakerModel, query, null, null, request.batch().inputs().get(), false, null, inputType, timeout, (ActionListener<InferenceServiceResults>)ActionListener.runAfter(request.listener(), () -> l.onResponse(null))));
            }
            subscribableListener.addListener(ActionListener.noop().delegateResponse((l, e) -> listener.onFailure((Exception)((Object)SageMakerService.internalFailure(model, e)))));
        }
        catch (Exception e2) {
            listener.onFailure((Exception)((Object)SageMakerService.internalFailure(model, e2)));
        }
    }

    public void start(Model model, TimeValue timeout, ActionListener<Boolean> listener) {
        listener.onResponse((Object)true);
    }

    public Model updateModelWithEmbeddingDetails(Model model, int embeddingSize) {
        if (model instanceof SageMakerModel) {
            SageMakerModel sageMakerModel = (SageMakerModel)model;
            return this.modelBuilder.updateModelWithEmbeddingDetails(sageMakerModel, embeddingSize);
        }
        throw ServiceUtils.invalidModelTypeForUpdateModelWithEmbeddingDetails(model.getClass());
    }

    public TransportVersion getMinimalSupportedVersion() {
        return ML_INFERENCE_SAGEMAKER;
    }

    public void close() throws IOException {
        this.client.close();
    }
}

