/*
 * Decompiled with CFR 0.152.
 */
package org.elasticsearch.xpack.rank.rrf;

import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import org.apache.lucene.search.ScoreDoc;
import org.elasticsearch.action.ActionRequestValidationException;
import org.elasticsearch.action.ResolvedIndices;
import org.elasticsearch.common.ParsingException;
import org.elasticsearch.common.util.Maps;
import org.elasticsearch.features.NodeFeature;
import org.elasticsearch.index.query.MatchNoneQueryBuilder;
import org.elasticsearch.index.query.QueryBuilder;
import org.elasticsearch.index.query.QueryRewriteContext;
import org.elasticsearch.license.LicenseUtils;
import org.elasticsearch.search.builder.SearchSourceBuilder;
import org.elasticsearch.search.rank.RankDoc;
import org.elasticsearch.search.retriever.CompoundRetrieverBuilder;
import org.elasticsearch.search.retriever.RetrieverBuilder;
import org.elasticsearch.search.retriever.RetrieverParserContext;
import org.elasticsearch.search.retriever.StandardRetrieverBuilder;
import org.elasticsearch.xcontent.ConstructingObjectParser;
import org.elasticsearch.xcontent.ParseField;
import org.elasticsearch.xcontent.ToXContent;
import org.elasticsearch.xcontent.XContentBuilder;
import org.elasticsearch.xcontent.XContentParser;
import org.elasticsearch.xpack.core.XPackPlugin;
import org.elasticsearch.xpack.rank.MultiFieldsInnerRetrieverUtils;
import org.elasticsearch.xpack.rank.rrf.RRFRankDoc;
import org.elasticsearch.xpack.rank.rrf.RRFRankPlugin;

public final class RRFRetrieverBuilder
extends CompoundRetrieverBuilder<RRFRetrieverBuilder> {
    public static final NodeFeature MULTI_FIELDS_QUERY_FORMAT_SUPPORT = new NodeFeature("rrf_retriever.multi_fields_query_format_support");
    public static final String NAME = "rrf";
    public static final NodeFeature RRF_RETRIEVER_SUPPORTED = new NodeFeature("rrf_retriever_supported", true);
    public static final NodeFeature RRF_RETRIEVER_COMPOSITION_SUPPORTED = new NodeFeature("rrf_retriever_composition_supported", true);
    public static final ParseField RETRIEVERS_FIELD = new ParseField("retrievers", new String[0]);
    public static final ParseField RANK_CONSTANT_FIELD = new ParseField("rank_constant", new String[0]);
    public static final ParseField FIELDS_FIELD = new ParseField("fields", new String[0]);
    public static final ParseField QUERY_FIELD = new ParseField("query", new String[0]);
    public static final int DEFAULT_RANK_CONSTANT = 60;
    static final ConstructingObjectParser<RRFRetrieverBuilder, RetrieverParserContext> PARSER = new ConstructingObjectParser("rrf", false, args -> {
        List childRetrievers = (List)args[0];
        List fields = (List)args[1];
        String query = (String)args[2];
        int rankWindowSize = args[3] == null ? 10 : (Integer)args[3];
        int rankConstant = args[4] == null ? 60 : (Integer)args[4];
        List<CompoundRetrieverBuilder.RetrieverSource> innerRetrievers = childRetrievers != null ? childRetrievers.stream().map(CompoundRetrieverBuilder.RetrieverSource::from).toList() : List.of();
        return new RRFRetrieverBuilder(innerRetrievers, fields, query, rankWindowSize, rankConstant);
    });
    private final List<String> fields;
    private final String query;
    private final int rankConstant;

    public static RRFRetrieverBuilder fromXContent(XContentParser parser, RetrieverParserContext context) throws IOException {
        if (!context.clusterSupportsFeature(RRF_RETRIEVER_SUPPORTED)) {
            throw new ParsingException(parser.getTokenLocation(), "unknown retriever [rrf]", new Object[0]);
        }
        if (!context.clusterSupportsFeature(RRF_RETRIEVER_COMPOSITION_SUPPORTED)) {
            throw new IllegalArgumentException("[rrf] retriever composition feature is not supported by all nodes in the cluster");
        }
        if (!RRFRankPlugin.RANK_RRF_FEATURE.check(XPackPlugin.getSharedLicenseState())) {
            throw LicenseUtils.newComplianceException((String)"Reciprocal Rank Fusion (RRF)");
        }
        return (RRFRetrieverBuilder)((Object)PARSER.apply(parser, (Object)context));
    }

    public RRFRetrieverBuilder(List<CompoundRetrieverBuilder.RetrieverSource> childRetrievers, int rankWindowSize, int rankConstant) {
        this(childRetrievers, null, null, rankWindowSize, rankConstant);
    }

    public RRFRetrieverBuilder(List<CompoundRetrieverBuilder.RetrieverSource> childRetrievers, List<String> fields, String query, int rankWindowSize, int rankConstant) {
        super(childRetrievers == null ? new ArrayList() : new ArrayList<CompoundRetrieverBuilder.RetrieverSource>(childRetrievers), rankWindowSize);
        this.fields = fields == null ? null : List.copyOf(fields);
        this.query = query;
        this.rankConstant = rankConstant;
    }

    public int rankConstant() {
        return this.rankConstant;
    }

    public String getName() {
        return NAME;
    }

    public ActionRequestValidationException validate(SearchSourceBuilder source, ActionRequestValidationException validationException, boolean isScroll, boolean allowPartialSearchResults) {
        validationException = super.validate(source, validationException, isScroll, allowPartialSearchResults);
        return MultiFieldsInnerRetrieverUtils.validateParams(this.innerRetrievers, this.fields, this.query, this.getName(), RETRIEVERS_FIELD.getPreferredName(), FIELDS_FIELD.getPreferredName(), QUERY_FIELD.getPreferredName(), validationException);
    }

    protected RRFRetrieverBuilder clone(List<CompoundRetrieverBuilder.RetrieverSource> newRetrievers, List<QueryBuilder> newPreFilterQueryBuilders) {
        RRFRetrieverBuilder clone = new RRFRetrieverBuilder(newRetrievers, this.fields, this.query, this.rankWindowSize, this.rankConstant);
        clone.preFilterQueryBuilders = newPreFilterQueryBuilders;
        clone.retrieverName = this.retrieverName;
        return clone;
    }

    protected RRFRankDoc[] combineInnerRetrieverResults(List<ScoreDoc[]> rankResults, boolean explain) {
        int rank;
        int queries = rankResults.size();
        Map docsToRankResults = Maps.newMapWithExpectedSize((int)this.rankWindowSize);
        int index = 0;
        for (ScoreDoc[] rrfRankResult : rankResults) {
            rank = 1;
            for (ScoreDoc scoreDoc : rrfRankResult) {
                int findex = index;
                int frank = rank++;
                docsToRankResults.compute(new RankDoc.RankKey(scoreDoc.doc, scoreDoc.shardIndex), (key, value) -> {
                    if (value == null) {
                        value = explain ? new RRFRankDoc(scoreDoc.doc, scoreDoc.shardIndex, queries, this.rankConstant) : new RRFRankDoc(scoreDoc.doc, scoreDoc.shardIndex);
                    }
                    value.score += 1.0f / (float)(this.rankConstant + frank);
                    if (explain && value.positions != null && value.scores != null) {
                        value.positions[findex] = frank - 1;
                        value.scores[findex] = scoreDoc.score;
                    }
                    return value;
                });
            }
            ++index;
        }
        RRFRankDoc[] sortedResults = (RRFRankDoc[])docsToRankResults.values().toArray(RRFRankDoc[]::new);
        Arrays.sort((Object[])sortedResults);
        RRFRankDoc[] topResults = new RRFRankDoc[Math.min(this.rankWindowSize, sortedResults.length)];
        for (rank = 0; rank < topResults.length; ++rank) {
            topResults[rank] = sortedResults[rank];
            topResults[rank].rank = rank + 1;
        }
        return topResults;
    }

    protected RetrieverBuilder doRewrite(QueryRewriteContext ctx) {
        RRFRetrieverBuilder rewritten = this;
        ResolvedIndices resolvedIndices = ctx.getResolvedIndices();
        if (resolvedIndices != null && this.query != null) {
            Map localIndicesMetadata = resolvedIndices.getConcreteLocalIndicesMetadata();
            if (localIndicesMetadata.size() > 1) {
                throw new IllegalArgumentException("[rrf] cannot specify [" + QUERY_FIELD.getPreferredName() + "] when querying multiple indices");
            }
            if (!resolvedIndices.getRemoteClusterIndices().isEmpty()) {
                throw new IllegalArgumentException("[rrf] cannot specify [" + QUERY_FIELD.getPreferredName() + "] when querying remote indices");
            }
            List<CompoundRetrieverBuilder.RetrieverSource> fieldsInnerRetrievers = MultiFieldsInnerRetrieverUtils.generateInnerRetrievers(this.fields, this.query, localIndicesMetadata.values(), r -> {
                List<CompoundRetrieverBuilder.RetrieverSource> retrievers = r.stream().map(MultiFieldsInnerRetrieverUtils.WeightedRetrieverSource::retrieverSource).toList();
                return new RRFRetrieverBuilder(retrievers, this.rankWindowSize, this.rankConstant);
            }, w -> {
                if (w.floatValue() != 1.0f) {
                    throw new IllegalArgumentException("[rrf] does not support per-field weights in [" + FIELDS_FIELD.getPreferredName() + "]");
                }
            }).stream().map(CompoundRetrieverBuilder.RetrieverSource::from).toList();
            if (!fieldsInnerRetrievers.isEmpty()) {
                rewritten = new RRFRetrieverBuilder(fieldsInnerRetrievers, this.rankWindowSize, this.rankConstant);
                rewritten.getPreFilterQueryBuilders().addAll(this.preFilterQueryBuilders);
            } else {
                rewritten = new StandardRetrieverBuilder((QueryBuilder)new MatchNoneQueryBuilder());
            }
        }
        return rewritten;
    }

    public boolean doEquals(Object o) {
        RRFRetrieverBuilder that = (RRFRetrieverBuilder)((Object)o);
        return super.doEquals(o) && Objects.equals(this.fields, that.fields) && Objects.equals(this.query, that.query) && this.rankConstant == that.rankConstant;
    }

    public int doHashCode() {
        return Objects.hash(super.doHashCode(), this.fields, this.query, this.rankConstant);
    }

    public void doToXContent(XContentBuilder builder, ToXContent.Params params) throws IOException {
        if (!this.innerRetrievers.isEmpty()) {
            builder.startArray(RETRIEVERS_FIELD.getPreferredName());
            for (CompoundRetrieverBuilder.RetrieverSource entry : this.innerRetrievers) {
                entry.retriever().toXContent(builder, params);
            }
            builder.endArray();
        }
        if (this.fields != null) {
            builder.startArray(FIELDS_FIELD.getPreferredName());
            for (String field : this.fields) {
                builder.value(field);
            }
            builder.endArray();
        }
        if (this.query != null) {
            builder.field(QUERY_FIELD.getPreferredName(), this.query);
        }
        builder.field(RANK_WINDOW_SIZE_FIELD.getPreferredName(), this.rankWindowSize);
        builder.field(RANK_CONSTANT_FIELD.getPreferredName(), this.rankConstant);
    }

    static {
        PARSER.declareObjectArray(ConstructingObjectParser.optionalConstructorArg(), (p, c) -> {
            p.nextToken();
            String name = p.currentName();
            RetrieverBuilder retrieverBuilder = (RetrieverBuilder)p.namedObject(RetrieverBuilder.class, name, c);
            c.trackRetrieverUsage(retrieverBuilder.getName());
            p.nextToken();
            return retrieverBuilder;
        }, RETRIEVERS_FIELD);
        PARSER.declareStringArray(ConstructingObjectParser.optionalConstructorArg(), FIELDS_FIELD);
        PARSER.declareString(ConstructingObjectParser.optionalConstructorArg(), QUERY_FIELD);
        PARSER.declareInt(ConstructingObjectParser.optionalConstructorArg(), RANK_WINDOW_SIZE_FIELD);
        PARSER.declareInt(ConstructingObjectParser.optionalConstructorArg(), RANK_CONSTANT_FIELD);
        RetrieverBuilder.declareBaseParserFields((String)NAME, PARSER);
    }
}

