/*
 * Decompiled with CFR 0.152.
 */
package org.elasticsearch.xpack.esql.plugin;

import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.Executor;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.function.Supplier;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.action.ActionListenerResponseHandler;
import org.elasticsearch.action.search.SearchRequest;
import org.elasticsearch.action.search.SearchShardsGroup;
import org.elasticsearch.action.search.SearchShardsRequest;
import org.elasticsearch.action.search.SearchShardsResponse;
import org.elasticsearch.action.support.ChannelActionListener;
import org.elasticsearch.action.support.ContextPreservingActionListener;
import org.elasticsearch.action.support.RefCountingListener;
import org.elasticsearch.action.support.RefCountingRunnable;
import org.elasticsearch.action.support.SubscribableListener;
import org.elasticsearch.cluster.node.DiscoveryNode;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.util.BigArrays;
import org.elasticsearch.common.util.concurrent.CountDown;
import org.elasticsearch.common.util.concurrent.ThreadContext;
import org.elasticsearch.compute.data.BlockFactory;
import org.elasticsearch.compute.data.Page;
import org.elasticsearch.compute.operator.Driver;
import org.elasticsearch.compute.operator.DriverTaskRunner;
import org.elasticsearch.compute.operator.exchange.ExchangeResponse;
import org.elasticsearch.compute.operator.exchange.ExchangeService;
import org.elasticsearch.compute.operator.exchange.ExchangeSinkHandler;
import org.elasticsearch.compute.operator.exchange.ExchangeSourceHandler;
import org.elasticsearch.compute.operator.exchange.RemoteSink;
import org.elasticsearch.core.IOUtils;
import org.elasticsearch.core.Releasable;
import org.elasticsearch.core.Releasables;
import org.elasticsearch.core.Tuple;
import org.elasticsearch.index.Index;
import org.elasticsearch.index.query.QueryBuilder;
import org.elasticsearch.index.shard.IndexShard;
import org.elasticsearch.index.shard.ShardId;
import org.elasticsearch.index.shard.ShardNotFoundException;
import org.elasticsearch.logging.LogManager;
import org.elasticsearch.logging.Logger;
import org.elasticsearch.search.DefaultSearchContext;
import org.elasticsearch.search.SearchService;
import org.elasticsearch.search.internal.AliasFilter;
import org.elasticsearch.search.internal.SearchContext;
import org.elasticsearch.search.internal.ShardSearchRequest;
import org.elasticsearch.tasks.CancellableTask;
import org.elasticsearch.tasks.Task;
import org.elasticsearch.tasks.TaskCancelledException;
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.transport.TransportChannel;
import org.elasticsearch.transport.TransportRequest;
import org.elasticsearch.transport.TransportRequestHandler;
import org.elasticsearch.transport.TransportRequestOptions;
import org.elasticsearch.transport.TransportResponse;
import org.elasticsearch.transport.TransportResponseHandler;
import org.elasticsearch.transport.TransportService;
import org.elasticsearch.xpack.esql.enrich.EnrichLookupService;
import org.elasticsearch.xpack.esql.plan.physical.OutputExec;
import org.elasticsearch.xpack.esql.plan.physical.PhysicalPlan;
import org.elasticsearch.xpack.esql.planner.EsPhysicalOperationProviders;
import org.elasticsearch.xpack.esql.planner.LocalExecutionPlanner;
import org.elasticsearch.xpack.esql.planner.PlannerUtils;
import org.elasticsearch.xpack.esql.plugin.DataNodeRequest;
import org.elasticsearch.xpack.esql.plugin.QueryPragmas;
import org.elasticsearch.xpack.esql.session.EsqlConfiguration;

public class ComputeService {
    private static final Logger LOGGER = LogManager.getLogger(ComputeService.class);
    private final SearchService searchService;
    private final BigArrays bigArrays;
    private final BlockFactory blockFactory;
    private final TransportService transportService;
    private final Executor esqlExecutor;
    private final DriverTaskRunner driverRunner;
    private final ExchangeService exchangeService;
    private final EnrichLookupService enrichLookupService;
    public static final String DATA_ACTION_NAME = "indices:data/read/esql/data";

    public ComputeService(SearchService searchService, TransportService transportService, ExchangeService exchangeService, EnrichLookupService enrichLookupService, ThreadPool threadPool, BigArrays bigArrays, BlockFactory blockFactory) {
        this.searchService = searchService;
        this.transportService = transportService;
        this.bigArrays = bigArrays.withCircuitBreaking();
        this.blockFactory = blockFactory;
        this.esqlExecutor = threadPool.executor("esql");
        transportService.registerRequestHandler(DATA_ACTION_NAME, this.esqlExecutor, DataNodeRequest::new, (TransportRequestHandler)new DataNodeRequestHandler());
        this.driverRunner = new DriverTaskRunner(transportService, this.esqlExecutor);
        this.exchangeService = exchangeService;
        this.enrichLookupService = enrichLookupService;
    }

    public void execute(String sessionId, CancellableTask rootTask, PhysicalPlan physicalPlan, EsqlConfiguration configuration, ActionListener<List<Page>> listener) {
        Tuple<PhysicalPlan, PhysicalPlan> coordinatorAndDataNodePlan = PlannerUtils.breakPlanBetweenCoordinatorAndDataNode(physicalPlan, configuration);
        List collectedPages = Collections.synchronizedList(new ArrayList());
        listener = listener.delegateResponse((l, e) -> {
            collectedPages.forEach(p -> Releasables.closeExpectNoException(() -> ((Page)p).releaseBlocks()));
            l.onFailure(e);
        });
        OutputExec coordinatorPlan = new OutputExec((PhysicalPlan)((Object)coordinatorAndDataNodePlan.v1()), collectedPages::add);
        PhysicalPlan dataNodePlan = (PhysicalPlan)((Object)coordinatorAndDataNodePlan.v2());
        Set<String> concreteIndices = PlannerUtils.planConcreteIndices(physicalPlan);
        QueryPragmas queryPragmas = configuration.pragmas();
        if (concreteIndices.isEmpty()) {
            ComputeContext computeContext = new ComputeContext(sessionId, List.of(), configuration, null, null);
            this.runCompute(rootTask, computeContext, coordinatorPlan, (ActionListener<Void>)listener.map(unused -> collectedPages));
            return;
        }
        QueryBuilder requestFilter = PlannerUtils.requestFilter(dataNodePlan);
        LOGGER.debug("Sending data node plan\n{}\n with filter [{}]", new Object[]{dataNodePlan, requestFilter});
        String[] originalIndices = PlannerUtils.planOriginalIndices(physicalPlan);
        this.computeTargetNodes((Task)rootTask, requestFilter, concreteIndices, originalIndices, (ActionListener<List<TargetNode>>)listener.delegateFailureAndWrap((delegate, targetNodes) -> {
            ExchangeSourceHandler exchangeSource = this.exchangeService.createSourceHandler(sessionId, queryPragmas.exchangeBufferSize(), "esql");
            try (Releasable ignored = () -> ((ExchangeSourceHandler)exchangeSource).decRef();
                 RefCountingListener requestRefs = new RefCountingListener(delegate.map(unused -> collectedPages));){
                AtomicBoolean cancelled = new AtomicBoolean();
                exchangeSource.addCompletionListener(requestRefs.acquire());
                ComputeContext computeContext = new ComputeContext(sessionId, List.of(), configuration, exchangeSource, null);
                this.runCompute(rootTask, computeContext, coordinatorPlan, this.cancelOnFailure(rootTask, cancelled, (ActionListener<Void>)requestRefs.acquire()));
                this.runComputeOnRemoteNodes(sessionId, rootTask, configuration, dataNodePlan, exchangeSource, (List<TargetNode>)targetNodes, () -> this.cancelOnFailure(rootTask, cancelled, (ActionListener<Void>)requestRefs.acquire()).map(unused -> null));
            }
        }));
    }

    private void runComputeOnRemoteNodes(String sessionId, CancellableTask rootTask, EsqlConfiguration configuration, PhysicalPlan dataNodePlan, ExchangeSourceHandler exchangeSource, List<TargetNode> targetNodes, Supplier<ActionListener<DataNodeResponse>> listener) {
        SubscribableListener blockingSinkFuture = new SubscribableListener();
        exchangeSource.addRemoteSink((sourceFinished, l) -> blockingSinkFuture.addListener(l.map(ignored -> new ExchangeResponse(null, true))), 1);
        try (RefCountingRunnable exchangeRefs = new RefCountingRunnable(() -> blockingSinkFuture.onResponse(null));){
            for (TargetNode targetNode : targetNodes) {
                ActionListener targetNodeListener = ActionListener.releaseAfter(listener.get(), (Releasable)exchangeRefs.acquire());
                QueryPragmas queryPragmas = configuration.pragmas();
                ExchangeService.openExchange((TransportService)this.transportService, (DiscoveryNode)targetNode.node(), (String)sessionId, (int)queryPragmas.exchangeBufferSize(), (Executor)this.esqlExecutor, (ActionListener)targetNodeListener.delegateFailureAndWrap((delegate, unused) -> {
                    RemoteSink remoteSink = this.exchangeService.newRemoteSink((Task)rootTask, sessionId, this.transportService, targetNode.node);
                    exchangeSource.addRemoteSink(remoteSink, queryPragmas.concurrentExchangeClients());
                    this.transportService.sendChildRequest(targetNode.node, DATA_ACTION_NAME, (TransportRequest)new DataNodeRequest(sessionId, configuration, targetNode.shardIds, targetNode.aliasFilters, dataNodePlan), (Task)rootTask, TransportRequestOptions.EMPTY, (TransportResponseHandler)new ActionListenerResponseHandler(delegate, DataNodeResponse::new, this.esqlExecutor));
                }));
            }
        }
    }

    private ActionListener<Void> cancelOnFailure(CancellableTask task, AtomicBoolean cancelled, ActionListener<Void> listener) {
        return listener.delegateResponse((l, e) -> {
            l.onFailure(e);
            if (cancelled.compareAndSet(false, true)) {
                LOGGER.debug("cancelling ESQL task {} on failure", new Object[]{task});
                this.transportService.getTaskManager().cancelTaskAndDescendants(task, "cancelled", false, ActionListener.noop());
            }
        });
    }

    void runCompute(CancellableTask task, ComputeContext context, PhysicalPlan plan, ActionListener<Void> listener) {
        List<Driver> drivers;
        listener = ActionListener.runAfter(listener, () -> Releasables.close(context.searchContexts));
        try {
            LocalExecutionPlanner planner = new LocalExecutionPlanner(context.sessionId, task, this.bigArrays, this.blockFactory, context.configuration, context.exchangeSource(), context.exchangeSink(), this.enrichLookupService, new EsPhysicalOperationProviders(context.searchContexts));
            LOGGER.debug("Received physical plan:\n{}", new Object[]{plan});
            plan = PlannerUtils.localPlan(context.searchContexts, context.configuration, plan);
            LocalExecutionPlanner.LocalExecutionPlan localExecutionPlan = planner.plan(plan);
            if (LOGGER.isDebugEnabled()) {
                LOGGER.debug("Local execution plan:\n{}", new Object[]{localExecutionPlan.describe()});
            }
            if ((drivers = localExecutionPlan.createDrivers(context.sessionId)).isEmpty()) {
                throw new IllegalStateException("no drivers created");
            }
            LOGGER.debug("using {} drivers", new Object[]{drivers.size()});
        }
        catch (Exception e) {
            listener.onFailure(e);
            return;
        }
        this.driverRunner.executeDrivers((Task)task, drivers, (Executor)this.transportService.getThreadPool().executor("esql_worker"), ActionListener.releaseAfter((ActionListener)listener, () -> Releasables.close((Iterable)drivers)));
    }

    private void acquireSearchContexts(List<ShardId> shardIds, Map<Index, AliasFilter> aliasFilters, ActionListener<List<SearchContext>> listener) {
        try {
            ArrayList<IndexShard> targetShards = new ArrayList<IndexShard>();
            for (ShardId shardId : shardIds) {
                IndexShard indexShard = this.searchService.getIndicesService().indexServiceSafe(shardId.getIndex()).getShard(shardId.id());
                targetShards.add(indexShard);
            }
            if (targetShards.isEmpty()) {
                listener.onResponse(List.of());
                return;
            }
            CountDown countDown = new CountDown(targetShards.size());
            for (IndexShard targetShard : targetShards) {
                targetShard.ensureShardSearchActive(ignored -> {
                    if (countDown.countDown()) {
                        ActionListener.completeWith((ActionListener)listener, () -> {
                            ArrayList<DefaultSearchContext> searchContexts = new ArrayList<DefaultSearchContext>(targetShards.size());
                            boolean success = false;
                            try {
                                for (IndexShard indexShard : targetShards) {
                                    AliasFilter aliasFilter = aliasFilters.getOrDefault(indexShard.shardId().getIndex(), AliasFilter.EMPTY);
                                    ShardSearchRequest shardSearchLocalRequest = new ShardSearchRequest(indexShard.shardId(), 0L, aliasFilter);
                                    DefaultSearchContext context = this.searchService.createSearchContext(shardSearchLocalRequest, SearchService.NO_TIMEOUT);
                                    searchContexts.add(context);
                                }
                                for (SearchContext searchContext : searchContexts) {
                                    searchContext.preProcess();
                                }
                                success = true;
                                ArrayList<DefaultSearchContext> arrayList = searchContexts;
                                return arrayList;
                            }
                            finally {
                                if (!success) {
                                    IOUtils.close(searchContexts);
                                }
                            }
                        });
                    }
                });
            }
        }
        catch (Exception e) {
            listener.onFailure(e);
        }
    }

    private void computeTargetNodes(Task parentTask, QueryBuilder filter, Set<String> concreteIndices, String[] originalIndices, ActionListener<List<TargetNode>> listener) {
        Map remoteIndices = this.transportService.getRemoteClusterService().groupIndices(SearchRequest.DEFAULT_INDICES_OPTIONS, originalIndices);
        remoteIndices.remove("");
        if (!remoteIndices.isEmpty()) {
            listener.onFailure((Exception)new IllegalArgumentException("ES|QL does not yet support querying remote indices " + Arrays.toString(originalIndices)));
            return;
        }
        ThreadContext threadContext = this.transportService.getThreadPool().getThreadContext();
        ContextPreservingActionListener preservingContextListener = ContextPreservingActionListener.wrapPreservingContext((ActionListener)listener.map(resp -> {
            HashMap<String, DiscoveryNode> nodes = new HashMap<String, DiscoveryNode>();
            for (DiscoveryNode node : resp.getNodes()) {
                nodes.put(node.getId(), node);
            }
            HashMap<String, List> nodeToShards = new HashMap<String, List>();
            HashMap nodeToAliasFilters = new HashMap();
            for (SearchShardsGroup group : resp.getGroups()) {
                ShardId shardId = group.shardId();
                if (group.skipped()) continue;
                if (group.allocatedNodes().isEmpty()) {
                    throw new ShardNotFoundException(group.shardId(), "no shard copies found {}", new Object[]{group.shardId()});
                }
                if (!concreteIndices.contains(shardId.getIndexName())) continue;
                String targetNode = (String)group.allocatedNodes().get(0);
                nodeToShards.computeIfAbsent(targetNode, k -> new ArrayList()).add(shardId);
                AliasFilter aliasFilter = (AliasFilter)resp.getAliasFilters().get(shardId.getIndex().getUUID());
                if (aliasFilter == null) continue;
                nodeToAliasFilters.computeIfAbsent(targetNode, k -> new HashMap()).put(shardId.getIndex(), aliasFilter);
            }
            ArrayList<TargetNode> targetNodes = new ArrayList<TargetNode>(nodeToShards.size());
            for (Map.Entry e : nodeToShards.entrySet()) {
                DiscoveryNode node = (DiscoveryNode)nodes.get(e.getKey());
                Map<Index, AliasFilter> aliasFilters = nodeToAliasFilters.getOrDefault(e.getKey(), Map.of());
                targetNodes.add(new TargetNode(node, (List)e.getValue(), aliasFilters));
            }
            return targetNodes;
        }), (ThreadContext)threadContext);
        try (ThreadContext.StoredContext ignored = threadContext.newStoredContextPreservingResponseHeaders();){
            threadContext.markAsSystemContext();
            SearchShardsRequest searchShardsRequest = new SearchShardsRequest(originalIndices, SearchRequest.DEFAULT_INDICES_OPTIONS, filter, null, null, false, null);
            this.transportService.sendChildRequest(this.transportService.getLocalNode(), "indices:admin/search/search_shards", (TransportRequest)searchShardsRequest, parentTask, TransportRequestOptions.EMPTY, (TransportResponseHandler)new ActionListenerResponseHandler((ActionListener)preservingContextListener, SearchShardsResponse::new, this.esqlExecutor));
        }
    }

    private class DataNodeRequestHandler
    implements TransportRequestHandler<DataNodeRequest> {
        private DataNodeRequestHandler() {
        }

        public void messageReceived(DataNodeRequest request, TransportChannel channel, Task task) {
            CancellableTask parentTask = (CancellableTask)task;
            String sessionId = request.sessionId();
            ExchangeSinkHandler exchangeSink = ComputeService.this.exchangeService.getSinkHandler(sessionId);
            parentTask.addListener(() -> ComputeService.this.exchangeService.finishSinkHandler(sessionId, (Exception)new TaskCancelledException("task cancelled")));
            ActionListener listener = new ChannelActionListener(channel).map(nullValue -> new DataNodeResponse());
            ComputeService.this.acquireSearchContexts(request.shardIds(), request.aliasFilters(), (ActionListener<List<SearchContext>>)ActionListener.wrap(searchContexts -> {
                ComputeContext computeContext = new ComputeContext(sessionId, (List<SearchContext>)searchContexts, request.configuration(), null, exchangeSink);
                ComputeService.this.runCompute(parentTask, computeContext, request.plan(), (ActionListener<Void>)ActionListener.wrap(unused -> exchangeSink.addCompletionListener(ActionListener.releaseAfter((ActionListener)listener, () -> ComputeService.this.exchangeService.finishSinkHandler(sessionId, null))), e -> {
                    ComputeService.this.exchangeService.finishSinkHandler(sessionId, e);
                    listener.onFailure(e);
                }));
            }, e -> {
                ComputeService.this.exchangeService.finishSinkHandler(sessionId, e);
                listener.onFailure(e);
            }));
        }
    }

    record ComputeContext(String sessionId, List<SearchContext> searchContexts, EsqlConfiguration configuration, ExchangeSourceHandler exchangeSource, ExchangeSinkHandler exchangeSink) {
    }

    record TargetNode(DiscoveryNode node, List<ShardId> shardIds, Map<Index, AliasFilter> aliasFilters) {
    }

    private static class DataNodeResponse
    extends TransportResponse {
        DataNodeResponse() {
        }

        DataNodeResponse(StreamInput in) throws IOException {
            super(in);
        }

        public void writeTo(StreamOutput out) {
        }
    }
}

