/*
 * Decompiled with CFR 0.152.
 */
package org.elasticsearch.xpack.ml.aggs.categorization;

import java.io.IOException;
import java.nio.charset.StandardCharsets;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Comparator;
import java.util.List;
import java.util.TreeMap;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.apache.lucene.analysis.TokenStream;
import org.apache.lucene.analysis.tokenattributes.CharTermAttribute;
import org.apache.lucene.util.Accountable;
import org.apache.lucene.util.BytesRef;
import org.apache.lucene.util.RamUsageEstimator;
import org.elasticsearch.core.Nullable;
import org.elasticsearch.core.Releasable;
import org.elasticsearch.core.Releasables;
import org.elasticsearch.search.aggregations.AggregationReduceContext;
import org.elasticsearch.search.aggregations.InternalAggregations;
import org.elasticsearch.xpack.ml.aggs.categorization.CategorizationBytesRefHash;
import org.elasticsearch.xpack.ml.aggs.categorization.CategorizationPartOfSpeechDictionary;
import org.elasticsearch.xpack.ml.aggs.categorization.InternalCategorizationAggregation;
import org.elasticsearch.xpack.ml.aggs.categorization.SerializableTokenListCategory;
import org.elasticsearch.xpack.ml.aggs.categorization.TokenListCategory;
import org.elasticsearch.xpack.ml.aggs.categorization.TokenListSimilarityTester;
import org.elasticsearch.xpack.ml.job.categorization.CategorizationAnalyzer;

public class TokenListCategorizer
implements Accountable {
    public static final int MAX_TOKENS = 100;
    private static final long SHALLOW_SIZE = RamUsageEstimator.shallowSizeOfInstance(TokenListCategorizer.class);
    private static final long SHALLOW_SIZE_OF_ARRAY_LIST = RamUsageEstimator.shallowSizeOfInstance(ArrayList.class);
    private static final float EPSILON = 1.0E-6f;
    private static final Logger logger = LogManager.getLogger(TokenListCategorizer.class);
    private final float lowerThreshold;
    private final float upperThreshold;
    private final CategorizationBytesRefHash bytesRefHash;
    @Nullable
    private final CategorizationPartOfSpeechDictionary partOfSpeechDictionary;
    private final List<TokenListCategory> categoriesById;
    private final List<TokenListCategory> categoriesByNumMatches;
    private long cachedSizeInBytes;
    private long categoriesByNumMatchesContentsSize;

    public TokenListCategorizer(CategorizationBytesRefHash bytesRefHash, CategorizationPartOfSpeechDictionary partOfSpeechDictionary, float threshold) {
        if (threshold < 0.01f || threshold > 1.0f) {
            throw new IllegalArgumentException("threshold must be between 0.01 and 1.0: got " + threshold);
        }
        this.bytesRefHash = bytesRefHash;
        this.partOfSpeechDictionary = partOfSpeechDictionary;
        this.lowerThreshold = threshold;
        this.upperThreshold = (1.0f + threshold) / 2.0f;
        this.categoriesByNumMatches = new ArrayList<TokenListCategory>();
        this.categoriesById = new ArrayList<TokenListCategory>();
        this.cacheRamUsage(0L);
    }

    @Nullable
    public TokenListCategory computeCategory(String s, CategorizationAnalyzer analyzer) {
        TokenListCategory tokenListCategory;
        block8: {
            TokenStream ts = analyzer.tokenStream("text", s);
            try {
                tokenListCategory = this.computeCategory(ts, s.length(), 1L);
                if (ts == null) break block8;
            }
            catch (Throwable throwable) {
                try {
                    if (ts != null) {
                        try {
                            ts.close();
                        }
                        catch (Throwable throwable2) {
                            throwable.addSuppressed(throwable2);
                        }
                    }
                    throw throwable;
                }
                catch (IOException e) {
                    throw new RuntimeException(e);
                }
            }
            ts.close();
        }
        return tokenListCategory;
    }

    @Nullable
    public TokenListCategory computeCategory(TokenStream ts, int unfilteredStringLen, long numDocs) throws IOException {
        assert (this.partOfSpeechDictionary != null) : "This version of computeCategory should only be used when a part-of-speech dictionary is available";
        if (numDocs <= 0L) {
            assert (numDocs == 0L) : "number of documents was negative: " + numDocs;
            return null;
        }
        ArrayList<TokenListCategory.TokenAndWeight> weightedTokenIds = new ArrayList<TokenListCategory.TokenAndWeight>();
        CharTermAttribute termAtt = (CharTermAttribute)ts.addAttribute(CharTermAttribute.class);
        ts.reset();
        WeightCalculator weightCalculator = new WeightCalculator(this.partOfSpeechDictionary);
        while (ts.incrementToken() && weightedTokenIds.size() < 100) {
            if (termAtt.length() <= 0) continue;
            String term = termAtt.toString();
            int weight = weightCalculator.calculateWeight(term);
            weightedTokenIds.add(new TokenListCategory.TokenAndWeight(this.bytesRefHash.put(new BytesRef(term.getBytes(StandardCharsets.UTF_8))), weight));
        }
        if (weightedTokenIds.isEmpty()) {
            return null;
        }
        return this.computeCategory(weightedTokenIds, unfilteredStringLen, numDocs);
    }

    public TokenListCategory computeCategory(List<TokenListCategory.TokenAndWeight> weightedTokenIds, int unfilteredStringLen, long numDocs) {
        int workWeight = 0;
        int minReweightedTotalWeight = 0;
        int maxReweightedTotalWeight = 0;
        TreeMap<Integer, TokenListCategory.TokenAndWeight> groupingMap = new TreeMap<Integer, TokenListCategory.TokenAndWeight>();
        for (TokenListCategory.TokenAndWeight weightedTokenId : weightedTokenIds) {
            int tokenId = weightedTokenId.getTokenId();
            int weight = weightedTokenId.getWeight();
            workWeight += weight;
            minReweightedTotalWeight += WeightCalculator.getMinMatchingWeight(weight);
            maxReweightedTotalWeight += WeightCalculator.getMaxMatchingWeight(weight);
            groupingMap.compute(tokenId, (k, v) -> v == null ? weightedTokenId : new TokenListCategory.TokenAndWeight(tokenId, v.getWeight() + weight));
        }
        ArrayList<TokenListCategory.TokenAndWeight> workTokenUniqueIds = new ArrayList<TokenListCategory.TokenAndWeight>(groupingMap.values());
        return this.computeCategory(weightedTokenIds, workTokenUniqueIds, workWeight, minReweightedTotalWeight, maxReweightedTotalWeight, unfilteredStringLen, unfilteredStringLen, numDocs);
    }

    public TokenListCategory mergeWireCategory(SerializableTokenListCategory serializableCategory) {
        int sizeBefore = this.categoriesByNumMatches.size();
        TokenListCategory foreignCategory = new TokenListCategory(0, serializableCategory, this.bytesRefHash);
        TokenListCategory mergedCategory = this.computeCategory(foreignCategory.getBaseWeightedTokenIds(), foreignCategory.getCommonUniqueTokenIds(), foreignCategory.getBaseWeight(), WeightCalculator.getMinMatchingWeight(foreignCategory.getBaseWeight()), WeightCalculator.getMaxMatchingWeight(foreignCategory.getBaseWeight()), foreignCategory.getBaseUnfilteredLength(), foreignCategory.getMaxUnfilteredStringLength(), foreignCategory.getNumMatches());
        if (logger.isDebugEnabled() && this.categoriesByNumMatches.size() == sizeBefore) {
            logger.debug("Merged wire category [{}] into existing category to form [{}]", (Object)serializableCategory, (Object)new SerializableTokenListCategory(mergedCategory, this.bytesRefHash));
        }
        return mergedCategory;
    }

    private synchronized TokenListCategory computeCategory(List<TokenListCategory.TokenAndWeight> weightedTokenIds, List<TokenListCategory.TokenAndWeight> workTokenUniqueIds, int workWeight, int minReweightedTotalWeight, int maxReweightedTotalWeight, int unfilteredStringLen, int maxUnfilteredStringLen, long numDocs) {
        int minWeight = TokenListCategorizer.minMatchingWeight(minReweightedTotalWeight, this.lowerThreshold);
        int maxWeight = TokenListCategorizer.maxMatchingWeight(maxReweightedTotalWeight, this.lowerThreshold);
        int bestSoFarIndex = -1;
        float bestSoFarSimilarity = this.lowerThreshold;
        for (int index = 0; index < this.categoriesByNumMatches.size(); ++index) {
            TokenListCategory compCategory = this.categoriesByNumMatches.get(index);
            List<TokenListCategory.TokenAndWeight> baseTokenIds = compCategory.getBaseWeightedTokenIds();
            int baseWeight = compCategory.getBaseWeight();
            boolean matchesSearch = compCategory.matchesSearchForCategory(workWeight, maxUnfilteredStringLen, workTokenUniqueIds, weightedTokenIds);
            if (!matchesSearch) {
                if (baseWeight < minWeight || baseWeight > maxWeight) {
                    assert (!baseTokenIds.equals(weightedTokenIds)) : "Min [" + minWeight + "] and/or max [" + maxWeight + "] weights calculated incorrectly " + baseTokenIds;
                    continue;
                }
                int missingCommonTokenWeight = compCategory.missingCommonTokenWeight(workTokenUniqueIds);
                if (missingCommonTokenWeight > 0) {
                    int origUniqueTokenWeight = compCategory.getOrigUniqueTokenWeight();
                    int commonUniqueTokenWeight = compCategory.getCommonUniqueTokenWeight();
                    float proportionOfOrig = (float)(commonUniqueTokenWeight - missingCommonTokenWeight) / (float)origUniqueTokenWeight;
                    if (proportionOfOrig < this.lowerThreshold) continue;
                }
            }
            float similarity = TokenListCategorizer.similarity(weightedTokenIds, workWeight, baseTokenIds, baseWeight);
            if (matchesSearch || similarity > this.upperThreshold) {
                if (similarity <= this.lowerThreshold) {
                    logger.trace("Reverse search match below threshold [{}]: orig tokens {} new tokens {}", (Object)Float.valueOf(similarity), compCategory.getBaseWeightedTokenIds(), weightedTokenIds);
                }
                return this.addCategoryMatch(maxUnfilteredStringLen, weightedTokenIds, workTokenUniqueIds, numDocs, index);
            }
            if (!(similarity > bestSoFarSimilarity)) continue;
            bestSoFarIndex = index;
            bestSoFarSimilarity = similarity;
            minWeight = TokenListCategorizer.minMatchingWeight(minReweightedTotalWeight, similarity);
            maxWeight = TokenListCategorizer.maxMatchingWeight(maxReweightedTotalWeight, similarity);
        }
        if (bestSoFarIndex >= 0) {
            return this.addCategoryMatch(maxUnfilteredStringLen, weightedTokenIds, workTokenUniqueIds, numDocs, bestSoFarIndex);
        }
        int newIndex = this.categoriesByNumMatches.size();
        TokenListCategory newCategory = new TokenListCategory(newIndex, unfilteredStringLen, weightedTokenIds, workTokenUniqueIds, maxUnfilteredStringLen, numDocs);
        this.categoriesById.add(newCategory);
        this.categoriesByNumMatches.add(newCategory);
        this.cacheRamUsage(newCategory.ramBytesUsed());
        return this.repositionCategory(newCategory, newIndex);
    }

    public long ramBytesUsed() {
        return this.cachedSizeInBytes;
    }

    long ramBytesUsedSlow() {
        return SHALLOW_SIZE + RamUsageEstimator.sizeOfCollection(this.categoriesByNumMatches);
    }

    private synchronized void cacheRamUsage(long contentsSizeDiff) {
        this.categoriesByNumMatchesContentsSize += contentsSizeDiff;
        this.cachedSizeInBytes = SHALLOW_SIZE + RamUsageEstimator.alignObjectSize((long)(SHALLOW_SIZE_OF_ARRAY_LIST + (long)RamUsageEstimator.NUM_BYTES_ARRAY_HEADER + (long)(this.categoriesByNumMatches.size() * RamUsageEstimator.NUM_BYTES_OBJECT_REF) + this.categoriesByNumMatchesContentsSize));
    }

    public int getCategoryCount() {
        return this.categoriesByNumMatches.size();
    }

    private TokenListCategory addCategoryMatch(int unfilteredLength, List<TokenListCategory.TokenAndWeight> weightedTokenIds, List<TokenListCategory.TokenAndWeight> uniqueTokenIds, long numDocs, int matchIndex) {
        TokenListCategory category = this.categoriesByNumMatches.get(matchIndex);
        long previousSize = category.ramBytesUsed();
        category.addString(unfilteredLength, weightedTokenIds, uniqueTokenIds, numDocs);
        this.cacheRamUsage(category.ramBytesUsed() - previousSize);
        if (numDocs == 1L) {
            return this.repositionCategory(category, matchIndex);
        }
        this.categoriesByNumMatches.sort(Comparator.comparing(TokenListCategory::getNumMatches).reversed());
        return category;
    }

    private TokenListCategory repositionCategory(TokenListCategory category, int currentIndex) {
        long newNumMatches = category.getNumMatches();
        int swapIndex = currentIndex;
        while (swapIndex > 0) {
            if (newNumMatches > this.categoriesByNumMatches.get(--swapIndex).getNumMatches()) continue;
            ++swapIndex;
            break;
        }
        if (swapIndex != currentIndex) {
            Collections.swap(this.categoriesByNumMatches, currentIndex, swapIndex);
        }
        return category;
    }

    static int minMatchingWeight(int weight, float threshold) {
        if (weight == 0) {
            return 0;
        }
        return (int)Math.floor((float)weight * threshold + 1.0E-6f) + 1;
    }

    static int maxMatchingWeight(int weight, float threshold) {
        if (weight == 0) {
            return 0;
        }
        return (int)Math.ceil((float)weight / threshold - 1.0E-6f) - 1;
    }

    static float similarity(List<TokenListCategory.TokenAndWeight> left, int leftWeight, List<TokenListCategory.TokenAndWeight> right, int rightWeight) {
        int maxWeight = Math.max(leftWeight, rightWeight);
        if (maxWeight > 0) {
            return 1.0f - (float)TokenListSimilarityTester.weightedEditDistance(left, right) / (float)maxWeight;
        }
        return 1.0f;
    }

    public List<SerializableTokenListCategory> toCategories(int size) {
        return this.categoriesByNumMatches.stream().limit(size).map(category -> new SerializableTokenListCategory((TokenListCategory)category, this.bytesRefHash)).toList();
    }

    public List<SerializableTokenListCategory> toCategoriesById() {
        return this.categoriesById.stream().map(category -> new SerializableTokenListCategory((TokenListCategory)category, this.bytesRefHash)).toList();
    }

    public InternalCategorizationAggregation.Bucket[] toOrderedBuckets(int size) {
        return (InternalCategorizationAggregation.Bucket[])this.categoriesByNumMatches.stream().limit(size).map(category -> new InternalCategorizationAggregation.Bucket(new SerializableTokenListCategory((TokenListCategory)category, this.bytesRefHash), category.getBucketOrd())).toArray(InternalCategorizationAggregation.Bucket[]::new);
    }

    public InternalCategorizationAggregation.Bucket[] toOrderedBuckets(int size, long minNumMatches, AggregationReduceContext reduceContext) {
        return (InternalCategorizationAggregation.Bucket[])this.categoriesByNumMatches.stream().limit(size).takeWhile(category -> category.getNumMatches() >= minNumMatches).map(category -> new InternalCategorizationAggregation.Bucket(new SerializableTokenListCategory((TokenListCategory)category, this.bytesRefHash), category.getBucketOrd(), category.getSubAggs().isEmpty() ? InternalAggregations.EMPTY : InternalAggregations.reduce(category.getSubAggs(), (AggregationReduceContext)reduceContext))).toArray(InternalCategorizationAggregation.Bucket[]::new);
    }

    static class WeightCalculator {
        private static final int MIN_DICTIONARY_LENGTH = 2;
        private static final int CONSECUTIVE_DICTIONARY_WORDS_FOR_EXTRA_WEIGHT = 3;
        private static final int EXTRA_VERB_WEIGHT = 5;
        private static final int EXTRA_OTHER_DICTIONARY_WEIGHT = 2;
        private static final int ADJACENCY_BOOST_MULTIPLIER = 6;
        private final CategorizationPartOfSpeechDictionary partOfSpeechDictionary;
        private int consecutiveHighWeights;

        WeightCalculator(CategorizationPartOfSpeechDictionary partOfSpeechDictionary) {
            this.partOfSpeechDictionary = partOfSpeechDictionary;
        }

        int calculateWeight(String term) {
            if (term.length() < 2) {
                this.consecutiveHighWeights = 0;
                return 1;
            }
            CategorizationPartOfSpeechDictionary.PartOfSpeech pos = this.partOfSpeechDictionary.getPartOfSpeech(term);
            if (pos == CategorizationPartOfSpeechDictionary.PartOfSpeech.NOT_IN_DICTIONARY) {
                this.consecutiveHighWeights = 0;
                return 1;
            }
            int posWeight = pos == CategorizationPartOfSpeechDictionary.PartOfSpeech.VERB ? 5 : 2;
            int adjacencyBoost = ++this.consecutiveHighWeights >= 3 ? 6 : 1;
            return 1 + posWeight * adjacencyBoost;
        }

        static int getMinMatchingWeight(int weight) {
            return weight <= 6 ? weight : 1 + (weight - 1) / 6;
        }

        static int getMaxMatchingWeight(int weight) {
            return weight <= Math.min(5, 2) || weight > Math.max(6, 3) ? weight : 1 + (weight - 1) * 6;
        }
    }

    public static class CloseableTokenListCategorizer
    extends TokenListCategorizer
    implements Releasable {
        public CloseableTokenListCategorizer(CategorizationBytesRefHash bytesRefHash, CategorizationPartOfSpeechDictionary partOfSpeechDictionary, float threshold) {
            super(bytesRefHash, partOfSpeechDictionary, threshold);
        }

        public void close() {
            Releasables.close((Releasable)this.bytesRefHash);
        }
    }
}

