/*
 * Decompiled with CFR 0.152.
 */
package org.elasticsearch.search.vectors;

import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Objects;
import org.apache.lucene.search.BooleanClause;
import org.apache.lucene.search.BooleanQuery;
import org.apache.lucene.search.Query;
import org.apache.lucene.search.join.BitSetProducer;
import org.apache.lucene.search.join.ToChildBlockJoinQuery;
import org.elasticsearch.TransportVersion;
import org.elasticsearch.TransportVersions;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.lucene.search.Queries;
import org.elasticsearch.core.Nullable;
import org.elasticsearch.index.mapper.MappedFieldType;
import org.elasticsearch.index.mapper.NestedObjectMapper;
import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper;
import org.elasticsearch.index.query.AbstractQueryBuilder;
import org.elasticsearch.index.query.MatchNoneQueryBuilder;
import org.elasticsearch.index.query.QueryBuilder;
import org.elasticsearch.index.query.QueryRewriteContext;
import org.elasticsearch.index.query.SearchExecutionContext;
import org.elasticsearch.xcontent.ConstructingObjectParser;
import org.elasticsearch.xcontent.ObjectParser;
import org.elasticsearch.xcontent.ParseField;
import org.elasticsearch.xcontent.ToXContent;
import org.elasticsearch.xcontent.XContentBuilder;
import org.elasticsearch.xcontent.XContentParser;

public class KnnVectorQueryBuilder
extends AbstractQueryBuilder<KnnVectorQueryBuilder> {
    public static final String NAME = "knn";
    private static final int NUM_CANDS_LIMIT = 10000;
    public static final ParseField FIELD_FIELD = new ParseField("field", new String[0]);
    public static final ParseField NUM_CANDS_FIELD = new ParseField("num_candidates", new String[0]);
    public static final ParseField QUERY_VECTOR_FIELD = new ParseField("query_vector", new String[0]);
    public static final ParseField VECTOR_SIMILARITY_FIELD = new ParseField("similarity", new String[0]);
    public static final ParseField FILTER_FIELD = new ParseField("filter", new String[0]);
    public static final ConstructingObjectParser<KnnVectorQueryBuilder, Void> PARSER = new ConstructingObjectParser("knn", args -> {
        float[] vectorArray;
        List vector = (List)args[1];
        if (vector != null) {
            vectorArray = new float[vector.size()];
            for (int i = 0; i < vector.size(); ++i) {
                vectorArray[i] = ((Float)vector.get(i)).floatValue();
            }
        } else {
            vectorArray = null;
        }
        return new KnnVectorQueryBuilder((String)args[0], vectorArray, (Integer)args[2], (Float)args[3]);
    });
    private final String fieldName;
    private final float[] queryVector;
    private final int numCands;
    private final List<QueryBuilder> filterQueries = new ArrayList<QueryBuilder>();
    private final Float vectorSimilarity;

    public static KnnVectorQueryBuilder fromXContent(XContentParser parser) {
        return (KnnVectorQueryBuilder)PARSER.apply(parser, null);
    }

    public KnnVectorQueryBuilder(String fieldName, float[] queryVector, int numCands, Float vectorSimilarity) {
        if (numCands > 10000) {
            throw new IllegalArgumentException("[" + NUM_CANDS_FIELD.getPreferredName() + "] cannot exceed [10000]");
        }
        if (queryVector == null) {
            throw new IllegalArgumentException("[" + QUERY_VECTOR_FIELD.getPreferredName() + "] must be provided");
        }
        this.fieldName = fieldName;
        this.queryVector = queryVector;
        this.numCands = numCands;
        this.vectorSimilarity = vectorSimilarity;
    }

    public KnnVectorQueryBuilder(StreamInput in) throws IOException {
        super(in);
        this.fieldName = in.readString();
        this.numCands = in.readVInt();
        if (in.getTransportVersion().before(TransportVersions.V_8_7_0) || in.getTransportVersion().onOrAfter(TransportVersions.KNN_AS_QUERY_ADDED)) {
            this.queryVector = in.readFloatArray();
        } else {
            in.readBoolean();
            this.queryVector = in.readFloatArray();
            in.readBoolean();
        }
        if (in.getTransportVersion().onOrAfter(TransportVersions.V_8_2_0)) {
            this.filterQueries.addAll(KnnVectorQueryBuilder.readQueries(in));
        }
        this.vectorSimilarity = in.getTransportVersion().onOrAfter(TransportVersions.V_8_8_0) ? in.readOptionalFloat() : null;
    }

    public String getFieldName() {
        return this.fieldName;
    }

    @Nullable
    public float[] queryVector() {
        return this.queryVector;
    }

    @Nullable
    public Float getVectorSimilarity() {
        return this.vectorSimilarity;
    }

    public int numCands() {
        return this.numCands;
    }

    public List<QueryBuilder> filterQueries() {
        return this.filterQueries;
    }

    public KnnVectorQueryBuilder addFilterQuery(QueryBuilder filterQuery) {
        Objects.requireNonNull(filterQuery);
        this.filterQueries.add(filterQuery);
        return this;
    }

    public KnnVectorQueryBuilder addFilterQueries(List<QueryBuilder> filterQueries) {
        Objects.requireNonNull(filterQueries);
        this.filterQueries.addAll(filterQueries);
        return this;
    }

    @Override
    protected void doWriteTo(StreamOutput out) throws IOException {
        out.writeString(this.fieldName);
        out.writeVInt(this.numCands);
        if (out.getTransportVersion().before(TransportVersions.V_8_7_0) || out.getTransportVersion().onOrAfter(TransportVersions.KNN_AS_QUERY_ADDED)) {
            out.writeFloatArray(this.queryVector);
        } else {
            out.writeBoolean(true);
            out.writeFloatArray(this.queryVector);
            out.writeBoolean(false);
        }
        if (out.getTransportVersion().onOrAfter(TransportVersions.V_8_2_0)) {
            KnnVectorQueryBuilder.writeQueries(out, this.filterQueries);
        }
        if (out.getTransportVersion().onOrAfter(TransportVersions.V_8_8_0)) {
            out.writeOptionalFloat(this.vectorSimilarity);
        }
    }

    @Override
    protected void doXContent(XContentBuilder builder, ToXContent.Params params) throws IOException {
        builder.startObject(NAME);
        builder.field(FIELD_FIELD.getPreferredName(), this.fieldName);
        builder.field(QUERY_VECTOR_FIELD.getPreferredName(), (Object)this.queryVector);
        builder.field(NUM_CANDS_FIELD.getPreferredName(), this.numCands);
        if (this.vectorSimilarity != null) {
            builder.field(VECTOR_SIMILARITY_FIELD.getPreferredName(), this.vectorSimilarity);
        }
        if (!this.filterQueries.isEmpty()) {
            builder.startArray(FILTER_FIELD.getPreferredName());
            for (QueryBuilder filterQuery : this.filterQueries) {
                filterQuery.toXContent(builder, params);
            }
            builder.endArray();
        }
        this.boostAndQueryNameToXContent(builder);
        builder.endObject();
    }

    @Override
    public String getWriteableName() {
        return NAME;
    }

    @Override
    protected QueryBuilder doRewrite(QueryRewriteContext ctx) throws IOException {
        boolean changed = false;
        ArrayList<QueryBuilder> rewrittenQueries = new ArrayList<QueryBuilder>(this.filterQueries.size());
        for (QueryBuilder query : this.filterQueries) {
            QueryBuilder rewrittenQuery = query.rewrite(ctx);
            if (rewrittenQuery instanceof MatchNoneQueryBuilder) {
                return rewrittenQuery;
            }
            if (rewrittenQuery != query) {
                changed = true;
            }
            rewrittenQueries.add(rewrittenQuery);
        }
        if (changed) {
            return ((KnnVectorQueryBuilder)((KnnVectorQueryBuilder)new KnnVectorQueryBuilder(this.fieldName, this.queryVector, this.numCands, this.vectorSimilarity).boost(this.boost)).queryName(this.queryName)).addFilterQueries(rewrittenQueries);
        }
        return this;
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    @Override
    protected Query doToQuery(SearchExecutionContext context) throws IOException {
        BooleanQuery booleanQuery;
        MappedFieldType fieldType = context.getFieldType(this.fieldName);
        if (fieldType == null) {
            throw new IllegalArgumentException("field [" + this.fieldName + "] does not exist in the mapping");
        }
        if (!(fieldType instanceof DenseVectorFieldMapper.DenseVectorFieldType)) {
            throw new IllegalArgumentException("[knn] queries are only supported on [dense_vector] fields");
        }
        BooleanQuery.Builder builder = new BooleanQuery.Builder();
        for (QueryBuilder query : this.filterQueries) {
            builder.add(query.toQuery(context), BooleanClause.Occur.FILTER);
        }
        if (context.getAliasFilter() != null) {
            builder.add(context.getAliasFilter().toQuery(context), BooleanClause.Occur.FILTER);
        }
        BooleanQuery filterQuery = (booleanQuery = builder.build()).clauses().isEmpty() ? null : booleanQuery;
        DenseVectorFieldMapper.DenseVectorFieldType vectorFieldType = (DenseVectorFieldMapper.DenseVectorFieldType)fieldType;
        String parentPath = context.nestedLookup().getNestedParent(this.fieldName);
        if (parentPath != null) {
            BitSetProducer parentFilter;
            NestedObjectMapper originalObjectMapper = context.nestedScope().getObjectMapper();
            if (originalObjectMapper != null) {
                try {
                    context.nestedScope().previousLevel();
                    NestedObjectMapper objectMapper = context.nestedScope().getObjectMapper();
                    parentFilter = objectMapper == null ? context.bitsetFilter(Queries.newNonNestedFilter(context.indexVersionCreated())) : context.bitsetFilter(objectMapper.nestedTypeFilter());
                }
                finally {
                    context.nestedScope().nextLevel(originalObjectMapper);
                }
            } else {
                parentFilter = context.bitsetFilter(Queries.newNonNestedFilter(context.indexVersionCreated()));
            }
            if (filterQuery != null) {
                filterQuery = new ToChildBlockJoinQuery((Query)filterQuery, parentFilter);
            }
            return vectorFieldType.createKnnQuery(this.queryVector, this.numCands, (Query)filterQuery, this.vectorSimilarity, parentFilter);
        }
        return vectorFieldType.createKnnQuery(this.queryVector, this.numCands, (Query)filterQuery, this.vectorSimilarity, null);
    }

    @Override
    protected int doHashCode() {
        return Objects.hash(this.fieldName, Arrays.hashCode(this.queryVector), this.numCands, this.filterQueries, this.vectorSimilarity);
    }

    @Override
    protected boolean doEquals(KnnVectorQueryBuilder other) {
        return Objects.equals(this.fieldName, other.fieldName) && Arrays.equals(this.queryVector, other.queryVector) && this.numCands == other.numCands && Objects.equals(this.filterQueries, other.filterQueries) && Objects.equals(this.vectorSimilarity, other.vectorSimilarity);
    }

    @Override
    public TransportVersion getMinimalSupportedVersion() {
        return TransportVersions.V_8_0_0;
    }

    static {
        PARSER.declareString(ConstructingObjectParser.constructorArg(), FIELD_FIELD);
        PARSER.declareFloatArray(ConstructingObjectParser.constructorArg(), QUERY_VECTOR_FIELD);
        PARSER.declareInt(ConstructingObjectParser.constructorArg(), NUM_CANDS_FIELD);
        PARSER.declareFloat(ConstructingObjectParser.optionalConstructorArg(), VECTOR_SIMILARITY_FIELD);
        PARSER.declareFieldArray(KnnVectorQueryBuilder::addFilterQueries, (p, c) -> AbstractQueryBuilder.parseTopLevelQuery(p), FILTER_FIELD, ObjectParser.ValueType.OBJECT_ARRAY);
        KnnVectorQueryBuilder.declareStandardFields(PARSER);
    }
}

