/*
 * Decompiled with CFR 0.152.
 */
package org.elasticsearch.xpack.inference.services.sagemaker.schema;

import java.util.Locale;
import java.util.concurrent.Flow;
import java.util.function.BiFunction;
import org.elasticsearch.ElasticsearchStatusException;
import org.elasticsearch.ExceptionsHelper;
import org.elasticsearch.common.CheckedBiFunction;
import org.elasticsearch.common.CheckedSupplier;
import org.elasticsearch.core.Tuple;
import org.elasticsearch.inference.InferenceServiceResults;
import org.elasticsearch.inference.UnifiedCompletionRequest;
import org.elasticsearch.rest.RestStatus;
import org.elasticsearch.xpack.core.inference.results.StreamingChatCompletionResults;
import org.elasticsearch.xpack.core.inference.results.StreamingUnifiedChatCompletionResults;
import org.elasticsearch.xpack.core.inference.results.UnifiedChatCompletionException;
import org.elasticsearch.xpack.inference.services.sagemaker.SageMakerClient;
import org.elasticsearch.xpack.inference.services.sagemaker.SageMakerInferenceRequest;
import org.elasticsearch.xpack.inference.services.sagemaker.model.SageMakerModel;
import org.elasticsearch.xpack.inference.services.sagemaker.schema.SageMakerSchema;
import org.elasticsearch.xpack.inference.services.sagemaker.schema.SageMakerStreamSchemaPayload;
import software.amazon.awssdk.core.SdkBytes;
import software.amazon.awssdk.services.sagemakerruntime.model.InvokeEndpointWithResponseStreamRequest;
import software.amazon.awssdk.services.sagemakerruntime.model.InvokeEndpointWithResponseStreamResponseHandler;
import software.amazon.awssdk.services.sagemakerruntime.model.ResponseStream;

public class SageMakerStreamSchema
extends SageMakerSchema {
    private final SageMakerStreamSchemaPayload payload;

    public SageMakerStreamSchema(SageMakerStreamSchemaPayload payload) {
        super(payload);
        this.payload = payload;
    }

    public InvokeEndpointWithResponseStreamRequest streamRequest(SageMakerModel model, SageMakerInferenceRequest request) {
        return this.streamRequest(model, (CheckedSupplier<SdkBytes, Exception>)((CheckedSupplier)() -> this.payload.requestBytes(model, request)));
    }

    private InvokeEndpointWithResponseStreamRequest streamRequest(SageMakerModel model, CheckedSupplier<SdkBytes, Exception> body) {
        try {
            return (InvokeEndpointWithResponseStreamRequest)this.createStreamRequest(model).accept(this.payload.accept(model)).contentType(this.payload.contentType(model)).body((SdkBytes)body.get()).build();
        }
        catch (ElasticsearchStatusException e) {
            throw e;
        }
        catch (Exception e) {
            throw new ElasticsearchStatusException("Failed to create SageMaker request for [%s]", RestStatus.INTERNAL_SERVER_ERROR, (Throwable)e, new Object[]{model.getInferenceEntityId()});
        }
    }

    public InferenceServiceResults streamResponse(SageMakerModel model, SageMakerClient.SageMakerStream response) {
        return new StreamingChatCompletionResults(this.streamResponse(model, response, this.payload::streamResponseBody, this::error));
    }

    private <T> Flow.Publisher<T> streamResponse(final SageMakerModel model, SageMakerClient.SageMakerStream response, final CheckedBiFunction<SageMakerModel, SdkBytes, T, Exception> parseFunction, final BiFunction<SageMakerModel, Exception, Exception> errorFunction) {
        return downstream -> response.responseStream().subscribe(new Flow.Subscriber<ResponseStream>(this){
            private volatile Flow.Subscription upstream;

            @Override
            public void onSubscribe(Flow.Subscription subscription) {
                this.upstream = subscription;
                downstream.onSubscribe(subscription);
            }

            @Override
            public void onNext(ResponseStream item) {
                if (item.sdkEventType() == ResponseStream.EventType.PAYLOAD_PART) {
                    item.accept(InvokeEndpointWithResponseStreamResponseHandler.Visitor.builder().onPayloadPart(payloadPart -> {
                        try {
                            downstream.onNext(parseFunction.apply((Object)model, (Object)payloadPart.bytes()));
                        }
                        catch (Exception e) {
                            downstream.onError((Throwable)errorFunction.apply(model, e));
                        }
                    }).build());
                } else {
                    assert (this.upstream != null) : "upstream is unset";
                    this.upstream.request(1L);
                }
            }

            @Override
            public void onError(Throwable throwable) {
                if (throwable instanceof Exception) {
                    Exception e = (Exception)throwable;
                    downstream.onError((Throwable)errorFunction.apply(model, e));
                } else {
                    ExceptionsHelper.maybeError((Throwable)throwable).ifPresent(ExceptionsHelper::maybeDieOnAnotherThread);
                    RuntimeException e = new RuntimeException("Fatal while streaming SageMaker response for [" + model.getInferenceEntityId() + "]");
                    downstream.onError((Throwable)errorFunction.apply(model, e));
                }
            }

            @Override
            public void onComplete() {
                downstream.onComplete();
            }
        });
    }

    public InvokeEndpointWithResponseStreamRequest chatCompletionStreamRequest(SageMakerModel model, UnifiedCompletionRequest request) {
        return this.streamRequest(model, (CheckedSupplier<SdkBytes, Exception>)((CheckedSupplier)() -> this.payload.chatCompletionRequestBytes(model, request)));
    }

    public InferenceServiceResults chatCompletionStreamResponse(SageMakerModel model, SageMakerClient.SageMakerStream response) {
        return new StreamingUnifiedChatCompletionResults(this.streamResponse(model, response, this.payload::chatCompletionResponseBody, this::chatCompletionError));
    }

    public UnifiedChatCompletionException chatCompletionError(SageMakerModel model, Exception e) {
        if (e instanceof UnifiedChatCompletionException) {
            UnifiedChatCompletionException ucce = (UnifiedChatCompletionException)e;
            return ucce;
        }
        Tuple<String, RestStatus> error = this.errorMessageAndStatus(model, e);
        return new UnifiedChatCompletionException((RestStatus)error.v2(), (String)error.v1(), "error", ((RestStatus)error.v2()).name().toLowerCase(Locale.ROOT));
    }

    private InvokeEndpointWithResponseStreamRequest.Builder createStreamRequest(SageMakerModel model) {
        InvokeEndpointWithResponseStreamRequest.Builder request = InvokeEndpointWithResponseStreamRequest.builder();
        request.endpointName(model.endpointName());
        model.customAttributes().ifPresent(arg_0 -> ((InvokeEndpointWithResponseStreamRequest.Builder)request).customAttributes(arg_0));
        model.inferenceComponentName().ifPresent(arg_0 -> ((InvokeEndpointWithResponseStreamRequest.Builder)request).inferenceComponentName(arg_0));
        model.inferenceIdForDataCapture().ifPresent(arg_0 -> ((InvokeEndpointWithResponseStreamRequest.Builder)request).inferenceId(arg_0));
        model.sessionId().ifPresent(arg_0 -> ((InvokeEndpointWithResponseStreamRequest.Builder)request).sessionId(arg_0));
        model.targetContainerHostname().ifPresent(arg_0 -> ((InvokeEndpointWithResponseStreamRequest.Builder)request).targetContainerHostname(arg_0));
        model.targetVariant().ifPresent(arg_0 -> ((InvokeEndpointWithResponseStreamRequest.Builder)request).targetVariant(arg_0));
        return request;
    }
}

