/*
 * Decompiled with CFR 0.152.
 */
package org.elasticsearch.index.codec.vectors.cluster;

import java.io.IOException;
import java.util.Arrays;
import java.util.Objects;
import org.apache.lucene.index.FloatVectorValues;
import org.apache.lucene.search.TaskExecutor;
import org.elasticsearch.index.codec.vectors.cluster.FloatVectorValuesSlice;
import org.elasticsearch.index.codec.vectors.cluster.KMeansIntermediate;
import org.elasticsearch.index.codec.vectors.cluster.KMeansLocal;
import org.elasticsearch.index.codec.vectors.cluster.KMeansLocalConcurrent;
import org.elasticsearch.index.codec.vectors.cluster.KMeansLocalSerial;
import org.elasticsearch.index.codec.vectors.cluster.KMeansResult;

public class HierarchicalKMeans {
    public static final int MAXK = 128;
    public static final int MAX_ITERATIONS_DEFAULT = 6;
    public static final int SAMPLES_PER_CLUSTER_DEFAULT = 64;
    public static final float DEFAULT_SOAR_LAMBDA = 1.0f;
    public static final int NO_SOAR_ASSIGNMENT = -1;
    private static final int MIN_VECTORS_PRE_THREAD = 64;
    final int dimension;
    final int maxIterations;
    final int samplesPerCluster;
    final int clustersPerNeighborhood;
    final float soarLambda;
    private final TaskExecutor executor;
    private final int numWorkers;

    private HierarchicalKMeans(int dimension, TaskExecutor executor, int numWorkers, int maxIterations, int samplesPerCluster, int clustersPerNeighborhood, float soarLambda) {
        this.dimension = dimension;
        this.executor = executor;
        this.numWorkers = numWorkers;
        this.maxIterations = maxIterations;
        this.samplesPerCluster = samplesPerCluster;
        this.clustersPerNeighborhood = clustersPerNeighborhood;
        this.soarLambda = soarLambda;
    }

    public static HierarchicalKMeans ofSerial(int dimension) {
        return HierarchicalKMeans.ofSerial(dimension, 6, 64, 128, 1.0f);
    }

    public static HierarchicalKMeans ofSerial(int dimension, int maxIterations, int samplesPerCluster, int clustersPerNeighborhood, float soarLambda) {
        return new HierarchicalKMeans(dimension, null, 1, maxIterations, samplesPerCluster, clustersPerNeighborhood, soarLambda);
    }

    public static HierarchicalKMeans ofConcurrent(int dimension, TaskExecutor executor, int numWorkers) {
        return HierarchicalKMeans.ofConcurrent(dimension, executor, numWorkers, 6, 64, 128, 1.0f);
    }

    public static HierarchicalKMeans ofConcurrent(int dimension, TaskExecutor executor, int numWorkers, int maxIterations, int samplesPerCluster, int clustersPerNeighborhood, float soarLambda) {
        return new HierarchicalKMeans(dimension, executor, numWorkers, maxIterations, samplesPerCluster, clustersPerNeighborhood, soarLambda);
    }

    public KMeansResult cluster(FloatVectorValues vectors, int targetSize) throws IOException {
        if (vectors.size() == 0) {
            return new KMeansIntermediate();
        }
        if (vectors.size() <= targetSize) {
            float[] centroid = new float[this.dimension];
            for (int i = 0; i < vectors.size(); ++i) {
                float[] vector = vectors.vectorValue(i);
                for (int j = 0; j < this.dimension; ++j) {
                    int n = j;
                    centroid[n] = centroid[n] + vector[j];
                }
            }
            int j = 0;
            while (j < this.dimension) {
                int n = j++;
                centroid[n] = centroid[n] / (float)vectors.size();
            }
            return new KMeansIntermediate(new float[][]{centroid}, new int[vectors.size()]);
        }
        KMeansIntermediate kMeansIntermediate = this.clusterAndSplit(vectors, targetSize);
        if (kMeansIntermediate.centroids().length > 1 && kMeansIntermediate.centroids().length < vectors.size()) {
            int localSampleSize = Math.min(kMeansIntermediate.centroids().length * this.samplesPerCluster / 2, vectors.size());
            KMeansLocal kMeansLocal = this.buildKmeansLocal(vectors.size(), localSampleSize);
            kMeansLocal.cluster(vectors, kMeansIntermediate, this.clustersPerNeighborhood, this.soarLambda);
        }
        return kMeansIntermediate;
    }

    private KMeansIntermediate clusterAndSplit(FloatVectorValues vectors, int targetSize) throws IOException {
        if (vectors.size() <= targetSize) {
            return new KMeansIntermediate();
        }
        int k = Math.clamp((long)((int)(((float)vectors.size() + (float)targetSize / 2.0f) / (float)targetSize)), 2, 128);
        int m = Math.min(k * this.samplesPerCluster, vectors.size());
        int[] assignments = new int[vectors.size()];
        Arrays.fill(assignments, -1);
        float[][] centroids = KMeansLocal.pickInitialCentroids(vectors, k);
        KMeansIntermediate kMeansIntermediate = new KMeansIntermediate(centroids, assignments, vectors::ordToDoc);
        KMeansLocal kMeansLocal = this.buildKmeansLocal(vectors.size(), m);
        kMeansLocal.cluster(vectors, kMeansIntermediate);
        int[] centroidVectorCount = new int[centroids.length];
        int effectiveCluster = -1;
        int effectiveK = 0;
        int[] nArray = assignments;
        int n = nArray.length;
        for (int i = 0; i < n; ++i) {
            int assigment;
            int n2 = assigment = nArray[i];
            centroidVectorCount[n2] = centroidVectorCount[n2] + 1;
            if (centroidVectorCount[assigment] != 1) continue;
            ++effectiveK;
            effectiveCluster = assigment;
        }
        if (effectiveK == 1) {
            float[][] singleClusterCentroid = new float[][]{centroids[effectiveCluster]};
            kMeansIntermediate.setCentroids(singleClusterCentroid);
            Arrays.fill(kMeansIntermediate.assignments(), 0);
            return kMeansIntermediate;
        }
        int removedElements = 0;
        for (int c = 0; c < centroidVectorCount.length; ++c) {
            int count = centroidVectorCount[c];
            int adjustedCentroid = c - removedElements;
            if (100 * count > 134 * targetSize) {
                FloatVectorValues sample = HierarchicalKMeans.createClusterSlice(count, adjustedCentroid, vectors, assignments);
                this.updateAssignmentsWithRecursiveSplit(kMeansIntermediate, adjustedCentroid, this.clusterAndSplit(sample, targetSize));
                continue;
            }
            if (count != 0) continue;
            int newSize = kMeansIntermediate.centroids().length - 1;
            float[][] newCentroids = new float[newSize][];
            System.arraycopy(kMeansIntermediate.centroids(), 0, newCentroids, 0, adjustedCentroid);
            System.arraycopy(kMeansIntermediate.centroids(), adjustedCentroid + 1, newCentroids, adjustedCentroid, newSize - adjustedCentroid);
            for (int i = 0; i < kMeansIntermediate.assignments().length; ++i) {
                if (kMeansIntermediate.assignments()[i] <= adjustedCentroid) continue;
                int[] nArray2 = kMeansIntermediate.assignments();
                int n3 = i;
                nArray2[n3] = nArray2[n3] - 1;
            }
            kMeansIntermediate.setCentroids(newCentroids);
            ++removedElements;
        }
        return kMeansIntermediate;
    }

    private KMeansLocal buildKmeansLocal(int numVectors, int localSampleSize) {
        int numWorkers = Math.min(this.numWorkers, numVectors / 64);
        return this.executor == null || numWorkers <= 1 ? new KMeansLocalSerial(localSampleSize, this.maxIterations) : new KMeansLocalConcurrent(this.executor, numWorkers, localSampleSize, this.maxIterations);
    }

    static FloatVectorValues createClusterSlice(int clusterSize, int cluster, FloatVectorValues vectors, int[] assignments) {
        int[] slice = new int[clusterSize];
        int idx = 0;
        for (int i = 0; i < assignments.length; ++i) {
            if (assignments[i] != cluster) continue;
            slice[idx] = i;
            ++idx;
        }
        return new FloatVectorValuesSlice(vectors, slice);
    }

    void updateAssignmentsWithRecursiveSplit(KMeansIntermediate current, int cluster, KMeansIntermediate subPartitions) {
        if (subPartitions.centroids().length == 0) {
            return;
        }
        int orgCentroidsSize = current.centroids().length;
        int newCentroidsSize = current.centroids().length + subPartitions.centroids().length - 1;
        float[][] newCentroids = new float[newCentroidsSize][];
        System.arraycopy(current.centroids(), 0, newCentroids, 0, current.centroids().length);
        int origCentroidOrd = 0;
        newCentroids[cluster] = subPartitions.centroids()[0];
        System.arraycopy(subPartitions.centroids(), 1, newCentroids, current.centroids().length, subPartitions.centroids().length - 1);
        assert (Arrays.stream(newCentroids).allMatch(Objects::nonNull));
        current.setCentroids(newCentroids);
        for (int i = 0; i < subPartitions.assignments().length; ++i) {
            if (subPartitions.assignments()[i] == origCentroidOrd) continue;
            int parentOrd = subPartitions.ordToDoc(i);
            assert (current.assignments()[parentOrd] == cluster);
            current.assignments()[parentOrd] = subPartitions.assignments()[i] + orgCentroidsSize - 1;
        }
    }
}

