/*
 * Decompiled with CFR 0.152.
 */
package org.elasticsearch.simdvec.internal.vectorization;

import java.io.IOException;
import java.lang.foreign.MemorySegment;
import java.nio.ByteOrder;
import jdk.incubator.vector.ByteVector;
import jdk.incubator.vector.IntVector;
import jdk.incubator.vector.ShortVector;
import jdk.incubator.vector.Vector;
import jdk.incubator.vector.VectorOperators;
import jdk.incubator.vector.VectorSpecies;
import org.apache.lucene.store.IndexInput;
import org.elasticsearch.simdvec.ES91Int4VectorsScorer;
import org.elasticsearch.simdvec.internal.vectorization.PanamaESVectorUtilSupport;

public final class MemorySegmentES91Int4VectorsScorer
extends ES91Int4VectorsScorer {
    private static final VectorSpecies<Byte> BYTE_SPECIES_64 = ByteVector.SPECIES_64;
    private static final VectorSpecies<Byte> BYTE_SPECIES_128 = ByteVector.SPECIES_128;
    private static final VectorSpecies<Short> SHORT_SPECIES_128 = ShortVector.SPECIES_128;
    private static final VectorSpecies<Short> SHORT_SPECIES_256 = ShortVector.SPECIES_256;
    private static final VectorSpecies<Integer> INT_SPECIES_128 = IntVector.SPECIES_128;
    private static final VectorSpecies<Integer> INT_SPECIES_256 = IntVector.SPECIES_256;
    private static final VectorSpecies<Integer> INT_SPECIES_512 = IntVector.SPECIES_512;
    private final MemorySegment memorySegment;

    public MemorySegmentES91Int4VectorsScorer(IndexInput in, int dimensions, MemorySegment memorySegment) {
        super(in, dimensions);
        this.memorySegment = memorySegment;
    }

    @Override
    public long int4DotProduct(byte[] q) throws IOException {
        if (PanamaESVectorUtilSupport.VECTOR_BITSIZE >= 512 || PanamaESVectorUtilSupport.VECTOR_BITSIZE == 256) {
            return this.dotProduct(q);
        }
        int i = 0;
        int res = 0;
        if (this.dimensions >= 32 && PanamaESVectorUtilSupport.HAS_FAST_INTEGER_VECTORS) {
            res += this.int4DotProductBody128(q, i += BYTE_SPECIES_128.loopBound(this.dimensions));
        }
        this.in.readBytes(this.scratch, i, this.dimensions - i);
        while (i < this.dimensions) {
            res += this.scratch[i] * q[i++];
        }
        return res;
    }

    private int int4DotProductBody128(byte[] q, int limit) throws IOException {
        int sum = 0;
        long offset = this.in.getFilePointer();
        for (int i = 0; i < limit; i += 1024) {
            ShortVector acc0 = ShortVector.zero(SHORT_SPECIES_128);
            ShortVector acc1 = ShortVector.zero(SHORT_SPECIES_128);
            int innerLimit = Math.min(limit - i, 1024);
            for (int j = 0; j < innerLimit; j += BYTE_SPECIES_128.length()) {
                ByteVector va8 = ByteVector.fromArray(BYTE_SPECIES_64, (byte[])q, (int)(i + j));
                ByteVector vb8 = ByteVector.fromMemorySegment(BYTE_SPECIES_64, (MemorySegment)this.memorySegment, (long)(offset + (long)i + (long)j), (ByteOrder)ByteOrder.LITTLE_ENDIAN);
                ByteVector prod8 = va8.mul((Vector)vb8);
                ShortVector prod16 = prod8.convertShape(VectorOperators.B2S, ShortVector.SPECIES_128, 0).reinterpretAsShorts();
                acc0 = acc0.add((Vector)prod16.and((short)255));
                va8 = ByteVector.fromArray(BYTE_SPECIES_64, (byte[])q, (int)(i + j + 8));
                vb8 = ByteVector.fromMemorySegment(BYTE_SPECIES_64, (MemorySegment)this.memorySegment, (long)(offset + (long)i + (long)j + 8L), (ByteOrder)ByteOrder.LITTLE_ENDIAN);
                prod8 = va8.mul((Vector)vb8);
                prod16 = prod8.convertShape(VectorOperators.B2S, SHORT_SPECIES_128, 0).reinterpretAsShorts();
                acc1 = acc1.add((Vector)prod16.and((short)255));
            }
            IntVector intAcc0 = acc0.convertShape(VectorOperators.S2I, INT_SPECIES_128, 0).reinterpretAsInts();
            IntVector intAcc1 = acc0.convertShape(VectorOperators.S2I, INT_SPECIES_128, 1).reinterpretAsInts();
            IntVector intAcc2 = acc1.convertShape(VectorOperators.S2I, INT_SPECIES_128, 0).reinterpretAsInts();
            IntVector intAcc3 = acc1.convertShape(VectorOperators.S2I, INT_SPECIES_128, 1).reinterpretAsInts();
            sum += intAcc0.add((Vector)intAcc1).add((Vector)intAcc2).add((Vector)intAcc3).reduceLanes(VectorOperators.ADD);
        }
        this.in.seek(offset + (long)limit);
        return sum;
    }

    private long dotProduct(byte[] q) throws IOException {
        int i = 0;
        int res = 0;
        if (this.dimensions >= 16 && PanamaESVectorUtilSupport.HAS_FAST_INTEGER_VECTORS) {
            res = PanamaESVectorUtilSupport.VECTOR_BITSIZE >= 512 ? (res += this.dotProductBody512(q, i += BYTE_SPECIES_128.loopBound(this.dimensions))) : (PanamaESVectorUtilSupport.VECTOR_BITSIZE == 256 ? (res += this.dotProductBody256(q, i += BYTE_SPECIES_64.loopBound(this.dimensions))) : (res += this.dotProductBody128(q, i += BYTE_SPECIES_64.loopBound(this.dimensions - BYTE_SPECIES_64.length()))));
        }
        while (i < q.length) {
            res += this.in.readByte() * q[i];
            ++i;
        }
        return res;
    }

    private int dotProductBody512(byte[] q, int limit) throws IOException {
        IntVector acc = IntVector.zero(INT_SPECIES_512);
        long offset = this.in.getFilePointer();
        for (int i = 0; i < limit; i += BYTE_SPECIES_128.length()) {
            ByteVector va8 = ByteVector.fromArray(BYTE_SPECIES_128, (byte[])q, (int)i);
            ByteVector vb8 = ByteVector.fromMemorySegment(BYTE_SPECIES_128, (MemorySegment)this.memorySegment, (long)(offset + (long)i), (ByteOrder)ByteOrder.LITTLE_ENDIAN);
            Vector va16 = va8.convertShape(VectorOperators.B2S, SHORT_SPECIES_256, 0);
            Vector vb16 = vb8.convertShape(VectorOperators.B2S, SHORT_SPECIES_256, 0);
            Vector prod16 = va16.mul(vb16);
            Vector prod32 = prod16.convertShape(VectorOperators.S2I, INT_SPECIES_512, 0);
            acc = acc.add(prod32);
        }
        this.in.seek(offset + (long)limit);
        return acc.reduceLanes(VectorOperators.ADD);
    }

    private int dotProductBody256(byte[] q, int limit) throws IOException {
        IntVector acc = IntVector.zero(INT_SPECIES_256);
        long offset = this.in.getFilePointer();
        for (int i = 0; i < limit; i += BYTE_SPECIES_64.length()) {
            ByteVector va8 = ByteVector.fromArray(BYTE_SPECIES_64, (byte[])q, (int)i);
            ByteVector vb8 = ByteVector.fromMemorySegment(BYTE_SPECIES_64, (MemorySegment)this.memorySegment, (long)(offset + (long)i), (ByteOrder)ByteOrder.LITTLE_ENDIAN);
            Vector va32 = va8.convertShape(VectorOperators.B2I, INT_SPECIES_256, 0);
            Vector vb32 = vb8.convertShape(VectorOperators.B2I, INT_SPECIES_256, 0);
            acc = acc.add(va32.mul(vb32));
        }
        this.in.seek(offset + (long)limit);
        return acc.reduceLanes(VectorOperators.ADD);
    }

    private int dotProductBody128(byte[] q, int limit) throws IOException {
        IntVector acc = IntVector.zero(INT_SPECIES_128);
        long offset = this.in.getFilePointer();
        for (int i = 0; i < limit; i += BYTE_SPECIES_64.length() >> 1) {
            ByteVector va8 = ByteVector.fromArray(BYTE_SPECIES_64, (byte[])q, (int)i);
            ByteVector vb8 = ByteVector.fromMemorySegment(BYTE_SPECIES_64, (MemorySegment)this.memorySegment, (long)(offset + (long)i), (ByteOrder)ByteOrder.LITTLE_ENDIAN);
            Vector va16 = va8.convert(VectorOperators.B2S, 0);
            Vector vb16 = vb8.convert(VectorOperators.B2S, 0);
            Vector prod16 = va16.mul(vb16);
            acc = acc.add(prod16.convertShape(VectorOperators.S2I, INT_SPECIES_128, 0));
        }
        this.in.seek(offset + (long)limit);
        return acc.reduceLanes(VectorOperators.ADD);
    }
}

