/*
 * Decompiled with CFR 0.152.
 */
package org.elasticsearch.xpack.ml.inference.rescorer;

import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Comparator;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.stream.Collectors;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.search.Explanation;
import org.apache.lucene.search.IndexSearcher;
import org.apache.lucene.search.ScoreDoc;
import org.apache.lucene.search.TopDocs;
import org.elasticsearch.common.util.Maps;
import org.elasticsearch.core.Strings;
import org.elasticsearch.inference.InferenceResults;
import org.elasticsearch.search.rescore.RescoreContext;
import org.elasticsearch.search.rescore.Rescorer;
import org.elasticsearch.xpack.core.ml.inference.results.WarningInferenceResults;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig;
import org.elasticsearch.xpack.ml.inference.loadingservice.LocalModel;
import org.elasticsearch.xpack.ml.inference.rescorer.FeatureExtractor;
import org.elasticsearch.xpack.ml.inference.rescorer.InferenceRescorerContext;

public class InferenceRescorer
implements Rescorer {
    public static final InferenceRescorer INSTANCE = new InferenceRescorer();
    private static final Logger logger = LogManager.getLogger(InferenceRescorer.class);
    private static final Comparator<ScoreDoc> SCORE_DOC_COMPARATOR = (o1, o2) -> {
        int cmp = Float.compare(o2.score, o1.score);
        return cmp == 0 ? Integer.compare(o1.doc, o2.doc) : cmp;
    };

    private InferenceRescorer() {
    }

    public TopDocs rescore(TopDocs topDocs, IndexSearcher searcher, RescoreContext rescoreContext) throws IOException {
        if (topDocs.scoreDocs.length == 0) {
            return topDocs;
        }
        InferenceRescorerContext ltrRescoreContext = (InferenceRescorerContext)rescoreContext;
        if (ltrRescoreContext.inferenceDefinition == null) {
            throw new IllegalStateException("local model reference is null, missing rewriteAndFetch before rescore phase?");
        }
        LocalModel definition = ltrRescoreContext.inferenceDefinition;
        TopDocs topNFirstPass = InferenceRescorer.topN(topDocs, rescoreContext.getWindowSize());
        Set topNDocIDs = Arrays.stream(topNFirstPass.scoreDocs).map(scoreDoc -> scoreDoc.doc).collect(Collectors.toUnmodifiableSet());
        rescoreContext.setRescoredDocs(topNDocIDs);
        ScoreDoc[] hitsToRescore = topNFirstPass.scoreDocs;
        Arrays.sort(hitsToRescore, Comparator.comparingInt(a -> a.doc));
        int readerUpto = -1;
        int endDoc = 0;
        int docBase = 0;
        List leaves = ltrRescoreContext.executionContext.searcher().getIndexReader().leaves();
        LeafReaderContext currentSegment = null;
        boolean changedSegment = true;
        List<FeatureExtractor> featureExtractors = ltrRescoreContext.buildFeatureExtractors(searcher);
        ArrayList<Map> docFeatures = new ArrayList<Map>(topNDocIDs.size());
        int featureSize = featureExtractors.stream().mapToInt(fe -> fe.featureNames().size()).sum();
        for (int hitUpto = 0; hitUpto < hitsToRescore.length; ++hitUpto) {
            ScoreDoc hit = hitsToRescore[hitUpto];
            int docID = hit.doc;
            while (docID >= endDoc) {
                currentSegment = (LeafReaderContext)leaves.get(++readerUpto);
                endDoc = currentSegment.docBase + currentSegment.reader().maxDoc();
                changedSegment = true;
            }
            assert (currentSegment != null) : "Unexpected null segment";
            if (changedSegment) {
                docBase = currentSegment.docBase;
                for (FeatureExtractor featureExtractor : featureExtractors) {
                    featureExtractor.setNextReader(currentSegment);
                }
                changedSegment = false;
            }
            int targetDoc = docID - docBase;
            Map features = Maps.newMapWithExpectedSize((int)featureSize);
            for (FeatureExtractor featureExtractor : featureExtractors) {
                featureExtractor.addFeatures(features, targetDoc);
            }
            logger.debug(() -> Strings.format((String)"doc [%d] has features [%s]", (Object[])new Object[]{targetDoc, features}));
            docFeatures.add(features);
        }
        for (int i = 0; i < hitsToRescore.length; ++i) {
            Map features = (Map)docFeatures.get(i);
            try {
                InferenceResults results = definition.inferLtr(features, (InferenceConfig)ltrRescoreContext.inferenceConfig);
                if (results instanceof WarningInferenceResults) {
                    WarningInferenceResults warningInferenceResults = (WarningInferenceResults)results;
                    logger.warn("Failure rescoring doc, warning returned [" + warningInferenceResults.getWarning() + "]");
                    continue;
                }
                Object object = results.predictedValue();
                if (object instanceof Number) {
                    Number prediction = (Number)object;
                    hitsToRescore[i].score = prediction.floatValue();
                    continue;
                }
                logger.warn("Failure rescoring doc, unexpected inference result of kind [" + results.getWriteableName() + "]");
                continue;
            }
            catch (Exception ex) {
                logger.warn("Failure rescoring doc...", (Throwable)ex);
            }
        }
        assert (rescoreContext.getWindowSize() >= hitsToRescore.length) : "unexpected, windows size [" + rescoreContext.getWindowSize() + "] should be gte [" + hitsToRescore.length + "]";
        Arrays.sort(topDocs.scoreDocs, SCORE_DOC_COMPARATOR);
        return topDocs;
    }

    public Explanation explain(int topLevelDocId, IndexSearcher searcher, RescoreContext rescoreContext, Explanation sourceExplanation) throws IOException {
        return null;
    }

    private static TopDocs topN(TopDocs in, int topN) {
        if (in.scoreDocs.length < topN) {
            return in;
        }
        ScoreDoc[] subset = new ScoreDoc[topN];
        System.arraycopy(in.scoreDocs, 0, subset, 0, topN);
        return new TopDocs(in.totalHits, subset);
    }
}

