/*
 * Decompiled with CFR 0.152.
 */
package org.elasticsearch.xpack.inference.services.googlevertexai;

import java.io.IOException;
import java.util.ArrayDeque;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Deque;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.function.BiFunction;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.elasticsearch.common.bytes.BytesReference;
import org.elasticsearch.common.xcontent.LoggingDeprecationHandler;
import org.elasticsearch.common.xcontent.XContentHelper;
import org.elasticsearch.common.xcontent.XContentParserUtils;
import org.elasticsearch.core.Nullable;
import org.elasticsearch.xcontent.ConstructingObjectParser;
import org.elasticsearch.xcontent.DeprecationHandler;
import org.elasticsearch.xcontent.ParseField;
import org.elasticsearch.xcontent.XContentBuilder;
import org.elasticsearch.xcontent.XContentFactory;
import org.elasticsearch.xcontent.XContentParser;
import org.elasticsearch.xcontent.XContentParserConfiguration;
import org.elasticsearch.xcontent.XContentType;
import org.elasticsearch.xpack.core.inference.results.StreamingUnifiedChatCompletionResults;
import org.elasticsearch.xpack.inference.common.DelegatingProcessor;
import org.elasticsearch.xpack.inference.external.response.XContentUtils;
import org.elasticsearch.xpack.inference.external.response.streaming.ServerSentEvent;

public class GoogleVertexAiUnifiedStreamingProcessor
extends DelegatingProcessor<Deque<ServerSentEvent>, StreamingUnifiedChatCompletionResults.Results> {
    private static final Logger logger = LogManager.getLogger(GoogleVertexAiUnifiedStreamingProcessor.class);
    private static final String CANDIDATES_FIELD = "candidates";
    private static final String CONTENT_FIELD = "content";
    private static final String ROLE_FIELD = "role";
    private static final String PARTS_FIELD = "parts";
    private static final String TEXT_FIELD = "text";
    private static final String FINISH_REASON_FIELD = "finishReason";
    private static final String INDEX_FIELD = "index";
    private static final String USAGE_METADATA_FIELD = "usageMetadata";
    private static final String PROMPT_TOKEN_COUNT_FIELD = "promptTokenCount";
    private static final String CANDIDATES_TOKEN_COUNT_FIELD = "candidatesTokenCount";
    private static final String TOTAL_TOKEN_COUNT_FIELD = "totalTokenCount";
    private static final String MODEL_VERSION_FIELD = "modelVersion";
    private static final String RESPONSE_ID_FIELD = "responseId";
    private static final String FUNCTION_CALL_FIELD = "functionCall";
    private static final String FUNCTION_NAME_FIELD = "name";
    private static final String FUNCTION_ARGS_FIELD = "args";
    private static final String CHAT_COMPLETION_CHUNK = "chat.completion.chunk";
    private static final String FUNCTION_TYPE = "function";
    private final BiFunction<String, Exception, Exception> errorParser;

    public GoogleVertexAiUnifiedStreamingProcessor(BiFunction<String, Exception, Exception> errorParser) {
        this.errorParser = errorParser;
    }

    @Override
    protected void next(Deque<ServerSentEvent> events) throws Exception {
        XContentParserConfiguration parserConfig = XContentParserConfiguration.EMPTY.withDeprecationHandler((DeprecationHandler)LoggingDeprecationHandler.INSTANCE);
        ArrayDeque results = new ArrayDeque(events.size());
        for (ServerSentEvent event : events) {
            try {
                Iterator<StreamingUnifiedChatCompletionResults.ChatCompletionChunk> completionChunk = this.parse(parserConfig, event.data());
                completionChunk.forEachRemaining(results::offer);
            }
            catch (Exception e) {
                String eventString = event.data();
                logger.warn("Failed to parse event from Google Vertex AI provider: {}", (Object)eventString);
                throw this.errorParser.apply(eventString, e);
            }
        }
        if (results.isEmpty()) {
            this.upstream().request(1L);
        } else {
            this.downstream().onNext(new StreamingUnifiedChatCompletionResults.Results(results));
        }
    }

    private Iterator<StreamingUnifiedChatCompletionResults.ChatCompletionChunk> parse(XContentParserConfiguration parserConfig, String event) throws IOException {
        try (XContentParser jsonParser = XContentFactory.xContent((XContentType)XContentType.JSON).createParser(parserConfig, event);){
            XContentUtils.moveToFirstToken(jsonParser);
            XContentParserUtils.ensureExpectedToken((XContentParser.Token)XContentParser.Token.START_OBJECT, (XContentParser.Token)jsonParser.currentToken(), (XContentParser)jsonParser);
            StreamingUnifiedChatCompletionResults.ChatCompletionChunk chunk = GoogleVertexAiChatCompletionChunkParser.parse(jsonParser);
            Iterator<StreamingUnifiedChatCompletionResults.ChatCompletionChunk> iterator = Collections.singleton(chunk).iterator();
            return iterator;
        }
    }

    public static class GoogleVertexAiChatCompletionChunkParser {
        private static final ConstructingObjectParser<StreamingUnifiedChatCompletionResults.ChatCompletionChunk, Void> PARSER = new ConstructingObjectParser("google_vertexai_chat_completion_chunk", true, args -> {
            List candidates = (List)args[0];
            UsageMetadata usage = (UsageMetadata)args[1];
            String modelversion = (String)args[2];
            String responseId = (String)args[3];
            boolean candidatesIsEmpty = candidates == null || candidates.isEmpty();
            List choices = candidatesIsEmpty ? Collections.emptyList() : candidates.stream().map(GoogleVertexAiChatCompletionChunkParser::candidateToChoice).toList();
            return new StreamingUnifiedChatCompletionResults.ChatCompletionChunk(responseId, choices, modelversion, GoogleVertexAiUnifiedStreamingProcessor.CHAT_COMPLETION_CHUNK, GoogleVertexAiChatCompletionChunkParser.usageMetadataToChunk(usage));
        });

        @Nullable
        private static StreamingUnifiedChatCompletionResults.ChatCompletionChunk.Usage usageMetadataToChunk(@Nullable UsageMetadata usage) {
            if (usage == null) {
                return null;
            }
            return new StreamingUnifiedChatCompletionResults.ChatCompletionChunk.Usage(usage.candidatesTokenCount(), usage.promptTokenCount(), usage.totalTokenCount());
        }

        private static StreamingUnifiedChatCompletionResults.ChatCompletionChunk.Choice candidateToChoice(Candidate candidate) {
            boolean contentAndPartsAreNotEmpty;
            StringBuilder contentTextBuilder = new StringBuilder();
            ArrayList<StreamingUnifiedChatCompletionResults.ChatCompletionChunk.Choice.Delta.ToolCall> toolCalls = new ArrayList<StreamingUnifiedChatCompletionResults.ChatCompletionChunk.Choice.Delta.ToolCall>();
            String role = null;
            boolean bl = contentAndPartsAreNotEmpty = candidate.content() != null && candidate.content().parts() != null && !candidate.content().parts().isEmpty();
            if (contentAndPartsAreNotEmpty) {
                role = candidate.content().role();
                for (Part part : candidate.content().parts()) {
                    if (part.text() != null) {
                        contentTextBuilder.append(part.text());
                    }
                    if (part.functionCall() == null) continue;
                    FunctionCall fc = part.functionCall();
                    StreamingUnifiedChatCompletionResults.ChatCompletionChunk.Choice.Delta.ToolCall.Function function = new StreamingUnifiedChatCompletionResults.ChatCompletionChunk.Choice.Delta.ToolCall.Function(fc.args(), fc.name());
                    toolCalls.add(new StreamingUnifiedChatCompletionResults.ChatCompletionChunk.Choice.Delta.ToolCall(0, function.name(), function, GoogleVertexAiUnifiedStreamingProcessor.FUNCTION_TYPE));
                }
            }
            ArrayList<StreamingUnifiedChatCompletionResults.ChatCompletionChunk.Choice.Delta.ToolCall> finalToolCalls = toolCalls.isEmpty() ? null : toolCalls;
            StreamingUnifiedChatCompletionResults.ChatCompletionChunk.Choice.Delta delta = new StreamingUnifiedChatCompletionResults.ChatCompletionChunk.Choice.Delta(contentTextBuilder.isEmpty() ? null : contentTextBuilder.toString(), null, role, finalToolCalls);
            return new StreamingUnifiedChatCompletionResults.ChatCompletionChunk.Choice(delta, candidate.finishReason(), candidate.index());
        }

        public static StreamingUnifiedChatCompletionResults.ChatCompletionChunk parse(XContentParser parser) throws IOException {
            return (StreamingUnifiedChatCompletionResults.ChatCompletionChunk)PARSER.parse(parser, null);
        }

        static {
            PARSER.declareObjectArray(ConstructingObjectParser.constructorArg(), (p, c) -> CandidateParser.parse(p), new ParseField(GoogleVertexAiUnifiedStreamingProcessor.CANDIDATES_FIELD, new String[0]));
            PARSER.declareObject(ConstructingObjectParser.constructorArg(), (p, c) -> UsageMetadataParser.parse(p), new ParseField(GoogleVertexAiUnifiedStreamingProcessor.USAGE_METADATA_FIELD, new String[0]));
            PARSER.declareString(ConstructingObjectParser.constructorArg(), new ParseField(GoogleVertexAiUnifiedStreamingProcessor.MODEL_VERSION_FIELD, new String[0]));
            PARSER.declareString(ConstructingObjectParser.constructorArg(), new ParseField(GoogleVertexAiUnifiedStreamingProcessor.RESPONSE_ID_FIELD, new String[0]));
        }
    }

    private static class UsageMetadataParser {
        private static final ConstructingObjectParser<UsageMetadata, Void> PARSER = new ConstructingObjectParser("usageMetadata", true, args -> {
            if (Objects.isNull(args[0]) && Objects.isNull(args[1]) && Objects.isNull(args[2])) {
                return null;
            }
            return new UsageMetadata(args[0] == null ? 0 : (Integer)args[0], args[1] == null ? 0 : (Integer)args[1], args[2] == null ? 0 : (Integer)args[2]);
        });

        private UsageMetadataParser() {
        }

        public static UsageMetadata parse(XContentParser parser) throws IOException {
            return (UsageMetadata)PARSER.parse(parser, null);
        }

        static {
            PARSER.declareInt(ConstructingObjectParser.optionalConstructorArg(), new ParseField(GoogleVertexAiUnifiedStreamingProcessor.PROMPT_TOKEN_COUNT_FIELD, new String[0]));
            PARSER.declareInt(ConstructingObjectParser.optionalConstructorArg(), new ParseField(GoogleVertexAiUnifiedStreamingProcessor.CANDIDATES_TOKEN_COUNT_FIELD, new String[0]));
            PARSER.declareInt(ConstructingObjectParser.optionalConstructorArg(), new ParseField(GoogleVertexAiUnifiedStreamingProcessor.TOTAL_TOKEN_COUNT_FIELD, new String[0]));
        }
    }

    private record UsageMetadata(int promptTokenCount, int candidatesTokenCount, int totalTokenCount) {
    }

    private static class FunctionCallParser {
        private static final ConstructingObjectParser<FunctionCall, Void> PARSER = new ConstructingObjectParser("functionCall", true, args -> {
            String name = (String)args[0];
            Map argsMap = (Map)args[1];
            if (argsMap == null) {
                return new FunctionCall(name, null);
            }
            try {
                XContentBuilder builder = XContentFactory.jsonBuilder().map(argsMap);
                String json = XContentHelper.convertToJson((BytesReference)BytesReference.bytes((XContentBuilder)builder), (boolean)false, (XContentType)XContentType.JSON);
                return new FunctionCall(name, json);
            }
            catch (IOException e) {
                logger.warn("Failed to parse and convert VertexAI function args to json", (Throwable)e);
                return new FunctionCall(name, null);
            }
        });

        private FunctionCallParser() {
        }

        public static FunctionCall parse(XContentParser parser) throws IOException {
            return (FunctionCall)PARSER.parse(parser, null);
        }

        static {
            PARSER.declareString(ConstructingObjectParser.constructorArg(), new ParseField(GoogleVertexAiUnifiedStreamingProcessor.FUNCTION_NAME_FIELD, new String[0]));
            PARSER.declareObject(ConstructingObjectParser.optionalConstructorArg(), (p, c) -> p.map(), new ParseField(GoogleVertexAiUnifiedStreamingProcessor.FUNCTION_ARGS_FIELD, new String[0]));
        }
    }

    private record FunctionCall(String name, String args) {
    }

    private static class PartParser {
        private static final ConstructingObjectParser<Part, Void> PARSER = new ConstructingObjectParser("part", true, args -> new Part((String)args[0], (FunctionCall)args[1]));

        private PartParser() {
        }

        public static Part parse(XContentParser parser) throws IOException {
            return (Part)PARSER.parse(parser, null);
        }

        static {
            PARSER.declareString(ConstructingObjectParser.optionalConstructorArg(), new ParseField(GoogleVertexAiUnifiedStreamingProcessor.TEXT_FIELD, new String[0]));
            PARSER.declareObject(ConstructingObjectParser.optionalConstructorArg(), (p, c) -> FunctionCallParser.parse(p), new ParseField(GoogleVertexAiUnifiedStreamingProcessor.FUNCTION_CALL_FIELD, new String[0]));
        }
    }

    private record Part(@Nullable String text, @Nullable FunctionCall functionCall) {
    }

    private static class ContentParser {
        private static final ConstructingObjectParser<Content, Void> PARSER = new ConstructingObjectParser("content", true, args -> new Content((String)args[0], (List)args[1]));

        private ContentParser() {
        }

        public static Content parse(XContentParser parser) throws IOException {
            return (Content)PARSER.parse(parser, null);
        }

        static {
            PARSER.declareString(ConstructingObjectParser.constructorArg(), new ParseField(GoogleVertexAiUnifiedStreamingProcessor.ROLE_FIELD, new String[0]));
            PARSER.declareObjectArray(ConstructingObjectParser.constructorArg(), (p, c) -> PartParser.parse(p), new ParseField(GoogleVertexAiUnifiedStreamingProcessor.PARTS_FIELD, new String[0]));
        }
    }

    private record Content(String role, List<Part> parts) {
    }

    private static class CandidateParser {
        private static final ConstructingObjectParser<Candidate, Void> PARSER = new ConstructingObjectParser("candidate", true, args -> {
            Content content = (Content)args[0];
            String finishReason = (String)args[1];
            int index = args[2] == null ? 0 : (Integer)args[2];
            return new Candidate(content, finishReason, index);
        });

        private CandidateParser() {
        }

        public static Candidate parse(XContentParser parser) throws IOException {
            return (Candidate)PARSER.parse(parser, null);
        }

        static {
            PARSER.declareObject(ConstructingObjectParser.constructorArg(), (p, c) -> ContentParser.parse(p), new ParseField(GoogleVertexAiUnifiedStreamingProcessor.CONTENT_FIELD, new String[0]));
            PARSER.declareString(ConstructingObjectParser.optionalConstructorArg(), new ParseField(GoogleVertexAiUnifiedStreamingProcessor.FINISH_REASON_FIELD, new String[0]));
            PARSER.declareInt(ConstructingObjectParser.optionalConstructorArg(), new ParseField(GoogleVertexAiUnifiedStreamingProcessor.INDEX_FIELD, new String[0]));
        }
    }

    private record Candidate(Content content, String finishReason, int index) {
    }
}

