/*
 * Decompiled with CFR 0.152.
 */
package org.elasticsearch.xpack.core.inference.action;

import java.io.IOException;
import java.util.ArrayList;
import java.util.EnumSet;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import org.elasticsearch.ElasticsearchStatusException;
import org.elasticsearch.TransportVersion;
import org.elasticsearch.TransportVersions;
import org.elasticsearch.action.ActionRequest;
import org.elasticsearch.action.ActionRequestValidationException;
import org.elasticsearch.action.ActionResponse;
import org.elasticsearch.action.ActionType;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.core.TimeValue;
import org.elasticsearch.inference.InferenceResults;
import org.elasticsearch.inference.InferenceServiceResults;
import org.elasticsearch.inference.InputType;
import org.elasticsearch.inference.TaskType;
import org.elasticsearch.rest.RestStatus;
import org.elasticsearch.xcontent.ObjectParser;
import org.elasticsearch.xcontent.ParseField;
import org.elasticsearch.xcontent.ToXContent;
import org.elasticsearch.xcontent.ToXContentObject;
import org.elasticsearch.xcontent.XContentBuilder;
import org.elasticsearch.xcontent.XContentParser;
import org.elasticsearch.xpack.core.inference.results.LegacyTextEmbeddingResults;
import org.elasticsearch.xpack.core.inference.results.SparseEmbeddingResults;
import org.elasticsearch.xpack.core.ml.inference.results.TextExpansionResults;

public class InferenceAction
extends ActionType<Response> {
    public static final InferenceAction INSTANCE = new InferenceAction();
    public static final String NAME = "cluster:monitor/xpack/inference";

    public InferenceAction() {
        super(NAME);
    }

    public static class Response
    extends ActionResponse
    implements ToXContentObject {
        private final InferenceServiceResults results;

        public Response(InferenceServiceResults results) {
            this.results = results;
        }

        public Response(StreamInput in) throws IOException {
            super(in);
            this.results = in.getTransportVersion().onOrAfter(TransportVersions.V_8_12_0) ? in.readNamedWriteable(InferenceServiceResults.class) : Response.transformToServiceResults(List.of(in.readNamedWriteable(InferenceResults.class)));
        }

        public static InferenceServiceResults transformToServiceResults(List<? extends InferenceResults> parsedResults) {
            if (parsedResults.isEmpty()) {
                throw new ElasticsearchStatusException("Failed to transform results to response format, expected a non-empty list, please remove and re-add the service", RestStatus.INTERNAL_SERVER_ERROR, new Object[0]);
            }
            InferenceResults inferenceResults = parsedResults.get(0);
            if (inferenceResults instanceof LegacyTextEmbeddingResults) {
                LegacyTextEmbeddingResults openaiResults = (LegacyTextEmbeddingResults)inferenceResults;
                if (parsedResults.size() > 1) {
                    throw new ElasticsearchStatusException("Failed to transform results to response format, malformed text embedding result, please remove and re-add the service", RestStatus.INTERNAL_SERVER_ERROR, new Object[0]);
                }
                return openaiResults.transformToTextEmbeddingResults();
            }
            if (parsedResults.get(0) instanceof TextExpansionResults) {
                return Response.transformToSparseEmbeddingResult(parsedResults);
            }
            throw new ElasticsearchStatusException("Failed to transform results to response format, unknown embedding type received, please remove and re-add the service", RestStatus.INTERNAL_SERVER_ERROR, new Object[0]);
        }

        private static SparseEmbeddingResults transformToSparseEmbeddingResult(List<? extends InferenceResults> parsedResults) {
            ArrayList<TextExpansionResults> textExpansionResults = new ArrayList<TextExpansionResults>(parsedResults.size());
            for (InferenceResults inferenceResults : parsedResults) {
                if (inferenceResults instanceof TextExpansionResults) {
                    TextExpansionResults textExpansion = (TextExpansionResults)inferenceResults;
                    textExpansionResults.add(textExpansion);
                    continue;
                }
                throw new ElasticsearchStatusException("Failed to transform results to response format, please remove and re-add the service", RestStatus.INTERNAL_SERVER_ERROR, new Object[0]);
            }
            return SparseEmbeddingResults.of(textExpansionResults);
        }

        public InferenceServiceResults getResults() {
            return this.results;
        }

        @Override
        public void writeTo(StreamOutput out) throws IOException {
            if (out.getTransportVersion().onOrAfter(TransportVersions.V_8_12_0)) {
                out.writeNamedWriteable(this.results);
            } else {
                out.writeNamedWriteable(this.results.transformToLegacyFormat().get(0));
            }
        }

        @Override
        public XContentBuilder toXContent(XContentBuilder builder, ToXContent.Params params) throws IOException {
            builder.startObject();
            this.results.toXContent(builder, params);
            builder.endObject();
            return builder;
        }

        public boolean equals(Object o) {
            if (this == o) {
                return true;
            }
            if (o == null || this.getClass() != o.getClass()) {
                return false;
            }
            Response response = (Response)o;
            return Objects.equals(this.results, response.results);
        }

        public int hashCode() {
            return Objects.hash(this.results);
        }
    }

    public static class Request
    extends ActionRequest {
        public static final TimeValue DEFAULT_TIMEOUT = TimeValue.timeValueSeconds(30L);
        public static final ParseField INPUT = new ParseField("input", new String[0]);
        public static final ParseField TASK_SETTINGS = new ParseField("task_settings", new String[0]);
        public static final ParseField QUERY = new ParseField("query", new String[0]);
        public static final ParseField TIMEOUT = new ParseField("timeout", new String[0]);
        static final ObjectParser<Builder, Void> PARSER = new ObjectParser("cluster:monitor/xpack/inference", Builder::new);
        private static final EnumSet<InputType> validEnumsBeforeUnspecifiedAdded;
        private static final EnumSet<InputType> validEnumsBeforeClassificationClusteringAdded;
        private final TaskType taskType;
        private final String inferenceEntityId;
        private final String query;
        private final List<String> input;
        private final Map<String, Object> taskSettings;
        private final InputType inputType;
        private final TimeValue inferenceTimeout;

        public static Builder parseRequest(String inferenceEntityId, TaskType taskType, XContentParser parser) throws IOException {
            Builder builder = PARSER.apply(parser, null);
            builder.setInferenceEntityId(inferenceEntityId);
            builder.setTaskType(taskType);
            builder.setInputType(InputType.UNSPECIFIED);
            return builder;
        }

        public Request(TaskType taskType, String inferenceEntityId, String query, List<String> input, Map<String, Object> taskSettings, InputType inputType, TimeValue inferenceTimeout) {
            this.taskType = taskType;
            this.inferenceEntityId = inferenceEntityId;
            this.query = query;
            this.input = input;
            this.taskSettings = taskSettings;
            this.inputType = inputType;
            this.inferenceTimeout = inferenceTimeout;
        }

        public Request(StreamInput in) throws IOException {
            super(in);
            this.taskType = TaskType.fromStream(in);
            this.inferenceEntityId = in.readString();
            this.input = in.getTransportVersion().onOrAfter(TransportVersions.V_8_12_0) ? in.readStringCollectionAsList() : List.of(in.readString());
            this.taskSettings = in.readGenericMap();
            this.inputType = in.getTransportVersion().onOrAfter(TransportVersions.ML_INFERENCE_REQUEST_INPUT_TYPE_ADDED) ? in.readEnum(InputType.class) : InputType.UNSPECIFIED;
            this.query = in.getTransportVersion().onOrAfter(TransportVersions.ML_INFERENCE_COHERE_RERANK) ? in.readOptionalString() : null;
            this.inferenceTimeout = in.getTransportVersion().onOrAfter(TransportVersions.ML_INFERENCE_TIMEOUT_ADDED) ? in.readTimeValue() : DEFAULT_TIMEOUT;
        }

        public TaskType getTaskType() {
            return this.taskType;
        }

        public String getInferenceEntityId() {
            return this.inferenceEntityId;
        }

        public List<String> getInput() {
            return this.input;
        }

        public String getQuery() {
            return this.query;
        }

        public Map<String, Object> getTaskSettings() {
            return this.taskSettings;
        }

        public InputType getInputType() {
            return this.inputType;
        }

        public TimeValue getInferenceTimeout() {
            return this.inferenceTimeout;
        }

        @Override
        public ActionRequestValidationException validate() {
            if (this.input == null) {
                ActionRequestValidationException e = new ActionRequestValidationException();
                e.addValidationError("missing input");
                return e;
            }
            if (this.input.isEmpty()) {
                ActionRequestValidationException e = new ActionRequestValidationException();
                e.addValidationError("input array is empty");
                return e;
            }
            return null;
        }

        @Override
        public void writeTo(StreamOutput out) throws IOException {
            super.writeTo(out);
            this.taskType.writeTo(out);
            out.writeString(this.inferenceEntityId);
            if (out.getTransportVersion().onOrAfter(TransportVersions.V_8_12_0)) {
                out.writeStringCollection(this.input);
            } else {
                out.writeString(this.input.get(0));
            }
            out.writeGenericMap(this.taskSettings);
            if (out.getTransportVersion().onOrAfter(TransportVersions.ML_INFERENCE_REQUEST_INPUT_TYPE_ADDED)) {
                out.writeEnum(Request.getInputTypeToWrite(this.inputType, out.getTransportVersion()));
            }
            if (out.getTransportVersion().onOrAfter(TransportVersions.ML_INFERENCE_COHERE_RERANK)) {
                out.writeOptionalString(this.query);
            }
            if (out.getTransportVersion().onOrAfter(TransportVersions.ML_INFERENCE_TIMEOUT_ADDED)) {
                out.writeTimeValue(this.inferenceTimeout);
            }
        }

        static InputType getInputTypeToWrite(InputType inputType, TransportVersion version) {
            if (version.before(TransportVersions.ML_INFERENCE_REQUEST_INPUT_TYPE_UNSPECIFIED_ADDED) && !validEnumsBeforeUnspecifiedAdded.contains((Object)inputType)) {
                return InputType.INGEST;
            }
            if (version.before(TransportVersions.ML_INFERENCE_REQUEST_INPUT_TYPE_CLASS_CLUSTER_ADDED) && !validEnumsBeforeClassificationClusteringAdded.contains((Object)inputType)) {
                return InputType.UNSPECIFIED;
            }
            return inputType;
        }

        public boolean equals(Object o) {
            if (this == o) {
                return true;
            }
            if (o == null || this.getClass() != o.getClass()) {
                return false;
            }
            Request request = (Request)o;
            return this.taskType == request.taskType && Objects.equals(this.inferenceEntityId, request.inferenceEntityId) && Objects.equals(this.input, request.input) && Objects.equals(this.taskSettings, request.taskSettings) && Objects.equals((Object)this.inputType, (Object)request.inputType) && Objects.equals(this.query, request.query) && Objects.equals(this.inferenceTimeout, request.inferenceTimeout);
        }

        public int hashCode() {
            return Objects.hash(new Object[]{this.taskType, this.inferenceEntityId, this.input, this.taskSettings, this.inputType, this.query, this.inferenceTimeout});
        }

        @Override
        public String toString() {
            return "InferenceAction.Request(taskType=" + this.getTaskType() + ", inferenceEntityId=" + this.getInferenceEntityId() + ", query=" + this.getQuery() + ", input=" + this.getInput() + ", taskSettings=" + this.getTaskSettings() + ", inputType=" + this.getInputType() + ", timeout=" + this.getInferenceTimeout() + ")";
        }

        static {
            PARSER.declareStringArray(Builder::setInput, INPUT);
            PARSER.declareObject(Builder::setTaskSettings, (p, c) -> p.mapOrdered(), TASK_SETTINGS);
            PARSER.declareString(Builder::setQuery, QUERY);
            PARSER.declareString(Builder::setInferenceTimeout, TIMEOUT);
            validEnumsBeforeUnspecifiedAdded = EnumSet.of(InputType.INGEST, InputType.SEARCH);
            validEnumsBeforeClassificationClusteringAdded = EnumSet.range(InputType.INGEST, InputType.UNSPECIFIED);
        }

        public static class Builder {
            private TaskType taskType;
            private String inferenceEntityId;
            private List<String> input;
            private InputType inputType = InputType.UNSPECIFIED;
            private Map<String, Object> taskSettings = Map.of();
            private String query;
            private TimeValue timeout = DEFAULT_TIMEOUT;

            private Builder() {
            }

            public Builder setInferenceEntityId(String inferenceEntityId) {
                this.inferenceEntityId = Objects.requireNonNull(inferenceEntityId);
                return this;
            }

            public Builder setTaskType(TaskType taskType) {
                this.taskType = taskType;
                return this;
            }

            public Builder setInput(List<String> input) {
                this.input = input;
                return this;
            }

            public Builder setQuery(String query) {
                this.query = query;
                return this;
            }

            public Builder setInputType(InputType inputType) {
                this.inputType = inputType;
                return this;
            }

            public Builder setTaskSettings(Map<String, Object> taskSettings) {
                this.taskSettings = taskSettings;
                return this;
            }

            public Builder setInferenceTimeout(TimeValue inferenceTimeout) {
                this.timeout = inferenceTimeout;
                return this;
            }

            private Builder setInferenceTimeout(String inferenceTimeout) {
                return this.setInferenceTimeout(TimeValue.parseTimeValue(inferenceTimeout, TIMEOUT.getPreferredName()));
            }

            public Request build() {
                return new Request(this.taskType, this.inferenceEntityId, this.query, this.input, this.taskSettings, this.inputType, this.timeout);
            }
        }
    }
}

