/*
 * Decompiled with CFR 0.152.
 */
package org.elasticsearch.xpack.rank.vectors.script;

import java.io.IOException;
import java.util.HexFormat;
import java.util.List;
import org.elasticsearch.ExceptionsHelper;
import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper;
import org.elasticsearch.script.ScoreScript;
import org.elasticsearch.script.field.vectors.DenseVector;
import org.elasticsearch.script.field.vectors.RankVectorsDocValuesField;

public class RankVectorsScoreScriptUtils {
    private static BytesOrList parseBytes(Object queryVector) {
        if (queryVector instanceof List) {
            if (((List)queryVector).get(0) instanceof List) {
                return new BytesOrList(null, (List)queryVector);
            }
            if (((List)queryVector).get(0) instanceof String) {
                byte[][] parsedQueryVector = new byte[((List)queryVector).size()][];
                int lastSize = -1;
                for (int i = 0; i < ((List)queryVector).size(); ++i) {
                    parsedQueryVector[i] = HexFormat.of().parseHex((String)((List)queryVector).get(i));
                    if (lastSize != -1 && lastSize != parsedQueryVector[i].length) {
                        throw new IllegalArgumentException("The query vector contains inner vectors which have inconsistent number of dimensions.");
                    }
                    lastSize = parsedQueryVector[i].length;
                }
                return new BytesOrList(parsedQueryVector, null);
            }
            throw new IllegalArgumentException("Unsupported input object for byte vectors: " + queryVector.getClass().getName());
        }
        throw new IllegalArgumentException("Unsupported input object for byte vectors: " + queryVector.getClass().getName());
    }

    private record BytesOrList(byte[][] bytes, List<List<Number>> list) {
    }

    public static final class MaxSimDotProduct {
        private final MaxSimDotProductInterface function;

        public MaxSimDotProduct(ScoreScript scoreScript, Object queryVector, String fieldName) {
            RankVectorsDocValuesField field = (RankVectorsDocValuesField)scoreScript.field(fieldName);
            this.function = switch (field.getElementType()) {
                default -> throw new MatchException(null, null);
                case DenseVectorFieldMapper.ElementType.BIT -> {
                    BytesOrList bytesOrList = RankVectorsScoreScriptUtils.parseBytes(queryVector);
                    if (bytesOrList.bytes != null) {
                        yield new MaxSimBitDotProduct(scoreScript, field, bytesOrList.bytes);
                    }
                    yield new MaxSimBitDotProduct(scoreScript, field, bytesOrList.list);
                }
                case DenseVectorFieldMapper.ElementType.BYTE -> {
                    BytesOrList bytesOrList = RankVectorsScoreScriptUtils.parseBytes(queryVector);
                    if (bytesOrList.bytes != null) {
                        yield new MaxSimByteDotProduct(scoreScript, field, bytesOrList.bytes);
                    }
                    yield new MaxSimByteDotProduct(scoreScript, field, bytesOrList.list);
                }
                case DenseVectorFieldMapper.ElementType.FLOAT -> {
                    if (queryVector instanceof List) {
                        yield new MaxSimFloatDotProduct(scoreScript, field, (List)queryVector);
                    }
                    throw new IllegalArgumentException("Unsupported input object for float vectors: " + queryVector.getClass().getName());
                }
            };
        }

        public double maxSimDotProduct() {
            return this.function.maxSimDotProduct();
        }
    }

    public static class MaxSimFloatDotProduct
    extends FloatRankVectorsFunction
    implements MaxSimDotProductInterface {
        public MaxSimFloatDotProduct(ScoreScript scoreScript, RankVectorsDocValuesField field, List<List<Number>> queryVector) {
            super(scoreScript, field, queryVector);
        }

        @Override
        public double maxSimDotProduct() {
            this.setNextVector();
            return this.field.get().maxSimDotProduct(this.queryVector);
        }
    }

    public static class MaxSimByteDotProduct
    extends ByteRankVectorsFunction
    implements MaxSimDotProductInterface {
        public MaxSimByteDotProduct(ScoreScript scoreScript, RankVectorsDocValuesField field, List<List<Number>> queryVector) {
            super(scoreScript, field, queryVector);
        }

        public MaxSimByteDotProduct(ScoreScript scoreScript, RankVectorsDocValuesField field, byte[][] queryVector) {
            super(scoreScript, field, queryVector);
        }

        @Override
        public double maxSimDotProduct() {
            this.setNextVector();
            return this.field.get().maxSimDotProduct(this.queryVector);
        }
    }

    public static class MaxSimBitDotProduct
    extends RankVectorsFunction
    implements MaxSimDotProductInterface {
        private final byte[][] byteQueryVector;
        private final float[][] floatQueryVector;

        public MaxSimBitDotProduct(ScoreScript scoreScript, RankVectorsDocValuesField field, byte[][] queryVector) {
            super(scoreScript, field);
            if (field.getElementType() != DenseVectorFieldMapper.ElementType.BIT) {
                throw new IllegalArgumentException("Cannot calculate bit dot product for non-bit vectors");
            }
            int fieldDims = field.get().getDims();
            if (fieldDims != queryVector.length * 8 && fieldDims != queryVector.length) {
                throw new IllegalArgumentException("The query vector has an incorrect number of dimensions. Must be [" + fieldDims / 8 + "] for bitwise operations, or [" + fieldDims + "] for byte wise operations: provided [" + queryVector.length + "].");
            }
            this.byteQueryVector = queryVector;
            this.floatQueryVector = null;
        }

        public MaxSimBitDotProduct(ScoreScript scoreScript, RankVectorsDocValuesField field, List<List<Number>> queryVector) {
            super(scoreScript, field);
            if (queryVector.isEmpty()) {
                throw new IllegalArgumentException("The query vector is empty.");
            }
            if (field.getElementType() != DenseVectorFieldMapper.ElementType.BIT) {
                throw new IllegalArgumentException("cannot calculate bit dot product for non-bit vectors");
            }
            float[][] floatQueryVector = new float[queryVector.size()][];
            byte[][] byteQueryVector = new byte[queryVector.size()][];
            boolean isFloat = false;
            int lastSize = -1;
            for (int i = 0; i < queryVector.size(); ++i) {
                if (lastSize != -1 && lastSize != queryVector.get(i).size()) {
                    throw new IllegalArgumentException("The query vector contains inner vectors which have inconsistent number of dimensions.");
                }
                lastSize = queryVector.get(i).size();
                floatQueryVector[i] = new float[queryVector.get(i).size()];
                if (!isFloat) {
                    byteQueryVector[i] = new byte[queryVector.get(i).size()];
                }
                for (int j = 0; j < queryVector.get(i).size(); ++j) {
                    Number number = queryVector.get(i).get(j);
                    floatQueryVector[i][j] = number.floatValue();
                    if (!isFloat) {
                        byteQueryVector[i][j] = number.byteValue();
                    }
                    if (!isFloat && floatQueryVector[i][j] % 1.0f == 0.0f && !(floatQueryVector[i][j] < -128.0f) && !(floatQueryVector[i][j] > 127.0f)) continue;
                    isFloat = true;
                }
            }
            int fieldDims = field.get().getDims();
            if (isFloat) {
                this.floatQueryVector = floatQueryVector;
                this.byteQueryVector = null;
                if (fieldDims != floatQueryVector[0].length) {
                    throw new IllegalArgumentException("The query vector contains inner vectors which have incorrect number of dimensions. Must be [" + fieldDims + "] for float wise operations: provided [" + floatQueryVector[0].length + "].");
                }
            } else {
                this.floatQueryVector = null;
                this.byteQueryVector = byteQueryVector;
                if (fieldDims != byteQueryVector[0].length * 8 && fieldDims != byteQueryVector[0].length) {
                    throw new IllegalArgumentException("The query vector contains inner vectors which have incorrect number of dimensions. Must be [" + fieldDims / 8 + "] for bitwise operations, or [" + fieldDims + "] for byte wise operations: provided [" + byteQueryVector[0].length + "].");
                }
            }
        }

        @Override
        public double maxSimDotProduct() {
            this.setNextVector();
            return this.byteQueryVector != null ? (double)this.field.get().maxSimDotProduct(this.byteQueryVector) : (double)this.field.get().maxSimDotProduct(this.floatQueryVector);
        }
    }

    public static interface MaxSimDotProductInterface {
        public double maxSimDotProduct();
    }

    public static final class MaxSimInvHamming {
        private final MaxSimInvHammingDistanceInterface function;

        public MaxSimInvHamming(ScoreScript scoreScript, Object queryVector, String fieldName) {
            RankVectorsDocValuesField field = (RankVectorsDocValuesField)scoreScript.field(fieldName);
            if (field.getElementType() == DenseVectorFieldMapper.ElementType.FLOAT) {
                throw new IllegalArgumentException("hamming distance is only supported for byte or bit vectors");
            }
            BytesOrList bytesOrList = RankVectorsScoreScriptUtils.parseBytes(queryVector);
            this.function = bytesOrList.bytes != null ? new ByteMaxSimInvHammingDistance(scoreScript, field, bytesOrList.bytes) : new ByteMaxSimInvHammingDistance(scoreScript, field, bytesOrList.list);
        }

        public double maxSimInvHamming() {
            return this.function.maxSimInvHamming();
        }
    }

    public static class ByteMaxSimInvHammingDistance
    extends ByteRankVectorsFunction
    implements MaxSimInvHammingDistanceInterface {
        public ByteMaxSimInvHammingDistance(ScoreScript scoreScript, RankVectorsDocValuesField field, List<List<Number>> queryVector) {
            super(scoreScript, field, queryVector);
        }

        public ByteMaxSimInvHammingDistance(ScoreScript scoreScript, RankVectorsDocValuesField field, byte[][] queryVector) {
            super(scoreScript, field, queryVector);
        }

        @Override
        public float maxSimInvHamming() {
            this.setNextVector();
            return this.field.get().maxSimInvHamming(this.queryVector);
        }
    }

    public static interface MaxSimInvHammingDistanceInterface {
        public float maxSimInvHamming();
    }

    public static class FloatRankVectorsFunction
    extends RankVectorsFunction {
        protected final float[][] queryVector;

        public FloatRankVectorsFunction(ScoreScript scoreScript, RankVectorsDocValuesField field, List<List<Number>> queryVector) {
            super(scoreScript, field);
            if (queryVector.isEmpty()) {
                throw new IllegalArgumentException("The query vector is empty.");
            }
            DenseVector.checkDimensions((int)field.get().getDims(), (int)queryVector.get(0).size());
            this.queryVector = new float[queryVector.size()][queryVector.get(0).size()];
            int lastSize = -1;
            for (int i = 0; i < queryVector.size(); ++i) {
                if (lastSize != -1 && lastSize != queryVector.get(i).size()) {
                    throw new IllegalArgumentException("The query vector contains inner vectors which have inconsistent number of dimensions.");
                }
                lastSize = queryVector.get(i).size();
                for (int j = 0; j < queryVector.get(i).size(); ++j) {
                    this.queryVector[i][j] = queryVector.get(i).get(j).floatValue();
                }
                field.getElementType().checkVectorBounds(this.queryVector[i]);
            }
        }
    }

    public static class ByteRankVectorsFunction
    extends RankVectorsFunction {
        protected final byte[][] queryVector;

        public ByteRankVectorsFunction(ScoreScript scoreScript, RankVectorsDocValuesField field, List<List<Number>> queryVector) {
            super(scoreScript, field);
            if (queryVector.isEmpty()) {
                throw new IllegalArgumentException("The query vector is empty.");
            }
            field.getElementType().checkDimensions(Integer.valueOf(field.get().getDims()), queryVector.get(0).size());
            this.queryVector = new byte[queryVector.size()][queryVector.get(0).size()];
            float[] validateValues = new float[queryVector.size()];
            int lastSize = -1;
            for (int i = 0; i < queryVector.size(); ++i) {
                if (lastSize != -1 && lastSize != queryVector.get(i).size()) {
                    throw new IllegalArgumentException("The query vector contains inner vectors which have inconsistent number of dimensions.");
                }
                lastSize = queryVector.get(i).size();
                for (int j = 0; j < queryVector.get(i).size(); ++j) {
                    byte value;
                    Number number = queryVector.get(i).get(j);
                    this.queryVector[i][j] = value = number.byteValue();
                    validateValues[i] = number.floatValue();
                }
                field.getElementType().checkVectorBounds(validateValues);
            }
        }

        public ByteRankVectorsFunction(ScoreScript scoreScript, RankVectorsDocValuesField field, byte[][] queryVector) {
            super(scoreScript, field);
            this.queryVector = queryVector;
        }
    }

    public static class RankVectorsFunction {
        protected final ScoreScript scoreScript;
        protected final RankVectorsDocValuesField field;

        public RankVectorsFunction(ScoreScript scoreScript, RankVectorsDocValuesField field) {
            this.scoreScript = scoreScript;
            this.field = field;
        }

        void setNextVector() {
            try {
                this.field.setNextDocId(this.scoreScript._getDocId());
            }
            catch (IOException e) {
                throw ExceptionsHelper.convertToElastic((Exception)e);
            }
            if (this.field.isEmpty()) {
                throw new IllegalArgumentException("A document doesn't have a value for a multi-vector field!");
            }
        }
    }
}

