/*
 * Decompiled with CFR 0.152.
 */
package org.elasticsearch.xpack.inference.services.sagemaker;

import java.io.Closeable;
import java.security.AccessController;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.CompletionException;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.Executor;
import java.util.concurrent.Flow;
import java.util.concurrent.Future;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicReference;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.elasticsearch.ElasticsearchStatusException;
import org.elasticsearch.ExceptionsHelper;
import org.elasticsearch.SpecialPermission;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.action.support.ContextPreservingActionListener;
import org.elasticsearch.action.support.ListenerTimeouts;
import org.elasticsearch.common.cache.Cache;
import org.elasticsearch.common.cache.CacheBuilder;
import org.elasticsearch.common.cache.CacheLoader;
import org.elasticsearch.common.util.concurrent.FutureUtils;
import org.elasticsearch.core.TimeValue;
import org.elasticsearch.core.Tuple;
import org.elasticsearch.rest.RestStatus;
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.xpack.inference.common.amazon.AwsSecretSettings;
import org.elasticsearch.xpack.inference.external.http.HttpSettings;
import org.reactivestreams.FlowAdapters;
import org.reactivestreams.Publisher;
import software.amazon.awssdk.auth.credentials.AwsBasicCredentials;
import software.amazon.awssdk.auth.credentials.AwsCredentials;
import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider;
import software.amazon.awssdk.auth.credentials.StaticCredentialsProvider;
import software.amazon.awssdk.core.client.config.ClientOverrideConfiguration;
import software.amazon.awssdk.http.async.SdkAsyncHttpClient;
import software.amazon.awssdk.http.nio.netty.NettyNioAsyncHttpClient;
import software.amazon.awssdk.profiles.ProfileFile;
import software.amazon.awssdk.regions.Region;
import software.amazon.awssdk.services.sagemakerruntime.SageMakerRuntimeAsyncClient;
import software.amazon.awssdk.services.sagemakerruntime.SageMakerRuntimeAsyncClientBuilder;
import software.amazon.awssdk.services.sagemakerruntime.model.InvokeEndpointRequest;
import software.amazon.awssdk.services.sagemakerruntime.model.InvokeEndpointResponse;
import software.amazon.awssdk.services.sagemakerruntime.model.InvokeEndpointWithResponseStreamRequest;
import software.amazon.awssdk.services.sagemakerruntime.model.InvokeEndpointWithResponseStreamResponse;
import software.amazon.awssdk.services.sagemakerruntime.model.InvokeEndpointWithResponseStreamResponseHandler;
import software.amazon.awssdk.services.sagemakerruntime.model.ResponseStream;

public class SageMakerClient
implements Closeable {
    private static final Logger log = LogManager.getLogger(SageMakerClient.class);
    private final Cache<RegionAndSecrets, SageMakerRuntimeAsyncClient> existingClients = CacheBuilder.builder().removalListener(removal -> ((SageMakerRuntimeAsyncClient)removal.getValue()).close()).setExpireAfterAccess(TimeValue.timeValueMinutes((long)15L)).build();
    private final CacheLoader<RegionAndSecrets, SageMakerRuntimeAsyncClient> clientFactory;
    private final ThreadPool threadPool;

    public SageMakerClient(CacheLoader<RegionAndSecrets, SageMakerRuntimeAsyncClient> clientFactory, ThreadPool threadPool) {
        this.clientFactory = clientFactory;
        this.threadPool = threadPool;
    }

    public void invoke(RegionAndSecrets regionAndSecrets, InvokeEndpointRequest request, TimeValue timeout, ActionListener<InvokeEndpointResponse> listener) {
        SageMakerRuntimeAsyncClient asyncClient;
        try {
            asyncClient = (SageMakerRuntimeAsyncClient)this.existingClients.computeIfAbsent((Object)regionAndSecrets, this.clientFactory);
        }
        catch (ExecutionException e) {
            listener.onFailure(SageMakerClient.clientFailure(regionAndSecrets, e));
            return;
        }
        ContextPreservingActionListener contextPreservingListener = new ContextPreservingActionListener(this.threadPool.getThreadContext().newRestorableContext(false), listener);
        CompletableFuture awsFuture = asyncClient.invokeEndpoint(request);
        ActionListener timeoutListener = ListenerTimeouts.wrapWithTimeout((ThreadPool)this.threadPool, (TimeValue)timeout, (Executor)this.threadPool.executor("inference_utility"), (ActionListener)contextPreservingListener, ignored -> {
            FutureUtils.cancel((Future)awsFuture);
            contextPreservingListener.onFailure((Exception)((Object)new ElasticsearchStatusException("Request timed out after [{}]", RestStatus.REQUEST_TIMEOUT, new Object[]{timeout})));
        });
        ((CompletableFuture)awsFuture.thenAcceptAsync(arg_0 -> ((ActionListener)timeoutListener).onResponse(arg_0), (Executor)this.threadPool.executor("inference_utility"))).exceptionallyAsync(t -> this.failAndMaybeThrowError((Throwable)t, (ActionListener<?>)timeoutListener), (Executor)this.threadPool.executor("inference_utility"));
    }

    private static Exception clientFailure(RegionAndSecrets regionAndSecrets, Exception cause) {
        return new ElasticsearchStatusException("failed to create SageMakerRuntime client for region [{}]", RestStatus.INTERNAL_SERVER_ERROR, (Throwable)cause, new Object[]{regionAndSecrets.region()});
    }

    private Void failAndMaybeThrowError(Throwable t, ActionListener<?> listener) {
        if (t instanceof CompletionException) {
            CompletionException ce = (CompletionException)t;
            t = ce.getCause();
        }
        if (t instanceof Exception) {
            Exception e = (Exception)t;
            listener.onFailure(e);
        } else {
            ExceptionsHelper.maybeError((Throwable)t).ifPresent(ExceptionsHelper::maybeDieOnAnotherThread);
            log.atWarn().withThrowable(t).log("Unknown failure calling SageMaker.");
            listener.onFailure((Exception)new RuntimeException("Unknown failure calling SageMaker.", t));
        }
        return null;
    }

    public void invokeStream(RegionAndSecrets regionAndSecrets, InvokeEndpointWithResponseStreamRequest request, TimeValue timeout, ActionListener<SageMakerStream> listener) {
        SageMakerRuntimeAsyncClient asyncClient;
        try {
            asyncClient = (SageMakerRuntimeAsyncClient)this.existingClients.computeIfAbsent((Object)regionAndSecrets, this.clientFactory);
        }
        catch (ExecutionException e) {
            listener.onFailure(SageMakerClient.clientFailure(regionAndSecrets, e));
            return;
        }
        ContextPreservingActionListener contextPreservingListener = new ContextPreservingActionListener(this.threadPool.getThreadContext().newRestorableContext(false), listener);
        SageMakerStreamingResponseProcessor responseStreamProcessor = new SageMakerStreamingResponseProcessor();
        AtomicReference<CompletableFuture> cancelAwsRequestListener = new AtomicReference<CompletableFuture>();
        ActionListener timeoutListener = ListenerTimeouts.wrapWithTimeout((ThreadPool)this.threadPool, (TimeValue)timeout, (Executor)this.threadPool.executor("inference_utility"), (ActionListener)contextPreservingListener, ignored -> {
            FutureUtils.cancel((Future)((Future)cancelAwsRequestListener.get()));
            contextPreservingListener.onFailure((Exception)((Object)new ElasticsearchStatusException("Request timed out after [{}]", RestStatus.REQUEST_TIMEOUT, new Object[]{timeout})));
        });
        InvokeEndpointWithResponseStreamResponseHandler responseStreamListener = ((InvokeEndpointWithResponseStreamResponseHandler.Builder)((InvokeEndpointWithResponseStreamResponseHandler.Builder)InvokeEndpointWithResponseStreamResponseHandler.builder().onResponse(response -> timeoutListener.onResponse((Object)new SageMakerStream((InvokeEndpointWithResponseStreamResponse)response, responseStreamProcessor)))).onEventStream(publisher -> responseStreamProcessor.setPublisher(FlowAdapters.toFlowPublisher((Publisher)publisher)))).build();
        CompletableFuture awsFuture = asyncClient.invokeEndpointWithResponseStream(request, responseStreamListener);
        cancelAwsRequestListener.set(awsFuture);
        awsFuture.exceptionallyAsync(t -> this.failAndMaybeThrowError((Throwable)t, (ActionListener<?>)timeoutListener), (Executor)this.threadPool.executor("inference_utility"));
    }

    @Override
    public void close() {
        this.existingClients.invalidateAll();
    }

    public record RegionAndSecrets(String region, AwsSecretSettings secretSettings) {
    }

    private static class SageMakerStreamingResponseProcessor
    implements Flow.Publisher<ResponseStream> {
        private static final Logger log = LogManager.getLogger(SageMakerStreamingResponseProcessor.class);
        private final AtomicReference<Tuple<Flow.Publisher<ResponseStream>, Flow.Subscriber<? super ResponseStream>>> holder = new AtomicReference<Object>(null);
        private final AtomicBoolean subscribeCalledOnce = new AtomicBoolean(false);

        private SageMakerStreamingResponseProcessor() {
        }

        @Override
        public void subscribe(Flow.Subscriber<? super ResponseStream> subscriber) {
            if (!this.subscribeCalledOnce.compareAndSet(false, true)) {
                subscriber.onError(new IllegalStateException("Subscriber already set."));
                return;
            }
            if (!this.holder.compareAndSet(null, (Tuple<Flow.Publisher<ResponseStream>, Flow.Subscriber<? super ResponseStream>>)Tuple.tuple(null, subscriber))) {
                log.debug("Subscriber connecting to publisher.");
                Flow.Publisher publisher = (Flow.Publisher)((Tuple)this.holder.getAndSet(null)).v1();
                publisher.subscribe(subscriber);
            } else {
                log.debug("Subscriber waiting for connection.");
            }
        }

        private void setPublisher(Flow.Publisher<ResponseStream> publisher) {
            if (!this.holder.compareAndSet(null, (Tuple<Flow.Publisher<ResponseStream>, Flow.Subscriber<? super ResponseStream>>)Tuple.tuple(publisher, null))) {
                log.debug("Publisher connecting to subscriber.");
                Flow.Subscriber subscriber = (Flow.Subscriber)((Tuple)this.holder.getAndSet(null)).v2();
                publisher.subscribe(subscriber);
            } else {
                log.debug("Publisher waiting for connection.");
            }
        }
    }

    public record SageMakerStream(InvokeEndpointWithResponseStreamResponse response, Flow.Publisher<ResponseStream> responseStream) {
    }

    public static class Factory
    implements CacheLoader<RegionAndSecrets, SageMakerRuntimeAsyncClient> {
        private final HttpSettings httpSettings;

        public Factory(HttpSettings httpSettings) {
            this.httpSettings = httpSettings;
        }

        public SageMakerRuntimeAsyncClient load(RegionAndSecrets key) throws Exception {
            SpecialPermission.check();
            return AccessController.doPrivileged(() -> {
                AwsBasicCredentials credentials = AwsBasicCredentials.create((String)key.secretSettings().accessKey().toString(), (String)key.secretSettings().secretKey().toString());
                StaticCredentialsProvider credentialsProvider = StaticCredentialsProvider.create((AwsCredentials)credentials);
                NettyNioAsyncHttpClient.Builder clientConfig = NettyNioAsyncHttpClient.builder().connectionTimeout(this.httpSettings.connectionTimeoutDuration());
                ClientOverrideConfiguration override = (ClientOverrideConfiguration)ClientOverrideConfiguration.builder().defaultProfileFileSupplier(() -> ((ProfileFile.Aggregator)ProfileFile.aggregator()).build()).defaultProfileFile(ProfileFile.aggregator().build()).retryPolicy(retryPolicy -> retryPolicy.numRetries(Integer.valueOf(3))).retryStrategy(retryStrategy -> retryStrategy.maxAttempts(3)).build();
                return (SageMakerRuntimeAsyncClient)((SageMakerRuntimeAsyncClientBuilder)((SageMakerRuntimeAsyncClientBuilder)((SageMakerRuntimeAsyncClientBuilder)((SageMakerRuntimeAsyncClientBuilder)SageMakerRuntimeAsyncClient.builder().credentialsProvider((AwsCredentialsProvider)credentialsProvider)).region(Region.of((String)key.region()))).httpClientBuilder((SdkAsyncHttpClient.Builder)clientConfig)).overrideConfiguration(override)).build();
            });
        }
    }
}

