/*
 * Decompiled with CFR 0.152.
 */
package org.elasticsearch.xpack.esql.inference.completion;

import java.util.stream.IntStream;
import org.elasticsearch.compute.data.BytesRefBlock;
import org.elasticsearch.compute.data.Page;
import org.elasticsearch.compute.operator.DriverContext;
import org.elasticsearch.compute.operator.EvalOperator;
import org.elasticsearch.compute.operator.Operator;
import org.elasticsearch.core.Releasable;
import org.elasticsearch.core.Releasables;
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.xpack.esql.inference.InferenceOperator;
import org.elasticsearch.xpack.esql.inference.InferenceRunner;
import org.elasticsearch.xpack.esql.inference.bulk.BulkInferenceExecutionConfig;
import org.elasticsearch.xpack.esql.inference.bulk.BulkInferenceRequestIterator;
import org.elasticsearch.xpack.esql.inference.completion.CompletionOperatorOutputBuilder;
import org.elasticsearch.xpack.esql.inference.completion.CompletionOperatorRequestIterator;

public class CompletionOperator
extends InferenceOperator {
    private final EvalOperator.ExpressionEvaluator promptEvaluator;

    public CompletionOperator(DriverContext driverContext, InferenceRunner inferenceRunner, ThreadPool threadPool, String inferenceId, EvalOperator.ExpressionEvaluator promptEvaluator) {
        super(driverContext, inferenceRunner, BulkInferenceExecutionConfig.DEFAULT, threadPool, inferenceId);
        this.promptEvaluator = promptEvaluator;
    }

    protected void doClose() {
        Releasables.close((Releasable)this.promptEvaluator);
    }

    public String toString() {
        return "CompletionOperator[inference_id=[" + this.inferenceId() + "]]";
    }

    public void addInput(Page input) {
        try {
            super.addInput(input.appendBlock(this.promptEvaluator.eval(input)));
        }
        catch (Exception e) {
            CompletionOperator.releasePageOnAnyThread((Page)input);
            throw e;
        }
    }

    @Override
    protected BulkInferenceRequestIterator requests(Page inputPage) {
        int inputBlockChannel = inputPage.getBlockCount() - 1;
        return new CompletionOperatorRequestIterator((BytesRefBlock)inputPage.getBlock(inputBlockChannel), this.inferenceId());
    }

    @Override
    protected CompletionOperatorOutputBuilder outputBuilder(Page input) {
        BytesRefBlock.Builder outputBlockBuilder = this.blockFactory().newBytesRefBlockBuilder(input.getPositionCount());
        return new CompletionOperatorOutputBuilder(outputBlockBuilder, input.projectBlocks(IntStream.range(0, input.getBlockCount() - 1).toArray()));
    }

    public record Factory(InferenceRunner inferenceRunner, String inferenceId, EvalOperator.ExpressionEvaluator.Factory promptEvaluatorFactory) implements Operator.OperatorFactory
    {
        public String describe() {
            return "CompletionOperator[inference_id=[" + this.inferenceId + "]]";
        }

        public Operator get(DriverContext driverContext) {
            return new CompletionOperator(driverContext, this.inferenceRunner, this.inferenceRunner.threadPool(), this.inferenceId, this.promptEvaluatorFactory.get(driverContext));
        }
    }
}

