/*
 * Decompiled with CFR 0.152.
 */
package org.elasticsearch.xpack.core.inference.action;

import java.io.IOException;
import java.util.ArrayList;
import java.util.Collection;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import org.elasticsearch.ElasticsearchStatusException;
import org.elasticsearch.TransportVersions;
import org.elasticsearch.action.ActionRequest;
import org.elasticsearch.action.ActionRequestValidationException;
import org.elasticsearch.action.ActionResponse;
import org.elasticsearch.action.ActionType;
import org.elasticsearch.common.VersionId;
import org.elasticsearch.common.io.stream.NamedWriteable;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.inference.InferenceResults;
import org.elasticsearch.inference.InferenceServiceResults;
import org.elasticsearch.inference.TaskType;
import org.elasticsearch.rest.RestStatus;
import org.elasticsearch.xcontent.ObjectParser;
import org.elasticsearch.xcontent.ParseField;
import org.elasticsearch.xcontent.ToXContent;
import org.elasticsearch.xcontent.ToXContentObject;
import org.elasticsearch.xcontent.XContentBuilder;
import org.elasticsearch.xcontent.XContentParser;
import org.elasticsearch.xpack.core.inference.results.LegacyTextEmbeddingResults;
import org.elasticsearch.xpack.core.inference.results.SparseEmbeddingResults;
import org.elasticsearch.xpack.core.ml.inference.results.TextExpansionResults;

public class InferenceAction
extends ActionType<Response> {
    public static final InferenceAction INSTANCE = new InferenceAction();
    public static final String NAME = "cluster:monitor/xpack/inference";

    public InferenceAction() {
        super(NAME, Response::new);
    }

    public static class Response
    extends ActionResponse
    implements ToXContentObject {
        private final InferenceServiceResults results;

        public Response(InferenceServiceResults results) {
            this.results = results;
        }

        public Response(StreamInput in) throws IOException {
            super(in);
            this.results = in.getTransportVersion().onOrAfter((VersionId)TransportVersions.INFERENCE_SERVICE_RESULTS_ADDED) ? (InferenceServiceResults)in.readNamedWriteable(InferenceServiceResults.class) : (in.getTransportVersion().onOrAfter((VersionId)TransportVersions.INFERENCE_MULTIPLE_INPUTS) ? Response.transformToServiceResults(in.readNamedWriteableCollectionAsList(InferenceResults.class)) : Response.transformToServiceResults(List.of((InferenceResults)in.readNamedWriteable(InferenceResults.class))));
        }

        public static InferenceServiceResults transformToServiceResults(List<? extends InferenceResults> parsedResults) {
            if (parsedResults.isEmpty()) {
                throw new ElasticsearchStatusException("Failed to transform results to response format, expected a non-empty list, please remove and re-add the service", RestStatus.INTERNAL_SERVER_ERROR, new Object[0]);
            }
            InferenceResults inferenceResults = parsedResults.get(0);
            if (inferenceResults instanceof LegacyTextEmbeddingResults) {
                LegacyTextEmbeddingResults openaiResults = (LegacyTextEmbeddingResults)inferenceResults;
                if (parsedResults.size() > 1) {
                    throw new ElasticsearchStatusException("Failed to transform results to response format, malformed text embedding result, please remove and re-add the service", RestStatus.INTERNAL_SERVER_ERROR, new Object[0]);
                }
                return openaiResults.transformToTextEmbeddingResults();
            }
            if (parsedResults.get(0) instanceof TextExpansionResults) {
                return Response.transformToSparseEmbeddingResult(parsedResults);
            }
            throw new ElasticsearchStatusException("Failed to transform results to response format, unknown embedding type received, please remove and re-add the service", RestStatus.INTERNAL_SERVER_ERROR, new Object[0]);
        }

        private static SparseEmbeddingResults transformToSparseEmbeddingResult(List<? extends InferenceResults> parsedResults) {
            ArrayList<TextExpansionResults> textExpansionResults = new ArrayList<TextExpansionResults>(parsedResults.size());
            for (InferenceResults inferenceResults : parsedResults) {
                if (inferenceResults instanceof TextExpansionResults) {
                    TextExpansionResults textExpansion = (TextExpansionResults)inferenceResults;
                    textExpansionResults.add(textExpansion);
                    continue;
                }
                throw new ElasticsearchStatusException("Failed to transform results to response format, please remove and re-add the service", RestStatus.INTERNAL_SERVER_ERROR, new Object[0]);
            }
            return SparseEmbeddingResults.of(textExpansionResults);
        }

        public InferenceServiceResults getResults() {
            return this.results;
        }

        public void writeTo(StreamOutput out) throws IOException {
            if (out.getTransportVersion().onOrAfter((VersionId)TransportVersions.INFERENCE_SERVICE_RESULTS_ADDED)) {
                out.writeNamedWriteable((NamedWriteable)this.results);
            } else if (out.getTransportVersion().onOrAfter((VersionId)TransportVersions.INFERENCE_MULTIPLE_INPUTS)) {
                out.writeNamedWriteableCollection((Collection)this.results.transformToLegacyFormat());
            } else {
                out.writeNamedWriteable((NamedWriteable)this.results.transformToLegacyFormat().get(0));
            }
        }

        public XContentBuilder toXContent(XContentBuilder builder, ToXContent.Params params) throws IOException {
            builder.startObject();
            this.results.toXContent(builder, params);
            builder.endObject();
            return builder;
        }

        public boolean equals(Object o) {
            if (this == o) {
                return true;
            }
            if (o == null || ((Object)((Object)this)).getClass() != o.getClass()) {
                return false;
            }
            Response response = (Response)((Object)o);
            return Objects.equals(this.results, response.results);
        }

        public int hashCode() {
            return Objects.hash(this.results);
        }
    }

    public static class Request
    extends ActionRequest {
        public static final ParseField INPUT = new ParseField("input", new String[0]);
        public static final ParseField TASK_SETTINGS = new ParseField("task_settings", new String[0]);
        static final ObjectParser<Builder, Void> PARSER = new ObjectParser("cluster:monitor/xpack/inference", Builder::new);
        private final TaskType taskType;
        private final String modelId;
        private final List<String> input;
        private final Map<String, Object> taskSettings;

        public static Request parseRequest(String modelId, String taskType, XContentParser parser) {
            Builder builder = (Builder)PARSER.apply(parser, null);
            builder.setModelId(modelId);
            builder.setTaskType(taskType);
            return builder.build();
        }

        public Request(TaskType taskType, String modelId, List<String> input, Map<String, Object> taskSettings) {
            this.taskType = taskType;
            this.modelId = modelId;
            this.input = input;
            this.taskSettings = taskSettings;
        }

        public Request(StreamInput in) throws IOException {
            super(in);
            this.taskType = TaskType.fromStream((StreamInput)in);
            this.modelId = in.readString();
            this.input = in.getTransportVersion().onOrAfter((VersionId)TransportVersions.INFERENCE_MULTIPLE_INPUTS) ? in.readStringCollectionAsList() : List.of(in.readString());
            this.taskSettings = in.readMap();
        }

        public TaskType getTaskType() {
            return this.taskType;
        }

        public String getModelId() {
            return this.modelId;
        }

        public List<String> getInput() {
            return this.input;
        }

        public Map<String, Object> getTaskSettings() {
            return this.taskSettings;
        }

        public ActionRequestValidationException validate() {
            if (this.input == null) {
                ActionRequestValidationException e = new ActionRequestValidationException();
                e.addValidationError("missing input");
                return e;
            }
            if (this.input.isEmpty()) {
                ActionRequestValidationException e = new ActionRequestValidationException();
                e.addValidationError("input array is empty");
                return e;
            }
            return null;
        }

        public void writeTo(StreamOutput out) throws IOException {
            super.writeTo(out);
            this.taskType.writeTo(out);
            out.writeString(this.modelId);
            if (out.getTransportVersion().onOrAfter((VersionId)TransportVersions.INFERENCE_MULTIPLE_INPUTS)) {
                out.writeStringCollection(this.input);
            } else {
                out.writeString(this.input.get(0));
            }
            out.writeGenericMap(this.taskSettings);
        }

        public boolean equals(Object o) {
            if (this == o) {
                return true;
            }
            if (o == null || ((Object)((Object)this)).getClass() != o.getClass()) {
                return false;
            }
            Request request = (Request)((Object)o);
            return this.taskType == request.taskType && Objects.equals(this.modelId, request.modelId) && Objects.equals(this.input, request.input) && Objects.equals(this.taskSettings, request.taskSettings);
        }

        public int hashCode() {
            return Objects.hash(this.taskType, this.modelId, this.input, this.taskSettings);
        }

        static {
            PARSER.declareStringArray(Builder::setInput, INPUT);
            PARSER.declareObject(Builder::setTaskSettings, (p, c) -> p.mapOrdered(), TASK_SETTINGS);
        }

        public static class Builder {
            private TaskType taskType;
            private String modelId;
            private List<String> input;
            private Map<String, Object> taskSettings = Map.of();

            private Builder() {
            }

            public Builder setModelId(String modelId) {
                this.modelId = Objects.requireNonNull(modelId);
                return this;
            }

            public Builder setTaskType(String taskTypeStr) {
                try {
                    TaskType taskType = TaskType.fromString((String)taskTypeStr);
                    this.taskType = Objects.requireNonNull(taskType);
                }
                catch (IllegalArgumentException e) {
                    throw new ElasticsearchStatusException("Unknown task_type [{}]", RestStatus.BAD_REQUEST, new Object[]{taskTypeStr});
                }
                return this;
            }

            public Builder setInput(List<String> input) {
                this.input = input;
                return this;
            }

            public Builder setTaskSettings(Map<String, Object> taskSettings) {
                this.taskSettings = taskSettings;
                return this;
            }

            public Request build() {
                return new Request(this.taskType, this.modelId, this.input, this.taskSettings);
            }
        }
    }
}

