/*
 * Decompiled with CFR 0.152.
 */
package org.elasticsearch.index.codec.vectors;

import java.io.IOException;
import java.util.Map;
import java.util.function.IntPredicate;
import org.apache.lucene.codecs.hnsw.FlatVectorsReader;
import org.apache.lucene.index.FieldInfo;
import org.apache.lucene.index.SegmentReadState;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.search.KnnCollector;
import org.apache.lucene.store.IndexInput;
import org.apache.lucene.util.ArrayUtil;
import org.apache.lucene.util.VectorUtil;
import org.apache.lucene.util.hnsw.NeighborQueue;
import org.elasticsearch.index.codec.vectors.BQSpaceUtils;
import org.elasticsearch.index.codec.vectors.BQVectorUtils;
import org.elasticsearch.index.codec.vectors.DocIdsWriter;
import org.elasticsearch.index.codec.vectors.IVFVectorsReader;
import org.elasticsearch.index.codec.vectors.OptimizedScalarQuantizer;
import org.elasticsearch.index.codec.vectors.reflect.OffHeapStats;
import org.elasticsearch.simdvec.ES91Int4VectorsScorer;
import org.elasticsearch.simdvec.ES91OSQVectorsScorer;
import org.elasticsearch.simdvec.ESVectorUtil;

public class DefaultIVFVectorsReader
extends IVFVectorsReader
implements OffHeapStats {
    private static final float FOUR_BIT_SCALE = 0.06666667f;

    public DefaultIVFVectorsReader(SegmentReadState state, FlatVectorsReader rawVectorsReader) throws IOException {
        super(state, rawVectorsReader);
    }

    @Override
    IVFVectorsReader.CentroidQueryScorer getCentroidScorer(final FieldInfo fieldInfo, final int numCentroids, final IndexInput centroids, float[] targetQuery) throws IOException {
        IVFVectorsReader.FieldEntry fieldEntry = (IVFVectorsReader.FieldEntry)this.fields.get(fieldInfo.number);
        final float globalCentroidDp = fieldEntry.globalCentroidDp();
        OptimizedScalarQuantizer scalarQuantizer = new OptimizedScalarQuantizer(fieldInfo.getVectorSimilarityFunction());
        final byte[] quantized = new byte[targetQuery.length];
        final OptimizedScalarQuantizer.QuantizationResult queryParams = scalarQuantizer.scalarQuantize(ArrayUtil.copyArray((float[])targetQuery), quantized, (byte)4, fieldEntry.globalCentroid());
        final ES91Int4VectorsScorer scorer = ESVectorUtil.getES91Int4VectorsScorer((IndexInput)centroids, (int)fieldInfo.getVectorDimension());
        return new IVFVectorsReader.CentroidQueryScorer(){
            int currentCentroid = -1;
            private final float[] centroid = new float[fieldInfo.getVectorDimension()];
            private final float[] centroidCorrectiveValues = new float[3];
            private final long rawCentroidsOffset = (long)numCentroids * (long)(fieldInfo.getVectorDimension() + 12 + 2);
            private final long rawCentroidsByteSize = 4L * (long)fieldInfo.getVectorDimension();

            @Override
            public int size() {
                return numCentroids;
            }

            @Override
            public float[] centroid(int centroidOrdinal) throws IOException {
                if (centroidOrdinal != this.currentCentroid) {
                    centroids.seek(this.rawCentroidsOffset + this.rawCentroidsByteSize * (long)centroidOrdinal);
                    centroids.readFloats(this.centroid, 0, this.centroid.length);
                    this.currentCentroid = centroidOrdinal;
                }
                return this.centroid;
            }

            @Override
            public void bulkScore(NeighborQueue queue) throws IOException {
                centroids.seek(0L);
                for (int i = 0; i < numCentroids; ++i) {
                    queue.add(i, this.score());
                }
            }

            private float score() throws IOException {
                float qcDist = scorer.int4DotProduct(quantized);
                centroids.readFloats(this.centroidCorrectiveValues, 0, 3);
                int quantizedCentroidComponentSum = Short.toUnsignedInt(centroids.readShort());
                return this.int4QuantizedScore(qcDist, queryParams, fieldInfo.getVectorDimension(), this.centroidCorrectiveValues, quantizedCentroidComponentSum, globalCentroidDp, fieldInfo.getVectorSimilarityFunction());
            }

            private float int4QuantizedScore(float qcDist, OptimizedScalarQuantizer.QuantizationResult queryCorrections, int dims, float[] targetCorrections, int targetComponentSum, float centroidDp, VectorSimilarityFunction similarityFunction) {
                float ax = targetCorrections[0];
                float lx = (targetCorrections[1] - ax) * 0.06666667f;
                float ay = queryCorrections.lowerInterval();
                float ly = (queryCorrections.upperInterval() - ay) * 0.06666667f;
                float y1 = queryCorrections.quantizedComponentSum();
                float score = ax * ay * (float)dims + ay * lx * (float)targetComponentSum + ax * ly * y1 + lx * ly * qcDist;
                if (similarityFunction == VectorSimilarityFunction.EUCLIDEAN) {
                    score = queryCorrections.additionalCorrection() + targetCorrections[2] - 2.0f * score;
                    return Math.max(1.0f / (1.0f + score), 0.0f);
                }
                score += queryCorrections.additionalCorrection() + targetCorrections[2] - centroidDp;
                if (similarityFunction == VectorSimilarityFunction.MAXIMUM_INNER_PRODUCT) {
                    return VectorUtil.scaleMaxInnerProductScore((float)score);
                }
                return Math.max((1.0f + score) / 2.0f, 0.0f);
            }
        };
    }

    @Override
    NeighborQueue scorePostingLists(FieldInfo fieldInfo, KnnCollector knnCollector, IVFVectorsReader.CentroidQueryScorer centroidQueryScorer, int nProbe) throws IOException {
        NeighborQueue neighborQueue = new NeighborQueue(centroidQueryScorer.size(), true);
        centroidQueryScorer.bulkScore(neighborQueue);
        return neighborQueue;
    }

    @Override
    IVFVectorsReader.PostingVisitor getPostingVisitor(FieldInfo fieldInfo, IndexInput indexInput, float[] target, IntPredicate needsScoring) throws IOException {
        IVFVectorsReader.FieldEntry entry = (IVFVectorsReader.FieldEntry)this.fields.get(fieldInfo.number);
        return new MemorySegmentPostingsVisitor(target, indexInput.clone(), entry, fieldInfo, needsScoring);
    }

    @Override
    public Map<String, Long> getOffHeapByteSize(FieldInfo fieldInfo) {
        return Map.of();
    }

    private static class MemorySegmentPostingsVisitor
    implements IVFVectorsReader.PostingVisitor {
        final long quantizedByteLength;
        final IndexInput indexInput;
        final float[] target;
        final IVFVectorsReader.FieldEntry entry;
        final FieldInfo fieldInfo;
        final IntPredicate needsScoring;
        private final ES91OSQVectorsScorer osqVectorsScorer;
        final float[] scores = new float[16];
        final float[] correctionsLower = new float[16];
        final float[] correctionsUpper = new float[16];
        final int[] correctionsSum = new int[16];
        final float[] correctionsAdd = new float[16];
        int[] docIdsScratch = new int[0];
        int vectors;
        boolean quantized = false;
        float centroidDp;
        float[] centroid;
        long slicePos;
        OptimizedScalarQuantizer.QuantizationResult queryCorrections;
        DocIdsWriter docIdsWriter = new DocIdsWriter();
        final float[] scratch;
        final byte[] quantizationScratch;
        final byte[] quantizedQueryScratch;
        final OptimizedScalarQuantizer quantizer;
        final float[] correctiveValues = new float[3];
        final long quantizedVectorByteSize;

        MemorySegmentPostingsVisitor(float[] target, IndexInput indexInput, IVFVectorsReader.FieldEntry entry, FieldInfo fieldInfo, IntPredicate needsScoring) throws IOException {
            this.target = target;
            this.indexInput = indexInput;
            this.entry = entry;
            this.fieldInfo = fieldInfo;
            this.needsScoring = needsScoring;
            this.scratch = new float[target.length];
            this.quantizationScratch = new byte[target.length];
            int discretizedDimensions = BQVectorUtils.discretize(fieldInfo.getVectorDimension(), 64);
            this.quantizedQueryScratch = new byte[4 * discretizedDimensions / 8];
            this.quantizedByteLength = discretizedDimensions / 8 + 12 + 2;
            this.quantizedVectorByteSize = discretizedDimensions / 8;
            this.quantizer = new OptimizedScalarQuantizer(fieldInfo.getVectorSimilarityFunction());
            this.osqVectorsScorer = ESVectorUtil.getES91OSQVectorsScorer((IndexInput)indexInput, (int)fieldInfo.getVectorDimension());
        }

        @Override
        public int resetPostingsScorer(int centroidOrdinal, float[] centroid) throws IOException {
            this.quantized = false;
            this.indexInput.seek(this.entry.postingListOffsets()[centroidOrdinal]);
            this.vectors = this.indexInput.readVInt();
            this.centroidDp = Float.intBitsToFloat(this.indexInput.readInt());
            this.centroid = centroid;
            this.docIdsScratch = this.vectors > this.docIdsScratch.length ? new int[this.vectors] : this.docIdsScratch;
            this.docIdsWriter.readInts(this.indexInput, this.vectors, this.docIdsScratch);
            this.slicePos = this.indexInput.getFilePointer();
            return this.vectors;
        }

        void scoreIndividually(int offset) throws IOException {
            int doc;
            int j;
            for (j = 0; j < 16; ++j) {
                float qcDist;
                doc = this.docIdsScratch[j + offset];
                if (doc == -1) continue;
                this.indexInput.seek(this.slicePos + (long)offset * this.quantizedByteLength + (long)j * this.quantizedVectorByteSize);
                this.scores[j] = qcDist = (float)this.osqVectorsScorer.quantizeScore(this.quantizedQueryScratch);
            }
            this.indexInput.seek(this.slicePos + (long)offset * this.quantizedByteLength + 16L * this.quantizedVectorByteSize);
            this.indexInput.readFloats(this.correctionsLower, 0, 16);
            this.indexInput.readFloats(this.correctionsUpper, 0, 16);
            for (j = 0; j < 16; ++j) {
                this.correctionsSum[j] = Short.toUnsignedInt(this.indexInput.readShort());
            }
            this.indexInput.readFloats(this.correctionsAdd, 0, 16);
            for (j = 0; j < 16; ++j) {
                doc = this.docIdsScratch[offset + j];
                if (doc == -1) continue;
                this.scores[j] = this.osqVectorsScorer.score(this.queryCorrections.lowerInterval(), this.queryCorrections.upperInterval(), this.queryCorrections.quantizedComponentSum(), this.queryCorrections.additionalCorrection(), this.fieldInfo.getVectorSimilarityFunction(), this.centroidDp, this.correctionsLower[j], this.correctionsUpper[j], this.correctionsSum[j], this.correctionsAdd[j], this.scores[j]);
            }
        }

        @Override
        public int visit(KnnCollector knnCollector) throws IOException {
            int i;
            int scoredDocs = 0;
            int limit = this.vectors - 16 + 1;
            for (i = 0; i < limit; i += 16) {
                int doc;
                int j;
                int docsToScore = 16;
                for (j = 0; j < 16; ++j) {
                    doc = this.docIdsScratch[i + j];
                    if (this.needsScoring.test(doc)) continue;
                    this.docIdsScratch[i + j] = -1;
                    --docsToScore;
                }
                if (docsToScore == 0) continue;
                this.quantizeQueryIfNecessary();
                this.indexInput.seek(this.slicePos + (long)i * this.quantizedByteLength);
                if (docsToScore < 8) {
                    this.scoreIndividually(i);
                } else {
                    this.osqVectorsScorer.scoreBulk(this.quantizedQueryScratch, this.queryCorrections.lowerInterval(), this.queryCorrections.upperInterval(), this.queryCorrections.quantizedComponentSum(), this.queryCorrections.additionalCorrection(), this.fieldInfo.getVectorSimilarityFunction(), this.centroidDp, this.scores);
                }
                for (j = 0; j < 16; ++j) {
                    doc = this.docIdsScratch[i + j];
                    if (doc == -1) continue;
                    ++scoredDocs;
                    knnCollector.collect(doc, this.scores[j]);
                }
            }
            while (i < this.vectors) {
                int doc = this.docIdsScratch[i];
                if (this.needsScoring.test(doc)) {
                    this.quantizeQueryIfNecessary();
                    this.indexInput.seek(this.slicePos + (long)i * this.quantizedByteLength);
                    float qcDist = this.osqVectorsScorer.quantizeScore(this.quantizedQueryScratch);
                    this.indexInput.readFloats(this.correctiveValues, 0, 3);
                    int quantizedComponentSum = Short.toUnsignedInt(this.indexInput.readShort());
                    float score = this.osqVectorsScorer.score(this.queryCorrections.lowerInterval(), this.queryCorrections.upperInterval(), this.queryCorrections.quantizedComponentSum(), this.queryCorrections.additionalCorrection(), this.fieldInfo.getVectorSimilarityFunction(), this.centroidDp, this.correctiveValues[0], this.correctiveValues[1], quantizedComponentSum, this.correctiveValues[2], qcDist);
                    ++scoredDocs;
                    knnCollector.collect(doc, score);
                }
                ++i;
            }
            if (scoredDocs > 0) {
                knnCollector.incVisitedCount(scoredDocs);
            }
            return scoredDocs;
        }

        private void quantizeQueryIfNecessary() {
            if (!this.quantized) {
                System.arraycopy(this.target, 0, this.scratch, 0, this.target.length);
                if (this.fieldInfo.getVectorSimilarityFunction() == VectorSimilarityFunction.COSINE) {
                    VectorUtil.l2normalize((float[])this.scratch);
                }
                this.queryCorrections = this.quantizer.scalarQuantize(this.scratch, this.quantizationScratch, (byte)4, this.centroid);
                BQSpaceUtils.transposeHalfByte(this.quantizationScratch, this.quantizedQueryScratch);
                this.quantized = true;
            }
        }
    }
}

