/*
 * Decompiled with CFR 0.152.
 */
package org.elasticsearch.xpack.inference.chunking;

import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.function.Function;
import java.util.stream.Collectors;
import org.elasticsearch.ElasticsearchStatusException;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.common.util.concurrent.AtomicArray;
import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper;
import org.elasticsearch.inference.ChunkedInferenceServiceResults;
import org.elasticsearch.inference.ChunkingSettings;
import org.elasticsearch.inference.InferenceServiceResults;
import org.elasticsearch.rest.RestStatus;
import org.elasticsearch.xpack.core.inference.results.ErrorChunkedInferenceResults;
import org.elasticsearch.xpack.core.inference.results.InferenceChunkedSparseEmbeddingResults;
import org.elasticsearch.xpack.core.inference.results.InferenceChunkedTextEmbeddingByteResults;
import org.elasticsearch.xpack.core.inference.results.InferenceChunkedTextEmbeddingFloatResults;
import org.elasticsearch.xpack.core.inference.results.InferenceTextEmbeddingByteResults;
import org.elasticsearch.xpack.core.inference.results.InferenceTextEmbeddingFloatResults;
import org.elasticsearch.xpack.core.inference.results.SparseEmbeddingResults;
import org.elasticsearch.xpack.core.ml.inference.results.MlChunkedTextExpansionResults;
import org.elasticsearch.xpack.inference.chunking.ChunkerBuilder;
import org.elasticsearch.xpack.inference.chunking.WordBoundaryChunker;

public class EmbeddingRequestChunker {
    public static final int DEFAULT_WORDS_PER_CHUNK = 250;
    public static final int DEFAULT_CHUNK_OVERLAP = 100;
    private final List<BatchRequest> batchedRequests = new ArrayList<BatchRequest>();
    private final AtomicInteger resultCount = new AtomicInteger();
    private final int maxNumberOfInputsPerBatch;
    private final int wordsPerChunk;
    private final int chunkOverlap;
    private final EmbeddingType embeddingType;
    private final ChunkingSettings chunkingSettings;
    private List<List<String>> chunkedInputs;
    private List<AtomicArray<List<InferenceTextEmbeddingFloatResults.InferenceFloatEmbedding>>> floatResults;
    private List<AtomicArray<List<InferenceTextEmbeddingByteResults.InferenceByteEmbedding>>> byteResults;
    private List<AtomicArray<List<SparseEmbeddingResults.Embedding>>> sparseResults;
    private AtomicArray<ErrorChunkedInferenceResults> errors;
    private ActionListener<List<ChunkedInferenceServiceResults>> finalListener;

    public EmbeddingRequestChunker(List<String> inputs, int maxNumberOfInputsPerBatch, EmbeddingType embeddingType) {
        this(inputs, maxNumberOfInputsPerBatch, 250, 100, embeddingType);
    }

    public EmbeddingRequestChunker(List<String> inputs, int maxNumberOfInputsPerBatch, int wordsPerChunk, int chunkOverlap, EmbeddingType embeddingType) {
        this.maxNumberOfInputsPerBatch = maxNumberOfInputsPerBatch;
        this.wordsPerChunk = wordsPerChunk;
        this.chunkOverlap = chunkOverlap;
        this.embeddingType = embeddingType;
        this.chunkingSettings = null;
        this.splitIntoBatchedRequests(inputs);
    }

    public EmbeddingRequestChunker(List<String> inputs, int maxNumberOfInputsPerBatch, EmbeddingType embeddingType, ChunkingSettings chunkingSettings) {
        this.maxNumberOfInputsPerBatch = maxNumberOfInputsPerBatch;
        this.wordsPerChunk = 250;
        this.chunkOverlap = 100;
        this.embeddingType = embeddingType;
        this.chunkingSettings = chunkingSettings;
        this.splitIntoBatchedRequests(inputs);
    }

    private void splitIntoBatchedRequests(List<String> inputs) {
        Function<String, List> chunkFunction;
        if (this.chunkingSettings != null) {
            chunker = ChunkerBuilder.fromChunkingStrategy(this.chunkingSettings.getChunkingStrategy());
            chunkFunction = input -> chunker.chunk((String)input, this.chunkingSettings);
        } else {
            chunker = new WordBoundaryChunker();
            chunkFunction = arg_0 -> this.lambda$splitIntoBatchedRequests$1((WordBoundaryChunker)chunker, arg_0);
        }
        this.chunkedInputs = new ArrayList<List<String>>(inputs.size());
        switch (this.embeddingType) {
            case FLOAT: {
                this.floatResults = new ArrayList<AtomicArray<List<InferenceTextEmbeddingFloatResults.InferenceFloatEmbedding>>>(inputs.size());
                break;
            }
            case BYTE: {
                this.byteResults = new ArrayList<AtomicArray<List<InferenceTextEmbeddingByteResults.InferenceByteEmbedding>>>(inputs.size());
                break;
            }
            case SPARSE: {
                this.sparseResults = new ArrayList<AtomicArray<List<SparseEmbeddingResults.Embedding>>>(inputs.size());
            }
        }
        this.errors = new AtomicArray(inputs.size());
        for (int i = 0; i < inputs.size(); ++i) {
            List chunks = chunkFunction.apply(inputs.get(i));
            int numberOfSubBatches = this.addToBatches(chunks, i);
            switch (this.embeddingType) {
                case FLOAT: {
                    this.floatResults.add((AtomicArray<List<InferenceTextEmbeddingFloatResults.InferenceFloatEmbedding>>)new AtomicArray(numberOfSubBatches));
                    break;
                }
                case BYTE: {
                    this.byteResults.add((AtomicArray<List<InferenceTextEmbeddingByteResults.InferenceByteEmbedding>>)new AtomicArray(numberOfSubBatches));
                    break;
                }
                case SPARSE: {
                    this.sparseResults.add((AtomicArray<List<SparseEmbeddingResults.Embedding>>)new AtomicArray(numberOfSubBatches));
                }
            }
            this.chunkedInputs.add(chunks);
        }
    }

    private int addToBatches(List<String> chunks, int inputIndex) {
        int toAdd;
        BatchRequest lastBatch;
        if (this.batchedRequests.isEmpty()) {
            lastBatch = new BatchRequest(new ArrayList<SubBatch>());
            this.batchedRequests.add(lastBatch);
        } else {
            lastBatch = this.batchedRequests.get(this.batchedRequests.size() - 1);
        }
        int freeSpace = this.maxNumberOfInputsPerBatch - lastBatch.size();
        assert (freeSpace >= 0);
        int chunkIndex = 0;
        if (freeSpace > 0) {
            int toAdd2 = Math.min(freeSpace, chunks.size());
            lastBatch.addSubBatch(new SubBatch(chunks.subList(0, toAdd2), new SubBatchPositionsAndCount(inputIndex, chunkIndex++, toAdd2)));
        }
        for (int start = freeSpace; start < chunks.size(); start += toAdd) {
            toAdd = Math.min(this.maxNumberOfInputsPerBatch, chunks.size() - start);
            BatchRequest batch = new BatchRequest(new ArrayList<SubBatch>());
            batch.addSubBatch(new SubBatch(chunks.subList(start, start + toAdd), new SubBatchPositionsAndCount(inputIndex, chunkIndex++, toAdd)));
            this.batchedRequests.add(batch);
        }
        return chunkIndex;
    }

    public List<BatchRequestAndListener> batchRequestsWithListeners(ActionListener<List<ChunkedInferenceServiceResults>> finalListener) {
        this.finalListener = finalListener;
        int numberOfRequests = this.batchedRequests.size();
        ArrayList<BatchRequestAndListener> requests = new ArrayList<BatchRequestAndListener>(numberOfRequests);
        for (BatchRequest batch : this.batchedRequests) {
            requests.add(new BatchRequestAndListener(batch, new DebatchingListener(batch.subBatches().stream().map(SubBatch::positions).collect(Collectors.toList()), numberOfRequests)));
        }
        return requests;
    }

    private ChunkedInferenceServiceResults mergeResultsWithInputs(int resultIndex) {
        return switch (this.embeddingType) {
            default -> throw new IncompatibleClassChangeError();
            case EmbeddingType.FLOAT -> this.mergeFloatResultsWithInputs(this.chunkedInputs.get(resultIndex), this.floatResults.get(resultIndex));
            case EmbeddingType.BYTE -> this.mergeByteResultsWithInputs(this.chunkedInputs.get(resultIndex), this.byteResults.get(resultIndex));
            case EmbeddingType.SPARSE -> this.mergeSparseResultsWithInputs(this.chunkedInputs.get(resultIndex), this.sparseResults.get(resultIndex));
        };
    }

    private InferenceChunkedTextEmbeddingFloatResults mergeFloatResultsWithInputs(List<String> chunks, AtomicArray<List<InferenceTextEmbeddingFloatResults.InferenceFloatEmbedding>> debatchedResults) {
        ArrayList all = new ArrayList();
        for (int i = 0; i < debatchedResults.length(); ++i) {
            List subBatch = (List)debatchedResults.get(i);
            all.addAll(subBatch);
        }
        assert (chunks.size() == all.size());
        ArrayList<InferenceChunkedTextEmbeddingFloatResults.InferenceFloatEmbeddingChunk> embeddingChunks = new ArrayList<InferenceChunkedTextEmbeddingFloatResults.InferenceFloatEmbeddingChunk>();
        for (int i = 0; i < chunks.size(); ++i) {
            embeddingChunks.add(new InferenceChunkedTextEmbeddingFloatResults.InferenceFloatEmbeddingChunk(chunks.get(i), ((InferenceTextEmbeddingFloatResults.InferenceFloatEmbedding)all.get(i)).values()));
        }
        return new InferenceChunkedTextEmbeddingFloatResults(embeddingChunks);
    }

    private InferenceChunkedTextEmbeddingByteResults mergeByteResultsWithInputs(List<String> chunks, AtomicArray<List<InferenceTextEmbeddingByteResults.InferenceByteEmbedding>> debatchedResults) {
        ArrayList all = new ArrayList();
        for (int i = 0; i < debatchedResults.length(); ++i) {
            List subBatch = (List)debatchedResults.get(i);
            all.addAll(subBatch);
        }
        assert (chunks.size() == all.size());
        ArrayList<InferenceChunkedTextEmbeddingByteResults.InferenceByteEmbeddingChunk> embeddingChunks = new ArrayList<InferenceChunkedTextEmbeddingByteResults.InferenceByteEmbeddingChunk>();
        for (int i = 0; i < chunks.size(); ++i) {
            embeddingChunks.add(new InferenceChunkedTextEmbeddingByteResults.InferenceByteEmbeddingChunk(chunks.get(i), ((InferenceTextEmbeddingByteResults.InferenceByteEmbedding)all.get(i)).values()));
        }
        return new InferenceChunkedTextEmbeddingByteResults(embeddingChunks, false);
    }

    private InferenceChunkedSparseEmbeddingResults mergeSparseResultsWithInputs(List<String> chunks, AtomicArray<List<SparseEmbeddingResults.Embedding>> debatchedResults) {
        ArrayList all = new ArrayList();
        for (int i = 0; i < debatchedResults.length(); ++i) {
            List subBatch = (List)debatchedResults.get(i);
            all.addAll(subBatch);
        }
        assert (chunks.size() == all.size());
        ArrayList<MlChunkedTextExpansionResults.ChunkedResult> embeddingChunks = new ArrayList<MlChunkedTextExpansionResults.ChunkedResult>();
        for (int i = 0; i < chunks.size(); ++i) {
            embeddingChunks.add(new MlChunkedTextExpansionResults.ChunkedResult(chunks.get(i), ((SparseEmbeddingResults.Embedding)all.get(i)).tokens()));
        }
        return new InferenceChunkedSparseEmbeddingResults(embeddingChunks);
    }

    private /* synthetic */ List lambda$splitIntoBatchedRequests$1(WordBoundaryChunker chunker, String input) {
        return chunker.chunk(input, this.wordsPerChunk, this.chunkOverlap);
    }

    public static enum EmbeddingType {
        FLOAT,
        BYTE,
        SPARSE;


        public static EmbeddingType fromDenseVectorElementType(DenseVectorFieldMapper.ElementType elementType) {
            return switch (elementType) {
                default -> throw new IncompatibleClassChangeError();
                case DenseVectorFieldMapper.ElementType.BYTE -> BYTE;
                case DenseVectorFieldMapper.ElementType.FLOAT -> FLOAT;
                case DenseVectorFieldMapper.ElementType.BIT -> throw new IllegalArgumentException("Bit vectors are not supported");
            };
        }
    }

    public record BatchRequest(List<SubBatch> subBatches) {
        public int size() {
            return this.subBatches.stream().mapToInt(SubBatch::size).sum();
        }

        public void addSubBatch(SubBatch sb) {
            this.subBatches.add(sb);
        }

        public List<String> inputs() {
            return this.subBatches.stream().flatMap(s -> s.requests().stream()).collect(Collectors.toList());
        }
    }

    record SubBatch(List<String> requests, SubBatchPositionsAndCount positions) {
        public int size() {
            return this.requests.size();
        }
    }

    record SubBatchPositionsAndCount(int inputIndex, int chunkIndex, int embeddingCount) {
    }

    public record BatchRequestAndListener(BatchRequest batch, ActionListener<InferenceServiceResults> listener) {
    }

    private class DebatchingListener
    implements ActionListener<InferenceServiceResults> {
        private final List<SubBatchPositionsAndCount> positions;
        private final int totalNumberOfRequests;

        DebatchingListener(List<SubBatchPositionsAndCount> positions, int totalNumberOfRequests) {
            this.positions = positions;
            this.totalNumberOfRequests = totalNumberOfRequests;
        }

        public void onResponse(InferenceServiceResults inferenceServiceResults) {
            switch (EmbeddingRequestChunker.this.embeddingType) {
                case FLOAT: {
                    this.handleFloatResults(inferenceServiceResults);
                    break;
                }
                case BYTE: {
                    this.handleByteResults(inferenceServiceResults);
                    break;
                }
                case SPARSE: {
                    this.handleSparseResults(inferenceServiceResults);
                }
            }
        }

        private void handleFloatResults(InferenceServiceResults inferenceServiceResults) {
            if (inferenceServiceResults instanceof InferenceTextEmbeddingFloatResults) {
                InferenceTextEmbeddingFloatResults floatEmbeddings = (InferenceTextEmbeddingFloatResults)inferenceServiceResults;
                if (this.failIfNumRequestsDoNotMatch(floatEmbeddings.embeddings().size())) {
                    return;
                }
                int start = 0;
                for (SubBatchPositionsAndCount pos : this.positions) {
                    EmbeddingRequestChunker.this.floatResults.get(pos.inputIndex()).setOnce(pos.chunkIndex(), floatEmbeddings.embeddings().subList(start, start + pos.embeddingCount()));
                    start += pos.embeddingCount();
                }
                if (EmbeddingRequestChunker.this.resultCount.incrementAndGet() == this.totalNumberOfRequests) {
                    this.sendResponse();
                }
            } else {
                this.onFailure((Exception)this.unexpectedResultTypeException(inferenceServiceResults.getWriteableName(), "text_embedding_service_results"));
            }
        }

        private void handleByteResults(InferenceServiceResults inferenceServiceResults) {
            if (inferenceServiceResults instanceof InferenceTextEmbeddingByteResults) {
                InferenceTextEmbeddingByteResults byteEmbeddings = (InferenceTextEmbeddingByteResults)inferenceServiceResults;
                if (this.failIfNumRequestsDoNotMatch(byteEmbeddings.embeddings().size())) {
                    return;
                }
                int start = 0;
                for (SubBatchPositionsAndCount pos : this.positions) {
                    EmbeddingRequestChunker.this.byteResults.get(pos.inputIndex()).setOnce(pos.chunkIndex(), byteEmbeddings.embeddings().subList(start, start + pos.embeddingCount()));
                    start += pos.embeddingCount();
                }
                if (EmbeddingRequestChunker.this.resultCount.incrementAndGet() == this.totalNumberOfRequests) {
                    this.sendResponse();
                }
            } else {
                this.onFailure((Exception)this.unexpectedResultTypeException(inferenceServiceResults.getWriteableName(), "text_embedding_service_byte_results"));
            }
        }

        private void handleSparseResults(InferenceServiceResults inferenceServiceResults) {
            if (inferenceServiceResults instanceof SparseEmbeddingResults) {
                SparseEmbeddingResults sparseEmbeddings = (SparseEmbeddingResults)inferenceServiceResults;
                if (this.failIfNumRequestsDoNotMatch(sparseEmbeddings.embeddings().size())) {
                    return;
                }
                int start = 0;
                for (SubBatchPositionsAndCount pos : this.positions) {
                    EmbeddingRequestChunker.this.sparseResults.get(pos.inputIndex()).setOnce(pos.chunkIndex(), sparseEmbeddings.embeddings().subList(start, start + pos.embeddingCount()));
                    start += pos.embeddingCount();
                }
                if (EmbeddingRequestChunker.this.resultCount.incrementAndGet() == this.totalNumberOfRequests) {
                    this.sendResponse();
                }
            } else {
                this.onFailure((Exception)this.unexpectedResultTypeException(inferenceServiceResults.getWriteableName(), "text_embedding_service_byte_results"));
            }
        }

        private boolean failIfNumRequestsDoNotMatch(int numberOfResults) {
            int numberOfRequests = this.positions.stream().mapToInt(SubBatchPositionsAndCount::embeddingCount).sum();
            if (numberOfRequests != numberOfResults) {
                this.onFailure((Exception)new ElasticsearchStatusException("Error the number of embedding responses [{}] does not equal the number of requests [{}]", RestStatus.INTERNAL_SERVER_ERROR, new Object[]{numberOfResults, numberOfRequests}));
                return true;
            }
            return false;
        }

        private ElasticsearchStatusException unexpectedResultTypeException(String got, String expected) {
            return new ElasticsearchStatusException("Unexpected inference result type [" + got + "], expected a [" + expected + "]", RestStatus.INTERNAL_SERVER_ERROR, new Object[0]);
        }

        public void onFailure(Exception e) {
            ErrorChunkedInferenceResults errorResult = new ErrorChunkedInferenceResults(e);
            for (SubBatchPositionsAndCount pos : this.positions) {
                EmbeddingRequestChunker.this.errors.set(pos.inputIndex(), (Object)errorResult);
            }
            if (EmbeddingRequestChunker.this.resultCount.incrementAndGet() == this.totalNumberOfRequests) {
                this.sendResponse();
            }
        }

        private void sendResponse() {
            ArrayList<ChunkedInferenceServiceResults> response = new ArrayList<ChunkedInferenceServiceResults>(EmbeddingRequestChunker.this.chunkedInputs.size());
            for (int i = 0; i < EmbeddingRequestChunker.this.chunkedInputs.size(); ++i) {
                if (EmbeddingRequestChunker.this.errors.get(i) != null) {
                    response.add((ChunkedInferenceServiceResults)EmbeddingRequestChunker.this.errors.get(i));
                    continue;
                }
                response.add(EmbeddingRequestChunker.this.mergeResultsWithInputs(i));
            }
            EmbeddingRequestChunker.this.finalListener.onResponse(response);
        }
    }
}

