/*
 * Decompiled with CFR 0.152.
 */
package org.elasticsearch.xpack.esql.plugin;

import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.Executor;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.function.Supplier;
import org.elasticsearch.TransportVersions;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.action.ActionListenerResponseHandler;
import org.elasticsearch.action.ActionRunnable;
import org.elasticsearch.action.search.SearchRequest;
import org.elasticsearch.action.search.SearchShardsGroup;
import org.elasticsearch.action.search.SearchShardsRequest;
import org.elasticsearch.action.search.SearchShardsResponse;
import org.elasticsearch.action.search.TransportSearchShardsAction;
import org.elasticsearch.action.support.ChannelActionListener;
import org.elasticsearch.action.support.ContextPreservingActionListener;
import org.elasticsearch.action.support.RefCountingListener;
import org.elasticsearch.action.support.RefCountingRunnable;
import org.elasticsearch.action.support.SubscribableListener;
import org.elasticsearch.cluster.node.DiscoveryNode;
import org.elasticsearch.cluster.service.ClusterService;
import org.elasticsearch.common.VersionId;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.util.BigArrays;
import org.elasticsearch.common.util.concurrent.ThreadContext;
import org.elasticsearch.compute.data.BlockFactory;
import org.elasticsearch.compute.data.Page;
import org.elasticsearch.compute.operator.Driver;
import org.elasticsearch.compute.operator.DriverProfile;
import org.elasticsearch.compute.operator.DriverTaskRunner;
import org.elasticsearch.compute.operator.ResponseHeadersCollector;
import org.elasticsearch.compute.operator.exchange.ExchangeResponse;
import org.elasticsearch.compute.operator.exchange.ExchangeService;
import org.elasticsearch.compute.operator.exchange.ExchangeSink;
import org.elasticsearch.compute.operator.exchange.ExchangeSinkHandler;
import org.elasticsearch.compute.operator.exchange.ExchangeSourceHandler;
import org.elasticsearch.compute.operator.exchange.RemoteSink;
import org.elasticsearch.core.IOUtils;
import org.elasticsearch.core.Releasable;
import org.elasticsearch.core.Releasables;
import org.elasticsearch.core.Tuple;
import org.elasticsearch.index.Index;
import org.elasticsearch.index.query.QueryBuilder;
import org.elasticsearch.index.shard.IndexShard;
import org.elasticsearch.index.shard.ShardId;
import org.elasticsearch.index.shard.ShardNotFoundException;
import org.elasticsearch.logging.LogManager;
import org.elasticsearch.logging.Logger;
import org.elasticsearch.search.SearchService;
import org.elasticsearch.search.internal.AliasFilter;
import org.elasticsearch.search.internal.SearchContext;
import org.elasticsearch.search.internal.ShardSearchRequest;
import org.elasticsearch.tasks.CancellableTask;
import org.elasticsearch.tasks.Task;
import org.elasticsearch.tasks.TaskCancelledException;
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.transport.TransportChannel;
import org.elasticsearch.transport.TransportRequest;
import org.elasticsearch.transport.TransportRequestHandler;
import org.elasticsearch.transport.TransportRequestOptions;
import org.elasticsearch.transport.TransportResponse;
import org.elasticsearch.transport.TransportResponseHandler;
import org.elasticsearch.transport.TransportService;
import org.elasticsearch.xpack.esql.enrich.EnrichLookupService;
import org.elasticsearch.xpack.esql.plan.physical.OutputExec;
import org.elasticsearch.xpack.esql.plan.physical.PhysicalPlan;
import org.elasticsearch.xpack.esql.planner.EsPhysicalOperationProviders;
import org.elasticsearch.xpack.esql.planner.LocalExecutionPlanner;
import org.elasticsearch.xpack.esql.planner.PlannerUtils;
import org.elasticsearch.xpack.esql.plugin.DataNodeRequest;
import org.elasticsearch.xpack.esql.plugin.QueryPragmas;
import org.elasticsearch.xpack.esql.session.EsqlConfiguration;

public class ComputeService {
    private static final Logger LOGGER = LogManager.getLogger(ComputeService.class);
    private final SearchService searchService;
    private final BigArrays bigArrays;
    private final BlockFactory blockFactory;
    private final TransportService transportService;
    private final Executor esqlExecutor;
    private final DriverTaskRunner driverRunner;
    private final ExchangeService exchangeService;
    private final EnrichLookupService enrichLookupService;
    private final ClusterService clusterService;
    public static final String DATA_ACTION_NAME = "indices:data/read/esql/data";

    public ComputeService(SearchService searchService, TransportService transportService, ExchangeService exchangeService, EnrichLookupService enrichLookupService, ClusterService clusterService, ThreadPool threadPool, BigArrays bigArrays, BlockFactory blockFactory) {
        this.searchService = searchService;
        this.transportService = transportService;
        this.bigArrays = bigArrays.withCircuitBreaking();
        this.blockFactory = blockFactory;
        this.esqlExecutor = threadPool.executor("esql");
        transportService.registerRequestHandler(DATA_ACTION_NAME, this.esqlExecutor, DataNodeRequest::new, (TransportRequestHandler)new DataNodeRequestHandler());
        this.driverRunner = new DriverTaskRunner(transportService, this.esqlExecutor);
        this.exchangeService = exchangeService;
        this.enrichLookupService = enrichLookupService;
        this.clusterService = clusterService;
    }

    public void execute(String sessionId, CancellableTask rootTask, PhysicalPlan physicalPlan, EsqlConfiguration configuration, ActionListener<Result> listener) {
        Tuple<PhysicalPlan, PhysicalPlan> coordinatorAndDataNodePlan = PlannerUtils.breakPlanBetweenCoordinatorAndDataNode(physicalPlan, configuration);
        List collectedPages = Collections.synchronizedList(new ArrayList());
        listener = listener.delegateResponse((l, e) -> {
            collectedPages.forEach(p -> Releasables.closeExpectNoException(() -> ((Page)p).releaseBlocks()));
            l.onFailure(e);
        });
        OutputExec coordinatorPlan = new OutputExec((PhysicalPlan)((Object)coordinatorAndDataNodePlan.v1()), collectedPages::add);
        PhysicalPlan dataNodePlan = (PhysicalPlan)((Object)coordinatorAndDataNodePlan.v2());
        Set<String> concreteIndices = PlannerUtils.planConcreteIndices(physicalPlan);
        QueryPragmas queryPragmas = configuration.pragmas();
        if (dataNodePlan == null) {
            if (!concreteIndices.isEmpty()) {
                String error = "expected no concrete indices without data node plan; got " + concreteIndices;
                assert (false) : error;
                listener.onFailure((Exception)new IllegalStateException(error));
                return;
            }
            ComputeContext computeContext = new ComputeContext(sessionId, List.of(), configuration, null, null);
            this.runCompute(rootTask, computeContext, coordinatorPlan, (ActionListener<List<DriverProfile>>)listener.map(driverProfiles -> new Result(collectedPages, (List<DriverProfile>)driverProfiles)));
            return;
        }
        if (concreteIndices.isEmpty()) {
            String error = "expected concrete indices with data node plan but got empty; data node plan " + dataNodePlan;
            assert (false) : error;
            listener.onFailure((Exception)new IllegalStateException(error));
            return;
        }
        QueryBuilder requestFilter = PlannerUtils.requestFilter(dataNodePlan);
        LOGGER.debug("Sending data node plan\n{}\n with filter [{}]", new Object[]{dataNodePlan, requestFilter});
        ResponseHeadersCollector responseHeadersCollector = new ResponseHeadersCollector(this.transportService.getThreadPool().getThreadContext());
        listener = ActionListener.runBefore((ActionListener)listener, () -> ((ResponseHeadersCollector)responseHeadersCollector).finish());
        String[] originalIndices = PlannerUtils.planOriginalIndices(physicalPlan);
        this.computeTargetNodes((Task)rootTask, requestFilter, concreteIndices, originalIndices, (ActionListener<List<TargetNode>>)listener.delegateFailureAndWrap((delegate, targetNodes) -> {
            ExchangeSourceHandler exchangeSource = this.exchangeService.createSourceHandler(sessionId, queryPragmas.exchangeBufferSize(), "esql");
            List collectedProfiles = configuration.profile() ? Collections.synchronizedList(new ArrayList()) : null;
            try (Releasable ignored = () -> ((ExchangeSourceHandler)exchangeSource).decRef();
                 RefCountingListener requestRefs = new RefCountingListener(delegate.map(unused -> new Result(collectedPages, collectedProfiles)));){
                AtomicBoolean cancelled = new AtomicBoolean();
                exchangeSource.addCompletionListener(requestRefs.acquire());
                ComputeContext computeContext = new ComputeContext(sessionId, List.of(), configuration, exchangeSource, null);
                this.runCompute(rootTask, computeContext, coordinatorPlan, (ActionListener<List<DriverProfile>>)this.cancelOnFailure(rootTask, cancelled, (ActionListener<Void>)requestRefs.acquire()).map(driverProfiles -> {
                    responseHeadersCollector.collect();
                    if (configuration.profile()) {
                        collectedProfiles.addAll(driverProfiles);
                    }
                    return null;
                }));
                this.runComputeOnRemoteNodes(sessionId, rootTask, configuration, dataNodePlan, exchangeSource, (List<TargetNode>)targetNodes, () -> this.cancelOnFailure(rootTask, cancelled, (ActionListener<Void>)requestRefs.acquire()).map(response -> {
                    responseHeadersCollector.collect();
                    if (configuration.profile()) {
                        collectedProfiles.addAll(response.profiles);
                    }
                    return null;
                }));
            }
        }));
    }

    private void runComputeOnRemoteNodes(String sessionId, CancellableTask rootTask, EsqlConfiguration configuration, PhysicalPlan dataNodePlan, ExchangeSourceHandler exchangeSource, List<TargetNode> targetNodes, Supplier<ActionListener<DataNodeResponse>> listener) {
        SubscribableListener blockingSinkFuture = new SubscribableListener();
        exchangeSource.addRemoteSink((sourceFinished, l) -> blockingSinkFuture.addListener(l.map(ignored -> new ExchangeResponse(null, true))), 1);
        try (RefCountingRunnable exchangeRefs = new RefCountingRunnable(() -> blockingSinkFuture.onResponse(null));){
            for (TargetNode targetNode : targetNodes) {
                ActionListener targetNodeListener = ActionListener.releaseAfter(listener.get(), (Releasable)exchangeRefs.acquire());
                QueryPragmas queryPragmas = configuration.pragmas();
                ExchangeService.openExchange((TransportService)this.transportService, (DiscoveryNode)targetNode.node(), (String)sessionId, (int)queryPragmas.exchangeBufferSize(), (Executor)this.esqlExecutor, (ActionListener)targetNodeListener.delegateFailureAndWrap((delegate, unused) -> {
                    RemoteSink remoteSink = this.exchangeService.newRemoteSink((Task)rootTask, sessionId, this.transportService, targetNode.node);
                    exchangeSource.addRemoteSink(remoteSink, queryPragmas.concurrentExchangeClients());
                    this.transportService.sendChildRequest(targetNode.node, DATA_ACTION_NAME, (TransportRequest)new DataNodeRequest(sessionId, configuration, targetNode.shardIds, targetNode.aliasFilters, dataNodePlan), (Task)rootTask, TransportRequestOptions.EMPTY, (TransportResponseHandler)new ActionListenerResponseHandler(delegate, DataNodeResponse::new, this.esqlExecutor));
                }));
            }
        }
    }

    private ActionListener<Void> cancelOnFailure(CancellableTask task, AtomicBoolean cancelled, ActionListener<Void> listener) {
        return listener.delegateResponse((l, e) -> {
            l.onFailure(e);
            if (cancelled.compareAndSet(false, true)) {
                LOGGER.debug("cancelling ESQL task {} on failure", new Object[]{task});
                this.transportService.getTaskManager().cancelTaskAndDescendants(task, "cancelled", false, ActionListener.noop());
            }
        });
    }

    void runCompute(CancellableTask task, ComputeContext context, PhysicalPlan plan, ActionListener<List<DriverProfile>> listener) {
        List<Driver> drivers;
        listener = ActionListener.runBefore(listener, () -> Releasables.close(context.searchContexts));
        try {
            LocalExecutionPlanner planner = new LocalExecutionPlanner(context.sessionId, task, this.bigArrays, this.blockFactory, this.clusterService.getSettings(), context.configuration, context.exchangeSource(), context.exchangeSink(), this.enrichLookupService, new EsPhysicalOperationProviders(context.searchContexts));
            LOGGER.debug("Received physical plan:\n{}", new Object[]{plan});
            plan = PlannerUtils.localPlan(context.searchContexts, context.configuration, plan);
            LocalExecutionPlanner.LocalExecutionPlan localExecutionPlan = planner.plan(plan);
            if (LOGGER.isDebugEnabled()) {
                LOGGER.debug("Local execution plan:\n{}", new Object[]{localExecutionPlan.describe()});
            }
            if ((drivers = localExecutionPlan.createDrivers(context.sessionId)).isEmpty()) {
                throw new IllegalStateException("no drivers created");
            }
            LOGGER.debug("using {} drivers", new Object[]{drivers.size()});
        }
        catch (Exception e) {
            listener.onFailure(e);
            return;
        }
        ActionListener listenerCollectingStatus = listener.map(ignored -> {
            if (context.configuration.profile()) {
                return drivers.stream().map(d -> new DriverProfile(d.status().completedOperators())).toList();
            }
            return null;
        });
        listenerCollectingStatus = ActionListener.releaseAfter((ActionListener)listenerCollectingStatus, () -> Releasables.close((Iterable)drivers));
        this.driverRunner.executeDrivers((Task)task, drivers, (Executor)this.transportService.getThreadPool().executor("esql_worker"), listenerCollectingStatus);
    }

    private void acquireSearchContexts(List<ShardId> shardIds, EsqlConfiguration configuration, Map<Index, AliasFilter> aliasFilters, ActionListener<List<SearchContext>> listener) {
        ArrayList<IndexShard> targetShards = new ArrayList<IndexShard>();
        try {
            for (ShardId shardId : shardIds) {
                IndexShard indexShard = this.searchService.getIndicesService().indexServiceSafe(shardId.getIndex()).getShard(shardId.id());
                targetShards.add(indexShard);
            }
        }
        catch (Exception e) {
            listener.onFailure(e);
            return;
        }
        ActionRunnable doAcquire = ActionRunnable.supply(listener, () -> {
            ArrayList<SearchContext> searchContexts = new ArrayList<SearchContext>(targetShards.size());
            boolean success = false;
            try {
                for (IndexShard shard : targetShards) {
                    AliasFilter aliasFilter = aliasFilters.getOrDefault(shard.shardId().getIndex(), AliasFilter.EMPTY);
                    ShardSearchRequest shardRequest = new ShardSearchRequest(shard.shardId(), configuration.absoluteStartedTimeInMillis(), aliasFilter);
                    SearchContext context = this.searchService.createSearchContext(shardRequest, SearchService.NO_TIMEOUT);
                    searchContexts.add(context);
                }
                for (SearchContext searchContext : searchContexts) {
                    searchContext.preProcess();
                }
                success = true;
                ArrayList<SearchContext> arrayList = searchContexts;
                return arrayList;
            }
            finally {
                if (!success) {
                    IOUtils.close(searchContexts);
                }
            }
        });
        AtomicBoolean waitedForRefreshes = new AtomicBoolean();
        try (RefCountingRunnable refs = new RefCountingRunnable(() -> {
            if (waitedForRefreshes.get()) {
                this.esqlExecutor.execute((Runnable)doAcquire);
            } else {
                doAcquire.run();
            }
        });){
            for (IndexShard targetShard : targetShards) {
                Releasable ref = refs.acquire();
                targetShard.ensureShardSearchActive(await -> {
                    try (Releasable releasable = ref;){
                        if (await.booleanValue()) {
                            waitedForRefreshes.set(true);
                        }
                    }
                });
            }
        }
    }

    private void computeTargetNodes(Task parentTask, QueryBuilder filter, Set<String> concreteIndices, String[] originalIndices, ActionListener<List<TargetNode>> listener) {
        Map remoteIndices = this.transportService.getRemoteClusterService().groupIndices(SearchRequest.DEFAULT_INDICES_OPTIONS, originalIndices);
        remoteIndices.remove("");
        if (!remoteIndices.isEmpty()) {
            listener.onFailure((Exception)new IllegalArgumentException("ES|QL does not yet support querying remote indices " + Arrays.toString(originalIndices)));
            return;
        }
        ThreadContext threadContext = this.transportService.getThreadPool().getThreadContext();
        ContextPreservingActionListener preservingContextListener = ContextPreservingActionListener.wrapPreservingContext((ActionListener)listener.map(resp -> {
            HashMap<String, DiscoveryNode> nodes = new HashMap<String, DiscoveryNode>();
            for (DiscoveryNode node : resp.getNodes()) {
                nodes.put(node.getId(), node);
            }
            HashMap<String, List> nodeToShards = new HashMap<String, List>();
            HashMap nodeToAliasFilters = new HashMap();
            for (SearchShardsGroup group : resp.getGroups()) {
                ShardId shardId = group.shardId();
                if (group.skipped()) continue;
                if (group.allocatedNodes().isEmpty()) {
                    throw new ShardNotFoundException(group.shardId(), "no shard copies found {}", new Object[]{group.shardId()});
                }
                if (!concreteIndices.contains(shardId.getIndexName())) continue;
                String targetNode = (String)group.allocatedNodes().get(0);
                nodeToShards.computeIfAbsent(targetNode, k -> new ArrayList()).add(shardId);
                AliasFilter aliasFilter = (AliasFilter)resp.getAliasFilters().get(shardId.getIndex().getUUID());
                if (aliasFilter == null) continue;
                nodeToAliasFilters.computeIfAbsent(targetNode, k -> new HashMap()).put(shardId.getIndex(), aliasFilter);
            }
            ArrayList<TargetNode> targetNodes = new ArrayList<TargetNode>(nodeToShards.size());
            for (Map.Entry e : nodeToShards.entrySet()) {
                DiscoveryNode node = (DiscoveryNode)nodes.get(e.getKey());
                Map<Index, AliasFilter> aliasFilters = nodeToAliasFilters.getOrDefault(e.getKey(), Map.of());
                targetNodes.add(new TargetNode(node, (List)e.getValue(), aliasFilters));
            }
            return targetNodes;
        }), (ThreadContext)threadContext);
        try (ThreadContext.StoredContext ignored = threadContext.newStoredContextPreservingResponseHeaders();){
            threadContext.markAsSystemContext();
            SearchShardsRequest searchShardsRequest = new SearchShardsRequest(originalIndices, SearchRequest.DEFAULT_INDICES_OPTIONS, filter, null, null, false, null);
            this.transportService.sendChildRequest(this.transportService.getLocalNode(), TransportSearchShardsAction.TYPE.name(), (TransportRequest)searchShardsRequest, parentTask, TransportRequestOptions.EMPTY, (TransportResponseHandler)new ActionListenerResponseHandler((ActionListener)preservingContextListener, SearchShardsResponse::new, this.esqlExecutor));
        }
    }

    private class DataNodeRequestHandler
    implements TransportRequestHandler<DataNodeRequest> {
        private DataNodeRequestHandler() {
        }

        public void messageReceived(DataNodeRequest request, TransportChannel channel, Task task) {
            DataNodeRequestExecutor executor = new DataNodeRequestExecutor(request, (CancellableTask)task, ComputeService.this.exchangeService.getSinkHandler(request.sessionId()), request.configuration().pragmas().maxConcurrentShardsPerNode(), (ActionListener<DataNodeResponse>)new ChannelActionListener(channel));
            executor.start();
        }
    }

    record ComputeContext(String sessionId, List<SearchContext> searchContexts, EsqlConfiguration configuration, ExchangeSourceHandler exchangeSource, ExchangeSinkHandler exchangeSink) {
    }

    record TargetNode(DiscoveryNode node, List<ShardId> shardIds, Map<Index, AliasFilter> aliasFilters) {
    }

    private static class DataNodeResponse
    extends TransportResponse {
        private final List<DriverProfile> profiles;

        DataNodeResponse(List<DriverProfile> profiles) {
            this.profiles = profiles;
        }

        DataNodeResponse(StreamInput in) throws IOException {
            super(in);
            this.profiles = in.getTransportVersion().onOrAfter((VersionId)TransportVersions.ESQL_PROFILE) ? (in.readBoolean() ? in.readCollectionAsImmutableList(DriverProfile::new) : null) : null;
        }

        public void writeTo(StreamOutput out) throws IOException {
            if (out.getTransportVersion().onOrAfter((VersionId)TransportVersions.ESQL_PROFILE)) {
                if (this.profiles == null) {
                    out.writeBoolean(false);
                } else {
                    out.writeBoolean(true);
                    out.writeCollection(this.profiles);
                }
            }
        }
    }

    public record Result(List<Page> pages, List<DriverProfile> profiles) {
    }

    private class DataNodeRequestExecutor {
        private final DataNodeRequest request;
        private final CancellableTask parentTask;
        private final ExchangeSinkHandler exchangeSink;
        private final ActionListener<DataNodeResponse> listener;
        private final List<DriverProfile> driverProfiles;
        private final int maxConcurrentShards;
        private final ExchangeSink blockingSink;

        DataNodeRequestExecutor(DataNodeRequest request, CancellableTask parentTask, ExchangeSinkHandler exchangeSink, int maxConcurrentShards, ActionListener<DataNodeResponse> listener) {
            this.request = request;
            this.parentTask = parentTask;
            this.exchangeSink = exchangeSink;
            this.listener = listener;
            this.driverProfiles = request.configuration().profile() ? Collections.synchronizedList(new ArrayList()) : List.of();
            this.maxConcurrentShards = maxConcurrentShards;
            this.blockingSink = exchangeSink.createExchangeSink();
        }

        void start() {
            this.parentTask.addListener(() -> ComputeService.this.exchangeService.finishSinkHandler(this.request.sessionId(), (Exception)new TaskCancelledException(this.parentTask.getReasonCancelled())));
            this.runBatch(0);
        }

        private void runBatch(int startBatchIndex) {
            EsqlConfiguration configuration = this.request.configuration();
            String sessionId = this.request.sessionId();
            int endBatchIndex = Math.min(startBatchIndex + this.maxConcurrentShards, this.request.shardIds().size());
            List<ShardId> shardIds = this.request.shardIds().subList(startBatchIndex, endBatchIndex);
            ComputeService.this.acquireSearchContexts(shardIds, configuration, this.request.aliasFilters(), (ActionListener<List<SearchContext>>)ActionListener.wrap(searchContexts -> {
                assert (ThreadPool.assertCurrentThreadPool((String[])new String[]{"esql", "esql_worker"}));
                ComputeContext computeContext = new ComputeContext(sessionId, (List<SearchContext>)searchContexts, configuration, null, this.exchangeSink);
                ComputeService.this.runCompute(this.parentTask, computeContext, this.request.plan(), (ActionListener<List<DriverProfile>>)ActionListener.wrap(profiles -> this.onBatchCompleted(endBatchIndex, (List<DriverProfile>)profiles), this::onFailure));
            }, this::onFailure));
        }

        private void onBatchCompleted(int lastBatchIndex, List<DriverProfile> batchProfiles) {
            if (this.request.configuration().profile()) {
                this.driverProfiles.addAll(batchProfiles);
            }
            if (lastBatchIndex < this.request.shardIds().size() && !this.exchangeSink.isFinished()) {
                this.runBatch(lastBatchIndex);
            } else {
                this.blockingSink.finish();
                this.exchangeSink.addCompletionListener((ActionListener)ContextPreservingActionListener.wrapPreservingContext((ActionListener)ActionListener.runBefore((ActionListener)this.listener.map(nullValue -> new DataNodeResponse(this.driverProfiles)), () -> ComputeService.this.exchangeService.finishSinkHandler(this.request.sessionId(), null)), (ThreadContext)ComputeService.this.transportService.getThreadPool().getThreadContext()));
            }
        }

        private void onFailure(Exception e) {
            ComputeService.this.exchangeService.finishSinkHandler(this.request.sessionId(), e);
            this.listener.onFailure(e);
        }
    }
}

