/*
 * Decompiled with CFR 0.152.
 */
package org.elasticsearch.index.codec.vectors.es818;

import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.util.VectorUtil;

class OptimizedScalarQuantizer {
    static final float[][] MINIMUM_MSE_GRID = new float[][]{{-0.798f, 0.798f}, {-1.493f, 1.493f}, {-2.051f, 2.051f}, {-2.514f, 2.514f}, {-2.916f, 2.916f}, {-3.278f, 3.278f}, {-3.611f, 3.611f}, {-3.922f, 3.922f}};
    private static final float DEFAULT_LAMBDA = 0.1f;
    private static final int DEFAULT_ITERS = 5;
    private final VectorSimilarityFunction similarityFunction;
    private final float lambda;
    private final int iters;

    OptimizedScalarQuantizer(VectorSimilarityFunction similarityFunction, float lambda, int iters) {
        this.similarityFunction = similarityFunction;
        this.lambda = lambda;
        this.iters = iters;
    }

    OptimizedScalarQuantizer(VectorSimilarityFunction similarityFunction) {
        this(similarityFunction, 0.1f, 5);
    }

    public QuantizationResult[] multiScalarQuantize(float[] vector, byte[][] destinations, byte[] bits, float[] centroid) {
        assert (this.similarityFunction != VectorSimilarityFunction.COSINE || VectorUtil.isUnitVector((float[])vector));
        assert (this.similarityFunction != VectorSimilarityFunction.COSINE || VectorUtil.isUnitVector((float[])centroid));
        assert (bits.length == destinations.length);
        float[] intervalScratch = new float[2];
        double vecMean = 0.0;
        double vecVar = 0.0;
        float norm2 = 0.0f;
        float centroidDot = 0.0f;
        float min = Float.MAX_VALUE;
        float max = -3.4028235E38f;
        for (int i = 0; i < vector.length; ++i) {
            if (this.similarityFunction != VectorSimilarityFunction.EUCLIDEAN) {
                centroidDot += vector[i] * centroid[i];
            }
            vector[i] = vector[i] - centroid[i];
            min = Math.min(min, vector[i]);
            max = Math.max(max, vector[i]);
            norm2 += vector[i] * vector[i];
            double delta = (double)vector[i] - vecMean;
            vecVar += delta * ((double)vector[i] - (vecMean += delta / (double)(i + 1)));
        }
        double vecStd = Math.sqrt(vecVar /= (double)vector.length);
        QuantizationResult[] results = new QuantizationResult[bits.length];
        for (int i = 0; i < bits.length; ++i) {
            assert (bits[i] > 0 && bits[i] <= 8);
            int points = 1 << bits[i];
            intervalScratch[0] = (float)OptimizedScalarQuantizer.clamp((double)MINIMUM_MSE_GRID[bits[i] - 1][0] * vecStd + vecMean, min, max);
            intervalScratch[1] = (float)OptimizedScalarQuantizer.clamp((double)MINIMUM_MSE_GRID[bits[i] - 1][1] * vecStd + vecMean, min, max);
            this.optimizeIntervals(intervalScratch, vector, norm2, points);
            float nSteps = (1 << bits[i]) - 1;
            float a = intervalScratch[0];
            float b = intervalScratch[1];
            float step = (b - a) / nSteps;
            int sumQuery = 0;
            for (int h = 0; h < vector.length; ++h) {
                float xi = (float)OptimizedScalarQuantizer.clamp(vector[h], a, b);
                int assignment = Math.round((xi - a) / step);
                sumQuery += assignment;
                destinations[i][h] = (byte)assignment;
            }
            results[i] = new QuantizationResult(intervalScratch[0], intervalScratch[1], this.similarityFunction == VectorSimilarityFunction.EUCLIDEAN ? norm2 : centroidDot, sumQuery);
        }
        return results;
    }

    public QuantizationResult scalarQuantize(float[] vector, byte[] destination, byte bits, float[] centroid) {
        assert (this.similarityFunction != VectorSimilarityFunction.COSINE || VectorUtil.isUnitVector((float[])vector));
        assert (this.similarityFunction != VectorSimilarityFunction.COSINE || VectorUtil.isUnitVector((float[])centroid));
        assert (vector.length <= destination.length);
        assert (bits > 0 && bits <= 8);
        float[] intervalScratch = new float[2];
        int points = 1 << bits;
        double vecMean = 0.0;
        double vecVar = 0.0;
        float norm2 = 0.0f;
        float centroidDot = 0.0f;
        float min = Float.MAX_VALUE;
        float max = -3.4028235E38f;
        for (int i = 0; i < vector.length; ++i) {
            if (this.similarityFunction != VectorSimilarityFunction.EUCLIDEAN) {
                centroidDot += vector[i] * centroid[i];
            }
            vector[i] = vector[i] - centroid[i];
            min = Math.min(min, vector[i]);
            max = Math.max(max, vector[i]);
            norm2 += vector[i] * vector[i];
            double delta = (double)vector[i] - vecMean;
            vecVar += delta * ((double)vector[i] - (vecMean += delta / (double)(i + 1)));
        }
        double vecStd = Math.sqrt(vecVar /= (double)vector.length);
        intervalScratch[0] = (float)OptimizedScalarQuantizer.clamp((double)MINIMUM_MSE_GRID[bits - 1][0] * vecStd + vecMean, min, max);
        intervalScratch[1] = (float)OptimizedScalarQuantizer.clamp((double)MINIMUM_MSE_GRID[bits - 1][1] * vecStd + vecMean, min, max);
        this.optimizeIntervals(intervalScratch, vector, norm2, points);
        float nSteps = (1 << bits) - 1;
        float a = intervalScratch[0];
        float b = intervalScratch[1];
        float step = (b - a) / nSteps;
        int sumQuery = 0;
        for (int h = 0; h < vector.length; ++h) {
            float xi = (float)OptimizedScalarQuantizer.clamp(vector[h], a, b);
            int assignment = Math.round((xi - a) / step);
            sumQuery += assignment;
            destination[h] = (byte)assignment;
        }
        return new QuantizationResult(intervalScratch[0], intervalScratch[1], this.similarityFunction == VectorSimilarityFunction.EUCLIDEAN ? norm2 : centroidDot, sumQuery);
    }

    private double loss(float[] vector, float[] interval, int points, float norm2) {
        double a = interval[0];
        double b = interval[1];
        double step = (b - a) / (double)((float)points - 1.0f);
        double stepInv = 1.0 / step;
        double xe = 0.0;
        double e = 0.0;
        float[] fArray = vector;
        int n = fArray.length;
        for (int i = 0; i < n; ++i) {
            double xi = fArray[i];
            double xiq = a + step * (double)Math.round((OptimizedScalarQuantizer.clamp(xi, a, b) - a) * stepInv);
            xe += xi * (xi - xiq);
            e += (xi - xiq) * (xi - xiq);
        }
        return (1.0 - (double)this.lambda) * xe * xe / (double)norm2 + (double)this.lambda * e;
    }

    private void optimizeIntervals(float[] initInterval, float[] vector, float norm2, int points) {
        double initialLoss = this.loss(vector, initInterval, points, norm2);
        float scale = (1.0f - this.lambda) / norm2;
        if (!Float.isFinite(scale)) {
            return;
        }
        for (int i = 0; i < this.iters; ++i) {
            float a = initInterval[0];
            float b = initInterval[1];
            float stepInv = ((float)points - 1.0f) / (b - a);
            double daa = 0.0;
            double dab = 0.0;
            double dbb = 0.0;
            double dax = 0.0;
            double dbx = 0.0;
            for (float xi : vector) {
                float k = Math.round((OptimizedScalarQuantizer.clamp(xi, a, b) - (double)a) * (double)stepInv);
                float s = k / (float)(points - 1);
                daa += (1.0 - (double)s) * (1.0 - (double)s);
                dab += (1.0 - (double)s) * (double)s;
                dbb += (double)(s * s);
                dax += (double)xi * (1.0 - (double)s);
                dbx += (double)(xi * s);
            }
            double m0 = (double)scale * dax * dax + (double)this.lambda * daa;
            double m2 = (double)scale * dbx * dbx + (double)this.lambda * dbb;
            double m1 = (double)scale * dax * dbx + (double)this.lambda * dab;
            double det = m0 * m2 - m1 * m1;
            if (det == 0.0) {
                return;
            }
            float aOpt = (float)((m2 * dax - m1 * dbx) / det);
            float bOpt = (float)((m0 * dbx - m1 * dax) / det);
            if ((double)Math.abs(initInterval[0] - aOpt) < 1.0E-8 && (double)Math.abs(initInterval[1] - bOpt) < 1.0E-8) {
                return;
            }
            float[] fArray = new float[]{aOpt, bOpt};
            double newLoss = this.loss(vector, fArray, points, norm2);
            if (newLoss > initialLoss) {
                return;
            }
            initInterval[0] = aOpt;
            initInterval[1] = bOpt;
            initialLoss = newLoss;
        }
    }

    private static double clamp(double x, double a, double b) {
        return Math.min(Math.max(x, a), b);
    }

    public record QuantizationResult(float lowerInterval, float upperInterval, float additionalCorrection, int quantizedComponentSum) {
    }
}

