/*
 * Decompiled with CFR 0.152.
 */
package org.elasticsearch.search.rescore;

import java.io.IOException;
import java.util.Arrays;
import java.util.List;
import java.util.Map;
import org.apache.lucene.search.FieldDoc;
import org.apache.lucene.search.ScoreDoc;
import org.apache.lucene.search.SortField;
import org.apache.lucene.search.TopDocs;
import org.apache.lucene.search.TopFieldDocs;
import org.elasticsearch.ElasticsearchException;
import org.elasticsearch.common.lucene.search.TopDocsAndMaxScore;
import org.elasticsearch.common.util.Maps;
import org.elasticsearch.lucene.grouping.TopFieldGroups;
import org.elasticsearch.search.internal.ContextIndexSearcher;
import org.elasticsearch.search.internal.SearchContext;
import org.elasticsearch.search.query.SearchTimeoutException;
import org.elasticsearch.search.rescore.RescoreContext;
import org.elasticsearch.search.sort.ShardDocSortField;
import org.elasticsearch.search.sort.SortAndFormats;

public class RescorePhase {
    private RescorePhase() {
    }

    public static void execute(SearchContext context) {
        if (context.size() == 0 || context.rescore() == null || context.rescore().isEmpty()) {
            return;
        }
        if (!RescorePhase.validateSort(context.sort())) {
            throw new IllegalStateException("Cannot use [sort] option in conjunction with [rescore], missing a validate?");
        }
        TopDocs topDocs = context.queryResult().topDocs().topDocs;
        if (topDocs.scoreDocs.length == 0) {
            return;
        }
        Arrays.stream(topDocs.scoreDocs).forEach(t -> {
            if (t instanceof FieldDoc) {
                FieldDoc fieldDoc = (FieldDoc)t;
                fieldDoc.score = ((Float)fieldDoc.fields[0]).floatValue();
            }
        });
        TopFieldGroups topGroups = null;
        TopFieldDocs topFields = null;
        if (topDocs instanceof TopFieldGroups) {
            TopFieldGroups topFieldGroups = (TopFieldGroups)topDocs;
            assert (context.collapse() != null && RescorePhase.validateSortFields(topFieldGroups.fields));
            topGroups = topFieldGroups;
        } else if (topDocs instanceof TopFieldDocs) {
            TopFieldDocs topFieldDocs = (TopFieldDocs)topDocs;
            assert (RescorePhase.validateSortFields(topFieldDocs.fields));
            topFields = topFieldDocs;
        }
        try {
            Runnable cancellationCheck = RescorePhase.getCancellationChecks(context);
            for (RescoreContext ctx : context.rescore()) {
                ctx.setCancellationChecker(cancellationCheck);
                topDocs = ctx.rescorer().rescore(topDocs, context.searcher(), ctx);
                assert (RescorePhase.topDocsSortedByScore(topDocs)) : "topdocs should be sorted after rescore";
                ctx.setCancellationChecker(null);
            }
            if (topGroups != null) {
                assert (context.collapse() != null);
                topDocs = RescorePhase.rewriteTopFieldGroups(topGroups, topDocs);
            } else if (topFields != null) {
                topDocs = RescorePhase.rewriteTopFieldDocs(topFields, topDocs);
            }
            context.queryResult().topDocs(new TopDocsAndMaxScore(topDocs, topDocs.scoreDocs[0].score), context.queryResult().sortValueFormats());
        }
        catch (IOException e) {
            throw new ElasticsearchException("Rescore Phase Failed", (Throwable)e, new Object[0]);
        }
        catch (ContextIndexSearcher.TimeExceededException timeExceededException) {
            SearchTimeoutException.handleTimeout(context.request().allowPartialSearchResults(), context.shardTarget(), context.queryResult());
        }
    }

    public static boolean validateSort(SortAndFormats sortAndFormats) {
        if (sortAndFormats == null) {
            return true;
        }
        return RescorePhase.validateSortFields(sortAndFormats.sort.getSort());
    }

    private static boolean validateSortFields(SortField[] fields) {
        if (!fields[0].equals(SortField.FIELD_SCORE)) {
            return false;
        }
        if (fields.length == 1) {
            return true;
        }
        return fields[1] instanceof ShardDocSortField && !fields[1].getReverse();
    }

    private static TopFieldDocs rewriteTopFieldDocs(TopFieldDocs originalTopFieldDocs, TopDocs rescoredTopDocs) {
        Map<Integer, FieldDoc> docIdToFieldDoc = Maps.newMapWithExpectedSize(originalTopFieldDocs.scoreDocs.length);
        for (int i = 0; i < originalTopFieldDocs.scoreDocs.length; ++i) {
            docIdToFieldDoc.put(originalTopFieldDocs.scoreDocs[i].doc, (FieldDoc)originalTopFieldDocs.scoreDocs[i]);
        }
        ScoreDoc[] newScoreDocs = new FieldDoc[rescoredTopDocs.scoreDocs.length];
        int pos = 0;
        for (ScoreDoc doc : rescoredTopDocs.scoreDocs) {
            newScoreDocs[pos] = (FieldDoc)docIdToFieldDoc.get(doc.doc);
            ((FieldDoc)newScoreDocs[pos]).score = doc.score;
            ((FieldDoc)newScoreDocs[pos]).fields[0] = Float.valueOf(((FieldDoc)newScoreDocs[pos]).score);
            ++pos;
        }
        return new TopFieldDocs(originalTopFieldDocs.totalHits, newScoreDocs, originalTopFieldDocs.fields);
    }

    private static TopFieldGroups rewriteTopFieldGroups(TopFieldGroups originalTopGroups, TopDocs rescoredTopDocs) {
        ScoreDoc[] newFieldDocs = RescorePhase.rewriteFieldDocs((FieldDoc[])originalTopGroups.scoreDocs, rescoredTopDocs.scoreDocs);
        Map<Integer, Object> docIdToGroupValue = Maps.newMapWithExpectedSize(originalTopGroups.scoreDocs.length);
        for (int i = 0; i < originalTopGroups.scoreDocs.length; ++i) {
            docIdToGroupValue.put(originalTopGroups.scoreDocs[i].doc, originalTopGroups.groupValues[i]);
        }
        Object[] newGroupValues = new Object[originalTopGroups.groupValues.length];
        int pos = 0;
        for (ScoreDoc doc : rescoredTopDocs.scoreDocs) {
            newGroupValues[pos++] = docIdToGroupValue.get(doc.doc);
        }
        return new TopFieldGroups(originalTopGroups.field, originalTopGroups.totalHits, newFieldDocs, originalTopGroups.fields, newGroupValues);
    }

    private static FieldDoc[] rewriteFieldDocs(FieldDoc[] originalTopDocs, ScoreDoc[] rescoredTopDocs) {
        Map docIdToFieldDoc = Maps.newMapWithExpectedSize(rescoredTopDocs.length);
        Arrays.stream(originalTopDocs).forEach(d -> docIdToFieldDoc.put(d.doc, d));
        FieldDoc[] newDocs = new FieldDoc[rescoredTopDocs.length];
        int pos = 0;
        for (ScoreDoc doc : rescoredTopDocs) {
            newDocs[pos] = (FieldDoc)docIdToFieldDoc.get(doc.doc);
            newDocs[pos].score = doc.score;
            newDocs[pos].fields[0] = Float.valueOf(doc.score);
            ++pos;
        }
        return newDocs;
    }

    private static boolean topDocsSortedByScore(TopDocs topDocs) {
        if (topDocs == null || topDocs.scoreDocs == null || topDocs.scoreDocs.length < 2) {
            return true;
        }
        float lastScore = topDocs.scoreDocs[0].score;
        for (int i = 1; i < topDocs.scoreDocs.length; ++i) {
            ScoreDoc doc = topDocs.scoreDocs[i];
            if (Float.compare(doc.score, lastScore) > 0) {
                return false;
            }
            lastScore = doc.score;
        }
        return true;
    }

    static Runnable getCancellationChecks(SearchContext context) {
        List<Runnable> cancellationChecks = context.getCancellationChecks();
        return () -> {
            for (Runnable check : cancellationChecks) {
                check.run();
            }
        };
    }
}

