/*
 * Decompiled with CFR 0.152.
 */
package org.elasticsearch.xpack.inference.queries;

import java.io.IOException;
import java.util.Collection;
import java.util.Collections;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
import java.util.stream.Collectors;
import org.apache.lucene.search.Query;
import org.elasticsearch.TransportVersion;
import org.elasticsearch.TransportVersions;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.action.ActionRequest;
import org.elasticsearch.action.ActionType;
import org.elasticsearch.action.ResolvedIndices;
import org.elasticsearch.client.internal.Client;
import org.elasticsearch.cluster.metadata.IndexMetadata;
import org.elasticsearch.cluster.metadata.InferenceFieldMetadata;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.core.Nullable;
import org.elasticsearch.features.NodeFeature;
import org.elasticsearch.index.mapper.MappedFieldType;
import org.elasticsearch.index.query.AbstractQueryBuilder;
import org.elasticsearch.index.query.MatchNoneQueryBuilder;
import org.elasticsearch.index.query.QueryBuilder;
import org.elasticsearch.index.query.QueryRewriteContext;
import org.elasticsearch.index.query.SearchExecutionContext;
import org.elasticsearch.inference.InferenceResults;
import org.elasticsearch.inference.InferenceServiceResults;
import org.elasticsearch.inference.InputType;
import org.elasticsearch.inference.TaskType;
import org.elasticsearch.xcontent.ConstructingObjectParser;
import org.elasticsearch.xcontent.ParseField;
import org.elasticsearch.xcontent.ToXContent;
import org.elasticsearch.xcontent.XContentBuilder;
import org.elasticsearch.xcontent.XContentParser;
import org.elasticsearch.xpack.core.ClientHelper;
import org.elasticsearch.xpack.core.inference.action.InferenceAction;
import org.elasticsearch.xpack.core.ml.inference.results.ErrorInferenceResults;
import org.elasticsearch.xpack.core.ml.inference.results.MlTextEmbeddingResults;
import org.elasticsearch.xpack.core.ml.inference.results.TextExpansionResults;
import org.elasticsearch.xpack.core.ml.inference.results.WarningInferenceResults;
import org.elasticsearch.xpack.inference.InferenceException;
import org.elasticsearch.xpack.inference.mapper.SemanticTextFieldMapper;
import org.elasticsearch.xpack.inference.queries.FullyQualifiedInferenceId;

public class SemanticQueryBuilder
extends AbstractQueryBuilder<SemanticQueryBuilder> {
    public static final String NAME = "semantic";
    public static final NodeFeature SEMANTIC_QUERY_MULTIPLE_INFERENCE_IDS = new NodeFeature("semantic_query.multiple_inference_ids");
    public static final NodeFeature SEMANTIC_QUERY_FILTER_FIELD_CAPS_FIX = new NodeFeature("semantic_query.filter_field_caps_fix");
    static final TransportVersion SEMANTIC_QUERY_MULTIPLE_INFERENCE_IDS_TV = TransportVersion.fromName((String)"semantic_query_multiple_inference_ids");
    static final TransportVersion INFERENCE_RESULTS_MAP_WITH_CLUSTER_ALIAS = TransportVersion.fromName((String)"inference_results_map_with_cluster_alias");
    public static final TransportVersion SEMANTIC_SEARCH_CCS_SUPPORT = TransportVersion.fromName((String)"semantic_search_ccs_support");
    private static final String PLACEHOLDER_INFERENCE_ID = "$PLACEHOLDER";
    private static final ParseField FIELD_FIELD = new ParseField("field", new String[0]);
    private static final ParseField QUERY_FIELD = new ParseField("query", new String[0]);
    private static final ParseField LENIENT_FIELD = new ParseField("lenient", new String[0]);
    private static final ConstructingObjectParser<SemanticQueryBuilder, Void> PARSER = new ConstructingObjectParser("semantic", false, args -> new SemanticQueryBuilder((String)args[0], (String)args[1], (Boolean)args[2]));
    private final String fieldName;
    private final String query;
    private final Map<FullyQualifiedInferenceId, InferenceResults> inferenceResultsMap;
    private final Boolean lenient;
    private final boolean ccsRequest;

    public SemanticQueryBuilder(String fieldName, String query) {
        this(fieldName, query, null);
    }

    public SemanticQueryBuilder(String fieldName, String query, Boolean lenient) {
        this(fieldName, query, lenient, null, false);
    }

    protected SemanticQueryBuilder(String fieldName, String query, Boolean lenient, Map<FullyQualifiedInferenceId, InferenceResults> inferenceResultsMap) {
        this(fieldName, query, lenient, inferenceResultsMap, false);
    }

    protected SemanticQueryBuilder(String fieldName, String query, Boolean lenient, Map<FullyQualifiedInferenceId, InferenceResults> inferenceResultsMap, boolean ccsRequest) {
        if (fieldName == null) {
            throw new IllegalArgumentException("[semantic] requires a " + FIELD_FIELD.getPreferredName() + " value");
        }
        if (query == null) {
            throw new IllegalArgumentException("[semantic] requires a " + QUERY_FIELD.getPreferredName() + " value");
        }
        this.fieldName = fieldName;
        this.query = query;
        this.inferenceResultsMap = inferenceResultsMap != null ? Map.copyOf(inferenceResultsMap) : null;
        this.lenient = lenient;
        this.ccsRequest = ccsRequest;
    }

    public SemanticQueryBuilder(StreamInput in) throws IOException {
        super(in);
        this.fieldName = in.readString();
        this.query = in.readString();
        if (in.getTransportVersion().supports(INFERENCE_RESULTS_MAP_WITH_CLUSTER_ALIAS)) {
            this.inferenceResultsMap = (Map)in.readOptional(i1 -> i1.readImmutableMap(FullyQualifiedInferenceId::new, i2 -> (InferenceResults)i2.readNamedWriteable(InferenceResults.class)));
        } else if (in.getTransportVersion().supports(SEMANTIC_QUERY_MULTIPLE_INFERENCE_IDS_TV)) {
            this.inferenceResultsMap = SemanticQueryBuilder.convertFromBwcInferenceResultsMap((Map)in.readOptional(i1 -> i1.readImmutableMap(i2 -> (InferenceResults)i2.readNamedWriteable(InferenceResults.class))));
        } else {
            InferenceResults inferenceResults = (InferenceResults)in.readOptionalNamedWriteable(InferenceResults.class);
            this.inferenceResultsMap = inferenceResults != null ? SemanticQueryBuilder.buildSingleResultInferenceResultsMap(inferenceResults) : null;
            in.readBoolean();
        }
        this.lenient = in.getTransportVersion().supports(TransportVersions.V_8_18_0) ? in.readOptionalBoolean() : null;
        this.ccsRequest = in.getTransportVersion().supports(SEMANTIC_SEARCH_CCS_SUPPORT) ? in.readBoolean() : false;
    }

    protected void doWriteTo(StreamOutput out) throws IOException {
        out.writeString(this.fieldName);
        out.writeString(this.query);
        if (out.getTransportVersion().supports(INFERENCE_RESULTS_MAP_WITH_CLUSTER_ALIAS)) {
            out.writeOptional((o, v) -> o.writeMap(v, StreamOutput::writeWriteable, StreamOutput::writeNamedWriteable), this.inferenceResultsMap);
        } else if (out.getTransportVersion().supports(SEMANTIC_QUERY_MULTIPLE_INFERENCE_IDS_TV)) {
            out.writeOptional((o1, v) -> o1.writeMap(v, (o2, id) -> {
                if (!id.clusterAlias().equals("")) {
                    throw new IllegalArgumentException("Cannot serialize remote cluster inference results in a mixed-version cluster");
                }
                o2.writeString(id.inferenceId());
            }, StreamOutput::writeNamedWriteable), this.inferenceResultsMap);
        } else {
            InferenceResults inferenceResults = null;
            if (this.inferenceResultsMap != null) {
                if (this.inferenceResultsMap.size() > 1) {
                    throw new IllegalArgumentException("Cannot query multiple inference IDs in a mixed-version cluster");
                }
                if (this.inferenceResultsMap.size() == 1) {
                    inferenceResults = this.inferenceResultsMap.values().iterator().next();
                }
            }
            out.writeOptionalNamedWriteable(inferenceResults);
            out.writeBoolean(inferenceResults == null);
        }
        if (out.getTransportVersion().supports(TransportVersions.V_8_18_0)) {
            out.writeOptionalBoolean(this.lenient);
        }
        if (out.getTransportVersion().supports(SEMANTIC_SEARCH_CCS_SUPPORT)) {
            out.writeBoolean(this.ccsRequest);
        } else if (this.ccsRequest) {
            throw new IllegalArgumentException("One or more nodes does not support semantic query cross-cluster search. Please update all nodes to at least Elasticsearch " + SEMANTIC_SEARCH_CCS_SUPPORT.toReleaseVersion() + ".");
        }
    }

    private SemanticQueryBuilder(SemanticQueryBuilder other, Map<FullyQualifiedInferenceId, InferenceResults> inferenceResultsMap, boolean ccsRequest) {
        this.fieldName = other.fieldName;
        this.query = other.query;
        this.boost = other.boost;
        this.queryName = other.queryName;
        this.inferenceResultsMap = inferenceResultsMap;
        this.lenient = other.lenient;
        this.ccsRequest = ccsRequest;
    }

    public String getWriteableName() {
        return NAME;
    }

    public String getFieldName() {
        return this.fieldName;
    }

    public String getQuery() {
        return this.query;
    }

    public TransportVersion getMinimalSupportedVersion() {
        return TransportVersions.V_8_15_0;
    }

    public static SemanticQueryBuilder fromXContent(XContentParser parser) throws IOException {
        return (SemanticQueryBuilder)((Object)PARSER.apply(parser, null));
    }

    static Map<FullyQualifiedInferenceId, InferenceResults> getInferenceResults(QueryRewriteContext queryRewriteContext, Set<FullyQualifiedInferenceId> fullyQualifiedInferenceIds, @Nullable Map<FullyQualifiedInferenceId, InferenceResults> inferenceResultsMap, @Nullable String query) {
        Map<FullyQualifiedInferenceId, Object> currentInferenceResultsMap;
        boolean modifiedInferenceResultsMap = false;
        Map<FullyQualifiedInferenceId, Object> map = currentInferenceResultsMap = inferenceResultsMap != null ? inferenceResultsMap : Map.of();
        if (query != null) {
            for (FullyQualifiedInferenceId fullyQualifiedInferenceId : fullyQualifiedInferenceIds) {
                if (currentInferenceResultsMap.containsKey(fullyQualifiedInferenceId)) continue;
                if (!fullyQualifiedInferenceId.clusterAlias().equals(queryRewriteContext.getLocalClusterAlias())) {
                    throw new IllegalStateException("Cannot get inference results for inference endpoint [" + String.valueOf(fullyQualifiedInferenceId) + "] on cluster [" + queryRewriteContext.getLocalClusterAlias() + "]");
                }
                if (!modifiedInferenceResultsMap) {
                    currentInferenceResultsMap = new ConcurrentHashMap<FullyQualifiedInferenceId, Object>(currentInferenceResultsMap);
                    modifiedInferenceResultsMap = true;
                }
                SemanticQueryBuilder.registerInferenceAsyncAction(queryRewriteContext, (ConcurrentHashMap)currentInferenceResultsMap, query, fullyQualifiedInferenceId.inferenceId());
            }
        }
        return currentInferenceResultsMap;
    }

    static void registerInferenceAsyncAction(QueryRewriteContext queryRewriteContext, ConcurrentHashMap<FullyQualifiedInferenceId, InferenceResults> inferenceResultsMap, String query, String inferenceId) {
        InferenceAction.Request inferenceRequest = new InferenceAction.Request(TaskType.ANY, inferenceId, null, null, null, List.of(query), Map.of(), InputType.INTERNAL_SEARCH, null, false);
        queryRewriteContext.registerAsyncAction((client, listener) -> ClientHelper.executeAsyncWithOrigin((Client)client, (String)"ml", (ActionType)InferenceAction.INSTANCE, (ActionRequest)inferenceRequest, (ActionListener)listener.delegateFailureAndWrap((l, inferenceResponse) -> {
            inferenceResultsMap.put(new FullyQualifiedInferenceId(queryRewriteContext.getLocalClusterAlias(), inferenceId), SemanticQueryBuilder.validateAndConvertInferenceResults(inferenceResponse.getResults(), inferenceId));
            l.onResponse(null);
        })));
    }

    static Map<FullyQualifiedInferenceId, InferenceResults> convertFromBwcInferenceResultsMap(Map<String, InferenceResults> inferenceResultsMap) {
        Map<FullyQualifiedInferenceId, InferenceResults> converted = null;
        if (inferenceResultsMap != null) {
            converted = Collections.unmodifiableMap(inferenceResultsMap.entrySet().stream().collect(Collectors.toMap(e -> new FullyQualifiedInferenceId("", (String)e.getKey()), Map.Entry::getValue)));
        }
        return converted;
    }

    static Map<FullyQualifiedInferenceId, InferenceResults> buildSingleResultInferenceResultsMap(InferenceResults inferenceResults) {
        return Map.of(new FullyQualifiedInferenceId("", PLACEHOLDER_INFERENCE_ID), inferenceResults);
    }

    private static InferenceResults getSingleInferenceResult(Map<FullyQualifiedInferenceId, InferenceResults> inferenceResultsMap) {
        return inferenceResultsMap.get(new FullyQualifiedInferenceId("", PLACEHOLDER_INFERENCE_ID));
    }

    protected void doXContent(XContentBuilder builder, ToXContent.Params params) throws IOException {
        builder.startObject(NAME);
        builder.field(FIELD_FIELD.getPreferredName(), this.fieldName);
        builder.field(QUERY_FIELD.getPreferredName(), this.query);
        if (this.lenient != null) {
            builder.field(LENIENT_FIELD.getPreferredName(), this.lenient);
        }
        this.boostAndQueryNameToXContent(builder);
        builder.endObject();
    }

    protected QueryBuilder doRewrite(QueryRewriteContext queryRewriteContext) {
        SearchExecutionContext searchExecutionContext = queryRewriteContext.convertToSearchExecutionContext();
        if (searchExecutionContext != null) {
            return this.doRewriteBuildSemanticQuery(searchExecutionContext);
        }
        ResolvedIndices resolvedIndices = queryRewriteContext.getResolvedIndices();
        if (resolvedIndices != null) {
            return this.doRewriteGetInferenceResults(queryRewriteContext);
        }
        return this;
    }

    private QueryBuilder doRewriteBuildSemanticQuery(SearchExecutionContext searchExecutionContext) {
        MappedFieldType fieldType = searchExecutionContext.getFieldType(this.fieldName);
        if (fieldType == null) {
            return new MatchNoneQueryBuilder();
        }
        if (fieldType instanceof SemanticTextFieldMapper.SemanticTextFieldType) {
            SemanticTextFieldMapper.SemanticTextFieldType semanticTextFieldType = (SemanticTextFieldMapper.SemanticTextFieldType)fieldType;
            if (this.inferenceResultsMap == null) {
                throw new IllegalStateException("No inference results set for [" + semanticTextFieldType.typeName() + "] field [" + this.fieldName + "]");
            }
            String inferenceId = semanticTextFieldType.getSearchInferenceId();
            InferenceResults inferenceResults = SemanticQueryBuilder.getSingleInferenceResult(this.inferenceResultsMap);
            if (inferenceResults == null) {
                inferenceResults = this.inferenceResultsMap.get(new FullyQualifiedInferenceId(searchExecutionContext.getLocalClusterAlias(), inferenceId));
            }
            if (inferenceResults == null) {
                throw new IllegalStateException("No inference results set for [" + semanticTextFieldType.typeName() + "] field [" + this.fieldName + "] with inference ID [" + inferenceId + "]");
            }
            return semanticTextFieldType.semanticQuery(inferenceResults, searchExecutionContext.requestSize(), this.boost(), this.queryName());
        }
        if (this.lenient != null && this.lenient.booleanValue()) {
            return new MatchNoneQueryBuilder();
        }
        throw new IllegalArgumentException("Field [" + this.fieldName + "] of type [" + fieldType.typeName() + "] does not support semantic queries");
    }

    private SemanticQueryBuilder doRewriteGetInferenceResults(QueryRewriteContext queryRewriteContext) {
        boolean ccsRequest;
        ResolvedIndices resolvedIndices = queryRewriteContext.getResolvedIndices();
        boolean bl = ccsRequest = !resolvedIndices.getRemoteClusterIndices().isEmpty();
        if (ccsRequest && !queryRewriteContext.isCcsMinimizeRoundTrips().booleanValue()) {
            throw new IllegalArgumentException("semantic query does not support cross-cluster search when [ccs_minimize_roundtrips] is false");
        }
        SemanticQueryBuilder rewritten = this;
        if (!queryRewriteContext.hasAsyncActions()) {
            Set<FullyQualifiedInferenceId> fullyQualifiedInferenceIds = SemanticQueryBuilder.getInferenceIdsForField(resolvedIndices.getConcreteLocalIndicesMetadata().values(), queryRewriteContext.getLocalClusterAlias(), this.fieldName);
            Map<FullyQualifiedInferenceId, InferenceResults> modifiedInferenceResultsMap = SemanticQueryBuilder.getInferenceResults(queryRewriteContext, fullyQualifiedInferenceIds, this.inferenceResultsMap, this.query);
            if (modifiedInferenceResultsMap == this.inferenceResultsMap) {
                this.inferenceResultsErrorCheck(modifiedInferenceResultsMap);
            } else {
                rewritten = new SemanticQueryBuilder(this, modifiedInferenceResultsMap, ccsRequest);
            }
        }
        return rewritten;
    }

    private static InferenceResults validateAndConvertInferenceResults(InferenceServiceResults inferenceServiceResults, String inferenceId) {
        List inferenceResultsList = inferenceServiceResults.transformToCoordinationFormat();
        if (inferenceResultsList.isEmpty()) {
            return new ErrorInferenceResults((Exception)new IllegalArgumentException("No query inference results retrieved for inference ID [" + inferenceId + "]"));
        }
        if (inferenceResultsList.size() > 1) {
            return new ErrorInferenceResults((Exception)new IllegalStateException(inferenceResultsList.size() + " query inference results retrieved for inference ID [" + inferenceId + "]"));
        }
        InferenceResults inferenceResults = (InferenceResults)inferenceResultsList.getFirst();
        if (!(inferenceResults instanceof TextExpansionResults || inferenceResults instanceof MlTextEmbeddingResults || inferenceResults instanceof ErrorInferenceResults || inferenceResults instanceof WarningInferenceResults)) {
            return new ErrorInferenceResults((Exception)new IllegalArgumentException("Expected query inference results to be of type [text_expansion_result] or [text_embedding_result], got [" + inferenceResults.getWriteableName() + "]. Has the inference endpoint [" + inferenceId + "] configuration changed?"));
        }
        return inferenceResults;
    }

    private void inferenceResultsErrorCheck(Map<FullyQualifiedInferenceId, InferenceResults> inferenceResultsMap) {
        for (Map.Entry<FullyQualifiedInferenceId, InferenceResults> entry : inferenceResultsMap.entrySet()) {
            String inferenceId = entry.getKey().inferenceId();
            InferenceResults inferenceResults = entry.getValue();
            if (inferenceResults instanceof ErrorInferenceResults) {
                ErrorInferenceResults errorInferenceResults = (ErrorInferenceResults)inferenceResults;
                throw new InferenceException("Field [" + this.fieldName + "] with inference ID [" + inferenceId + "] query inference error", errorInferenceResults.getException(), new Object[0]);
            }
            if (!(inferenceResults instanceof WarningInferenceResults)) continue;
            WarningInferenceResults warningInferenceResults = (WarningInferenceResults)inferenceResults;
            throw new IllegalStateException("Field [" + this.fieldName + "] with inference ID [" + inferenceId + "] query inference warning: " + warningInferenceResults.getWarning());
        }
    }

    protected Query doToQuery(SearchExecutionContext context) throws IOException {
        throw new IllegalStateException("semantic should have been rewritten to another query type");
    }

    private static Set<FullyQualifiedInferenceId> getInferenceIdsForField(Collection<IndexMetadata> indexMetadataCollection, String clusterAlias, String fieldName) {
        HashSet<FullyQualifiedInferenceId> fullyQualifiedInferenceIds = new HashSet<FullyQualifiedInferenceId>();
        for (IndexMetadata indexMetadata : indexMetadataCollection) {
            InferenceFieldMetadata inferenceFieldMetadata = (InferenceFieldMetadata)indexMetadata.getInferenceFields().get(fieldName);
            String indexInferenceId = inferenceFieldMetadata != null ? inferenceFieldMetadata.getSearchInferenceId() : null;
            if (indexInferenceId == null) continue;
            fullyQualifiedInferenceIds.add(new FullyQualifiedInferenceId(clusterAlias, indexInferenceId));
        }
        return fullyQualifiedInferenceIds;
    }

    protected boolean doEquals(SemanticQueryBuilder other) {
        return Objects.equals(this.fieldName, other.fieldName) && Objects.equals(this.query, other.query) && Objects.equals(this.inferenceResultsMap, other.inferenceResultsMap) && Objects.equals(this.ccsRequest, other.ccsRequest);
    }

    protected int doHashCode() {
        return Objects.hash(this.fieldName, this.query, this.inferenceResultsMap, this.ccsRequest);
    }

    static {
        PARSER.declareString(ConstructingObjectParser.constructorArg(), FIELD_FIELD);
        PARSER.declareString(ConstructingObjectParser.constructorArg(), QUERY_FIELD);
        PARSER.declareBoolean(ConstructingObjectParser.optionalConstructorArg(), LENIENT_FIELD);
        SemanticQueryBuilder.declareStandardFields(PARSER);
    }
}

