/*
 * Decompiled with CFR 0.152.
 */
package org.elasticsearch.xpack.core.ml.search;

import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import org.apache.lucene.search.MatchNoDocsQuery;
import org.apache.lucene.search.Query;
import org.apache.lucene.util.SetOnce;
import org.elasticsearch.TransportVersion;
import org.elasticsearch.TransportVersions;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.common.ParsingException;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.core.Nullable;
import org.elasticsearch.index.mapper.MappedFieldType;
import org.elasticsearch.index.query.AbstractQueryBuilder;
import org.elasticsearch.index.query.QueryBuilder;
import org.elasticsearch.index.query.QueryRewriteContext;
import org.elasticsearch.index.query.SearchExecutionContext;
import org.elasticsearch.inference.InferenceResults;
import org.elasticsearch.xcontent.ConstructingObjectParser;
import org.elasticsearch.xcontent.ParseField;
import org.elasticsearch.xcontent.ToXContent;
import org.elasticsearch.xcontent.XContentBuilder;
import org.elasticsearch.xcontent.XContentParser;
import org.elasticsearch.xpack.core.ClientHelper;
import org.elasticsearch.xpack.core.ml.action.CoordinatedInferenceAction;
import org.elasticsearch.xpack.core.ml.action.InferModelAction;
import org.elasticsearch.xpack.core.ml.inference.TrainedModelPrefixStrings;
import org.elasticsearch.xpack.core.ml.inference.results.TextExpansionResults;
import org.elasticsearch.xpack.core.ml.inference.results.WarningInferenceResults;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TextExpansionConfigUpdate;
import org.elasticsearch.xpack.core.ml.search.TokenPruningConfig;
import org.elasticsearch.xpack.core.ml.search.WeightedToken;
import org.elasticsearch.xpack.core.ml.search.WeightedTokensUtils;

public class SparseVectorQueryBuilder
extends AbstractQueryBuilder<SparseVectorQueryBuilder> {
    public static final String NAME = "sparse_vector";
    public static final String ALLOWED_FIELD_TYPE = "sparse_vector";
    public static final ParseField FIELD_FIELD = new ParseField("field", new String[0]);
    public static final ParseField QUERY_VECTOR_FIELD = new ParseField("query_vector", new String[0]);
    public static final ParseField INFERENCE_ID_FIELD = new ParseField("inference_id", new String[0]);
    public static final ParseField QUERY_FIELD = new ParseField("query", new String[0]);
    public static final ParseField PRUNE_FIELD = new ParseField("prune", new String[0]);
    public static final ParseField PRUNING_CONFIG_FIELD = new ParseField("pruning_config", new String[0]);
    private static final boolean DEFAULT_PRUNE = false;
    private final String fieldName;
    private final List<WeightedToken> queryVectors;
    private final String inferenceId;
    private final String query;
    private final boolean shouldPruneTokens;
    private final SetOnce<TextExpansionResults> weightedTokensSupplier;
    @Nullable
    private final TokenPruningConfig tokenPruningConfig;
    private static final ConstructingObjectParser<SparseVectorQueryBuilder, Void> PARSER = new ConstructingObjectParser("sparse_vector", a -> {
        String fieldName = (String)a[0];
        List<WeightedToken> weightedTokens = SparseVectorQueryBuilder.parseWeightedTokens((Map)a[1]);
        String inferenceId = (String)a[2];
        String text = (String)a[3];
        Boolean shouldPruneTokens = (Boolean)a[4];
        TokenPruningConfig tokenPruningConfig = (TokenPruningConfig)a[5];
        return new SparseVectorQueryBuilder(fieldName, weightedTokens, inferenceId, text, shouldPruneTokens, tokenPruningConfig);
    });

    public SparseVectorQueryBuilder(String fieldName, String inferenceId, String query) {
        this(fieldName, null, inferenceId, query, false, null);
    }

    public SparseVectorQueryBuilder(String fieldName, @Nullable List<WeightedToken> queryVectors, @Nullable String inferenceId, @Nullable String query, @Nullable Boolean shouldPruneTokens, @Nullable TokenPruningConfig tokenPruningConfig) {
        this.fieldName = Objects.requireNonNull(fieldName, "[sparse_vector] requires a [" + FIELD_FIELD.getPreferredName() + "]");
        this.shouldPruneTokens = shouldPruneTokens != null ? shouldPruneTokens : false;
        this.queryVectors = queryVectors;
        this.inferenceId = inferenceId;
        this.query = query;
        this.tokenPruningConfig = tokenPruningConfig != null ? tokenPruningConfig : (this.shouldPruneTokens ? new TokenPruningConfig() : null);
        this.weightedTokensSupplier = null;
        if (queryVectors != null && inferenceId != null) {
            throw new IllegalArgumentException("[sparse_vector] requires one of [" + QUERY_VECTOR_FIELD.getPreferredName() + "] or [" + INFERENCE_ID_FIELD.getPreferredName() + "] for sparse_vector fields");
        }
        if (queryVectors == null == (query == null)) {
            throw new IllegalArgumentException("[sparse_vector] requires one of [" + QUERY_VECTOR_FIELD.getPreferredName() + "] or [" + INFERENCE_ID_FIELD.getPreferredName() + "] for sparse_vector fields");
        }
    }

    public SparseVectorQueryBuilder(StreamInput in) throws IOException {
        super(in);
        this.fieldName = in.readString();
        this.shouldPruneTokens = in.readBoolean();
        this.queryVectors = in.readOptionalCollectionAsList(WeightedToken::new);
        this.inferenceId = in.readOptionalString();
        this.query = in.readOptionalString();
        this.tokenPruningConfig = in.readOptionalWriteable(TokenPruningConfig::new);
        this.weightedTokensSupplier = null;
    }

    private SparseVectorQueryBuilder(SparseVectorQueryBuilder other, SetOnce<TextExpansionResults> weightedTokensSupplier) {
        this.fieldName = other.fieldName;
        this.shouldPruneTokens = other.shouldPruneTokens;
        this.queryVectors = other.queryVectors;
        this.inferenceId = other.inferenceId;
        this.query = other.query;
        this.tokenPruningConfig = other.tokenPruningConfig;
        this.weightedTokensSupplier = weightedTokensSupplier;
    }

    public String getFieldName() {
        return this.fieldName;
    }

    public List<WeightedToken> getQueryVectors() {
        return this.queryVectors;
    }

    public String getInferenceId() {
        return this.inferenceId;
    }

    public String getQuery() {
        return this.query;
    }

    public boolean shouldPruneTokens() {
        return this.shouldPruneTokens;
    }

    public TokenPruningConfig getTokenPruningConfig() {
        return this.tokenPruningConfig;
    }

    @Override
    protected void doWriteTo(StreamOutput out) throws IOException {
        if (this.weightedTokensSupplier != null) {
            throw new IllegalStateException("weighted tokens supplier must be null, can't serialize suppliers, missing a rewriteAndFetch?");
        }
        out.writeString(this.fieldName);
        out.writeBoolean(this.shouldPruneTokens);
        out.writeOptionalCollection(this.queryVectors);
        out.writeOptionalString(this.inferenceId);
        out.writeOptionalString(this.query);
        out.writeOptionalWriteable(this.tokenPruningConfig);
    }

    @Override
    protected void doXContent(XContentBuilder builder, ToXContent.Params params) throws IOException {
        builder.startObject("sparse_vector");
        builder.field(FIELD_FIELD.getPreferredName(), this.fieldName);
        if (this.queryVectors != null) {
            builder.startObject(QUERY_VECTOR_FIELD.getPreferredName());
            for (WeightedToken token : this.queryVectors) {
                token.toXContent(builder, params);
            }
            builder.endObject();
        } else {
            if (this.inferenceId != null) {
                builder.field(INFERENCE_ID_FIELD.getPreferredName(), this.inferenceId);
            }
            builder.field(QUERY_FIELD.getPreferredName(), this.query);
        }
        builder.field(PRUNE_FIELD.getPreferredName(), this.shouldPruneTokens);
        if (this.tokenPruningConfig != null) {
            builder.field(PRUNING_CONFIG_FIELD.getPreferredName(), this.tokenPruningConfig);
        }
        this.boostAndQueryNameToXContent(builder);
        builder.endObject();
    }

    @Override
    protected Query doToQuery(SearchExecutionContext context) throws IOException {
        if (this.queryVectors == null) {
            return new MatchNoDocsQuery("Empty query vectors");
        }
        MappedFieldType ft = context.getFieldType(this.fieldName);
        if (ft == null) {
            return new MatchNoDocsQuery("The \"" + this.getName() + "\" query is against a field that does not exist");
        }
        String fieldTypeName = ft.typeName();
        if (!fieldTypeName.equals("sparse_vector")) {
            throw new IllegalArgumentException("field [" + this.fieldName + "] must be type [sparse_vector] but is type [" + fieldTypeName + "]");
        }
        return this.shouldPruneTokens ? WeightedTokensUtils.queryBuilderWithPrunedTokens(this.fieldName, this.tokenPruningConfig, this.queryVectors, ft, context) : WeightedTokensUtils.queryBuilderWithAllTokens(this.fieldName, this.queryVectors, ft, context);
    }

    @Override
    protected QueryBuilder doRewrite(QueryRewriteContext queryRewriteContext) {
        if (this.queryVectors != null) {
            return this;
        }
        if (this.weightedTokensSupplier != null) {
            TextExpansionResults textExpansionResults = this.weightedTokensSupplier.get();
            if (textExpansionResults == null) {
                return this;
            }
            return new SparseVectorQueryBuilder(this.fieldName, textExpansionResults.getWeightedTokens(), null, null, this.shouldPruneTokens, this.tokenPruningConfig);
        }
        if (this.inferenceId == null) {
            throw new IllegalArgumentException("inference_id required to perform vector search on query string");
        }
        CoordinatedInferenceAction.Request inferRequest = CoordinatedInferenceAction.Request.forTextInput(this.inferenceId, List.of(this.query), TextExpansionConfigUpdate.EMPTY_UPDATE, false, InferModelAction.Request.DEFAULT_TIMEOUT_FOR_API);
        inferRequest.setHighPriority(true);
        inferRequest.setPrefixType(TrainedModelPrefixStrings.PrefixType.SEARCH);
        SetOnce<TextExpansionResults> textExpansionResultsSupplier = new SetOnce<TextExpansionResults>();
        queryRewriteContext.registerAsyncAction((client, listener) -> ClientHelper.executeAsyncWithOrigin(client, "ml", CoordinatedInferenceAction.INSTANCE, inferRequest, ActionListener.wrap(inferenceResponse -> {
            List<InferenceResults> inferenceResults = inferenceResponse.getInferenceResults();
            if (inferenceResults.isEmpty()) {
                listener.onFailure(new IllegalStateException("inference response contain no results"));
                return;
            }
            if (inferenceResults.size() > 1) {
                listener.onFailure(new IllegalStateException("inference response should contain only one result"));
                return;
            }
            InferenceResults patt0$temp = inferenceResults.get(0);
            if (patt0$temp instanceof TextExpansionResults) {
                TextExpansionResults textExpansionResults = (TextExpansionResults)patt0$temp;
                textExpansionResultsSupplier.set(textExpansionResults);
                listener.onResponse(null);
            } else {
                InferenceResults patt1$temp = inferenceResults.get(0);
                if (patt1$temp instanceof WarningInferenceResults) {
                    WarningInferenceResults warning = (WarningInferenceResults)patt1$temp;
                    listener.onFailure(new IllegalStateException(warning.getWarning()));
                } else {
                    listener.onFailure(new IllegalArgumentException("expected a result of type [text_expansion_result] received [" + inferenceResults.get(0).getWriteableName() + "]. Is [" + this.inferenceId + "] a compatible model?"));
                }
            }
        }, listener::onFailure)));
        return new SparseVectorQueryBuilder(this, textExpansionResultsSupplier);
    }

    @Override
    protected boolean doEquals(SparseVectorQueryBuilder other) {
        return Objects.equals(this.fieldName, other.fieldName) && Objects.equals(this.tokenPruningConfig, other.tokenPruningConfig) && Objects.equals(this.queryVectors, other.queryVectors) && Objects.equals(this.shouldPruneTokens, other.shouldPruneTokens) && Objects.equals(this.inferenceId, other.inferenceId) && Objects.equals(this.query, other.query);
    }

    @Override
    protected int doHashCode() {
        return Objects.hash(this.fieldName, this.queryVectors, this.tokenPruningConfig, this.shouldPruneTokens, this.inferenceId, this.query);
    }

    @Override
    public String getWriteableName() {
        return "sparse_vector";
    }

    @Override
    public TransportVersion getMinimalSupportedVersion() {
        return TransportVersions.V_8_15_0;
    }

    private static List<WeightedToken> parseWeightedTokens(Map<String, Object> weightedTokenMap) {
        ArrayList<WeightedToken> weightedTokens = null;
        if (weightedTokenMap != null) {
            weightedTokens = new ArrayList<WeightedToken>();
            for (Map.Entry<String, Object> entry : weightedTokenMap.entrySet()) {
                String token = entry.getKey();
                Object weight = entry.getValue();
                if (weight instanceof Number) {
                    Number number = (Number)weight;
                    WeightedToken weightedToken = new WeightedToken(token, number.floatValue());
                    weightedTokens.add(weightedToken);
                    continue;
                }
                throw new IllegalArgumentException("weight must be a number, was [" + String.valueOf(weight) + "]");
            }
        }
        return weightedTokens;
    }

    public static SparseVectorQueryBuilder fromXContent(XContentParser parser) {
        try {
            return PARSER.apply(parser, null);
        }
        catch (IllegalArgumentException e) {
            throw new ParsingException(parser.getTokenLocation(), e.getMessage(), e, new Object[0]);
        }
    }

    static {
        PARSER.declareString(ConstructingObjectParser.constructorArg(), FIELD_FIELD);
        PARSER.declareObject(ConstructingObjectParser.optionalConstructorArg(), (p, c) -> p.map(), QUERY_VECTOR_FIELD);
        PARSER.declareString(ConstructingObjectParser.optionalConstructorArg(), INFERENCE_ID_FIELD);
        PARSER.declareString(ConstructingObjectParser.optionalConstructorArg(), QUERY_FIELD);
        PARSER.declareBoolean(ConstructingObjectParser.optionalConstructorArg(), PRUNE_FIELD);
        PARSER.declareObject(ConstructingObjectParser.optionalConstructorArg(), (p, c) -> TokenPruningConfig.fromXContent(p), PRUNING_CONFIG_FIELD);
        SparseVectorQueryBuilder.declareStandardFields(PARSER);
    }
}

