/*
 * Decompiled with CFR 0.152.
 */
package org.elasticsearch.xpack.inference.queries;

import java.io.IOException;
import java.util.Collection;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import org.apache.lucene.search.Query;
import org.apache.lucene.util.SetOnce;
import org.elasticsearch.TransportVersion;
import org.elasticsearch.TransportVersions;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.action.ActionRequest;
import org.elasticsearch.action.ActionType;
import org.elasticsearch.action.ResolvedIndices;
import org.elasticsearch.client.internal.Client;
import org.elasticsearch.cluster.metadata.IndexMetadata;
import org.elasticsearch.cluster.metadata.InferenceFieldMetadata;
import org.elasticsearch.common.io.stream.NamedWriteable;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.features.NodeFeature;
import org.elasticsearch.index.mapper.MappedFieldType;
import org.elasticsearch.index.query.AbstractQueryBuilder;
import org.elasticsearch.index.query.MatchNoneQueryBuilder;
import org.elasticsearch.index.query.QueryBuilder;
import org.elasticsearch.index.query.QueryRewriteContext;
import org.elasticsearch.index.query.SearchExecutionContext;
import org.elasticsearch.inference.InferenceResults;
import org.elasticsearch.inference.InferenceServiceResults;
import org.elasticsearch.inference.InputType;
import org.elasticsearch.inference.TaskType;
import org.elasticsearch.xcontent.ConstructingObjectParser;
import org.elasticsearch.xcontent.ParseField;
import org.elasticsearch.xcontent.ToXContent;
import org.elasticsearch.xcontent.XContentBuilder;
import org.elasticsearch.xcontent.XContentParser;
import org.elasticsearch.xpack.core.ClientHelper;
import org.elasticsearch.xpack.core.inference.action.InferenceAction;
import org.elasticsearch.xpack.core.ml.action.InferModelAction;
import org.elasticsearch.xpack.core.ml.inference.results.ErrorInferenceResults;
import org.elasticsearch.xpack.core.ml.inference.results.MlTextEmbeddingResults;
import org.elasticsearch.xpack.core.ml.inference.results.TextExpansionResults;
import org.elasticsearch.xpack.core.ml.inference.results.WarningInferenceResults;
import org.elasticsearch.xpack.inference.mapper.SemanticTextFieldMapper;

public class SemanticQueryBuilder
extends AbstractQueryBuilder<SemanticQueryBuilder> {
    public static final String NAME = "semantic";
    public static final NodeFeature SEMANTIC_QUERY_FILTER_FIELD_CAPS_FIX = new NodeFeature("semantic_query.filter_field_caps_fix");
    private static final ParseField FIELD_FIELD = new ParseField("field", new String[0]);
    private static final ParseField QUERY_FIELD = new ParseField("query", new String[0]);
    private static final ParseField LENIENT_FIELD = new ParseField("lenient", new String[0]);
    private static final ConstructingObjectParser<SemanticQueryBuilder, Void> PARSER = new ConstructingObjectParser("semantic", false, args -> new SemanticQueryBuilder((String)args[0], (String)args[1], (Boolean)args[2]));
    private final String fieldName;
    private final String query;
    private final SetOnce<InferenceServiceResults> inferenceResultsSupplier;
    private final InferenceResults inferenceResults;
    private final boolean noInferenceResults;
    private final Boolean lenient;

    public SemanticQueryBuilder(String fieldName, String query) {
        this(fieldName, query, null);
    }

    public SemanticQueryBuilder(String fieldName, String query, Boolean lenient) {
        if (fieldName == null) {
            throw new IllegalArgumentException("[semantic] requires a " + FIELD_FIELD.getPreferredName() + " value");
        }
        if (query == null) {
            throw new IllegalArgumentException("[semantic] requires a " + QUERY_FIELD.getPreferredName() + " value");
        }
        this.fieldName = fieldName;
        this.query = query;
        this.inferenceResults = null;
        this.inferenceResultsSupplier = null;
        this.noInferenceResults = false;
        this.lenient = lenient;
    }

    public SemanticQueryBuilder(StreamInput in) throws IOException {
        super(in);
        this.fieldName = in.readString();
        this.query = in.readString();
        this.inferenceResults = (InferenceResults)in.readOptionalNamedWriteable(InferenceResults.class);
        this.noInferenceResults = in.readBoolean();
        this.inferenceResultsSupplier = null;
        this.lenient = in.getTransportVersion().supports(TransportVersions.V_8_18_0) ? in.readOptionalBoolean() : null;
    }

    protected void doWriteTo(StreamOutput out) throws IOException {
        if (this.inferenceResultsSupplier != null) {
            throw new IllegalStateException("Inference results supplier is set. Missing a rewriteAndFetch?");
        }
        out.writeString(this.fieldName);
        out.writeString(this.query);
        out.writeOptionalNamedWriteable((NamedWriteable)this.inferenceResults);
        out.writeBoolean(this.noInferenceResults);
        if (out.getTransportVersion().supports(TransportVersions.V_8_18_0)) {
            out.writeOptionalBoolean(this.lenient);
        }
    }

    private SemanticQueryBuilder(SemanticQueryBuilder other, SetOnce<InferenceServiceResults> inferenceResultsSupplier, InferenceResults inferenceResults, boolean noInferenceResults) {
        this.fieldName = other.fieldName;
        this.query = other.query;
        this.boost = other.boost;
        this.queryName = other.queryName;
        this.inferenceResultsSupplier = inferenceResultsSupplier;
        this.inferenceResults = inferenceResults;
        this.noInferenceResults = noInferenceResults;
        this.lenient = other.lenient;
    }

    public String getWriteableName() {
        return NAME;
    }

    public String getFieldName() {
        return this.fieldName;
    }

    public String getQuery() {
        return this.query;
    }

    public TransportVersion getMinimalSupportedVersion() {
        return TransportVersions.V_8_15_0;
    }

    public static SemanticQueryBuilder fromXContent(XContentParser parser) throws IOException {
        return (SemanticQueryBuilder)((Object)PARSER.apply(parser, null));
    }

    protected void doXContent(XContentBuilder builder, ToXContent.Params params) throws IOException {
        builder.startObject(NAME);
        builder.field(FIELD_FIELD.getPreferredName(), this.fieldName);
        builder.field(QUERY_FIELD.getPreferredName(), this.query);
        if (this.lenient != null) {
            builder.field(LENIENT_FIELD.getPreferredName(), this.lenient);
        }
        this.boostAndQueryNameToXContent(builder);
        builder.endObject();
    }

    protected QueryBuilder doRewrite(QueryRewriteContext queryRewriteContext) {
        SearchExecutionContext searchExecutionContext = queryRewriteContext.convertToSearchExecutionContext();
        if (searchExecutionContext != null) {
            return this.doRewriteBuildSemanticQuery(searchExecutionContext);
        }
        return this.doRewriteGetInferenceResults(queryRewriteContext);
    }

    private QueryBuilder doRewriteBuildSemanticQuery(SearchExecutionContext searchExecutionContext) {
        MappedFieldType fieldType = searchExecutionContext.getFieldType(this.fieldName);
        if (fieldType == null) {
            return new MatchNoneQueryBuilder();
        }
        if (fieldType instanceof SemanticTextFieldMapper.SemanticTextFieldType) {
            SemanticTextFieldMapper.SemanticTextFieldType semanticTextFieldType = (SemanticTextFieldMapper.SemanticTextFieldType)fieldType;
            if (this.inferenceResults == null) {
                throw new IllegalStateException("No inference results set for [" + semanticTextFieldType.typeName() + "] field [" + this.fieldName + "]");
            }
            return semanticTextFieldType.semanticQuery(this.inferenceResults, searchExecutionContext.requestSize(), this.boost(), this.queryName());
        }
        if (this.lenient != null && this.lenient.booleanValue()) {
            return new MatchNoneQueryBuilder();
        }
        throw new IllegalArgumentException("Field [" + this.fieldName + "] of type [" + fieldType.typeName() + "] does not support semantic queries");
    }

    private SemanticQueryBuilder doRewriteGetInferenceResults(QueryRewriteContext queryRewriteContext) {
        if (this.inferenceResults != null || this.noInferenceResults) {
            return this;
        }
        if (this.inferenceResultsSupplier != null) {
            InferenceResults inferenceResults = SemanticQueryBuilder.validateAndConvertInferenceResults(this.inferenceResultsSupplier, this.fieldName);
            return inferenceResults != null ? new SemanticQueryBuilder(this, null, inferenceResults, this.noInferenceResults) : this;
        }
        ResolvedIndices resolvedIndices = queryRewriteContext.getResolvedIndices();
        if (resolvedIndices == null) {
            throw new IllegalStateException("Rewriting on the coordinator node requires a query rewrite context with non-null resolved indices");
        }
        if (!resolvedIndices.getRemoteClusterIndices().isEmpty()) {
            throw new IllegalArgumentException("semantic query does not support cross-cluster search");
        }
        String inferenceId = SemanticQueryBuilder.getInferenceIdForForField(resolvedIndices.getConcreteLocalIndicesMetadata().values(), this.fieldName);
        SetOnce inferenceResultsSupplier = new SetOnce();
        boolean noInferenceResults = false;
        if (inferenceId != null) {
            InferenceAction.Request inferenceRequest = new InferenceAction.Request(TaskType.ANY, inferenceId, null, null, null, List.of(this.query), Map.of(), InputType.INTERNAL_SEARCH, InferModelAction.Request.DEFAULT_TIMEOUT_FOR_API, false);
            queryRewriteContext.registerAsyncAction((client, listener) -> ClientHelper.executeAsyncWithOrigin((Client)client, (String)"ml", (ActionType)InferenceAction.INSTANCE, (ActionRequest)inferenceRequest, (ActionListener)listener.delegateFailureAndWrap((l, inferenceResponse) -> {
                inferenceResultsSupplier.set((Object)inferenceResponse.getResults());
                l.onResponse(null);
            })));
        } else {
            noInferenceResults = true;
        }
        return new SemanticQueryBuilder(this, (SetOnce<InferenceServiceResults>)(noInferenceResults ? null : inferenceResultsSupplier), null, noInferenceResults);
    }

    private static InferenceResults validateAndConvertInferenceResults(SetOnce<InferenceServiceResults> inferenceResultsSupplier, String fieldName) {
        InferenceServiceResults inferenceServiceResults = (InferenceServiceResults)inferenceResultsSupplier.get();
        if (inferenceServiceResults == null) {
            return null;
        }
        List inferenceResultsList = inferenceServiceResults.transformToCoordinationFormat();
        if (inferenceResultsList.isEmpty()) {
            throw new IllegalArgumentException("No inference results retrieved for field [" + fieldName + "]");
        }
        if (inferenceResultsList.size() > 1) {
            throw new IllegalStateException(inferenceResultsList.size() + " inference results retrieved for field [" + fieldName + "]");
        }
        InferenceResults inferenceResults = (InferenceResults)inferenceResultsList.get(0);
        if (inferenceResults instanceof ErrorInferenceResults) {
            ErrorInferenceResults errorInferenceResults = (ErrorInferenceResults)inferenceResults;
            throw new IllegalStateException("Field [" + fieldName + "] query inference error: " + errorInferenceResults.getException().getMessage(), errorInferenceResults.getException());
        }
        if (inferenceResults instanceof WarningInferenceResults) {
            WarningInferenceResults warningInferenceResults = (WarningInferenceResults)inferenceResults;
            throw new IllegalStateException("Field [" + fieldName + "] query inference warning: " + warningInferenceResults.getWarning());
        }
        if (!(inferenceResults instanceof TextExpansionResults) && !(inferenceResults instanceof MlTextEmbeddingResults)) {
            throw new IllegalArgumentException("Field [" + fieldName + "] expected query inference results to be of type [text_expansion_result] or [text_embedding_result], got [" + inferenceResults.getWriteableName() + "]. Has the inference endpoint configuration changed?");
        }
        return inferenceResults;
    }

    protected Query doToQuery(SearchExecutionContext context) throws IOException {
        throw new IllegalStateException("semantic should have been rewritten to another query type");
    }

    private static String getInferenceIdForForField(Collection<IndexMetadata> indexMetadataCollection, String fieldName) {
        String inferenceId = null;
        for (IndexMetadata indexMetadata : indexMetadataCollection) {
            InferenceFieldMetadata inferenceFieldMetadata = (InferenceFieldMetadata)indexMetadata.getInferenceFields().get(fieldName);
            String indexInferenceId = inferenceFieldMetadata != null ? inferenceFieldMetadata.getSearchInferenceId() : null;
            if (indexInferenceId == null) continue;
            if (inferenceId != null && !inferenceId.equals(indexInferenceId)) {
                throw new IllegalArgumentException("Field [" + fieldName + "] has multiple inference IDs associated with it");
            }
            inferenceId = indexInferenceId;
        }
        return inferenceId;
    }

    protected boolean doEquals(SemanticQueryBuilder other) {
        return Objects.equals(this.fieldName, other.fieldName) && Objects.equals(this.query, other.query) && Objects.equals(this.inferenceResults, other.inferenceResults) && Objects.equals(this.inferenceResultsSupplier, other.inferenceResultsSupplier);
    }

    protected int doHashCode() {
        return Objects.hash(this.fieldName, this.query, this.inferenceResults, this.inferenceResultsSupplier);
    }

    static {
        PARSER.declareString(ConstructingObjectParser.constructorArg(), FIELD_FIELD);
        PARSER.declareString(ConstructingObjectParser.constructorArg(), QUERY_FIELD);
        PARSER.declareBoolean(ConstructingObjectParser.optionalConstructorArg(), LENIENT_FIELD);
        SemanticQueryBuilder.declareStandardFields(PARSER);
    }
}

