/*
 * Decompiled with CFR 0.152.
 */
package org.elasticsearch.xpack.inference.services.sagemaker.schema;

import java.util.Map;
import java.util.stream.Stream;
import org.elasticsearch.ElasticsearchStatusException;
import org.elasticsearch.common.ValidationException;
import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
import org.elasticsearch.common.util.concurrent.ThreadContext;
import org.elasticsearch.core.Strings;
import org.elasticsearch.core.Tuple;
import org.elasticsearch.inference.InferenceServiceResults;
import org.elasticsearch.rest.RestStatus;
import org.elasticsearch.xpack.inference.services.sagemaker.SageMakerInferenceRequest;
import org.elasticsearch.xpack.inference.services.sagemaker.model.SageMakerModel;
import org.elasticsearch.xpack.inference.services.sagemaker.schema.SageMakerSchemaPayload;
import org.elasticsearch.xpack.inference.services.sagemaker.schema.SageMakerStoredServiceSchema;
import org.elasticsearch.xpack.inference.services.sagemaker.schema.SageMakerStoredTaskSchema;
import software.amazon.awssdk.services.sagemakerruntime.model.InternalDependencyException;
import software.amazon.awssdk.services.sagemakerruntime.model.InternalFailureException;
import software.amazon.awssdk.services.sagemakerruntime.model.InvokeEndpointRequest;
import software.amazon.awssdk.services.sagemakerruntime.model.InvokeEndpointResponse;
import software.amazon.awssdk.services.sagemakerruntime.model.ModelErrorException;
import software.amazon.awssdk.services.sagemakerruntime.model.ModelNotReadyException;
import software.amazon.awssdk.services.sagemakerruntime.model.SageMakerRuntimeException;
import software.amazon.awssdk.services.sagemakerruntime.model.ServiceUnavailableException;
import software.amazon.awssdk.services.sagemakerruntime.model.ValidationErrorException;

public class SageMakerSchema {
    private static final String CUSTOM_ATTRIBUTES_HEADER = "X-elastic-sagemaker-custom-attributes";
    private static final String NEW_SESSION_HEADER = "X-elastic-sagemaker-new-session-id";
    private static final String CLOSED_SESSION_HEADER = "X-elastic-sagemaker-closed-session-id";
    private static final String ACCESS_DENIED_CODE = "AccessDeniedException";
    private static final String INCOMPLETE_SIGNATURE = "IncompleteSignature";
    private static final String INVALID_ACTION = "InvalidAction";
    private static final String INVALID_CLIENT_TOKEN = "InvalidClientTokenId";
    private static final String NOT_AUTHORIZED = "NotAuthorized";
    private static final String OPT_IN_REQUIRED = "OptInRequired";
    private static final String REQUEST_EXPIRED = "RequestExpired";
    private static final String THROTTLING_EXCEPTION = "ThrottlingException";
    private final SageMakerSchemaPayload schemaPayload;

    public SageMakerSchema(SageMakerSchemaPayload schemaPayload) {
        this.schemaPayload = schemaPayload;
    }

    public InvokeEndpointRequest request(SageMakerModel model, SageMakerInferenceRequest request) {
        try {
            return (InvokeEndpointRequest)this.createRequest(model).accept(this.schemaPayload.accept(model)).contentType(this.schemaPayload.contentType(model)).body(this.schemaPayload.requestBytes(model, request)).build();
        }
        catch (ElasticsearchStatusException e) {
            throw e;
        }
        catch (Exception e) {
            throw new ElasticsearchStatusException("Failed to create SageMaker request for [{}]", RestStatus.INTERNAL_SERVER_ERROR, (Throwable)e, new Object[]{model.getInferenceEntityId()});
        }
    }

    private InvokeEndpointRequest.Builder createRequest(SageMakerModel model) {
        InvokeEndpointRequest.Builder request = InvokeEndpointRequest.builder();
        request.endpointName(model.endpointName());
        model.customAttributes().ifPresent(arg_0 -> ((InvokeEndpointRequest.Builder)request).customAttributes(arg_0));
        model.enableExplanations().ifPresent(arg_0 -> ((InvokeEndpointRequest.Builder)request).enableExplanations(arg_0));
        model.inferenceComponentName().ifPresent(arg_0 -> ((InvokeEndpointRequest.Builder)request).inferenceComponentName(arg_0));
        model.inferenceIdForDataCapture().ifPresent(arg_0 -> ((InvokeEndpointRequest.Builder)request).inferenceId(arg_0));
        model.sessionId().ifPresent(arg_0 -> ((InvokeEndpointRequest.Builder)request).sessionId(arg_0));
        model.targetContainerHostname().ifPresent(arg_0 -> ((InvokeEndpointRequest.Builder)request).targetContainerHostname(arg_0));
        model.targetModel().ifPresent(arg_0 -> ((InvokeEndpointRequest.Builder)request).targetModel(arg_0));
        model.targetVariant().ifPresent(arg_0 -> ((InvokeEndpointRequest.Builder)request).targetVariant(arg_0));
        return request;
    }

    public InferenceServiceResults response(SageMakerModel model, InvokeEndpointResponse response, ThreadContext threadContext) throws Exception {
        try {
            this.addHeaders(response, threadContext);
            return this.schemaPayload.responseBody(model, response);
        }
        catch (ElasticsearchStatusException e) {
            throw e;
        }
        catch (Exception e) {
            throw new ElasticsearchStatusException("Failed to translate SageMaker response for [{}]", RestStatus.INTERNAL_SERVER_ERROR, (Throwable)e, new Object[]{model.getInferenceEntityId()});
        }
    }

    private void addHeaders(InvokeEndpointResponse response, ThreadContext threadContext) {
        if (response.customAttributes() != null) {
            threadContext.addResponseHeader(CUSTOM_ATTRIBUTES_HEADER, response.customAttributes());
        }
        if (response.newSessionId() != null) {
            threadContext.addResponseHeader(NEW_SESSION_HEADER, response.newSessionId());
        }
        if (response.closedSessionId() != null) {
            threadContext.addResponseHeader(CLOSED_SESSION_HEADER, response.closedSessionId());
        }
    }

    public Exception error(SageMakerModel model, Exception e) {
        if (e instanceof ElasticsearchStatusException) {
            ElasticsearchStatusException ee = (ElasticsearchStatusException)((Object)e);
            return ee;
        }
        Tuple<String, RestStatus> error = this.errorMessageAndStatus(model, e);
        return new ElasticsearchStatusException((String)error.v1(), (RestStatus)error.v2(), (Throwable)e, new Object[0]);
    }

    protected Tuple<String, RestStatus> errorMessageAndStatus(SageMakerModel model, Exception e) {
        SageMakerRuntimeException re;
        String errorMessage = null;
        RestStatus restStatus = null;
        if (e instanceof InternalDependencyException) {
            errorMessage = Strings.format((String)"Received an internal dependency error from SageMaker for [%s]", (Object[])new Object[]{model.getInferenceEntityId()});
            restStatus = RestStatus.INTERNAL_SERVER_ERROR;
        } else if (e instanceof InternalFailureException) {
            errorMessage = Strings.format((String)"Received an internal failure from SageMaker for [%s]", (Object[])new Object[]{model.getInferenceEntityId()});
            restStatus = RestStatus.INTERNAL_SERVER_ERROR;
        } else if (e instanceof ModelErrorException) {
            errorMessage = Strings.format((String)"Received a model error from SageMaker for [%s]", (Object[])new Object[]{model.getInferenceEntityId()});
            restStatus = RestStatus.FAILED_DEPENDENCY;
        } else if (e instanceof ModelNotReadyException) {
            errorMessage = Strings.format((String)"Received a model not ready error from SageMaker for [%s]", (Object[])new Object[]{model.getInferenceEntityId()});
            restStatus = RestStatus.TOO_MANY_REQUESTS;
        } else if (e instanceof ServiceUnavailableException) {
            errorMessage = Strings.format((String)"SageMaker is unavailable for [%s]", (Object[])new Object[]{model.getInferenceEntityId()});
            restStatus = RestStatus.SERVICE_UNAVAILABLE;
        } else if (e instanceof ValidationErrorException) {
            errorMessage = Strings.format((String)"Received a validation error from SageMaker for [%s]", (Object[])new Object[]{model.getInferenceEntityId()});
            restStatus = RestStatus.BAD_REQUEST;
        }
        if (errorMessage == null && e instanceof SageMakerRuntimeException && (re = (SageMakerRuntimeException)e).awsErrorDetails() != null) {
            switch (re.awsErrorDetails().errorCode()) {
                case "AccessDeniedException": 
                case "NotAuthorized": {
                    errorMessage = Strings.format((String)"Access and Secret key stored in [%s] do not have sufficient permissions.", (Object[])new Object[]{model.getInferenceEntityId()});
                    restStatus = RestStatus.BAD_REQUEST;
                    break;
                }
                case "IncompleteSignature": {
                    errorMessage = Strings.format((String)"The request signature does not conform to AWS standards [%s]", (Object[])new Object[]{model.getInferenceEntityId()});
                    restStatus = RestStatus.INTERNAL_SERVER_ERROR;
                    break;
                }
                case "InvalidAction": {
                    errorMessage = Strings.format((String)"The requested action is not valid for [%s]", (Object[])new Object[]{model.getInferenceEntityId()});
                    restStatus = RestStatus.BAD_REQUEST;
                    break;
                }
                case "InvalidClientTokenId": {
                    errorMessage = Strings.format((String)"Access key stored in [%s] does not exist in AWS", (Object[])new Object[]{model.getInferenceEntityId()});
                    restStatus = RestStatus.FORBIDDEN;
                    break;
                }
                case "OptInRequired": {
                    errorMessage = Strings.format((String)"Access key stored in [%s] needs a subscription for the service", (Object[])new Object[]{model.getInferenceEntityId()});
                    restStatus = RestStatus.FORBIDDEN;
                    break;
                }
                case "RequestExpired": {
                    errorMessage = Strings.format((String)"The request reached SageMaker more than 15 minutes after the date stamp on the request for [%s]", (Object[])new Object[]{model.getInferenceEntityId()});
                    restStatus = RestStatus.BAD_REQUEST;
                    break;
                }
                case "ThrottlingException": {
                    errorMessage = Strings.format((String)"SageMaker denied the request for [%s] due to request throttling", (Object[])new Object[]{model.getInferenceEntityId()});
                    restStatus = RestStatus.BAD_REQUEST;
                }
            }
        }
        if (errorMessage == null) {
            errorMessage = Strings.format((String)"Received an error from SageMaker for [%s]", (Object[])new Object[]{model.getInferenceEntityId()});
            restStatus = RestStatus.INTERNAL_SERVER_ERROR;
        }
        return Tuple.tuple((Object)errorMessage, (Object)restStatus);
    }

    public SageMakerStoredServiceSchema apiServiceSettings(Map<String, Object> serviceSettings, ValidationException validationException) {
        return this.schemaPayload.apiServiceSettings(serviceSettings, validationException);
    }

    public SageMakerStoredTaskSchema apiTaskSettings(Map<String, Object> taskSettings, ValidationException validationException) {
        return this.schemaPayload.apiTaskSettings(taskSettings, validationException);
    }

    public Stream<NamedWriteableRegistry.Entry> namedWriteables() {
        return this.schemaPayload.namedWriteables();
    }
}

