/*
 * Decompiled with CFR 0.152.
 */
package org.elasticsearch.script;

import java.io.IOException;
import java.util.HexFormat;
import java.util.List;
import org.elasticsearch.ExceptionsHelper;
import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper;
import org.elasticsearch.script.ScoreScript;
import org.elasticsearch.script.field.vectors.DenseVector;
import org.elasticsearch.script.field.vectors.DenseVectorDocValuesField;

public class VectorScoreScriptUtils {

    public static final class CosineSimilarity {
        private final CosineSimilarityInterface function;

        public CosineSimilarity(ScoreScript scoreScript, Object queryVector, String fieldName) {
            DenseVectorDocValuesField field = (DenseVectorDocValuesField)scoreScript.field(fieldName);
            this.function = switch (field.getElementType()) {
                default -> throw new MatchException(null, null);
                case DenseVectorFieldMapper.ElementType.BYTE, DenseVectorFieldMapper.ElementType.BIT -> {
                    if (queryVector instanceof List) {
                        yield new ByteCosineSimilarity(scoreScript, field, (List)queryVector);
                    }
                    if (queryVector instanceof String) {
                        String s = (String)queryVector;
                        byte[] parsedQueryVector = HexFormat.of().parseHex(s);
                        yield new ByteCosineSimilarity(scoreScript, field, parsedQueryVector);
                    }
                    throw new IllegalArgumentException("Unsupported input object for byte vectors: " + queryVector.getClass().getName());
                }
                case DenseVectorFieldMapper.ElementType.FLOAT -> {
                    if (queryVector instanceof List) {
                        yield new FloatCosineSimilarity(scoreScript, field, (List)queryVector);
                    }
                    throw new IllegalArgumentException("Unsupported input object for float vectors: " + queryVector.getClass().getName());
                }
            };
        }

        public double cosineSimilarity() {
            return this.function.cosineSimilarity();
        }
    }

    public static class FloatCosineSimilarity
    extends FloatDenseVectorFunction
    implements CosineSimilarityInterface {
        public FloatCosineSimilarity(ScoreScript scoreScript, DenseVectorDocValuesField field, List<Number> queryVector) {
            super(scoreScript, field, queryVector, true);
        }

        @Override
        public double cosineSimilarity() {
            this.setNextVector();
            return this.field.get().cosineSimilarity(this.queryVector, false);
        }
    }

    public static class ByteCosineSimilarity
    extends ByteDenseVectorFunction
    implements CosineSimilarityInterface {
        public ByteCosineSimilarity(ScoreScript scoreScript, DenseVectorDocValuesField field, List<Number> queryVector) {
            super(scoreScript, field, queryVector, true, DenseVectorFieldMapper.ElementType.BYTE, DenseVectorFieldMapper.ElementType.FLOAT);
        }

        public ByteCosineSimilarity(ScoreScript scoreScript, DenseVectorDocValuesField field, byte[] queryVector) {
            super(scoreScript, field, queryVector);
        }

        @Override
        public double cosineSimilarity() {
            this.setNextVector();
            if (this.floatQueryVector != null) {
                return this.field.get().cosineSimilarity(this.floatQueryVector, false);
            }
            return this.field.get().cosineSimilarity(this.byteQueryVector, this.qvMagnitude);
        }
    }

    public static interface CosineSimilarityInterface {
        public double cosineSimilarity();
    }

    public static final class DotProduct {
        private final DotProductInterface function;

        public DotProduct(ScoreScript scoreScript, Object queryVector, String fieldName) {
            DenseVectorDocValuesField field = (DenseVectorDocValuesField)scoreScript.field(fieldName);
            this.function = switch (field.getElementType()) {
                default -> throw new MatchException(null, null);
                case DenseVectorFieldMapper.ElementType.BIT -> {
                    if (queryVector instanceof List) {
                        yield new BitDotProduct(scoreScript, field, (List)queryVector);
                    }
                    if (queryVector instanceof String) {
                        String s = (String)queryVector;
                        byte[] parsedQueryVector = HexFormat.of().parseHex(s);
                        yield new BitDotProduct(scoreScript, field, parsedQueryVector);
                    }
                    throw new IllegalArgumentException("Unsupported input object for bit vectors: " + queryVector.getClass().getName());
                }
                case DenseVectorFieldMapper.ElementType.BYTE -> {
                    if (queryVector instanceof List) {
                        yield new ByteDotProduct(scoreScript, field, (List)queryVector);
                    }
                    if (queryVector instanceof String) {
                        String s = (String)queryVector;
                        byte[] parsedQueryVector = HexFormat.of().parseHex(s);
                        yield new ByteDotProduct(scoreScript, field, parsedQueryVector);
                    }
                    throw new IllegalArgumentException("Unsupported input object for byte vectors: " + queryVector.getClass().getName());
                }
                case DenseVectorFieldMapper.ElementType.FLOAT -> {
                    if (queryVector instanceof List) {
                        yield new FloatDotProduct(scoreScript, field, (List)queryVector);
                    }
                    throw new IllegalArgumentException("Unsupported input object for float vectors: " + queryVector.getClass().getName());
                }
            };
        }

        public double dotProduct() {
            return this.function.dotProduct();
        }
    }

    public static class FloatDotProduct
    extends FloatDenseVectorFunction
    implements DotProductInterface {
        public FloatDotProduct(ScoreScript scoreScript, DenseVectorDocValuesField field, List<Number> queryVector) {
            super(scoreScript, field, queryVector, false);
        }

        @Override
        public double dotProduct() {
            this.setNextVector();
            return this.field.get().dotProduct(this.queryVector);
        }
    }

    public static class ByteDotProduct
    extends ByteDenseVectorFunction
    implements DotProductInterface {
        public ByteDotProduct(ScoreScript scoreScript, DenseVectorDocValuesField field, List<Number> queryVector) {
            super(scoreScript, field, queryVector, false, DenseVectorFieldMapper.ElementType.BYTE, DenseVectorFieldMapper.ElementType.FLOAT);
        }

        public ByteDotProduct(ScoreScript scoreScript, DenseVectorDocValuesField field, byte[] queryVector) {
            super(scoreScript, field, queryVector);
        }

        @Override
        public double dotProduct() {
            this.setNextVector();
            if (this.floatQueryVector != null) {
                return this.field.get().dotProduct(this.floatQueryVector);
            }
            return this.field.get().dotProduct(this.byteQueryVector);
        }
    }

    public static class BitDotProduct
    extends DenseVectorFunction
    implements DotProductInterface {
        private final byte[] byteQueryVector;
        private final float[] floatQueryVector;

        public BitDotProduct(ScoreScript scoreScript, DenseVectorDocValuesField field, byte[] queryVector) {
            super(scoreScript, field);
            if (field.getElementType() != DenseVectorFieldMapper.ElementType.BIT) {
                throw new IllegalArgumentException("cannot calculate bit dot product for non-bit vectors");
            }
            int fieldDims = field.get().getDims();
            if (fieldDims != queryVector.length * 8 && fieldDims != queryVector.length) {
                throw new IllegalArgumentException("The query vector has an incorrect number of dimensions. Must be [" + fieldDims / 8 + "] for bitwise operations, or [" + fieldDims + "] for byte wise operations: provided [" + queryVector.length + "].");
            }
            this.byteQueryVector = queryVector;
            this.floatQueryVector = null;
        }

        public BitDotProduct(ScoreScript scoreScript, DenseVectorDocValuesField field, List<Number> queryVector) {
            super(scoreScript, field);
            if (field.getElementType() != DenseVectorFieldMapper.ElementType.BIT) {
                throw new IllegalArgumentException("cannot calculate bit dot product for non-bit vectors");
            }
            float[] floatQueryVector = new float[queryVector.size()];
            byte[] byteQueryVector = new byte[queryVector.size()];
            boolean isFloat = false;
            for (int i = 0; i < queryVector.size(); ++i) {
                Number number = queryVector.get(i);
                floatQueryVector[i] = number.floatValue();
                byteQueryVector[i] = number.byteValue();
                if (!isFloat && floatQueryVector[i] % 1.0f == 0.0f && !(floatQueryVector[i] < -128.0f) && !(floatQueryVector[i] > 127.0f)) continue;
                isFloat = true;
            }
            int fieldDims = field.get().getDims();
            if (isFloat) {
                this.floatQueryVector = floatQueryVector;
                this.byteQueryVector = null;
                if (fieldDims != floatQueryVector.length) {
                    throw new IllegalArgumentException("The query vector has an incorrect number of dimensions. Must be [" + fieldDims + "] for float wise operations: provided [" + floatQueryVector.length + "].");
                }
            } else {
                this.floatQueryVector = null;
                this.byteQueryVector = byteQueryVector;
                if (fieldDims != byteQueryVector.length * 8 && fieldDims != byteQueryVector.length) {
                    throw new IllegalArgumentException("The query vector has an incorrect number of dimensions. Must be [" + fieldDims / 8 + "] for bitwise operations, or [" + fieldDims + "] for byte wise operations: provided [" + byteQueryVector.length + "].");
                }
            }
        }

        @Override
        public double dotProduct() {
            this.setNextVector();
            return this.byteQueryVector != null ? (double)this.field.get().dotProduct(this.byteQueryVector) : this.field.get().dotProduct(this.floatQueryVector);
        }
    }

    public static interface DotProductInterface {
        public double dotProduct();
    }

    public static final class L2Norm {
        private final L2NormInterface function;

        public L2Norm(ScoreScript scoreScript, Object queryVector, String fieldName) {
            DenseVectorDocValuesField field = (DenseVectorDocValuesField)scoreScript.field(fieldName);
            this.function = switch (field.getElementType()) {
                default -> throw new MatchException(null, null);
                case DenseVectorFieldMapper.ElementType.BYTE, DenseVectorFieldMapper.ElementType.BIT -> {
                    if (queryVector instanceof List) {
                        yield new ByteL2Norm(scoreScript, field, (List)queryVector);
                    }
                    if (queryVector instanceof String) {
                        String s = (String)queryVector;
                        byte[] parsedQueryVector = HexFormat.of().parseHex(s);
                        yield new ByteL2Norm(scoreScript, field, parsedQueryVector);
                    }
                    throw new IllegalArgumentException("Unsupported input object for byte vectors: " + queryVector.getClass().getName());
                }
                case DenseVectorFieldMapper.ElementType.FLOAT -> {
                    if (queryVector instanceof List) {
                        yield new FloatL2Norm(scoreScript, field, (List)queryVector);
                    }
                    throw new IllegalArgumentException("Unsupported input object for float vectors: " + queryVector.getClass().getName());
                }
            };
        }

        public double l2norm() {
            return this.function.l2norm();
        }
    }

    public static class FloatL2Norm
    extends FloatDenseVectorFunction
    implements L2NormInterface {
        public FloatL2Norm(ScoreScript scoreScript, DenseVectorDocValuesField field, List<Number> queryVector) {
            super(scoreScript, field, queryVector, false);
        }

        @Override
        public double l2norm() {
            this.setNextVector();
            return this.field.get().l2Norm(this.queryVector);
        }
    }

    public static class ByteL2Norm
    extends ByteDenseVectorFunction
    implements L2NormInterface {
        public ByteL2Norm(ScoreScript scoreScript, DenseVectorDocValuesField field, List<Number> queryVector) {
            super(scoreScript, field, queryVector, false, DenseVectorFieldMapper.ElementType.BYTE);
        }

        public ByteL2Norm(ScoreScript scoreScript, DenseVectorDocValuesField field, byte[] queryVector) {
            super(scoreScript, field, queryVector);
        }

        @Override
        public double l2norm() {
            this.setNextVector();
            return this.field.get().l2Norm(this.byteQueryVector);
        }
    }

    public static interface L2NormInterface {
        public double l2norm();
    }

    public static final class Hamming {
        private final HammingDistanceInterface function;

        public Hamming(ScoreScript scoreScript, Object queryVector, String fieldName) {
            DenseVectorDocValuesField field = (DenseVectorDocValuesField)scoreScript.field(fieldName);
            if (field.getElementType() == DenseVectorFieldMapper.ElementType.FLOAT) {
                throw new IllegalArgumentException("hamming distance is only supported for byte or bit vectors");
            }
            if (queryVector instanceof List) {
                this.function = new ByteHammingDistance(scoreScript, field, (List)queryVector);
            } else if (queryVector instanceof String) {
                String s = (String)queryVector;
                byte[] parsedQueryVector = HexFormat.of().parseHex(s);
                this.function = new ByteHammingDistance(scoreScript, field, parsedQueryVector);
            } else {
                throw new IllegalArgumentException("Unsupported input object for byte vectors: " + queryVector.getClass().getName());
            }
        }

        public double hamming() {
            return this.function.hamming();
        }
    }

    public static class ByteHammingDistance
    extends ByteDenseVectorFunction
    implements HammingDistanceInterface {
        public ByteHammingDistance(ScoreScript scoreScript, DenseVectorDocValuesField field, List<Number> queryVector) {
            super(scoreScript, field, queryVector, false, DenseVectorFieldMapper.ElementType.BYTE);
        }

        public ByteHammingDistance(ScoreScript scoreScript, DenseVectorDocValuesField field, byte[] queryVector) {
            super(scoreScript, field, queryVector);
        }

        @Override
        public int hamming() {
            this.setNextVector();
            return this.field.get().hamming(this.byteQueryVector);
        }
    }

    public static interface HammingDistanceInterface {
        public int hamming();
    }

    public static final class L1Norm {
        private final L1NormInterface function;

        public L1Norm(ScoreScript scoreScript, Object queryVector, String fieldName) {
            DenseVectorDocValuesField field = (DenseVectorDocValuesField)scoreScript.field(fieldName);
            this.function = switch (field.getElementType()) {
                default -> throw new MatchException(null, null);
                case DenseVectorFieldMapper.ElementType.BYTE, DenseVectorFieldMapper.ElementType.BIT -> {
                    if (queryVector instanceof List) {
                        yield new ByteL1Norm(scoreScript, field, (List)queryVector);
                    }
                    if (queryVector instanceof String) {
                        String s = (String)queryVector;
                        byte[] parsedQueryVector = HexFormat.of().parseHex(s);
                        yield new ByteL1Norm(scoreScript, field, parsedQueryVector);
                    }
                    throw new IllegalArgumentException("Unsupported input object for byte vectors: " + queryVector.getClass().getName());
                }
                case DenseVectorFieldMapper.ElementType.FLOAT -> {
                    if (queryVector instanceof List) {
                        yield new FloatL1Norm(scoreScript, field, (List)queryVector);
                    }
                    throw new IllegalArgumentException("Unsupported input object for float vectors: " + queryVector.getClass().getName());
                }
            };
        }

        public double l1norm() {
            return this.function.l1norm();
        }
    }

    public static class FloatL1Norm
    extends FloatDenseVectorFunction
    implements L1NormInterface {
        public FloatL1Norm(ScoreScript scoreScript, DenseVectorDocValuesField field, List<Number> queryVector) {
            super(scoreScript, field, queryVector, false);
        }

        @Override
        public double l1norm() {
            this.setNextVector();
            return this.field.get().l1Norm(this.queryVector);
        }
    }

    public static class ByteL1Norm
    extends ByteDenseVectorFunction
    implements L1NormInterface {
        public ByteL1Norm(ScoreScript scoreScript, DenseVectorDocValuesField field, List<Number> queryVector) {
            super(scoreScript, field, queryVector, false, DenseVectorFieldMapper.ElementType.BYTE);
        }

        public ByteL1Norm(ScoreScript scoreScript, DenseVectorDocValuesField field, byte[] queryVector) {
            super(scoreScript, field, queryVector);
        }

        @Override
        public double l1norm() {
            this.setNextVector();
            return this.field.get().l1Norm(this.byteQueryVector);
        }
    }

    public static interface L1NormInterface {
        public double l1norm();
    }

    public static class FloatDenseVectorFunction
    extends DenseVectorFunction {
        protected final float[] queryVector;

        public FloatDenseVectorFunction(ScoreScript scoreScript, DenseVectorDocValuesField field, List<Number> queryVector, boolean normalizeQuery) {
            super(scoreScript, field);
            DenseVector.checkDimensions(field.get().getDims(), queryVector.size());
            this.queryVector = new float[queryVector.size()];
            double queryMagnitude = 0.0;
            for (int i = 0; i < queryVector.size(); ++i) {
                float value;
                this.queryVector[i] = value = queryVector.get(i).floatValue();
                queryMagnitude += (double)(value * value);
            }
            queryMagnitude = Math.sqrt(queryMagnitude);
            field.getElementType().checkVectorBounds(this.queryVector);
            if (normalizeQuery) {
                int dim = 0;
                while (dim < this.queryVector.length) {
                    int n = dim++;
                    this.queryVector[n] = this.queryVector[n] / (float)queryMagnitude;
                }
            }
        }
    }

    public static class ByteDenseVectorFunction
    extends DenseVectorFunction {
        protected final byte[] byteQueryVector;
        protected final float[] floatQueryVector;
        protected final float qvMagnitude;

        public ByteDenseVectorFunction(ScoreScript scoreScript, DenseVectorDocValuesField field, List<Number> queryVector, boolean normalizeFloatQuery, DenseVectorFieldMapper.ElementType ... allowedTypes) {
            super(scoreScript, field);
            int i;
            field.getElementType().checkDimensions(field.get().getDims(), queryVector.size());
            float[] floatValues = new float[queryVector.size()];
            double queryMagnitude = 0.0;
            for (i = 0; i < queryVector.size(); ++i) {
                float value;
                floatValues[i] = value = queryVector.get(i).floatValue();
                queryMagnitude += (double)(value * value);
            }
            queryMagnitude = Math.sqrt(queryMagnitude);
            switch (DenseVectorFieldMapper.ElementType.checkValidVector(floatValues, allowedTypes)) {
                case FLOAT: {
                    this.byteQueryVector = null;
                    this.floatQueryVector = floatValues;
                    this.qvMagnitude = -1.0f;
                    if (!normalizeFloatQuery) break;
                    i = 0;
                    while (i < this.floatQueryVector.length) {
                        int n = i++;
                        this.floatQueryVector[n] = this.floatQueryVector[n] / (float)queryMagnitude;
                    }
                    break;
                }
                case BYTE: {
                    this.floatQueryVector = null;
                    this.byteQueryVector = new byte[floatValues.length];
                    for (i = 0; i < floatValues.length; ++i) {
                        this.byteQueryVector[i] = (byte)floatValues[i];
                    }
                    this.qvMagnitude = (float)queryMagnitude;
                    break;
                }
                default: {
                    throw new AssertionError((Object)"Unexpected element type");
                }
            }
        }

        public ByteDenseVectorFunction(ScoreScript scoreScript, DenseVectorDocValuesField field, byte[] queryVector) {
            super(scoreScript, field);
            this.byteQueryVector = queryVector;
            this.floatQueryVector = null;
            double queryMagnitude = 0.0;
            for (byte value : queryVector) {
                queryMagnitude += (double)(value * value);
            }
            this.qvMagnitude = (float)Math.sqrt(queryMagnitude);
        }
    }

    public static class DenseVectorFunction {
        protected final ScoreScript scoreScript;
        protected final DenseVectorDocValuesField field;

        public DenseVectorFunction(ScoreScript scoreScript, DenseVectorDocValuesField field) {
            this.scoreScript = scoreScript;
            this.field = field;
        }

        void setNextVector() {
            try {
                this.field.setNextDocId(this.scoreScript._getDocId());
            }
            catch (IOException e) {
                throw ExceptionsHelper.convertToElastic(e);
            }
            if (this.field.isEmpty()) {
                throw new IllegalArgumentException("A document doesn't have a value for a vector field!");
            }
        }
    }
}

