/*
 * Decompiled with CFR 0.152.
 */
package org.elasticsearch.index.codec.vectors.es816;

import java.io.IOException;
import org.apache.lucene.codecs.hnsw.FlatVectorsScorer;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.util.ArrayUtil;
import org.apache.lucene.util.VectorUtil;
import org.apache.lucene.util.hnsw.RandomAccessVectorValues;
import org.apache.lucene.util.hnsw.RandomVectorScorer;
import org.apache.lucene.util.hnsw.RandomVectorScorerSupplier;
import org.elasticsearch.index.codec.vectors.BQVectorUtils;
import org.elasticsearch.index.codec.vectors.es816.BinaryQuantizer;
import org.elasticsearch.index.codec.vectors.es816.RandomAccessBinarizedByteVectorValues;
import org.elasticsearch.simdvec.ESVectorUtil;

class ES816BinaryFlatVectorsScorer
implements FlatVectorsScorer {
    private final FlatVectorsScorer nonQuantizedDelegate;

    ES816BinaryFlatVectorsScorer(FlatVectorsScorer nonQuantizedDelegate) {
        this.nonQuantizedDelegate = nonQuantizedDelegate;
    }

    public RandomVectorScorerSupplier getRandomVectorScorerSupplier(VectorSimilarityFunction similarityFunction, RandomAccessVectorValues vectorValues) {
        throw new UnsupportedOperationException();
    }

    public RandomVectorScorer getRandomVectorScorer(VectorSimilarityFunction similarityFunction, RandomAccessVectorValues vectorValues, float[] target) throws IOException {
        if (vectorValues instanceof RandomAccessBinarizedByteVectorValues) {
            RandomAccessBinarizedByteVectorValues binarizedVectors = (RandomAccessBinarizedByteVectorValues)vectorValues;
            assert (binarizedVectors.getQuantizer() != null) : "BinarizedByteVectorValues must have a quantizer for ES816BinaryFlatVectorsScorer";
            assert (binarizedVectors.size() > 0) : "BinarizedByteVectorValues must have at least one vector for ES816BinaryFlatVectorsScorer";
            BinaryQuantizer quantizer = binarizedVectors.getQuantizer();
            float[] centroid = binarizedVectors.getCentroid();
            int discretizedDimensions = BQVectorUtils.discretize(target.length, 64);
            if (similarityFunction == VectorSimilarityFunction.COSINE) {
                float[] copy = ArrayUtil.copyOfSubArray((float[])target, (int)0, (int)target.length);
                VectorUtil.l2normalize((float[])copy);
                target = copy;
            }
            byte[] quantized = new byte[4 * discretizedDimensions / 8];
            BinaryQuantizer.QueryFactors factors = quantizer.quantizeForQuery(target, quantized, centroid);
            BinaryQueryVector queryVector = new BinaryQueryVector(quantized, factors);
            return new BinarizedRandomVectorScorer(queryVector, binarizedVectors, similarityFunction);
        }
        return this.nonQuantizedDelegate.getRandomVectorScorer(similarityFunction, vectorValues, target);
    }

    public RandomVectorScorer getRandomVectorScorer(VectorSimilarityFunction similarityFunction, RandomAccessVectorValues vectorValues, byte[] target) throws IOException {
        return this.nonQuantizedDelegate.getRandomVectorScorer(similarityFunction, vectorValues, target);
    }

    public String toString() {
        return "ES816BinaryFlatVectorsScorer(nonQuantizedDelegate=" + this.nonQuantizedDelegate + ")";
    }

    record BinaryQueryVector(byte[] vector, BinaryQuantizer.QueryFactors factors) {
    }

    static class BinarizedRandomVectorScorer
    extends RandomVectorScorer.AbstractRandomVectorScorer {
        private final BinaryQueryVector queryVector;
        private final RandomAccessBinarizedByteVectorValues targetVectors;
        private final VectorSimilarityFunction similarityFunction;
        private final float sqrtDimensions;
        private final float maxX1;

        BinarizedRandomVectorScorer(BinaryQueryVector queryVectors, RandomAccessBinarizedByteVectorValues targetVectors, VectorSimilarityFunction similarityFunction) {
            super((RandomAccessVectorValues)targetVectors);
            this.queryVector = queryVectors;
            this.targetVectors = targetVectors;
            this.similarityFunction = similarityFunction;
            this.sqrtDimensions = targetVectors.sqrtDimensions();
            this.maxX1 = targetVectors.maxX1();
        }

        public float score(int targetOrd) throws IOException {
            float score;
            float dist;
            byte[] quantizedQuery = this.queryVector.vector();
            int quantizedSum = this.queryVector.factors().quantizedSum();
            float lower = this.queryVector.factors().lower();
            float width = this.queryVector.factors().width();
            float distanceToCentroid = this.queryVector.factors().distToC();
            if (this.similarityFunction == VectorSimilarityFunction.EUCLIDEAN) {
                return this.euclideanScore(targetOrd, this.sqrtDimensions, quantizedQuery, distanceToCentroid, lower, quantizedSum, width);
            }
            float vmC = this.queryVector.factors().normVmC();
            float vDotC = this.queryVector.factors().vDotC();
            float cDotC = this.targetVectors.getCentroidDP();
            byte[] binaryCode = this.targetVectors.vectorValue(targetOrd);
            float ooq = this.targetVectors.getOOQ(targetOrd);
            float normOC = this.targetVectors.getNormOC(targetOrd);
            float oDotC = this.targetVectors.getODotC(targetOrd);
            float qcDist = ESVectorUtil.ipByteBinByte((byte[])quantizedQuery, (byte[])binaryCode);
            float xbSum = BQVectorUtils.popcount(binaryCode);
            if (normOC == 0.0f || ooq == 0.0f) {
                dist = oDotC + vDotC - cDotC;
            } else {
                assert (Float.isFinite(ooq));
                float estimatedDot = (2.0f * width / this.sqrtDimensions * qcDist + 2.0f * lower / this.sqrtDimensions * xbSum - width / this.sqrtDimensions * (float)quantizedSum - this.sqrtDimensions * lower) / ooq;
                dist = vmC * normOC * estimatedDot + oDotC + vDotC - cDotC;
            }
            assert (Float.isFinite(dist));
            float ooqSqr = (float)Math.pow(ooq, 2.0);
            float errorBound = (float)((double)(vmC * normOC) * ((double)this.maxX1 * Math.sqrt((1.0f - ooqSqr) / ooqSqr)));
            float f = score = Float.isFinite(errorBound) ? dist - errorBound : dist;
            if (this.similarityFunction == VectorSimilarityFunction.MAXIMUM_INNER_PRODUCT) {
                return VectorUtil.scaleMaxInnerProductScore((float)score);
            }
            return Math.max((1.0f + score) / 2.0f, 0.0f);
        }

        private float euclideanScore(int targetOrd, float sqrtDimensions, byte[] quantizedQuery, float distanceToCentroid, float lower, int quantizedSum, float width) throws IOException {
            byte[] binaryCode = this.targetVectors.vectorValue(targetOrd);
            float targetDistToC = this.targetVectors.getCentroidDistance(targetOrd);
            float x0 = this.targetVectors.getVectorMagnitude(targetOrd);
            float sqrX = targetDistToC * targetDistToC;
            double xX0 = targetDistToC / x0;
            float xbSum = BQVectorUtils.popcount(binaryCode);
            float factorPPC = (float)(-2.0 / (double)sqrtDimensions * xX0 * ((double)xbSum * 2.0 - (double)this.targetVectors.dimension()));
            float factorIP = (float)(-2.0 / (double)sqrtDimensions * xX0);
            long qcDist = ESVectorUtil.ipByteBinByte((byte[])quantizedQuery, (byte[])binaryCode);
            float score = sqrX + distanceToCentroid + factorPPC * lower + (float)(qcDist * 2L - (long)quantizedSum) * factorIP * width;
            float projectionDist = (float)Math.sqrt(xX0 * xX0 - (double)(targetDistToC * targetDistToC));
            float error = 2.0f * this.maxX1 * projectionDist;
            float y = (float)Math.sqrt(distanceToCentroid);
            float errorBound = y * error;
            if (Float.isFinite(errorBound)) {
                score += errorBound;
            }
            return Math.max(1.0f / (1.0f + score), 0.0f);
        }
    }
}

