/*
 * Decompiled with CFR 0.152.
 */
package org.elasticsearch.xpack.inference.queries;

import java.io.IOException;
import java.util.Collection;
import java.util.Map;
import org.apache.lucene.search.join.ScoreMode;
import org.apache.lucene.util.SetOnce;
import org.elasticsearch.TransportVersion;
import org.elasticsearch.action.ResolvedIndices;
import org.elasticsearch.cluster.metadata.IndexMetadata;
import org.elasticsearch.cluster.metadata.InferenceFieldMetadata;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.index.mapper.MappedFieldType;
import org.elasticsearch.index.query.MatchNoneQueryBuilder;
import org.elasticsearch.index.query.NestedQueryBuilder;
import org.elasticsearch.index.query.QueryBuilder;
import org.elasticsearch.index.query.QueryBuilders;
import org.elasticsearch.index.query.QueryRewriteContext;
import org.elasticsearch.inference.InferenceResults;
import org.elasticsearch.inference.MinimalServiceSettings;
import org.elasticsearch.inference.TaskType;
import org.elasticsearch.plugins.internal.rewriter.QueryRewriteInterceptor;
import org.elasticsearch.search.vectors.KnnVectorQueryBuilder;
import org.elasticsearch.search.vectors.QueryVectorBuilder;
import org.elasticsearch.search.vectors.VectorData;
import org.elasticsearch.xpack.core.ml.inference.results.MlTextEmbeddingResults;
import org.elasticsearch.xpack.core.ml.vectors.TextEmbeddingQueryVectorBuilder;
import org.elasticsearch.xpack.inference.mapper.SemanticTextField;
import org.elasticsearch.xpack.inference.mapper.SemanticTextFieldMapper;
import org.elasticsearch.xpack.inference.queries.FullyQualifiedInferenceId;
import org.elasticsearch.xpack.inference.queries.InterceptedInferenceQueryBuilder;
import org.elasticsearch.xpack.inference.queries.LegacySemanticKnnVectorQueryRewriteInterceptor;

public class InterceptedInferenceKnnVectorQueryBuilder
extends InterceptedInferenceQueryBuilder<KnnVectorQueryBuilder> {
    public static final String NAME = "intercepted_inference_knn";
    private static final QueryRewriteInterceptor BWC_INTERCEPTOR = new LegacySemanticKnnVectorQueryRewriteInterceptor();
    private static final TransportVersion NEW_SEMANTIC_QUERY_INTERCEPTORS = TransportVersion.fromName((String)"new_semantic_query_interceptors");

    public InterceptedInferenceKnnVectorQueryBuilder(KnnVectorQueryBuilder originalQuery) {
        super(originalQuery);
    }

    public InterceptedInferenceKnnVectorQueryBuilder(KnnVectorQueryBuilder originalQuery, Map<FullyQualifiedInferenceId, InferenceResults> inferenceResultsMap) {
        super(originalQuery, inferenceResultsMap);
    }

    public InterceptedInferenceKnnVectorQueryBuilder(StreamInput in) throws IOException {
        super(in);
    }

    private InterceptedInferenceKnnVectorQueryBuilder(InterceptedInferenceQueryBuilder<KnnVectorQueryBuilder> other, Map<FullyQualifiedInferenceId, InferenceResults> inferenceResultsMap, SetOnce<Map<FullyQualifiedInferenceId, InferenceResults>> inferenceResultsMapSupplier, boolean ccsRequest) {
        super(other, inferenceResultsMap, inferenceResultsMapSupplier, ccsRequest);
    }

    @Override
    protected Map<String, Float> getFields() {
        return Map.of(this.getField(), Float.valueOf(1.0f));
    }

    @Override
    protected String getQuery() {
        String query = null;
        QueryVectorBuilder queryVectorBuilder = ((KnnVectorQueryBuilder)this.originalQuery).queryVectorBuilder();
        if (queryVectorBuilder instanceof TextEmbeddingQueryVectorBuilder) {
            TextEmbeddingQueryVectorBuilder textEmbeddingQueryVectorBuilder = (TextEmbeddingQueryVectorBuilder)queryVectorBuilder;
            query = textEmbeddingQueryVectorBuilder.getModelText();
        }
        return query;
    }

    @Override
    protected FullyQualifiedInferenceId getInferenceIdOverride() {
        String modelId = this.getQueryVectorBuilderModelId();
        return modelId != null ? new FullyQualifiedInferenceId("", modelId) : null;
    }

    @Override
    protected void coordinatorNodeValidate(ResolvedIndices resolvedIndices) {
        if (((KnnVectorQueryBuilder)this.originalQuery).queryVector() == null && !(((KnnVectorQueryBuilder)this.originalQuery).queryVectorBuilder() instanceof TextEmbeddingQueryVectorBuilder)) {
            throw new IllegalStateException("No [text_embedding] query vector builder or query vector specified");
        }
        Collection indexMetadataCollection = resolvedIndices.getConcreteLocalIndicesMetadata().values();
        for (IndexMetadata indexMetadata : indexMetadataCollection) {
            TextEmbeddingQueryVectorBuilder textEmbeddingQueryVectorBuilder;
            QueryVectorBuilder queryVectorBuilder;
            InferenceFieldMetadata inferenceFieldMetadata = (InferenceFieldMetadata)indexMetadata.getInferenceFields().get(this.getField());
            if (inferenceFieldMetadata != null || !((queryVectorBuilder = ((KnnVectorQueryBuilder)this.originalQuery).queryVectorBuilder()) instanceof TextEmbeddingQueryVectorBuilder) || (textEmbeddingQueryVectorBuilder = (TextEmbeddingQueryVectorBuilder)queryVectorBuilder).getModelId() != null) continue;
            throw new IllegalArgumentException("[model_id] must not be null.");
        }
    }

    @Override
    protected QueryBuilder doRewriteBwC(QueryRewriteContext queryRewriteContext) {
        InterceptedInferenceKnnVectorQueryBuilder rewritten = this;
        if (!queryRewriteContext.getMinTransportVersion().supports(NEW_SEMANTIC_QUERY_INTERCEPTORS)) {
            rewritten = BWC_INTERCEPTOR.interceptAndRewrite(queryRewriteContext, (QueryBuilder)this.originalQuery);
        }
        return rewritten;
    }

    @Override
    protected QueryBuilder copy(Map<FullyQualifiedInferenceId, InferenceResults> inferenceResultsMap, SetOnce<Map<FullyQualifiedInferenceId, InferenceResults>> inferenceResultsMapSupplier, boolean ccsRequest) {
        return new InterceptedInferenceKnnVectorQueryBuilder(this, inferenceResultsMap, inferenceResultsMapSupplier, ccsRequest);
    }

    @Override
    protected QueryBuilder queryFields(Map<String, Float> inferenceFields, Map<String, Float> nonInferenceFields, QueryRewriteContext indexMetadataContext) {
        MatchNoneQueryBuilder rewritten;
        MappedFieldType fieldType = indexMetadataContext.getFieldType(this.getField());
        if (fieldType == null) {
            rewritten = new MatchNoneQueryBuilder();
        } else if (fieldType instanceof SemanticTextFieldMapper.SemanticTextFieldType) {
            SemanticTextFieldMapper.SemanticTextFieldType semanticTextFieldType = (SemanticTextFieldMapper.SemanticTextFieldType)fieldType;
            rewritten = this.querySemanticTextField(indexMetadataContext.getLocalClusterAlias(), semanticTextFieldType);
        } else {
            rewritten = this.queryNonSemanticTextField();
        }
        return rewritten;
    }

    @Override
    protected boolean resolveWildcards() {
        return false;
    }

    @Override
    protected boolean useDefaultFields() {
        return false;
    }

    public String getWriteableName() {
        return NAME;
    }

    private String getField() {
        return ((KnnVectorQueryBuilder)this.originalQuery).getFieldName();
    }

    private String getQueryVectorBuilderModelId() {
        String modelId = null;
        QueryVectorBuilder queryVectorBuilder = ((KnnVectorQueryBuilder)this.originalQuery).queryVectorBuilder();
        if (queryVectorBuilder instanceof TextEmbeddingQueryVectorBuilder) {
            TextEmbeddingQueryVectorBuilder textEmbeddingQueryVectorBuilder = (TextEmbeddingQueryVectorBuilder)queryVectorBuilder;
            modelId = textEmbeddingQueryVectorBuilder.getModelId();
        }
        return modelId;
    }

    private QueryBuilder querySemanticTextField(String clusterAlias, SemanticTextFieldMapper.SemanticTextFieldType semanticTextFieldType) {
        MinimalServiceSettings modelSettings = semanticTextFieldType.getModelSettings();
        if (modelSettings == null) {
            return new MatchNoneQueryBuilder();
        }
        if (modelSettings.taskType() != TaskType.TEXT_EMBEDDING) {
            throw new IllegalArgumentException("Field [" + this.getField() + "] does not use a [" + String.valueOf(TaskType.TEXT_EMBEDDING) + "] model");
        }
        VectorData queryVector = ((KnnVectorQueryBuilder)this.originalQuery).queryVector();
        if (queryVector == null) {
            FullyQualifiedInferenceId fullyQualifiedInferenceId = this.getInferenceIdOverride();
            if (fullyQualifiedInferenceId == null) {
                fullyQualifiedInferenceId = new FullyQualifiedInferenceId(clusterAlias, semanticTextFieldType.getSearchInferenceId());
            }
            MlTextEmbeddingResults textEmbeddingResults = this.getTextEmbeddingResults(fullyQualifiedInferenceId);
            queryVector = new VectorData(textEmbeddingResults.getInferenceAsFloat());
        }
        KnnVectorQueryBuilder innerKnnQuery = new KnnVectorQueryBuilder(SemanticTextField.getEmbeddingsFieldName(this.getField()), queryVector, ((KnnVectorQueryBuilder)this.originalQuery).k(), ((KnnVectorQueryBuilder)this.originalQuery).numCands(), ((KnnVectorQueryBuilder)this.originalQuery).visitPercentage(), ((KnnVectorQueryBuilder)this.originalQuery).rescoreVectorBuilder(), ((KnnVectorQueryBuilder)this.originalQuery).getVectorSimilarity());
        innerKnnQuery.addFilterQueries(((KnnVectorQueryBuilder)this.originalQuery).filterQueries());
        return ((NestedQueryBuilder)QueryBuilders.nestedQuery((String)SemanticTextField.getChunksFieldName(this.getField()), (QueryBuilder)innerKnnQuery, (ScoreMode)ScoreMode.Max).boost(((KnnVectorQueryBuilder)this.originalQuery).boost())).queryName(((KnnVectorQueryBuilder)this.originalQuery).queryName());
    }

    private QueryBuilder queryNonSemanticTextField() {
        VectorData queryVector = ((KnnVectorQueryBuilder)this.originalQuery).queryVector();
        if (queryVector == null) {
            FullyQualifiedInferenceId fullyQualifiedInferenceId = this.getInferenceIdOverride();
            if (fullyQualifiedInferenceId == null) {
                throw new IllegalStateException("No query vector or query vector builder model ID specified");
            }
            MlTextEmbeddingResults textEmbeddingResults = this.getTextEmbeddingResults(fullyQualifiedInferenceId);
            queryVector = new VectorData(textEmbeddingResults.getInferenceAsFloat());
        }
        KnnVectorQueryBuilder knnQuery = (KnnVectorQueryBuilder)((KnnVectorQueryBuilder)new KnnVectorQueryBuilder(this.getField(), queryVector, ((KnnVectorQueryBuilder)this.originalQuery).k(), ((KnnVectorQueryBuilder)this.originalQuery).numCands(), ((KnnVectorQueryBuilder)this.originalQuery).visitPercentage(), ((KnnVectorQueryBuilder)this.originalQuery).rescoreVectorBuilder(), ((KnnVectorQueryBuilder)this.originalQuery).getVectorSimilarity()).boost(((KnnVectorQueryBuilder)this.originalQuery).boost())).queryName(((KnnVectorQueryBuilder)this.originalQuery).queryName());
        knnQuery.addFilterQueries(((KnnVectorQueryBuilder)this.originalQuery).filterQueries());
        return knnQuery;
    }

    private MlTextEmbeddingResults getTextEmbeddingResults(FullyQualifiedInferenceId fullyQualifiedInferenceId) {
        InferenceResults inferenceResults = (InferenceResults)this.inferenceResultsMap.get(fullyQualifiedInferenceId);
        if (inferenceResults == null) {
            throw new IllegalStateException("Could not find inference results from inference endpoint [" + String.valueOf(fullyQualifiedInferenceId) + "]");
        }
        if (!(inferenceResults instanceof MlTextEmbeddingResults)) {
            throw new IllegalArgumentException("Expected query inference results to be of type [text_embedding_result], got [" + inferenceResults.getWriteableName() + "]. Are you specifying a compatible inference endpoint? Has the inference endpoint configuration changed?");
        }
        return (MlTextEmbeddingResults)inferenceResults;
    }
}

