/*
 * Decompiled with CFR 0.152.
 */
package org.elasticsearch.action.search;

import java.util.Arrays;
import java.util.List;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.apache.lucene.search.ScoreDoc;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.action.search.AbstractSearchAsyncAction;
import org.elasticsearch.action.search.ArraySearchPhaseResults;
import org.elasticsearch.action.search.CountedCollector;
import org.elasticsearch.action.search.FetchSearchPhase;
import org.elasticsearch.action.search.SearchActionListener;
import org.elasticsearch.action.search.SearchPhase;
import org.elasticsearch.action.search.SearchPhaseController;
import org.elasticsearch.action.search.SearchPhaseResults;
import org.elasticsearch.action.search.SearchProgressListener;
import org.elasticsearch.action.support.ThreadedActionListener;
import org.elasticsearch.client.internal.Client;
import org.elasticsearch.common.util.concurrent.AbstractRunnable;
import org.elasticsearch.search.SearchPhaseResult;
import org.elasticsearch.search.SearchShardTarget;
import org.elasticsearch.search.builder.SearchSourceBuilder;
import org.elasticsearch.search.dfs.AggregatedDfs;
import org.elasticsearch.search.internal.ShardSearchContextId;
import org.elasticsearch.search.rank.context.RankFeaturePhaseRankCoordinatorContext;
import org.elasticsearch.search.rank.feature.RankFeatureDoc;
import org.elasticsearch.search.rank.feature.RankFeatureResult;
import org.elasticsearch.search.rank.feature.RankFeatureShardRequest;
import org.elasticsearch.transport.Transport;

public class RankFeaturePhase
extends SearchPhase {
    static final String NAME = "rank-feature";
    private static final Logger logger = LogManager.getLogger(RankFeaturePhase.class);
    private final AbstractSearchAsyncAction<?> context;
    final SearchPhaseResults<SearchPhaseResult> queryPhaseResults;
    final SearchPhaseResults<SearchPhaseResult> rankPhaseResults;
    private final AggregatedDfs aggregatedDfs;
    private final SearchProgressListener progressListener;
    private final RankFeaturePhaseRankCoordinatorContext rankFeaturePhaseRankCoordinatorContext;

    RankFeaturePhase(SearchPhaseResults<SearchPhaseResult> queryPhaseResults, AggregatedDfs aggregatedDfs, AbstractSearchAsyncAction<?> context, RankFeaturePhaseRankCoordinatorContext rankFeaturePhaseRankCoordinatorContext) {
        super(NAME);
        assert (rankFeaturePhaseRankCoordinatorContext != null);
        this.rankFeaturePhaseRankCoordinatorContext = rankFeaturePhaseRankCoordinatorContext;
        if (context.getNumShards() != queryPhaseResults.getNumShards()) {
            throw new IllegalStateException("number of shards must match the length of the query results but doesn't:" + context.getNumShards() + "!=" + queryPhaseResults.getNumShards());
        }
        this.context = context;
        this.queryPhaseResults = queryPhaseResults;
        this.aggregatedDfs = aggregatedDfs;
        this.rankPhaseResults = new ArraySearchPhaseResults<SearchPhaseResult>(context.getNumShards());
        context.addReleasable(this.rankPhaseResults);
        this.progressListener = context.getTask().getProgressListener();
    }

    @Override
    protected void run() {
        this.context.execute(new AbstractRunnable(){

            @Override
            protected void doRun() throws Exception {
                RankFeaturePhase.this.innerRun(RankFeaturePhase.this.rankFeaturePhaseRankCoordinatorContext);
            }

            @Override
            public void onFailure(Exception e) {
                RankFeaturePhase.this.context.onPhaseFailure(RankFeaturePhase.NAME, "", e);
            }
        });
    }

    void innerRun(RankFeaturePhaseRankCoordinatorContext rankFeaturePhaseRankCoordinatorContext) throws Exception {
        SearchPhaseController.ReducedQueryPhase reducedQueryPhase = this.queryPhaseResults.reduce();
        ScoreDoc[] queryScoreDocs = reducedQueryPhase.sortedTopDocs().scoreDocs();
        List<Integer>[] docIdsToLoad = SearchPhaseController.fillDocIdsToLoad(this.context.getNumShards(), queryScoreDocs);
        CountedCollector<SearchPhaseResult> rankRequestCounter = new CountedCollector<SearchPhaseResult>(this.rankPhaseResults, this.context.getNumShards(), () -> this.onPhaseDone(rankFeaturePhaseRankCoordinatorContext, reducedQueryPhase), this.context);
        for (int i = 0; i < docIdsToLoad.length; ++i) {
            List<Integer> entry = docIdsToLoad[i];
            SearchPhaseResult queryResult = this.queryPhaseResults.getAtomicArray().get(i);
            if (entry == null || entry.isEmpty()) {
                if (queryResult != null) {
                    RankFeaturePhase.releaseIrrelevantSearchContext(queryResult, this.context);
                    this.progressListener.notifyRankFeatureResult(i);
                }
                rankRequestCounter.countDown();
                continue;
            }
            this.executeRankFeatureShardPhase(queryResult, rankRequestCounter, entry);
        }
    }

    static RankFeaturePhaseRankCoordinatorContext coordinatorContext(SearchSourceBuilder source, Client client) {
        return source == null || source.rankBuilder() == null ? null : source.rankBuilder().buildRankFeaturePhaseCoordinatorContext(source.size(), source.from(), client);
    }

    private void executeRankFeatureShardPhase(final SearchPhaseResult queryResult, final CountedCollector<SearchPhaseResult> rankRequestCounter, List<Integer> entry) {
        Transport.Connection connection;
        final SearchShardTarget shardTarget = queryResult.queryResult().getSearchShardTarget();
        final ShardSearchContextId contextId = queryResult.queryResult().getContextId();
        final int shardIndex = queryResult.getShardIndex();
        var listener = new SearchActionListener<RankFeatureResult>(shardTarget, shardIndex){

            @Override
            protected void innerOnResponse(RankFeatureResult response) {
                try {
                    RankFeaturePhase.this.progressListener.notifyRankFeatureResult(shardIndex);
                    rankRequestCounter.onResult(response);
                }
                catch (Exception e) {
                    RankFeaturePhase.this.context.onPhaseFailure(RankFeaturePhase.NAME, "", e);
                }
            }

            @Override
            public void onFailure(Exception e) {
                try {
                    logger.debug(() -> "[" + contextId + "] Failed to execute rank phase", (Throwable)e);
                    RankFeaturePhase.this.progressListener.notifyRankFeatureFailure(shardIndex, shardTarget, e);
                    rankRequestCounter.onFailure(shardIndex, shardTarget, e);
                }
                finally {
                    SearchPhase.releaseIrrelevantSearchContext(queryResult, RankFeaturePhase.this.context);
                }
            }
        };
        try {
            connection = this.context.getConnection(shardTarget.getClusterAlias(), shardTarget.getNodeId());
        }
        catch (Exception e) {
            listener.onFailure(e);
            return;
        }
        this.context.getSearchTransport().sendExecuteRankFeature(connection, new RankFeatureShardRequest(this.context.getOriginalIndices(queryResult.getShardIndex()), queryResult.getContextId(), queryResult.getShardSearchRequest(), entry), this.context.getTask(), listener);
    }

    private void onPhaseDone(final RankFeaturePhaseRankCoordinatorContext rankFeaturePhaseRankCoordinatorContext, final SearchPhaseController.ReducedQueryPhase reducedQueryPhase) {
        ThreadedActionListener<RankFeatureDoc[]> rankResultListener = new ThreadedActionListener<RankFeatureDoc[]>(this.context::execute, new ActionListener<RankFeatureDoc[]>(){

            @Override
            public void onResponse(RankFeatureDoc[] docsWithUpdatedScores) {
                ScoreDoc[] topResults = rankFeaturePhaseRankCoordinatorContext.rankAndPaginate(docsWithUpdatedScores, true);
                SearchPhaseController.ReducedQueryPhase reducedRankFeaturePhase = RankFeaturePhase.this.newReducedQueryPhaseResults(reducedQueryPhase, topResults);
                RankFeaturePhase.this.moveToNextPhase(RankFeaturePhase.this.rankPhaseResults, reducedRankFeaturePhase);
            }

            @Override
            public void onFailure(Exception e) {
                if (rankFeaturePhaseRankCoordinatorContext.failuresAllowed()) {
                    logger.warn("Exception computing updated ranks, continuing with existing ranks: {}", (Object)e.toString());
                    ScoreDoc[] inputDocs = reducedQueryPhase.sortedTopDocs().scoreDocs();
                    RankFeatureDoc[] rankDocs = new RankFeatureDoc[inputDocs.length];
                    for (int i = 0; i < inputDocs.length; ++i) {
                        ScoreDoc doc = inputDocs[i];
                        rankDocs[i] = new RankFeatureDoc(doc.doc, Float.isNaN(doc.score) ? 1.0f / (float)(i + 1) : doc.score, doc.shardIndex);
                    }
                    ScoreDoc[] topResults = rankFeaturePhaseRankCoordinatorContext.rankAndPaginate(rankDocs, false);
                    SearchPhaseController.ReducedQueryPhase reducedRankFeaturePhase = RankFeaturePhase.this.newReducedQueryPhaseResults(reducedQueryPhase, topResults);
                    RankFeaturePhase.this.moveToNextPhase(RankFeaturePhase.this.rankPhaseResults, reducedRankFeaturePhase);
                } else {
                    RankFeaturePhase.this.context.onPhaseFailure(RankFeaturePhase.NAME, "Computing updated ranks for results failed", e);
                }
            }
        });
        rankFeaturePhaseRankCoordinatorContext.computeRankScoresForGlobalResults((RankFeatureDoc[])this.rankPhaseResults.getSuccessfulResults().flatMap(r -> Arrays.stream(r.rankFeatureResult().shardResult().rankFeatureDocs)).filter(rfd -> rfd.featureData != null).toArray(RankFeatureDoc[]::new), rankResultListener);
    }

    private SearchPhaseController.ReducedQueryPhase newReducedQueryPhaseResults(SearchPhaseController.ReducedQueryPhase reducedQueryPhase, ScoreDoc[] scoreDocs) {
        return new SearchPhaseController.ReducedQueryPhase(reducedQueryPhase.totalHits(), reducedQueryPhase.fetchHits(), this.maxScore(scoreDocs), reducedQueryPhase.timedOut(), reducedQueryPhase.terminatedEarly(), reducedQueryPhase.suggest(), reducedQueryPhase.aggregations(), reducedQueryPhase.profileBuilder(), new SearchPhaseController.SortedTopDocs(scoreDocs, false, null, null, null, 0), reducedQueryPhase.sortValueFormats(), reducedQueryPhase.queryPhaseRankCoordinatorContext(), reducedQueryPhase.numReducePhases(), reducedQueryPhase.size(), reducedQueryPhase.from(), reducedQueryPhase.isEmptyResult());
    }

    private float maxScore(ScoreDoc[] scoreDocs) {
        float maxScore = Float.NaN;
        for (ScoreDoc scoreDoc : scoreDocs) {
            if (!Float.isNaN(maxScore) && !(scoreDoc.score > maxScore)) continue;
            maxScore = scoreDoc.score;
        }
        return maxScore;
    }

    void moveToNextPhase(SearchPhaseResults<SearchPhaseResult> phaseResults, SearchPhaseController.ReducedQueryPhase reducedQueryPhase) {
        this.context.executeNextPhase(NAME, () -> new FetchSearchPhase(phaseResults, this.aggregatedDfs, this.context, reducedQueryPhase));
    }
}

