/*
 * Decompiled with CFR 0.152.
 */
package org.elasticsearch.action.search;

import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import org.apache.lucene.search.TotalHits;
import org.elasticsearch.action.search.SearchProgressListener;
import org.elasticsearch.action.search.SearchResponse;
import org.elasticsearch.action.search.SearchShard;
import org.elasticsearch.action.search.ShardSearchFailure;
import org.elasticsearch.action.search.TransportSearchAction;
import org.elasticsearch.common.util.CollectionUtils;
import org.elasticsearch.core.TimeValue;
import org.elasticsearch.search.SearchShardTarget;
import org.elasticsearch.search.aggregations.InternalAggregations;
import org.elasticsearch.search.query.QuerySearchResult;

public class CCSSingleCoordinatorSearchProgressListener
extends SearchProgressListener {
    private SearchResponse.Clusters clusters;
    private TransportSearchAction.SearchTimeProvider timeProvider;

    @Override
    public void onListShards(List<SearchShard> shards, List<SearchShard> skipped, SearchResponse.Clusters clusters, boolean fetchPhase, TransportSearchAction.SearchTimeProvider timeProvider) {
        assert (!clusters.isCcsMinimizeRoundtrips().booleanValue()) : "minimize_roundtrips must be false to use this SearchListener";
        this.clusters = clusters;
        this.timeProvider = timeProvider;
        Map<String, Integer> skippedByClusterAlias = this.partitionCountsByClusterAlias(skipped);
        Map<String, Integer> totalByClusterAlias = this.partitionCountsByClusterAlias(shards);
        skippedByClusterAlias.forEach((cluster, count) -> totalByClusterAlias.merge((String)cluster, (Integer)count, Integer::sum));
        for (Map.Entry<String, Integer> entry : totalByClusterAlias.entrySet()) {
            String clusterAlias = entry.getKey();
            clusters.swapCluster(clusterAlias, (k, v) -> {
                assert (v.getTotalShards() == null) : "total shards should not be set on a Cluster before onListShards";
                int totalCount = (Integer)entry.getValue();
                int skippedCount = skippedByClusterAlias.getOrDefault(k, 0);
                TimeValue took = null;
                SearchResponse.Cluster.Status status = v.getStatus();
                assert (status == SearchResponse.Cluster.Status.RUNNING) : "should have RUNNING status during onListShards but has " + status;
                if (skippedCount == totalCount) {
                    took = new TimeValue(timeProvider.buildTookInMillis());
                    status = SearchResponse.Cluster.Status.SUCCESSFUL;
                }
                return new SearchResponse.Cluster.Builder((SearchResponse.Cluster)v).setStatus(status).setTotalShards(totalCount).setSuccessfulShards(skippedCount).setSkippedShards(skippedCount).setFailedShards(0).setTook(took).setTimedOut(false).build();
            });
        }
    }

    @Override
    public void onQueryResult(int shardIndex, QuerySearchResult queryResult) {
        if (queryResult.searchTimedOut() && this.clusters.hasClusterObjects()) {
            SearchShardTarget shardTarget = queryResult.getSearchShardTarget();
            String clusterAlias = shardTarget.getClusterAlias();
            if (clusterAlias == null) {
                clusterAlias = "";
            }
            this.clusters.swapCluster(clusterAlias, (k, v) -> {
                if (v.isTimedOut()) {
                    return v;
                }
                if (v.getStatus() == SearchResponse.Cluster.Status.FAILED || v.getStatus() == SearchResponse.Cluster.Status.SKIPPED) {
                    return v;
                }
                return new SearchResponse.Cluster.Builder((SearchResponse.Cluster)v).setTimedOut(true).build();
            });
        }
    }

    @Override
    public void onQueryFailure(int shardIndex, SearchShardTarget shardTarget, Exception e) {
        if (!this.clusters.hasClusterObjects()) {
            return;
        }
        String clusterAlias = shardTarget.getClusterAlias();
        if (clusterAlias == null) {
            clusterAlias = "";
        }
        this.clusters.swapCluster(clusterAlias, (k, v) -> {
            SearchResponse.Cluster.Status status;
            TimeValue took;
            int numFailedShards;
            int n = numFailedShards = v.getFailedShards() == null ? 1 : v.getFailedShards() + 1;
            assert (v.getTotalShards() != null) : "total shards should be set on the Cluster but not for " + k;
            if (v.getTotalShards() == numFailedShards) {
                took = null;
                status = v.isSkipUnavailable() ? SearchResponse.Cluster.Status.SKIPPED : SearchResponse.Cluster.Status.FAILED;
            } else if (v.getTotalShards() == numFailedShards + v.getSuccessfulShards()) {
                status = SearchResponse.Cluster.Status.PARTIAL;
                took = new TimeValue(this.timeProvider.buildTookInMillis());
            } else {
                took = null;
                status = SearchResponse.Cluster.Status.RUNNING;
            }
            return new SearchResponse.Cluster.Builder((SearchResponse.Cluster)v).setStatus(status).setFailedShards(numFailedShards).setFailures(CollectionUtils.appendToCopy(v.getFailures(), new ShardSearchFailure(e, shardTarget))).setTook(took).build();
        });
    }

    @Override
    public void onPartialReduce(List<SearchShard> shards, TotalHits totalHits, InternalAggregations aggs, int reducePhase) {
        Map<String, Integer> totalByClusterAlias = this.partitionCountsByClusterAlias(shards);
        for (Map.Entry<String, Integer> entry : totalByClusterAlias.entrySet()) {
            String clusterAlias = entry.getKey();
            int successfulCount = entry.getValue();
            this.clusters.swapCluster(clusterAlias, (k, v) -> {
                SearchResponse.Cluster.Status status = v.getStatus();
                if (status != SearchResponse.Cluster.Status.RUNNING) {
                    return v;
                }
                TimeValue took = null;
                int successfulShards = successfulCount + v.getSkippedShards();
                if (successfulShards == v.getTotalShards()) {
                    status = v.isTimedOut() ? SearchResponse.Cluster.Status.PARTIAL : SearchResponse.Cluster.Status.SUCCESSFUL;
                    took = new TimeValue(this.timeProvider.buildTookInMillis());
                } else if (successfulShards + v.getFailedShards() == v.getTotalShards()) {
                    status = SearchResponse.Cluster.Status.PARTIAL;
                    took = new TimeValue(this.timeProvider.buildTookInMillis());
                }
                return new SearchResponse.Cluster.Builder((SearchResponse.Cluster)v).setStatus(status).setSuccessfulShards(successfulShards).setTook(took).build();
            });
        }
    }

    @Override
    public void onFinalReduce(List<SearchShard> shards, TotalHits totalHits, InternalAggregations aggs, int reducePhase) {
        if (!this.clusters.hasClusterObjects()) {
            return;
        }
        Map<String, Integer> totalByClusterAlias = this.partitionCountsByClusterAlias(shards);
        for (Map.Entry<String, Integer> entry : totalByClusterAlias.entrySet()) {
            String clusterAlias = entry.getKey();
            int successfulCount = entry.getValue();
            this.clusters.swapCluster(clusterAlias, (k, v) -> {
                SearchResponse.Cluster.Status status = v.getStatus();
                if (status != SearchResponse.Cluster.Status.RUNNING) {
                    return v;
                }
                TimeValue took = new TimeValue(this.timeProvider.buildTookInMillis());
                int successfulShards = successfulCount + v.getSkippedShards();
                assert (successfulShards + v.getFailedShards() == v.getTotalShards()) : "successfulShards(" + successfulShards + ") + failedShards(" + v.getFailedShards() + ") != totalShards (" + v.getTotalShards() + ")";
                if (v.isTimedOut() || successfulShards < v.getTotalShards()) {
                    status = SearchResponse.Cluster.Status.PARTIAL;
                } else {
                    assert (successfulShards == v.getTotalShards()) : "successful (" + successfulShards + ") should equal total(" + v.getTotalShards() + ") if get here";
                    status = SearchResponse.Cluster.Status.SUCCESSFUL;
                }
                return new SearchResponse.Cluster.Builder((SearchResponse.Cluster)v).setStatus(status).setSuccessfulShards(successfulShards).setTook(took).build();
            });
        }
    }

    @Override
    public void onFetchResult(int shardIndex) {
    }

    @Override
    public void onFetchFailure(int shardIndex, SearchShardTarget shardTarget, Exception exc) {
    }

    private Map<String, Integer> partitionCountsByClusterAlias(List<SearchShard> shards) {
        HashMap<String, Integer> res = new HashMap<String, Integer>();
        for (SearchShard shard : shards) {
            res.merge(Objects.requireNonNullElse(shard.clusterAlias(), ""), 1, Integer::sum);
        }
        return res;
    }
}

