/*
 * Decompiled with CFR 0.152.
 */
package org.elasticsearch.index.codec.vectors.cluster;

import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Random;
import org.apache.lucene.index.FloatVectorValues;
import org.apache.lucene.util.VectorUtil;
import org.elasticsearch.index.codec.vectors.cluster.KMeansIntermediate;
import org.elasticsearch.index.codec.vectors.cluster.NeighborQueue;
import org.elasticsearch.simdvec.ESVectorUtil;

class KMeansLocal {
    final int sampleSize;
    final int maxIterations;
    final int clustersPerNeighborhood;
    final float soarLambda;

    KMeansLocal(int sampleSize, int maxIterations, int clustersPerNeighborhood, float soarLambda) {
        this.sampleSize = sampleSize;
        this.maxIterations = maxIterations;
        this.clustersPerNeighborhood = clustersPerNeighborhood;
        this.soarLambda = soarLambda;
    }

    KMeansLocal(int sampleSize, int maxIterations) {
        this(sampleSize, maxIterations, -1, -1.0f);
    }

    static float[][] pickInitialCentroids(FloatVectorValues vectors, int centroidCount) throws IOException {
        Random random = new Random(42L);
        int centroidsSize = Math.min(vectors.size(), centroidCount);
        float[][] centroids = new float[centroidsSize][vectors.dimension()];
        for (int i = 0; i < vectors.size(); ++i) {
            float[] vector;
            if (i < centroidCount) {
                vector = vectors.vectorValue(i);
                System.arraycopy(vector, 0, centroids[i], 0, vector.length);
                continue;
            }
            if (!(random.nextDouble() < (double)centroidCount * (1.0 / (double)i))) continue;
            int c = random.nextInt(centroidCount);
            vector = vectors.vectorValue(i);
            System.arraycopy(vector, 0, centroids[c], 0, vector.length);
        }
        return centroids;
    }

    private boolean stepLloyd(FloatVectorValues vectors, float[][] centroids, float[][] nextCentroids, int[] assignments, int sampleSize, List<int[]> neighborhoods) throws IOException {
        boolean changed = false;
        int dim = vectors.dimension();
        int[] centroidCounts = new int[centroids.length];
        for (float[] nextCentroid : nextCentroids) {
            Arrays.fill(nextCentroid, 0.0f);
        }
        for (int i = 0; i < sampleSize; ++i) {
            int bestCentroidOffset;
            float[] vector = vectors.vectorValue(i);
            int[] neighborOffsets = null;
            int centroidIdx = -1;
            if (neighborhoods != null) {
                neighborOffsets = neighborhoods.get(assignments[i]);
                centroidIdx = assignments[i];
            }
            if (assignments[i] != (bestCentroidOffset = this.getBestCentroidOffset(centroids, vector, centroidIdx, neighborOffsets))) {
                changed = true;
            }
            assignments[i] = bestCentroidOffset;
            int n = bestCentroidOffset;
            centroidCounts[n] = centroidCounts[n] + 1;
            for (int d = 0; d < dim; ++d) {
                float[] fArray = nextCentroids[bestCentroidOffset];
                int n2 = d;
                fArray[n2] = fArray[n2] + vector[d];
            }
        }
        for (int clusterIdx = 0; clusterIdx < centroids.length; ++clusterIdx) {
            if (centroidCounts[clusterIdx] <= 0) continue;
            float countF = centroidCounts[clusterIdx];
            for (int d = 0; d < dim; ++d) {
                centroids[clusterIdx][d] = nextCentroids[clusterIdx][d] / countF;
            }
        }
        return changed;
    }

    int getBestCentroidOffset(float[][] centroids, float[] vector, int centroidIdx, int[] centroidOffsets) {
        int bestCentroidOffset = centroidIdx;
        float minDsq = centroidIdx > 0 && centroidIdx < centroids.length ? VectorUtil.squareDistance((float[])vector, (float[])centroids[centroidIdx]) : Float.MAX_VALUE;
        int k = 0;
        for (int j = 0; j < centroids.length; ++j) {
            float dsq;
            if (centroidOffsets != null && j != centroidOffsets[k] || !((dsq = VectorUtil.squareDistance((float[])vector, (float[])centroids[j])) < minDsq)) continue;
            minDsq = dsq;
            bestCentroidOffset = j;
        }
        return bestCentroidOffset;
    }

    private void computeNeighborhoods(float[][] centers, List<int[]> neighborhoods, int clustersPerNeighborhood) {
        int i;
        int k = neighborhoods.size();
        if (k == 0 || clustersPerNeighborhood <= 0) {
            return;
        }
        ArrayList<NeighborQueue> neighborQueues = new ArrayList<NeighborQueue>(k);
        for (i = 0; i < k; ++i) {
            neighborQueues.add(new NeighborQueue(clustersPerNeighborhood, true));
        }
        for (i = 0; i < k - 1; ++i) {
            for (int j = i + 1; j < k; ++j) {
                float dsq = VectorUtil.squareDistance((float[])centers[i], (float[])centers[j]);
                ((NeighborQueue)neighborQueues.get(j)).insertWithOverflow(i, dsq);
                ((NeighborQueue)neighborQueues.get(i)).insertWithOverflow(j, dsq);
            }
        }
        for (i = 0; i < k; ++i) {
            NeighborQueue queue = (NeighborQueue)neighborQueues.get(i);
            int neighborCount = queue.size();
            int[] neighbors = new int[neighborCount];
            queue.consumeNodes(neighbors);
            neighborhoods.set(i, neighbors);
        }
    }

    private int[] assignSpilled(FloatVectorValues vectors, List<int[]> neighborhoods, float[][] centroids, int[] assignments) throws IOException {
        int[] spilledAssignments = new int[assignments.length];
        float[] diffs = new float[vectors.dimension()];
        for (int i = 0; i < vectors.size(); ++i) {
            float[] vector = vectors.vectorValue(i);
            int currAssignment = assignments[i];
            float[] currentCentroid = centroids[currAssignment];
            for (int j = 0; j < vectors.dimension(); ++j) {
                float diff;
                diffs[j] = diff = vector[j] - currentCentroid[j];
            }
            float vectorCentroidDist = VectorUtil.squareDistance((float[])vector, (float[])currentCentroid);
            int bestAssignment = -1;
            float minSoar = Float.MAX_VALUE;
            assert (neighborhoods.get(currAssignment) != null);
            for (int neighbor : neighborhoods.get(currAssignment)) {
                float[] neighborCentroid;
                float soar;
                if (neighbor == currAssignment || !((soar = ESVectorUtil.soarDistance((float[])vector, (float[])(neighborCentroid = centroids[neighbor]), (float[])diffs, (float)this.soarLambda, (float)vectorCentroidDist)) < minSoar)) continue;
                bestAssignment = neighbor;
                minSoar = soar;
            }
            spilledAssignments[i] = bestAssignment;
        }
        return spilledAssignments;
    }

    void cluster(FloatVectorValues vectors, KMeansIntermediate kMeansIntermediate) throws IOException {
        this.cluster(vectors, kMeansIntermediate, false);
    }

    void cluster(FloatVectorValues vectors, KMeansIntermediate kMeansIntermediate, boolean neighborAware) throws IOException {
        float[][] centroids = kMeansIntermediate.centroids();
        ArrayList<int[]> neighborhoods = null;
        if (neighborAware) {
            int k = centroids.length;
            neighborhoods = new ArrayList<int[]>(k);
            for (int i = 0; i < k; ++i) {
                neighborhoods.add(null);
            }
            this.computeNeighborhoods(centroids, neighborhoods, this.clustersPerNeighborhood);
        }
        this.cluster(vectors, kMeansIntermediate, neighborhoods);
        if (neighborAware && this.clustersPerNeighborhood > 0) {
            int[] assignments = kMeansIntermediate.assignments();
            assert (assignments != null);
            assert (assignments.length == vectors.size());
            kMeansIntermediate.setSoarAssignments(this.assignSpilled(vectors, neighborhoods, centroids, assignments));
        }
    }

    void cluster(FloatVectorValues vectors, KMeansIntermediate kMeansIntermediate, List<int[]> neighborhoods) throws IOException {
        float[][] centroids = kMeansIntermediate.centroids();
        int k = centroids.length;
        int n = vectors.size();
        if (k == 1 || k >= n) {
            return;
        }
        int[] assignments = new int[n];
        float[][] nextCentroids = new float[centroids.length][vectors.dimension()];
        for (int i = 0; i < this.maxIterations && this.stepLloyd(vectors, centroids, nextCentroids, assignments, this.sampleSize, neighborhoods); ++i) {
        }
        this.stepLloyd(vectors, centroids, nextCentroids, assignments, vectors.size(), neighborhoods);
    }

    public static void cluster(FloatVectorValues vectors, float[][] centroids, int sampleSize, int maxIterations) throws IOException {
        KMeansIntermediate kMeansIntermediate = new KMeansIntermediate(centroids);
        KMeansLocal kMeans = new KMeansLocal(sampleSize, maxIterations);
        kMeans.cluster(vectors, kMeansIntermediate);
    }
}

