/*
 * Decompiled with CFR 0.152.
 */
package org.elasticsearch.xpack.esql.plugin;

import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.Executor;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicLong;
import java.util.function.Supplier;
import org.elasticsearch.ExceptionsHelper;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.action.OriginalIndices;
import org.elasticsearch.action.search.SearchRequest;
import org.elasticsearch.action.search.ShardSearchFailure;
import org.elasticsearch.cluster.RemoteException;
import org.elasticsearch.cluster.service.ClusterService;
import org.elasticsearch.common.util.BigArrays;
import org.elasticsearch.common.util.concurrent.RunOnce;
import org.elasticsearch.common.util.concurrent.ThreadContext;
import org.elasticsearch.compute.data.BlockFactory;
import org.elasticsearch.compute.data.Page;
import org.elasticsearch.compute.operator.Driver;
import org.elasticsearch.compute.operator.DriverCompletionInfo;
import org.elasticsearch.compute.operator.DriverTaskRunner;
import org.elasticsearch.compute.operator.FailureCollector;
import org.elasticsearch.compute.operator.exchange.ExchangeService;
import org.elasticsearch.compute.operator.exchange.ExchangeSourceHandler;
import org.elasticsearch.core.Releasable;
import org.elasticsearch.core.Releasables;
import org.elasticsearch.core.TimeValue;
import org.elasticsearch.core.Tuple;
import org.elasticsearch.index.query.SearchExecutionContext;
import org.elasticsearch.logging.LogManager;
import org.elasticsearch.logging.Logger;
import org.elasticsearch.search.SearchService;
import org.elasticsearch.search.internal.SearchContext;
import org.elasticsearch.search.lookup.SourceProvider;
import org.elasticsearch.tasks.CancellableTask;
import org.elasticsearch.tasks.Task;
import org.elasticsearch.tasks.TaskAwareRequest;
import org.elasticsearch.tasks.TaskCancelledException;
import org.elasticsearch.tasks.TaskId;
import org.elasticsearch.tasks.TaskManager;
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.transport.TransportException;
import org.elasticsearch.transport.TransportRequest;
import org.elasticsearch.transport.TransportService;
import org.elasticsearch.xpack.esql.action.EsqlExecutionInfo;
import org.elasticsearch.xpack.esql.core.expression.Attribute;
import org.elasticsearch.xpack.esql.core.expression.FoldContext;
import org.elasticsearch.xpack.esql.enrich.EnrichLookupService;
import org.elasticsearch.xpack.esql.enrich.LookupFromIndexService;
import org.elasticsearch.xpack.esql.inference.InferenceRunner;
import org.elasticsearch.xpack.esql.plan.physical.ExchangeSinkExec;
import org.elasticsearch.xpack.esql.plan.physical.ExchangeSourceExec;
import org.elasticsearch.xpack.esql.plan.physical.OutputExec;
import org.elasticsearch.xpack.esql.plan.physical.PhysicalPlan;
import org.elasticsearch.xpack.esql.planner.EsPhysicalOperationProviders;
import org.elasticsearch.xpack.esql.planner.LocalExecutionPlanner;
import org.elasticsearch.xpack.esql.planner.PhysicalSettings;
import org.elasticsearch.xpack.esql.planner.PlannerUtils;
import org.elasticsearch.xpack.esql.plugin.ClusterComputeHandler;
import org.elasticsearch.xpack.esql.plugin.ComputeContext;
import org.elasticsearch.xpack.esql.plugin.ComputeListener;
import org.elasticsearch.xpack.esql.plugin.ComputeResponse;
import org.elasticsearch.xpack.esql.plugin.DataNodeComputeHandler;
import org.elasticsearch.xpack.esql.plugin.EsqlFlags;
import org.elasticsearch.xpack.esql.plugin.QueryPragmas;
import org.elasticsearch.xpack.esql.plugin.ReinitializingSourceProvider;
import org.elasticsearch.xpack.esql.plugin.TransportActionServices;
import org.elasticsearch.xpack.esql.session.Configuration;
import org.elasticsearch.xpack.esql.session.EsqlCCSUtils;
import org.elasticsearch.xpack.esql.session.Result;

public class ComputeService {
    public static final String DATA_ACTION_NAME = "indices:data/read/esql/data";
    public static final String CLUSTER_ACTION_NAME = "indices:data/read/esql/cluster";
    private static final String LOCAL_CLUSTER = "";
    private static final Logger LOGGER = LogManager.getLogger(ComputeService.class);
    private final SearchService searchService;
    private final BigArrays bigArrays;
    private final BlockFactory blockFactory;
    private final TransportService transportService;
    private final DriverTaskRunner driverRunner;
    private final EnrichLookupService enrichLookupService;
    private final LookupFromIndexService lookupFromIndexService;
    private final InferenceRunner inferenceRunner;
    private final ClusterService clusterService;
    private final AtomicLong childSessionIdGenerator = new AtomicLong();
    private final DataNodeComputeHandler dataNodeComputeHandler;
    private final ClusterComputeHandler clusterComputeHandler;
    private final ExchangeService exchangeService;
    private final PhysicalSettings physicalSettings;

    public ComputeService(TransportActionServices transportActionServices, EnrichLookupService enrichLookupService, LookupFromIndexService lookupFromIndexService, ThreadPool threadPool, BigArrays bigArrays, BlockFactory blockFactory) {
        this.searchService = transportActionServices.searchService();
        this.transportService = transportActionServices.transportService();
        this.exchangeService = transportActionServices.exchangeService();
        this.bigArrays = bigArrays.withCircuitBreaking();
        this.blockFactory = blockFactory;
        ExecutorService esqlExecutor = threadPool.executor("search");
        this.driverRunner = new DriverTaskRunner(this.transportService, (Executor)esqlExecutor);
        this.enrichLookupService = enrichLookupService;
        this.lookupFromIndexService = lookupFromIndexService;
        this.inferenceRunner = transportActionServices.inferenceRunner();
        this.clusterService = transportActionServices.clusterService();
        this.dataNodeComputeHandler = new DataNodeComputeHandler(this, this.clusterService, this.searchService, this.transportService, this.exchangeService, esqlExecutor);
        this.clusterComputeHandler = new ClusterComputeHandler(this, this.exchangeService, this.transportService, esqlExecutor, this.dataNodeComputeHandler);
        this.physicalSettings = new PhysicalSettings(this.clusterService);
    }

    public void execute(String sessionId, CancellableTask rootTask, EsqlFlags flags, PhysicalPlan physicalPlan, Configuration configuration, FoldContext foldContext, EsqlExecutionInfo execInfo, ActionListener<Result> listener) {
        assert (ThreadPool.assertCurrentThreadPool((String[])new String[]{"esql_worker", "system_read", "search", "search_coordination"}));
        Tuple<PhysicalPlan, PhysicalPlan> coordinatorAndDataNodePlan = PlannerUtils.breakPlanBetweenCoordinatorAndDataNode(physicalPlan, configuration);
        List collectedPages = Collections.synchronizedList(new ArrayList());
        listener = listener.delegateResponse((l, e) -> {
            collectedPages.forEach(p -> Releasables.closeExpectNoException(() -> ((Page)p).releaseBlocks()));
            l.onFailure(e);
        });
        OutputExec coordinatorPlan = new OutputExec((PhysicalPlan)((Object)coordinatorAndDataNodePlan.v1()), collectedPages::add);
        PhysicalPlan dataNodePlan = (PhysicalPlan)((Object)coordinatorAndDataNodePlan.v2());
        if (dataNodePlan != null && !(dataNodePlan instanceof ExchangeSinkExec)) {
            assert (false) : "expected data node plan starts with an ExchangeSink; got " + String.valueOf((Object)dataNodePlan);
            listener.onFailure((Exception)new IllegalStateException("expected data node plan starts with an ExchangeSink; got " + String.valueOf((Object)dataNodePlan)));
            return;
        }
        Map clusterToConcreteIndices = this.transportService.getRemoteClusterService().groupIndices(SearchRequest.DEFAULT_INDICES_OPTIONS, (String[])PlannerUtils.planConcreteIndices(physicalPlan).toArray(String[]::new));
        QueryPragmas queryPragmas = configuration.pragmas();
        Runnable cancelQueryOnFailure = this.cancelQueryOnFailure(rootTask);
        if (dataNodePlan == null) {
            if (!clusterToConcreteIndices.values().stream().allMatch(v -> v.indices().length == 0)) {
                String error = "expected no concrete indices without data node plan; got " + String.valueOf(clusterToConcreteIndices);
                assert (false) : error;
                listener.onFailure((Exception)new IllegalStateException(error));
                return;
            }
            ComputeContext computeContext = new ComputeContext(this.newChildSession(sessionId), "single", LOCAL_CLUSTER, flags, List.of(), configuration, foldContext, null, null);
            ComputeService.updateShardCountForCoordinatorOnlyQuery(execInfo);
            try (ComputeListener computeListener = new ComputeListener(this.transportService.getThreadPool(), cancelQueryOnFailure, (ActionListener<DriverCompletionInfo>)listener.map(completionInfo -> {
                ComputeService.updateExecutionInfoAfterCoordinatorOnlyQuery(execInfo);
                return new Result(physicalPlan.output(), collectedPages, (DriverCompletionInfo)completionInfo, execInfo);
            }));){
                this.runCompute(rootTask, computeContext, coordinatorPlan, computeListener.acquireCompute());
                return;
            }
        }
        if (clusterToConcreteIndices.values().stream().allMatch(v -> v.indices().length == 0)) {
            String error = "expected concrete indices with data node plan but got empty; data node plan " + String.valueOf((Object)dataNodePlan);
            assert (false) : error;
            listener.onFailure((Exception)new IllegalStateException(error));
            return;
        }
        Map clusterToOriginalIndices = this.transportService.getRemoteClusterService().groupIndices(SearchRequest.DEFAULT_INDICES_OPTIONS, PlannerUtils.planOriginalIndices(physicalPlan));
        OriginalIndices localOriginalIndices = (OriginalIndices)clusterToOriginalIndices.remove(LOCAL_CLUSTER);
        OriginalIndices localConcreteIndices = (OriginalIndices)clusterToConcreteIndices.remove(LOCAL_CLUSTER);
        List<Attribute> outputAttributes = physicalPlan.output();
        ExchangeSourceHandler exchangeSource = new ExchangeSourceHandler(queryPragmas.exchangeBufferSize(), (Executor)this.transportService.getThreadPool().executor("search"));
        listener = ActionListener.runBefore((ActionListener)listener, () -> this.exchangeService.removeExchangeSourceHandler(sessionId));
        this.exchangeService.addExchangeSourceHandler(sessionId, exchangeSource);
        try (ComputeListener computeListener = new ComputeListener(this.transportService.getThreadPool(), cancelQueryOnFailure, (ActionListener<DriverCompletionInfo>)listener.delegateFailureAndWrap((l, completionInfo) -> {
            ComputeService.failIfAllShardsFailed(execInfo, collectedPages);
            execInfo.markEndQuery();
            l.onResponse((Object)new Result(outputAttributes, collectedPages, (DriverCompletionInfo)completionInfo, execInfo));
        }));
             Releasable ignored = exchangeSource.addEmptySink();){
            AtomicBoolean localClusterWasInterrupted = new AtomicBoolean();
            try (ComputeListener localListener = new ComputeListener(this.transportService.getThreadPool(), cancelQueryOnFailure, (ActionListener<DriverCompletionInfo>)computeListener.acquireCompute().delegateFailure((l, completionInfo) -> {
                if (execInfo.clusterInfo.containsKey(LOCAL_CLUSTER)) {
                    execInfo.swapCluster(LOCAL_CLUSTER, (k, v) -> {
                        TimeValue tookTime = execInfo.tookSoFar();
                        EsqlExecutionInfo.Cluster.Builder builder = new EsqlExecutionInfo.Cluster.Builder((EsqlExecutionInfo.Cluster)v).setTook(tookTime);
                        if (v.getStatus() == EsqlExecutionInfo.Cluster.Status.RUNNING) {
                            Integer failedShards = execInfo.getCluster(LOCAL_CLUSTER).getFailedShards();
                            EsqlExecutionInfo.Cluster.Status status = localClusterWasInterrupted.get() || failedShards != null && failedShards > 0 || !v.getFailures().isEmpty() ? EsqlExecutionInfo.Cluster.Status.PARTIAL : EsqlExecutionInfo.Cluster.Status.SUCCESSFUL;
                            builder.setStatus(status);
                        }
                        return builder.build();
                    });
                }
                l.onResponse(completionInfo);
            }));){
                this.runCompute(rootTask, new ComputeContext(sessionId, "final", LOCAL_CLUSTER, flags, List.of(), configuration, foldContext, () -> ((ExchangeSourceHandler)exchangeSource).createExchangeSource(), null), coordinatorPlan, localListener.acquireCompute());
                if (localConcreteIndices != null && localConcreteIndices.indices().length > 0) {
                    ActionListener<DriverCompletionInfo> dataNodesListener = localListener.acquireCompute();
                    this.dataNodeComputeHandler.startComputeOnDataNodes(sessionId, LOCAL_CLUSTER, rootTask, flags, configuration, dataNodePlan, Set.of(localConcreteIndices.indices()), localOriginalIndices, exchangeSource, cancelQueryOnFailure, (ActionListener<ComputeResponse>)ActionListener.wrap(r -> {
                        localClusterWasInterrupted.set(execInfo.isStopped());
                        execInfo.swapCluster(LOCAL_CLUSTER, (k, v) -> new EsqlExecutionInfo.Cluster.Builder((EsqlExecutionInfo.Cluster)v).setTotalShards(r.getTotalShards()).setSuccessfulShards(r.getSuccessfulShards()).setSkippedShards(r.getSkippedShards()).setFailedShards(r.getFailedShards()).addFailures(r.failures).build());
                        dataNodesListener.onResponse((Object)r.getCompletionInfo());
                    }, e -> {
                        if (configuration.allowPartialResults() && EsqlCCSUtils.canAllowPartial(e)) {
                            execInfo.swapCluster(LOCAL_CLUSTER, (k, v) -> new EsqlExecutionInfo.Cluster.Builder((EsqlExecutionInfo.Cluster)v).setStatus(EsqlExecutionInfo.Cluster.Status.PARTIAL).addFailures(List.of(new ShardSearchFailure(e))).build());
                            dataNodesListener.onResponse((Object)DriverCompletionInfo.EMPTY);
                        } else {
                            dataNodesListener.onFailure(e);
                        }
                    }));
                }
            }
            List<ClusterComputeHandler.RemoteCluster> remoteClusters = this.clusterComputeHandler.getRemoteClusters(clusterToConcreteIndices, clusterToOriginalIndices);
            for (ClusterComputeHandler.RemoteCluster cluster : remoteClusters) {
                if (execInfo.getCluster(cluster.clusterAlias()).getStatus() != EsqlExecutionInfo.Cluster.Status.RUNNING) continue;
                this.clusterComputeHandler.startComputeOnRemoteCluster(sessionId, rootTask, configuration, dataNodePlan, exchangeSource, cluster, cancelQueryOnFailure, execInfo, (ActionListener<DriverCompletionInfo>)computeListener.acquireCompute().delegateResponse((l, ex) -> {
                    if (ex instanceof TransportException) {
                        TransportException te = (TransportException)ex;
                        l.onFailure((Exception)new RemoteException(cluster.clusterAlias(), (Throwable)FailureCollector.unwrapTransportException((TransportException)te)));
                    } else {
                        l.onFailure((Exception)new RemoteException(cluster.clusterAlias(), (Throwable)ex));
                    }
                }));
            }
        }
    }

    private static void updateShardCountForCoordinatorOnlyQuery(EsqlExecutionInfo execInfo) {
        if (execInfo.isCrossClusterSearch()) {
            for (String clusterAlias : execInfo.clusterAliases()) {
                execInfo.swapCluster(clusterAlias, (k, v) -> new EsqlExecutionInfo.Cluster.Builder((EsqlExecutionInfo.Cluster)v).setTotalShards(0).setSuccessfulShards(0).setSkippedShards(0).setFailedShards(0).build());
            }
        }
    }

    private static void updateExecutionInfoAfterCoordinatorOnlyQuery(EsqlExecutionInfo execInfo) {
        execInfo.markEndQuery();
        if (execInfo.isCrossClusterSearch()) {
            assert (execInfo.planningTookTime() != null) : "Planning took time should be set on EsqlExecutionInfo but is null";
            for (String clusterAlias : execInfo.clusterAliases()) {
                execInfo.swapCluster(clusterAlias, (k, v) -> {
                    EsqlExecutionInfo.Cluster.Builder builder = new EsqlExecutionInfo.Cluster.Builder((EsqlExecutionInfo.Cluster)v).setTook(execInfo.overallTook());
                    if (v.getStatus() == EsqlExecutionInfo.Cluster.Status.RUNNING) {
                        builder.setStatus(EsqlExecutionInfo.Cluster.Status.SUCCESSFUL);
                    }
                    return builder.build();
                });
            }
        }
    }

    static void failIfAllShardsFailed(EsqlExecutionInfo execInfo, List<Page> finalResults) {
        if (finalResults.stream().anyMatch(p -> p.getPositionCount() > 0)) {
            return;
        }
        int totalFailedShards = 0;
        for (EsqlExecutionInfo.Cluster cluster : execInfo.clusterInfo.values()) {
            Integer successfulShards = cluster.getSuccessfulShards();
            if (successfulShards != null && successfulShards > 0) {
                return;
            }
            if (cluster.getFailedShards() == null) continue;
            totalFailedShards += cluster.getFailedShards().intValue();
        }
        if (totalFailedShards == 0) {
            return;
        }
        FailureCollector failureCollector = new FailureCollector();
        for (EsqlExecutionInfo.Cluster cluster : execInfo.clusterInfo.values()) {
            Integer failedShards = cluster.getFailedShards();
            if (failedShards == null || failedShards <= 0) continue;
            assert (!cluster.getFailures().isEmpty()) : "expected failures for cluster [" + cluster.getClusterAlias() + "]";
            for (ShardSearchFailure failure : cluster.getFailures()) {
                Throwable throwable = failure.getCause();
                if (throwable instanceof Exception) {
                    Exception e = (Exception)throwable;
                    failureCollector.unwrapAndCollect(e);
                    continue;
                }
                assert (false) : "unexpected failure: " + String.valueOf(new AssertionError((Object)failure.getCause()));
                failureCollector.unwrapAndCollect((Exception)failure);
            }
        }
        ExceptionsHelper.reThrowIfNotNull((Throwable)failureCollector.getFailure());
    }

    void runCompute(CancellableTask task, ComputeContext context, PhysicalPlan plan, ActionListener<DriverCompletionInfo> listener) {
        listener = ActionListener.runBefore(listener, () -> Releasables.close(context.searchContexts()));
        ArrayList<EsPhysicalOperationProviders.ShardContext> contexts = new ArrayList<EsPhysicalOperationProviders.ShardContext>(context.searchContexts().size());
        for (int i = 0; i < context.searchContexts().size(); ++i) {
            SearchContext searchContext = context.searchContexts().get(i);
            SearchExecutionContext searchExecutionContext = new SearchExecutionContext(searchContext.getSearchExecutionContext()){

                public SourceProvider createSourceProvider() {
                    return new ReinitializingSourceProvider(() -> super.createSourceProvider());
                }
            };
            contexts.add(new EsPhysicalOperationProviders.DefaultShardContext(i, searchExecutionContext, searchContext.request().getAliasFilter()));
        }
        EsPhysicalOperationProviders physicalOperationProviders = new EsPhysicalOperationProviders(context.foldCtx(), contexts, this.searchService.getIndicesService().getAnalysis(), this.physicalSettings);
        try {
            List<Driver> drivers;
            LocalExecutionPlanner planner = new LocalExecutionPlanner(context.sessionId(), context.clusterAlias(), task, this.bigArrays, this.blockFactory, this.clusterService.getSettings(), context.configuration(), context.exchangeSourceSupplier(), context.exchangeSinkSupplier(), this.enrichLookupService, this.lookupFromIndexService, this.inferenceRunner, physicalOperationProviders, contexts);
            LOGGER.debug("Received physical plan:\n{}", new Object[]{plan});
            PhysicalPlan localPlan = PlannerUtils.localPlan(context.flags(), context.searchExecutionContexts(), context.configuration(), context.foldCtx(), plan);
            LocalExecutionPlanner.LocalExecutionPlan localExecutionPlan = planner.plan(context.taskDescription(), context.foldCtx(), localPlan);
            if (LOGGER.isDebugEnabled()) {
                LOGGER.debug("Local execution plan:\n{}", new Object[]{localExecutionPlan.describe()});
            }
            if ((drivers = localExecutionPlan.createDrivers(context.sessionId())).isEmpty()) {
                throw new IllegalStateException("no drivers created");
            }
            LOGGER.debug("using {} drivers", new Object[]{drivers.size()});
            this.driverRunner.executeDrivers((Task)task, drivers, (Executor)this.transportService.getThreadPool().executor("esql_worker"), ActionListener.releaseAfter((ActionListener)listener.map(ignored -> {
                if (context.configuration().profile()) {
                    return DriverCompletionInfo.includingProfiles((List)drivers, (String)context.taskDescription(), (String)this.clusterService.getClusterName().value(), (String)this.transportService.getLocalNode().getName(), (String)localPlan.toString());
                }
                return DriverCompletionInfo.excludingProfiles((List)drivers);
            }), () -> Releasables.close((Iterable)drivers)));
        }
        catch (Exception e) {
            listener.onFailure(e);
        }
    }

    static PhysicalPlan reductionPlan(ExchangeSinkExec plan, boolean enable) {
        PhysicalPlan p;
        PhysicalPlan reducePlan = new ExchangeSourceExec(plan.source(), plan.output(), plan.isIntermediateAgg());
        if (enable && (p = PlannerUtils.reductionPlan(plan)) != null) {
            reducePlan = (PhysicalPlan)p.replaceChildren(List.of(reducePlan));
        }
        return new ExchangeSinkExec(plan.source(), plan.output(), plan.isIntermediateAgg(), reducePlan);
    }

    String newChildSession(String session) {
        return session + "/" + this.childSessionIdGenerator.incrementAndGet();
    }

    Runnable cancelQueryOnFailure(CancellableTask task) {
        return new RunOnce(() -> {
            LOGGER.debug("cancelling ESQL task {} on failure", new Object[]{task});
            this.transportService.getTaskManager().cancelTaskAndDescendants(task, "cancelled on failure", false, ActionListener.noop());
        });
    }

    CancellableTask createGroupTask(Task parentTask, Supplier<String> description) throws TaskCancelledException {
        TaskManager taskManager = this.transportService.getTaskManager();
        try (ThreadContext.StoredContext ignored = this.transportService.getThreadPool().getThreadContext().newTraceContext();){
            CancellableTask cancellableTask = (CancellableTask)taskManager.register("transport", "esql_compute_group", (TaskAwareRequest)new ComputeGroupTaskRequest(parentTask.taskInfo(this.transportService.getLocalNode().getId(), false).taskId(), description));
            return cancellableTask;
        }
    }

    public EsqlFlags createFlags() {
        return new EsqlFlags(this.clusterService.getClusterSettings());
    }

    private static class ComputeGroupTaskRequest
    extends TransportRequest {
        private final Supplier<String> parentDescription;

        ComputeGroupTaskRequest(TaskId parentTask, Supplier<String> description) {
            this.parentDescription = description;
            this.setParentTask(parentTask);
        }

        public Task createTask(long id, String type, String action, TaskId parentTaskId, Map<String, String> headers) {
            assert (parentTaskId.isSet());
            return new CancellableTask(id, type, action, ComputeService.LOCAL_CLUSTER, parentTaskId, headers);
        }

        public String getDescription() {
            return "group [" + this.parentDescription.get() + "]";
        }
    }
}

