/*
 * Decompiled with CFR 0.152.
 */
package org.elasticsearch.xpack.inference.highlight;

import java.io.IOException;
import java.io.UncheckedIOException;
import java.util.ArrayList;
import java.util.Comparator;
import java.util.HashMap;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.function.Function;
import org.apache.lucene.index.IndexReader;
import org.apache.lucene.index.LeafReader;
import org.apache.lucene.index.Term;
import org.apache.lucene.index.Terms;
import org.apache.lucene.search.BooleanClause;
import org.apache.lucene.search.BooleanQuery;
import org.apache.lucene.search.IndexSearcher;
import org.apache.lucene.search.KnnByteVectorQuery;
import org.apache.lucene.search.KnnFloatVectorQuery;
import org.apache.lucene.search.MatchAllDocsQuery;
import org.apache.lucene.search.Query;
import org.apache.lucene.search.QueryVisitor;
import org.apache.lucene.search.ScoreMode;
import org.apache.lucene.search.Scorer;
import org.apache.lucene.search.Weight;
import org.apache.lucene.util.BitSet;
import org.elasticsearch.common.text.Text;
import org.elasticsearch.common.xcontent.support.XContentMapValues;
import org.elasticsearch.index.mapper.MappedFieldType;
import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper;
import org.elasticsearch.index.mapper.vectors.SparseVectorFieldMapper;
import org.elasticsearch.index.query.SearchExecutionContext;
import org.elasticsearch.inference.TaskType;
import org.elasticsearch.search.fetch.FetchSubPhase;
import org.elasticsearch.search.fetch.subphase.highlight.DefaultHighlighter;
import org.elasticsearch.search.fetch.subphase.highlight.FieldHighlightContext;
import org.elasticsearch.search.fetch.subphase.highlight.HighlightField;
import org.elasticsearch.search.fetch.subphase.highlight.HighlightUtils;
import org.elasticsearch.search.fetch.subphase.highlight.Highlighter;
import org.elasticsearch.search.vectors.VectorData;
import org.elasticsearch.xpack.core.ml.search.SparseVectorQueryWrapper;
import org.elasticsearch.xpack.inference.mapper.OffsetSourceField;
import org.elasticsearch.xpack.inference.mapper.OffsetSourceFieldMapper;
import org.elasticsearch.xpack.inference.mapper.SemanticTextFieldMapper;

public class SemanticTextHighlighter
implements Highlighter {
    public static final String NAME = "semantic";

    public boolean canHighlight(MappedFieldType fieldType) {
        return fieldType instanceof SemanticTextFieldMapper.SemanticTextFieldType;
    }

    public HighlightField highlight(FieldHighlightContext fieldContext) throws IOException {
        Function<OffsetAndScore, String> offsetToContent;
        if (!this.canHighlight(fieldContext.fieldType)) {
            return null;
        }
        SemanticTextFieldMapper.SemanticTextFieldType fieldType = (SemanticTextFieldMapper.SemanticTextFieldType)fieldContext.fieldType;
        if (fieldType.getEmbeddingsField() == null) {
            return null;
        }
        List<Query> queries = switch (fieldType.getModelSettings().taskType()) {
            case TaskType.SPARSE_EMBEDDING -> this.extractSparseVectorQueries((SparseVectorFieldMapper.SparseVectorFieldType)fieldType.getEmbeddingsField().fieldType(), fieldContext.query);
            case TaskType.TEXT_EMBEDDING -> this.extractDenseVectorQueries((DenseVectorFieldMapper.DenseVectorFieldType)fieldType.getEmbeddingsField().fieldType(), fieldContext.query);
            default -> throw new IllegalStateException("Wrong task type for a semantic text field, got [" + fieldType.getModelSettings().taskType().name() + "]");
        };
        if (queries.isEmpty()) {
            return null;
        }
        int numberOfFragments = fieldContext.field.fieldOptions().numberOfFragments() <= 0 ? 1 : fieldContext.field.fieldOptions().numberOfFragments();
        List<OffsetAndScore> chunks = this.extractOffsetAndScores(fieldContext.context.getSearchExecutionContext(), fieldContext.hitContext.reader(), fieldType, fieldContext.hitContext.docId(), queries);
        if (chunks.size() == 0) {
            return null;
        }
        chunks.sort(Comparator.comparingDouble(OffsetAndScore::score).reversed());
        int size = Math.min(chunks.size(), numberOfFragments);
        if (!fieldContext.field.fieldOptions().scoreOrdered().booleanValue()) {
            chunks = chunks.subList(0, size);
            chunks.sort(Comparator.comparingInt(c -> c.index));
        }
        Text[] snippets = new Text[size];
        if (fieldType.useLegacyFormat()) {
            List nestedSources = XContentMapValues.extractNestedSources((String)fieldType.getChunksField().fullPath(), (Map)fieldContext.hitContext.source().source());
            offsetToContent = entry -> this.getContentFromLegacyNestedSources(fieldType.name(), (OffsetAndScore)entry, nestedSources);
        } else {
            HashMap fieldToContent = new HashMap();
            offsetToContent = entry -> {
                String content = fieldToContent.computeIfAbsent(entry.offset().field(), key -> {
                    try {
                        return this.extractFieldContent(fieldContext.context.getSearchExecutionContext(), fieldContext.hitContext, entry.offset.field());
                    }
                    catch (IOException e) {
                        throw new UncheckedIOException("Error extracting field content from field " + entry.offset.field(), e);
                    }
                });
                return content.substring(entry.offset().start(), entry.offset().end());
            };
        }
        for (int i = 0; i < size; ++i) {
            OffsetAndScore chunk = chunks.get(i);
            String content = offsetToContent.apply(chunk);
            if (content == null) {
                throw new IllegalStateException(String.format(Locale.ROOT, "Invalid content detected for field [%s]: missing text for the chunk at offset [%d].", fieldType.name(), chunk.offset));
            }
            snippets[i] = new Text(content);
        }
        return new HighlightField(fieldContext.fieldName, snippets);
    }

    private String extractFieldContent(SearchExecutionContext searchContext, FetchSubPhase.HitContext hitContext, String sourceField) throws IOException {
        MappedFieldType sourceFieldType = searchContext.getMappingLookup().getFieldType(sourceField);
        if (sourceFieldType == null) {
            return null;
        }
        List<Object> values = HighlightUtils.loadFieldValues((MappedFieldType)sourceFieldType, (SearchExecutionContext)searchContext, (FetchSubPhase.HitContext)hitContext).stream().map(s -> DefaultHighlighter.convertFieldValue((MappedFieldType)sourceFieldType, (Object)s)).toList();
        if (values.size() == 0) {
            return null;
        }
        return DefaultHighlighter.mergeFieldValues(values, (char)'\u0000');
    }

    private String getContentFromLegacyNestedSources(String fieldName, OffsetAndScore cand, List<Map<?, ?>> nestedSources) {
        if (nestedSources.size() <= cand.index) {
            throw new IllegalStateException(String.format(Locale.ROOT, "Invalid content detected for field [%s]: the chunks size is [%d], but a reference to offset [%d] was found in the result.", fieldName, nestedSources.size(), cand.index));
        }
        return (String)nestedSources.get(cand.index).get("text");
    }

    private List<OffsetAndScore> extractOffsetAndScores(SearchExecutionContext context, LeafReader reader, SemanticTextFieldMapper.SemanticTextFieldType fieldType, int docId, List<Query> leafQueries) throws IOException {
        BitSet bitSet = context.bitsetFilter(fieldType.getChunksField().parentTypeFilter()).getBitSet(reader.getContext());
        int previousParent = docId > 0 ? bitSet.prevSetBit(docId - 1) : -1;
        BooleanQuery.Builder bq = new BooleanQuery.Builder().add(fieldType.getChunksField().nestedTypeFilter(), BooleanClause.Occur.FILTER);
        leafQueries.stream().forEach(q -> bq.add(q, BooleanClause.Occur.SHOULD));
        Weight weight = new IndexSearcher((IndexReader)reader).createWeight((Query)bq.build(), ScoreMode.COMPLETE, 1.0f);
        Scorer scorer = weight.scorer(reader.getContext());
        if (previousParent != -1 ? scorer.iterator().advance(previousParent) == Integer.MAX_VALUE : scorer.iterator().nextDoc() == Integer.MAX_VALUE) {
            return List.of();
        }
        OffsetSourceField.OffsetSourceLoader offsetReader = null;
        if (!fieldType.useLegacyFormat()) {
            Terms terms = reader.terms(fieldType.getOffsetsField().fullPath());
            if (terms == null) {
                return List.of();
            }
            offsetReader = OffsetSourceField.loader(terms);
        }
        ArrayList<OffsetAndScore> results = new ArrayList<OffsetAndScore>();
        int index = 0;
        while (scorer.docID() < docId) {
            if (offsetReader != null) {
                OffsetSourceFieldMapper.OffsetSource offset = offsetReader.advanceTo(scorer.docID());
                if (offset == null) {
                    throw new IllegalStateException("Cannot highlight field [" + fieldType.name() + "], missing offsets for doc [" + docId + "]");
                }
                results.add(new OffsetAndScore(index++, offset, scorer.score()));
            } else {
                results.add(new OffsetAndScore(index++, null, scorer.score()));
            }
            if (scorer.iterator().nextDoc() != Integer.MAX_VALUE) continue;
            break;
        }
        return results;
    }

    private List<Query> extractDenseVectorQueries(final DenseVectorFieldMapper.DenseVectorFieldType fieldType, Query querySection) {
        final ArrayList<Query> queries = new ArrayList<Query>();
        querySection.visit(new QueryVisitor(){

            public boolean acceptField(String field) {
                return fieldType.name().equals(field);
            }

            public void consumeTerms(Query query, Term ... terms) {
                super.consumeTerms(query, terms);
            }

            public void visitLeaf(Query query) {
                if (query instanceof KnnFloatVectorQuery) {
                    KnnFloatVectorQuery knnQuery = (KnnFloatVectorQuery)query;
                    queries.add(fieldType.createExactKnnQuery(VectorData.fromFloats((float[])knnQuery.getTargetCopy()), null));
                } else if (query instanceof KnnByteVectorQuery) {
                    KnnByteVectorQuery knnQuery = (KnnByteVectorQuery)query;
                    queries.add(fieldType.createExactKnnQuery(VectorData.fromBytes((byte[])knnQuery.getTargetCopy()), null));
                } else if (query instanceof MatchAllDocsQuery) {
                    queries.add(new MatchAllDocsQuery());
                }
            }
        });
        return queries;
    }

    private List<Query> extractSparseVectorQueries(final SparseVectorFieldMapper.SparseVectorFieldType fieldType, Query querySection) {
        final ArrayList<Query> queries = new ArrayList<Query>();
        querySection.visit(new QueryVisitor(){

            public boolean acceptField(String field) {
                return fieldType.name().equals(field);
            }

            public void consumeTerms(Query query, Term ... terms) {
                super.consumeTerms(query, terms);
            }

            public QueryVisitor getSubVisitor(BooleanClause.Occur occur, Query parent) {
                if (parent instanceof SparseVectorQueryWrapper) {
                    SparseVectorQueryWrapper sparseVectorQuery = (SparseVectorQueryWrapper)parent;
                    queries.add(sparseVectorQuery.getTermsQuery());
                }
                return this;
            }

            public void visitLeaf(Query query) {
                if (query instanceof MatchAllDocsQuery) {
                    queries.add(new MatchAllDocsQuery());
                }
            }
        });
        return queries;
    }

    private record OffsetAndScore(int index, OffsetSourceFieldMapper.OffsetSource offset, float score) {
    }
}

