/*
 * Decompiled with CFR 0.152.
 */
package org.elasticsearch.search.vectors;

import com.carrotsearch.hppc.IntHashSet;
import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import java.util.Objects;
import java.util.concurrent.Callable;
import java.util.concurrent.atomic.LongAccumulator;
import org.apache.lucene.index.FloatVectorValues;
import org.apache.lucene.index.IndexReader;
import org.apache.lucene.index.LeafReader;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.search.AcceptDocs;
import org.apache.lucene.search.BooleanClause;
import org.apache.lucene.search.BooleanQuery;
import org.apache.lucene.search.FieldExistsQuery;
import org.apache.lucene.search.IndexSearcher;
import org.apache.lucene.search.MatchNoDocsQuery;
import org.apache.lucene.search.Query;
import org.apache.lucene.search.QueryVisitor;
import org.apache.lucene.search.ScoreDoc;
import org.apache.lucene.search.ScoreMode;
import org.apache.lucene.search.Scorer;
import org.apache.lucene.search.TaskExecutor;
import org.apache.lucene.search.TopDocs;
import org.apache.lucene.search.TopDocsCollector;
import org.apache.lucene.search.Weight;
import org.apache.lucene.search.knn.KnnCollectorManager;
import org.apache.lucene.search.knn.KnnSearchStrategy;
import org.apache.lucene.util.Bits;
import org.elasticsearch.search.profile.query.QueryProfiler;
import org.elasticsearch.search.vectors.AbstractMaxScoreKnnCollector;
import org.elasticsearch.search.vectors.IVFKnnFloatVectorQuery;
import org.elasticsearch.search.vectors.KnnScoreDocQuery;
import org.elasticsearch.search.vectors.MaxScoreTopKnnCollector;
import org.elasticsearch.search.vectors.QueryProfilerProvider;

abstract class AbstractIVFKnnVectorQuery
extends Query
implements QueryProfilerProvider {
    static final TopDocs NO_RESULTS = TopDocsCollector.EMPTY_TOPDOCS;
    protected final String field;
    protected final float providedVisitRatio;
    protected final int k;
    protected final int numCands;
    protected final Query filter;
    protected int vectorOpsCount;

    protected AbstractIVFKnnVectorQuery(String field, float visitRatio, int k, int numCands, Query filter) {
        if (k < 1) {
            throw new IllegalArgumentException("k must be at least 1, got: " + k);
        }
        if (visitRatio < 0.0f || visitRatio > 1.0f) {
            throw new IllegalArgumentException("visitRatio must be between 0.0 and 1.0 (both inclusive), got: " + visitRatio);
        }
        if (numCands < k) {
            throw new IllegalArgumentException("numCands must be at least k, got: " + numCands);
        }
        this.field = field;
        this.providedVisitRatio = visitRatio;
        this.k = k;
        this.filter = filter;
        this.numCands = numCands;
    }

    public void visit(QueryVisitor visitor) {
        if (visitor.acceptField(this.field)) {
            visitor.visitLeaf((Query)this);
        }
    }

    public boolean equals(Object o) {
        if (this == o) {
            return true;
        }
        if (o == null || this.getClass() != o.getClass()) {
            return false;
        }
        AbstractIVFKnnVectorQuery that = (AbstractIVFKnnVectorQuery)o;
        return this.k == that.k && Objects.equals(this.field, that.field) && Objects.equals(this.filter, that.filter) && Objects.equals(Float.valueOf(this.providedVisitRatio), Float.valueOf(that.providedVisitRatio));
    }

    public int hashCode() {
        return Objects.hash(this.field, this.k, this.filter, Float.valueOf(this.providedVisitRatio));
    }

    public Query rewrite(IndexSearcher indexSearcher) throws IOException {
        float visitRatio;
        Weight filterWeight;
        this.vectorOpsCount = 0;
        IndexReader reader = indexSearcher.getIndexReader();
        if (this.filter != null) {
            BooleanQuery booleanQuery = new BooleanQuery.Builder().add(this.filter, BooleanClause.Occur.FILTER).add((Query)new FieldExistsQuery(this.field), BooleanClause.Occur.FILTER).build();
            Query rewritten = indexSearcher.rewrite((Query)booleanQuery);
            if (rewritten.getClass() == MatchNoDocsQuery.class) {
                return rewritten;
            }
            filterWeight = indexSearcher.createWeight(rewritten, ScoreMode.COMPLETE_NO_SCORES, 1.0f);
        } else {
            filterWeight = null;
        }
        IVFCollectorManager knnCollectorManager = this.getKnnCollectorManager(Math.round(2.0f * (float)this.k), indexSearcher);
        TaskExecutor taskExecutor = indexSearcher.getTaskExecutor();
        List leafReaderContexts = reader.leaves();
        assert (this instanceof IVFKnnFloatVectorQuery);
        int totalVectors = 0;
        for (LeafReaderContext leafReaderContext : leafReaderContexts) {
            LeafReader leafReader = leafReaderContext.reader();
            FloatVectorValues floatVectorValues = leafReader.getFloatVectorValues(this.field);
            if (floatVectorValues == null) continue;
            totalVectors += floatVectorValues.size();
        }
        if (this.providedVisitRatio == 0.0f) {
            float expected = Math.round(Math.log10(totalVectors) * Math.log10(totalVectors) * (double)Math.min(10000, Math.max(this.numCands, 5 * this.k)));
            visitRatio = expected / (float)totalVectors;
        } else {
            visitRatio = this.providedVisitRatio;
        }
        ArrayList<Callable<TopDocs>> tasks = new ArrayList<Callable<TopDocs>>(leafReaderContexts.size());
        for (LeafReaderContext context : leafReaderContexts) {
            tasks.add(() -> this.searchLeaf(context, filterWeight, knnCollectorManager, visitRatio));
        }
        TopDocs[] perLeafResults = (TopDocs[])taskExecutor.invokeAll(tasks).toArray(TopDocs[]::new);
        TopDocs topK = TopDocs.merge((int)this.k, (TopDocs[])perLeafResults);
        this.vectorOpsCount = (int)topK.totalHits.value();
        if (topK.scoreDocs.length == 0) {
            return new MatchNoDocsQuery();
        }
        return new KnnScoreDocQuery(topK.scoreDocs, reader);
    }

    private TopDocs searchLeaf(LeafReaderContext ctx, Weight filterWeight, IVFCollectorManager knnCollectorManager, float visitRatio) throws IOException {
        TopDocs results = this.getLeafResults(ctx, filterWeight, knnCollectorManager, visitRatio);
        IntHashSet dedup = new IntHashSet(results.scoreDocs.length * 4 / 3);
        int deduplicateCount = 0;
        for (ScoreDoc scoreDoc : results.scoreDocs) {
            if (!dedup.add(scoreDoc.doc)) continue;
            ++deduplicateCount;
        }
        ScoreDoc[] deduplicatedScoreDocs = new ScoreDoc[deduplicateCount];
        dedup.clear();
        int index = 0;
        for (ScoreDoc scoreDoc : results.scoreDocs) {
            if (!dedup.add(scoreDoc.doc)) continue;
            scoreDoc.doc += ctx.docBase;
            deduplicatedScoreDocs[index++] = scoreDoc;
        }
        return new TopDocs(results.totalHits, deduplicatedScoreDocs);
    }

    TopDocs getLeafResults(LeafReaderContext ctx, Weight filterWeight, IVFCollectorManager knnCollectorManager, float visitRatio) throws IOException {
        LeafReader reader = ctx.reader();
        if (filterWeight == null) {
            AcceptDocs acceptDocs = AcceptDocs.fromLiveDocs((Bits)reader.getLiveDocs(), (int)reader.maxDoc());
            return this.approximateSearch(ctx, acceptDocs, Integer.MAX_VALUE, knnCollectorManager, visitRatio);
        }
        Scorer scorer = filterWeight.scorer(ctx);
        if (scorer == null) {
            return TopDocsCollector.EMPTY_TOPDOCS;
        }
        AcceptDocs acceptDocs = AcceptDocs.fromIteratorSupplier(() -> ((Scorer)scorer).iterator(), (Bits)reader.getLiveDocs(), (int)reader.maxDoc());
        int cost = acceptDocs.cost();
        return this.approximateSearch(ctx, acceptDocs, cost + 1, knnCollectorManager, visitRatio);
    }

    abstract TopDocs approximateSearch(LeafReaderContext var1, AcceptDocs var2, int var3, IVFCollectorManager var4, float var5) throws IOException;

    protected IVFCollectorManager getKnnCollectorManager(int k, IndexSearcher searcher) {
        return new IVFCollectorManager(k, searcher);
    }

    @Override
    public final void profile(QueryProfiler queryProfiler) {
        queryProfiler.addVectorOpsCount(this.vectorOpsCount);
    }

    static class IVFCollectorManager
    implements KnnCollectorManager {
        private final int k;
        final LongAccumulator longAccumulator;

        IVFCollectorManager(int k, IndexSearcher searcher) {
            this.k = k;
            this.longAccumulator = searcher.getIndexReader().leaves().size() > 1 ? new LongAccumulator(Long::max, AbstractMaxScoreKnnCollector.LEAST_COMPETITIVE) : null;
        }

        public AbstractMaxScoreKnnCollector newCollector(int visitedLimit, KnnSearchStrategy searchStrategy, LeafReaderContext context) throws IOException {
            return new MaxScoreTopKnnCollector(this.k, visitedLimit, searchStrategy);
        }
    }
}

