/*
 * Decompiled with CFR 0.152.
 */
package org.elasticsearch.index.codec.vectors.cluster;

import java.io.IOException;
import java.util.ArrayList;
import org.apache.lucene.search.KnnCollector;
import org.apache.lucene.search.ScoreDoc;
import org.apache.lucene.search.TaskExecutor;
import org.apache.lucene.search.TopDocs;
import org.apache.lucene.search.knn.KnnSearchStrategy;
import org.apache.lucene.util.Bits;
import org.apache.lucene.util.VectorUtil;
import org.apache.lucene.util.hnsw.HnswConcurrentMergeBuilder;
import org.apache.lucene.util.hnsw.HnswGraphBuilder;
import org.apache.lucene.util.hnsw.HnswGraphSearcher;
import org.apache.lucene.util.hnsw.OnHeapHnswGraph;
import org.apache.lucene.util.hnsw.RandomVectorScorer;
import org.apache.lucene.util.hnsw.RandomVectorScorerSupplier;
import org.apache.lucene.util.hnsw.UpdateableRandomVectorScorer;
import org.elasticsearch.index.codec.vectors.cluster.NeighborQueue;
import org.elasticsearch.simdvec.ESVectorUtil;

public record NeighborHood(int[] neighbors, float maxIntraDistance) {
    private static final int M = 8;
    private static final int EF_CONSTRUCTION = 150;
    static final NeighborHood EMPTY = new NeighborHood(new int[0], Float.POSITIVE_INFINITY);

    public static NeighborHood[] computeNeighborhoods(float[][] centers, int clustersPerNeighborhood) throws IOException {
        assert (centers.length > clustersPerNeighborhood);
        return NeighborHood.computeNeighborhoods(null, 1, centers, clustersPerNeighborhood);
    }

    public static NeighborHood[] computeNeighborhoods(TaskExecutor executor, int numWorkers, float[][] centers, int clustersPerNeighborhood) throws IOException {
        assert (centers.length > clustersPerNeighborhood);
        if (centers.length < 10000) {
            return NeighborHood.computeNeighborhoodsBruteForce(centers, clustersPerNeighborhood);
        }
        if (executor == null || numWorkers < 2) {
            return NeighborHood.computeNeighborhoodsGraph(centers, clustersPerNeighborhood);
        }
        return NeighborHood.computeNeighborhoodsGraph(executor, numWorkers, centers, clustersPerNeighborhood);
    }

    public static NeighborHood[] computeNeighborhoodsBruteForce(float[][] centers, int clustersPerNeighborhood) {
        int k = centers.length;
        NeighborQueue[] neighborQueues = new NeighborQueue[k];
        for (int i = 0; i < k; ++i) {
            neighborQueues[i] = new NeighborQueue(clustersPerNeighborhood, true);
        }
        float[] scores = new float[4];
        int limit = k - 3;
        for (int i = 0; i < k - 1; ++i) {
            int j;
            float[] center = centers[i];
            for (j = i + 1; j < limit; j += 4) {
                ESVectorUtil.squareDistanceBulk((float[])center, (float[])centers[j], (float[])centers[j + 1], (float[])centers[j + 2], (float[])centers[j + 3], (float[])scores);
                for (int h = 0; h < 4; ++h) {
                    neighborQueues[j + h].insertWithOverflow(i, scores[h]);
                    neighborQueues[i].insertWithOverflow(j + h, scores[h]);
                }
            }
            while (j < k) {
                float dsq = VectorUtil.squareDistance(center, centers[j]);
                neighborQueues[j].insertWithOverflow(i, dsq);
                neighborQueues[i].insertWithOverflow(j, dsq);
                ++j;
            }
        }
        NeighborHood[] neighborhoods = new NeighborHood[k];
        for (int i = 0; i < k; ++i) {
            NeighborQueue queue = neighborQueues[i];
            if (queue.size() == 0) {
                neighborhoods[i] = EMPTY;
                continue;
            }
            int[] neighbors = new int[queue.size()];
            float maxIntraDistance = queue.topScore();
            int iter = 0;
            while (queue.size() > 0) {
                neighbors[neighbors.length - ++iter] = queue.pop();
            }
            neighborhoods[i] = new NeighborHood(neighbors, maxIntraDistance);
        }
        return neighborhoods;
    }

    public static NeighborHood[] computeNeighborhoodsGraph(float[][] centers, int clustersPerNeighborhood) throws IOException {
        CentersScorerSupplier supplier = new CentersScorerSupplier(centers);
        OnHeapHnswGraph graph = HnswGraphBuilder.create(supplier, 8, 150, 42L, centers.length).build(centers.length);
        NeighborHood[] neighborhoods = new NeighborHood[centers.length];
        NeighborHood.populateNeighboursFromGraph(graph, clustersPerNeighborhood, neighborhoods, supplier, 0, centers.length);
        return neighborhoods;
    }

    public static NeighborHood[] computeNeighborhoodsGraph(TaskExecutor executor, int numWorkers, float[][] centers, int clustersPerNeighborhood) throws IOException {
        CentersScorerSupplier supplier = new CentersScorerSupplier(centers);
        OnHeapHnswGraph initGraph = HnswGraphBuilder.create(supplier, 8, 150, 42L, centers.length).build(0);
        OnHeapHnswGraph graph = new HnswConcurrentMergeBuilder(executor, numWorkers, supplier, 8, 150, initGraph, null).build(centers.length);
        NeighborHood[] neighborhoods = new NeighborHood[centers.length];
        int len = centers.length / numWorkers;
        ArrayList runners = new ArrayList(numWorkers);
        for (int i = 0; i < numWorkers; ++i) {
            int start = i * len;
            int end = i == numWorkers - 1 ? centers.length : (i + 1) * len;
            runners.add(() -> {
                NeighborHood.populateNeighboursFromGraph(graph, clustersPerNeighborhood, neighborhoods, supplier.copy(), start, end);
                return null;
            });
        }
        executor.invokeAll(runners);
        return neighborhoods;
    }

    private static void populateNeighboursFromGraph(OnHeapHnswGraph graph, int clustersPerNeighborhood, NeighborHood[] neighborhoods, RandomVectorScorerSupplier supplier, int start, int end) throws IOException {
        ReusableBits bits = new ReusableBits(graph.size());
        for (int i = start; i < end; ++i) {
            supplier.scorer().setScoringOrdinal(i);
            bits.currentOrd = i;
            KnnCollector collector = HnswGraphSearcher.search((RandomVectorScorer)supplier.scorer(), 2 * clustersPerNeighborhood, graph, (Bits)bits, Integer.MAX_VALUE);
            ScoreDoc[] scoreDocs = collector.topDocs().scoreDocs;
            int len = Math.min(clustersPerNeighborhood, scoreDocs.length);
            if (len == 0) {
                neighborhoods[i] = EMPTY;
                continue;
            }
            float minScore = scoreDocs[len - 1].score;
            int[] neighbors = new int[len];
            for (int j = 0; j < len; ++j) {
                neighbors[j] = scoreDocs[j].doc;
            }
            neighborhoods[i] = new NeighborHood(neighbors, 1.0f / minScore - 1.0f);
        }
    }

    private record CentersScorerSupplier(float[][] centers, UpdateableRandomVectorScorer scorer) implements RandomVectorScorerSupplier
    {
        CentersScorerSupplier(final float[][] centers) {
            this(centers, new UpdateableRandomVectorScorer(){
                private int scoringOrdinal;
                private final float[] distances = new float[4];

                @Override
                public float score(int node) {
                    return VectorUtil.normalizeDistanceToUnitInterval(VectorUtil.squareDistance(centers[this.scoringOrdinal], centers[node]));
                }

                @Override
                public void bulkScore(int[] nodes, float[] scores, int numNodes) {
                    int i;
                    int limit = numNodes - 3;
                    for (i = 0; i < limit; i += 4) {
                        ESVectorUtil.squareDistanceBulk((float[])centers[this.scoringOrdinal], (float[])centers[nodes[i]], (float[])centers[nodes[i + 1]], (float[])centers[nodes[i + 2]], (float[])centers[nodes[i + 3]], (float[])this.distances);
                        for (int j = 0; j < 4; ++j) {
                            scores[i + j] = VectorUtil.normalizeDistanceToUnitInterval(this.distances[j]);
                        }
                    }
                    while (i < numNodes) {
                        scores[i] = this.score(nodes[i]);
                        ++i;
                    }
                }

                @Override
                public int maxOrd() {
                    return centers.length;
                }

                @Override
                public void setScoringOrdinal(int node) {
                    this.scoringOrdinal = node;
                }
            });
        }

        @Override
        public RandomVectorScorerSupplier copy() {
            return new CentersScorerSupplier(this.centers);
        }
    }

    private static class ReusableBits
    implements Bits {
        final int size;
        int currentOrd;

        ReusableBits(int size) {
            this.size = size;
        }

        @Override
        public boolean get(int index) {
            return index != this.currentOrd;
        }

        @Override
        public int length() {
            return this.size;
        }
    }

    private static class ReusableKnnCollector
    implements KnnCollector {
        private final NeighborQueue queue;
        private final int k;
        int visitedCount;
        int currenOrd;

        ReusableKnnCollector(int k) {
            this.k = k;
            this.queue = new NeighborQueue(k, false);
        }

        void reset(int ord) {
            this.queue.clear();
            this.visitedCount = 0;
            this.currenOrd = ord;
        }

        @Override
        public boolean earlyTerminated() {
            return false;
        }

        @Override
        public void incVisitedCount(int count) {
            this.visitedCount += count;
        }

        @Override
        public long visitedCount() {
            return this.visitedCount;
        }

        @Override
        public long visitLimit() {
            return Integer.MAX_VALUE;
        }

        @Override
        public int k() {
            return this.k;
        }

        @Override
        public boolean collect(int docId, float similarity) {
            if (this.currenOrd != docId) {
                return this.queue.insertWithOverflow(docId, similarity);
            }
            return false;
        }

        @Override
        public float minCompetitiveSimilarity() {
            return this.queue.size() >= this.k() ? this.queue.topScore() : Float.NEGATIVE_INFINITY;
        }

        @Override
        public TopDocs topDocs() {
            throw new UnsupportedOperationException();
        }

        @Override
        public KnnSearchStrategy getSearchStrategy() {
            return null;
        }
    }
}

