/*
 * Decompiled with CFR 0.152.
 */
package org.elasticsearch.search.retriever;

import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import java.util.Locale;
import java.util.Objects;
import org.apache.lucene.search.ScoreDoc;
import org.apache.lucene.util.SetOnce;
import org.elasticsearch.ElasticsearchStatusException;
import org.elasticsearch.ExceptionsHelper;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.action.ActionRequestValidationException;
import org.elasticsearch.action.ValidateActions;
import org.elasticsearch.action.search.MultiSearchRequest;
import org.elasticsearch.action.search.MultiSearchResponse;
import org.elasticsearch.action.search.SearchRequest;
import org.elasticsearch.action.search.SearchResponse;
import org.elasticsearch.action.search.TransportMultiSearchAction;
import org.elasticsearch.features.NodeFeature;
import org.elasticsearch.index.query.QueryBuilder;
import org.elasticsearch.index.query.QueryRewriteContext;
import org.elasticsearch.rest.RestStatus;
import org.elasticsearch.search.SearchHit;
import org.elasticsearch.search.builder.PointInTimeBuilder;
import org.elasticsearch.search.builder.SearchSourceBuilder;
import org.elasticsearch.search.fetch.StoredFieldsContext;
import org.elasticsearch.search.rank.RankDoc;
import org.elasticsearch.search.retriever.RankDocsRetrieverBuilder;
import org.elasticsearch.search.retriever.RetrieverBuilder;
import org.elasticsearch.search.sort.FieldSortBuilder;
import org.elasticsearch.search.sort.ScoreSortBuilder;
import org.elasticsearch.search.sort.ShardDocSortField;
import org.elasticsearch.search.sort.SortBuilder;
import org.elasticsearch.xcontent.ParseField;

public abstract class CompoundRetrieverBuilder<T extends CompoundRetrieverBuilder<T>>
extends RetrieverBuilder {
    public static final NodeFeature INNER_RETRIEVERS_FILTER_SUPPORT = new NodeFeature("inner_retrievers_filter_support");
    public static final ParseField RANK_WINDOW_SIZE_FIELD = new ParseField("rank_window_size", new String[0]);
    protected final int rankWindowSize;
    protected final List<RetrieverSource> innerRetrievers;

    protected CompoundRetrieverBuilder(List<RetrieverSource> innerRetrievers, int rankWindowSize) {
        this.rankWindowSize = rankWindowSize;
        this.innerRetrievers = innerRetrievers;
    }

    public T addChild(RetrieverBuilder retrieverBuilder) {
        this.innerRetrievers.add(new RetrieverSource(retrieverBuilder, null));
        return (T)this;
    }

    protected abstract T clone(List<RetrieverSource> var1, List<QueryBuilder> var2);

    protected abstract RankDoc[] combineInnerRetrieverResults(List<ScoreDoc[]> var1, boolean var2);

    @Override
    public final boolean isCompound() {
        return true;
    }

    public ParseField getRankWindowSizeField() {
        return RANK_WINDOW_SIZE_FIELD;
    }

    @Override
    public final RetrieverBuilder rewrite(final QueryRewriteContext ctx) throws IOException {
        if (ctx.getPointInTimeBuilder() == null) {
            throw new IllegalStateException("PIT is required");
        }
        List<QueryBuilder> newPreFilters = this.rewritePreFilters(ctx);
        if (newPreFilters != this.preFilterQueryBuilders) {
            return this.clone(this.innerRetrievers, newPreFilters);
        }
        boolean hasChanged = false;
        ArrayList<RetrieverSource> newRetrievers = new ArrayList<RetrieverSource>();
        for (RetrieverSource entry : this.innerRetrievers) {
            Object newRetriever;
            if (entry.retriever.isCompound() && !this.preFilterQueryBuilders.isEmpty()) {
                entry.retriever.getPreFilterQueryBuilders().addAll(this.preFilterQueryBuilders);
            }
            if ((newRetriever = entry.retriever.rewrite(ctx)) != entry.retriever) {
                newRetrievers.add(new RetrieverSource((RetrieverBuilder)newRetriever, null));
                hasChanged |= true;
                continue;
            }
            SearchSourceBuilder sourceBuilder = entry.source != null ? entry.source : this.createSearchSourceBuilder(ctx.getPointInTimeBuilder(), (RetrieverBuilder)newRetriever);
            SearchSourceBuilder rewrittenSource = sourceBuilder.rewrite(ctx);
            newRetrievers.add(new RetrieverSource((RetrieverBuilder)newRetriever, rewrittenSource));
            hasChanged |= rewrittenSource != entry.source;
        }
        if (hasChanged) {
            return this.clone(newRetrievers, newPreFilters);
        }
        final SetOnce results = new SetOnce();
        MultiSearchRequest multiSearchRequest = new MultiSearchRequest();
        for (RetrieverSource entry : this.innerRetrievers) {
            SearchRequest searchRequest = new SearchRequest().source(entry.source);
            searchRequest.setPreFilterShardSize(Integer.MAX_VALUE);
            multiSearchRequest.add(searchRequest);
        }
        ctx.registerAsyncAction((client, listener) -> client.execute(TransportMultiSearchAction.TYPE, multiSearchRequest, new ActionListener<MultiSearchResponse>(){

            @Override
            public void onResponse(MultiSearchResponse items) {
                ArrayList<ScoreDoc[]> topDocs = new ArrayList<ScoreDoc[]>();
                ArrayList<Exception> failures = new ArrayList<Exception>();
                int statusCode = RestStatus.OK.getStatus();
                ArrayList<String> retrieversWithFailures = new ArrayList<String>();
                for (int i = 0; i < items.getResponses().length; ++i) {
                    MultiSearchResponse.Item item = items.getResponses()[i];
                    if (item.isFailure()) {
                        failures.add(item.getFailure());
                        retrieversWithFailures.add(CompoundRetrieverBuilder.this.innerRetrievers.get(i).retriever().getName());
                        if (ExceptionsHelper.status(item.getFailure()).getStatus() <= statusCode) continue;
                        statusCode = ExceptionsHelper.status(item.getFailure()).getStatus();
                        continue;
                    }
                    assert (item.getResponse() != null);
                    RankDoc[] rankDocs = CompoundRetrieverBuilder.this.getRankDocs(item.getResponse());
                    CompoundRetrieverBuilder.this.innerRetrievers.get(i).retriever().setRankDocs(rankDocs);
                    topDocs.add(rankDocs);
                }
                if (!failures.isEmpty()) {
                    assert (statusCode != RestStatus.OK.getStatus());
                    String errMessage = "[" + CompoundRetrieverBuilder.this.getName() + "] search failed - retrievers '" + String.valueOf(retrieversWithFailures) + "' returned errors. All failures are attached as suppressed exceptions.";
                    ElasticsearchStatusException ex = new ElasticsearchStatusException(errMessage, RestStatus.fromCode(statusCode), new Object[0]);
                    failures.forEach(ex::addSuppressed);
                    listener.onFailure(ex);
                } else {
                    results.set(CompoundRetrieverBuilder.this.combineInnerRetrieverResults(topDocs, ctx.isExplain()));
                    listener.onResponse(null);
                }
            }

            @Override
            public void onFailure(Exception e) {
                listener.onFailure(e);
            }
        }));
        RankDocsRetrieverBuilder rankDocsRetrieverBuilder = new RankDocsRetrieverBuilder(this.rankWindowSize, newRetrievers.stream().map(s -> s.retriever).toList(), results::get);
        rankDocsRetrieverBuilder.retrieverName(this.retrieverName());
        return rankDocsRetrieverBuilder;
    }

    @Override
    public final QueryBuilder topDocsQuery() {
        throw new IllegalStateException("Should not be called, missing a rewrite?");
    }

    @Override
    public final QueryBuilder explainQuery() {
        throw new IllegalStateException("Should not be called, missing a rewrite?");
    }

    @Override
    public final void extractToSearchSourceBuilder(SearchSourceBuilder searchSourceBuilder, boolean compoundUsed) {
        throw new IllegalStateException("Should not be called, missing a rewrite?");
    }

    @Override
    public ActionRequestValidationException validate(SearchSourceBuilder source, ActionRequestValidationException validationException, boolean isScroll, boolean allowPartialSearchResults) {
        validationException = super.validate(source, validationException, isScroll, allowPartialSearchResults);
        int size = source.size();
        if (size > this.rankWindowSize) {
            validationException = ValidateActions.addValidationError(String.format(Locale.ROOT, "[%s] requires [%s: %d] be greater than or equal to [size: %d]", this.getName(), this.getRankWindowSizeField().getPreferredName(), this.rankWindowSize, size), validationException);
        }
        if (allowPartialSearchResults) {
            validationException = ValidateActions.addValidationError("cannot specify [" + this.getName() + "] and [allow_partial_search_results]", validationException);
        }
        if (isScroll) {
            validationException = ValidateActions.addValidationError("cannot specify [" + this.getName() + "] and [scroll]", validationException);
        }
        for (RetrieverSource innerRetriever : this.innerRetrievers) {
            validationException = innerRetriever.retriever().validate(source, validationException, isScroll, allowPartialSearchResults);
            RetrieverBuilder retrieverBuilder = innerRetriever.retriever();
            if (!(retrieverBuilder instanceof CompoundRetrieverBuilder)) continue;
            CompoundRetrieverBuilder compoundChild = (CompoundRetrieverBuilder)retrieverBuilder;
            if (this.rankWindowSize <= compoundChild.rankWindowSize) continue;
            String errorMessage = String.format(Locale.ROOT, "[%s] requires [%s: %d] to be smaller than or equal to its sub retriever's %s [%s: %d]", this.getName(), this.getRankWindowSizeField().getPreferredName(), this.rankWindowSize, compoundChild.getName(), compoundChild.getRankWindowSizeField(), compoundChild.rankWindowSize);
            validationException = ValidateActions.addValidationError(errorMessage, validationException);
        }
        return validationException;
    }

    @Override
    public boolean doEquals(Object o) {
        CompoundRetrieverBuilder that = (CompoundRetrieverBuilder)o;
        return this.rankWindowSize == that.rankWindowSize && Objects.equals(this.innerRetrievers, that.innerRetrievers);
    }

    @Override
    public int doHashCode() {
        return Objects.hash(this.innerRetrievers);
    }

    protected final SearchSourceBuilder createSearchSourceBuilder(PointInTimeBuilder pit, RetrieverBuilder retrieverBuilder) {
        ArrayList sortBuilders;
        SearchSourceBuilder sourceBuilder = new SearchSourceBuilder().pointInTimeBuilder(pit).trackTotalHits(false).storedFields(new StoredFieldsContext(false)).size(this.rankWindowSize);
        if (!this.preFilterQueryBuilders.isEmpty()) {
            retrieverBuilder.getPreFilterQueryBuilders().addAll(this.preFilterQueryBuilders);
        }
        retrieverBuilder.extractToSearchSourceBuilder(sourceBuilder, true);
        ArrayList<SortBuilder<Object>> arrayList = sortBuilders = sourceBuilder.sorts() != null ? new ArrayList(sourceBuilder.sorts()) : new ArrayList();
        if (sortBuilders.isEmpty()) {
            sortBuilders.add(new ScoreSortBuilder());
        }
        sortBuilders.add(new FieldSortBuilder("_shard_doc"));
        sourceBuilder.sort(sortBuilders);
        sourceBuilder.skipInnerHits(true);
        return this.finalizeSourceBuilder(sourceBuilder);
    }

    protected SearchSourceBuilder finalizeSourceBuilder(SearchSourceBuilder sourceBuilder) {
        return sourceBuilder;
    }

    private RankDoc[] getRankDocs(SearchResponse searchResponse) {
        int size = searchResponse.getHits().getHits().length;
        RankDoc[] docs = new RankDoc[size];
        for (int i = 0; i < size; ++i) {
            SearchHit hit = searchResponse.getHits().getAt(i);
            long sortValue = (Long)hit.getRawSortValues()[hit.getRawSortValues().length - 1];
            int doc = ShardDocSortField.decodeDoc(sortValue);
            int shardRequestIndex = ShardDocSortField.decodeShardRequestIndex(sortValue);
            docs[i] = new RankDoc(doc, hit.getScore(), shardRequestIndex);
            docs[i].rank = i + 1;
        }
        return docs;
    }

    public record RetrieverSource(RetrieverBuilder retriever, SearchSourceBuilder source) {
    }
}

