/*
 * Decompiled with CFR 0.152.
 */
package org.elasticsearch.xpack.ml.aggs.categorization;

import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Objects;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import org.apache.lucene.util.Accountable;
import org.apache.lucene.util.RamUsageEstimator;
import org.elasticsearch.common.util.set.Sets;
import org.elasticsearch.search.aggregations.InternalAggregations;
import org.elasticsearch.xpack.ml.aggs.categorization.CategorizationBytesRefHash;
import org.elasticsearch.xpack.ml.aggs.categorization.SerializableTokenListCategory;

public class TokenListCategory
implements Accountable {
    private static final long SHALLOW_SIZE = RamUsageEstimator.shallowSizeOfInstance(TokenListCategory.class);
    private static final long SHALLOW_SIZE_OF_ARRAY_LIST = RamUsageEstimator.shallowSizeOfInstance(ArrayList.class);
    private final int id;
    private final List<TokenAndWeight> baseWeightedTokenIds;
    private final int baseWeight;
    private final int baseUnfilteredLength;
    private int maxUnfilteredStringLength;
    private int orderedCommonTokenBeginIndex;
    private int orderedCommonTokenEndIndex;
    private final List<TokenAndWeight> commonUniqueTokenIds;
    private int commonUniqueTokenWeight;
    private final int origUniqueTokenWeight;
    private long numMatches;
    private long bucketOrd = -1L;
    private List<InternalAggregations> subAggs = List.of();
    private long cachedSizeInBytes;

    public TokenListCategory(int id, int unfilteredLength, List<TokenAndWeight> baseWeightedTokenIds, List<TokenAndWeight> uniqueTokenIds, long numMatches) {
        this(id, unfilteredLength, baseWeightedTokenIds, uniqueTokenIds, unfilteredLength, numMatches);
    }

    public TokenListCategory(int id, int unfilteredLength, List<TokenAndWeight> baseWeightedTokenIds, List<TokenAndWeight> uniqueTokenIds, int maxUnfilteredStringLength, long numMatches) {
        this.id = id;
        this.baseWeightedTokenIds = List.copyOf(baseWeightedTokenIds);
        this.baseWeight = baseWeightedTokenIds.stream().mapToInt(TokenAndWeight::getWeight).sum();
        assert (unfilteredLength > 0) : "unfiltered length must be positive, got " + unfilteredLength;
        this.baseUnfilteredLength = unfilteredLength;
        assert (maxUnfilteredStringLength >= this.baseUnfilteredLength) : "max unfiltered length, " + maxUnfilteredStringLength + ", is smaller than base unfiltered length, " + this.baseUnfilteredLength;
        this.maxUnfilteredStringLength = maxUnfilteredStringLength;
        this.orderedCommonTokenBeginIndex = 0;
        this.orderedCommonTokenEndIndex = baseWeightedTokenIds.size();
        assert (uniqueTokenIds.stream().map(TokenAndWeight::getTokenId).distinct().count() == (long)uniqueTokenIds.size()) : "Unique token IDs contains duplicates " + uniqueTokenIds;
        assert (TokenListCategory.isSorted(uniqueTokenIds)) : "Unique token IDs is not sorted " + uniqueTokenIds;
        assert (Sets.intersection(uniqueTokenIds.stream().map(TokenAndWeight::getTokenId).collect(Collectors.toSet()), baseWeightedTokenIds.stream().map(TokenAndWeight::getTokenId).collect(Collectors.toSet())).size() == uniqueTokenIds.size()) : "Some unique token IDs " + uniqueTokenIds + " are not base token IDs " + baseWeightedTokenIds;
        this.commonUniqueTokenIds = new ArrayList<TokenAndWeight>(uniqueTokenIds);
        this.origUniqueTokenWeight = this.commonUniqueTokenWeight = this.commonUniqueTokenIds.stream().mapToInt(TokenAndWeight::getWeight).sum();
        assert (numMatches > 0L) : "number of matches must be positive, got " + numMatches;
        assert (numMatches > 1L || maxUnfilteredStringLength == this.baseUnfilteredLength) : "max unfiltered length, " + maxUnfilteredStringLength + ", is different to base unfiltered length, " + this.baseUnfilteredLength + ", for a category with a single match";
        this.numMatches = numMatches;
        this.cacheRamUsage();
    }

    public TokenListCategory(int id, SerializableTokenListCategory serializable, CategorizationBytesRefHash bytesRefHash) {
        this.id = id;
        this.baseWeightedTokenIds = IntStream.range(0, serializable.baseTokens.length).mapToObj(index -> new TokenAndWeight(bytesRefHash.put(serializable.baseTokens[index]), serializable.baseTokenWeights[index])).collect(Collectors.toList());
        this.baseWeight = this.baseWeightedTokenIds.stream().mapToInt(TokenAndWeight::getWeight).sum();
        this.baseUnfilteredLength = serializable.baseUnfilteredLength;
        this.maxUnfilteredStringLength = serializable.maxUnfilteredStringLength;
        this.orderedCommonTokenBeginIndex = serializable.orderedCommonTokenBeginIndex;
        this.orderedCommonTokenEndIndex = serializable.orderedCommonTokenEndIndex;
        this.commonUniqueTokenIds = IntStream.range(0, serializable.commonUniqueTokenIndexes.length).mapToObj(index -> new TokenAndWeight(this.baseWeightedTokenIds.get(serializable.commonUniqueTokenIndexes[index]).getTokenId(), serializable.commonUniqueTokenWeights[index])).sorted().collect(Collectors.toCollection(ArrayList::new));
        this.commonUniqueTokenWeight = this.commonUniqueTokenIds.stream().mapToInt(TokenAndWeight::getWeight).sum();
        this.origUniqueTokenWeight = serializable.origUniqueTokenWeight;
        this.numMatches = serializable.numMatches;
        this.cacheRamUsage();
    }

    public void addString(int unfilteredLength, List<TokenAndWeight> weightedTokenIds, List<TokenAndWeight> uniqueTokenIds, long numMatches) {
        assert (TokenListCategory.isSorted(uniqueTokenIds)) : "Unique token IDs is not sorted " + uniqueTokenIds;
        assert (numMatches > 0L) : "number of matches must be positive, got " + numMatches;
        this.mergeWith(unfilteredLength, weightedTokenIds, 0, weightedTokenIds.size(), uniqueTokenIds, numMatches);
    }

    public void mergeWith(TokenListCategory other) {
        this.mergeWith(other.maxUnfilteredStringLength, other.baseWeightedTokenIds, other.orderedCommonTokenBeginIndex, other.orderedCommonTokenEndIndex, other.commonUniqueTokenIds, other.numMatches);
    }

    private void mergeWith(int unfilteredLength, List<TokenAndWeight> weightedTokenIds, int orderedCommonTokenBeginIndex, int orderedCommonTokenEndIndex, List<TokenAndWeight> uniqueTokenIds, long numMatches) {
        this.updateCommonUniqueTokenIds(uniqueTokenIds);
        this.updateOrderedCommonTokenIds(weightedTokenIds, orderedCommonTokenBeginIndex, orderedCommonTokenEndIndex);
        if (unfilteredLength > this.maxUnfilteredStringLength) {
            this.maxUnfilteredStringLength = unfilteredLength;
        }
        this.numMatches += numMatches;
    }

    public void addSubAggs(InternalAggregations aggs) {
        if (this.subAggs.isEmpty()) {
            this.subAggs = new ArrayList<InternalAggregations>();
        }
        this.subAggs.add(aggs);
    }

    public List<InternalAggregations> getSubAggs() {
        return this.subAggs;
    }

    private void updateCommonUniqueTokenIds(List<TokenAndWeight> newUniqueTokenIds) {
        assert (this.commonUniqueTokenWeight == this.commonUniqueTokenIds.stream().mapToInt(TokenAndWeight::getWeight).sum()) : "commonUniqueTokenWeight not up to date";
        this.commonUniqueTokenWeight = 0;
        int initialSize = this.commonUniqueTokenIds.size();
        int commonIndex = 0;
        int newIndex = 0;
        int outputIndex = 0;
        while (commonIndex < initialSize) {
            if (newIndex >= newUniqueTokenIds.size()) {
                ++commonIndex;
                continue;
            }
            TokenAndWeight commonTokenAndWeight = this.commonUniqueTokenIds.get(commonIndex);
            int cmp = commonTokenAndWeight.compareTo(newUniqueTokenIds.get(newIndex));
            if (cmp < 0) {
                ++commonIndex;
                continue;
            }
            if (cmp == 0) {
                this.commonUniqueTokenIds.set(outputIndex++, commonTokenAndWeight);
                this.commonUniqueTokenWeight += commonTokenAndWeight.getWeight();
                ++commonIndex;
            }
            ++newIndex;
        }
        if (outputIndex < initialSize) {
            this.commonUniqueTokenIds.subList(outputIndex, initialSize).clear();
            this.cacheRamUsage();
        } else assert (outputIndex == initialSize) : "should be impossible for output index to exceed initial size, but got " + outputIndex + " > " + initialSize;
        assert (this.commonUniqueTokenWeight == this.commonUniqueTokenIds.stream().mapToInt(TokenAndWeight::getWeight).sum()) : "commonUniqueTokenWeight not up to date";
    }

    void updateOrderedCommonTokenIds(List<TokenAndWeight> newTokenIds, int newBeginIndex, int newEndIndex) {
        while (this.orderedCommonTokenEndIndex > this.orderedCommonTokenBeginIndex && !this.isTokenIdCommon(this.baseWeightedTokenIds.get(this.orderedCommonTokenEndIndex - 1))) {
            --this.orderedCommonTokenEndIndex;
        }
        while (this.orderedCommonTokenBeginIndex < this.orderedCommonTokenEndIndex && !this.isTokenIdCommon(this.baseWeightedTokenIds.get(this.orderedCommonTokenBeginIndex))) {
            ++this.orderedCommonTokenBeginIndex;
        }
        int bestOrderedCommonTokenBeginIndex = this.orderedCommonTokenEndIndex;
        int bestOrderedCommonTokenEndIndex = this.orderedCommonTokenEndIndex;
        int bestWeight = 0;
        for (int tryOrderedCommonTokenBeginIndex = this.orderedCommonTokenBeginIndex; tryOrderedCommonTokenBeginIndex < this.orderedCommonTokenEndIndex; ++tryOrderedCommonTokenBeginIndex) {
            int newIndex = newBeginIndex;
            int tryWeight = 0;
            for (int commonIndex = tryOrderedCommonTokenBeginIndex; commonIndex < this.orderedCommonTokenEndIndex; ++commonIndex) {
                if (!this.isTokenIdCommon(this.baseWeightedTokenIds.get(commonIndex))) continue;
                while (newIndex < newEndIndex) {
                    TokenAndWeight baseToken = this.baseWeightedTokenIds.get(commonIndex);
                    TokenAndWeight newToken = newTokenIds.get(newIndex);
                    if (newToken.getTokenId() != baseToken.getTokenId()) {
                        ++newIndex;
                        continue;
                    }
                    tryWeight += baseToken.getWeight();
                    break;
                }
                if (newIndex < newEndIndex) continue;
                if (tryWeight <= bestWeight) break;
                bestWeight = tryWeight;
                bestOrderedCommonTokenBeginIndex = tryOrderedCommonTokenBeginIndex;
                bestOrderedCommonTokenEndIndex = commonIndex;
                break;
            }
            if (newIndex >= newEndIndex) continue;
            if (tryWeight <= bestWeight) break;
            bestWeight = tryWeight;
            bestOrderedCommonTokenBeginIndex = tryOrderedCommonTokenBeginIndex;
            bestOrderedCommonTokenEndIndex = this.orderedCommonTokenEndIndex;
            break;
        }
        if (this.orderedCommonTokenBeginIndex != bestOrderedCommonTokenBeginIndex) {
            this.orderedCommonTokenBeginIndex = bestOrderedCommonTokenBeginIndex;
        }
        if (this.orderedCommonTokenEndIndex != bestOrderedCommonTokenEndIndex) {
            this.orderedCommonTokenEndIndex = bestOrderedCommonTokenEndIndex;
        }
    }

    boolean isTokenIdCommon(TokenAndWeight token) {
        return Collections.binarySearch(this.commonUniqueTokenIds, token) >= 0;
    }

    public int getId() {
        return this.id;
    }

    public List<TokenAndWeight> getBaseWeightedTokenIds() {
        return this.baseWeightedTokenIds;
    }

    public int getBaseWeight() {
        return this.baseWeight;
    }

    public int getBaseUnfilteredLength() {
        return this.baseUnfilteredLength;
    }

    public int getMaxUnfilteredStringLength() {
        return this.maxUnfilteredStringLength;
    }

    public int getOrderedCommonTokenBeginIndex() {
        return this.orderedCommonTokenBeginIndex;
    }

    public int getOrderedCommonTokenEndIndex() {
        return this.orderedCommonTokenEndIndex;
    }

    public List<TokenAndWeight> getCommonUniqueTokenIds() {
        return List.copyOf(this.commonUniqueTokenIds);
    }

    public int getCommonUniqueTokenWeight() {
        return this.commonUniqueTokenWeight;
    }

    public int getOrigUniqueTokenWeight() {
        return this.origUniqueTokenWeight;
    }

    public long getNumMatches() {
        return this.numMatches;
    }

    public int maxMatchingStringLen() {
        return TokenListCategory.maxMatchingStringLen(this.baseUnfilteredLength, this.maxUnfilteredStringLength, this.commonUniqueTokenIds.size());
    }

    static int maxMatchingStringLen(int baseUnfilteredLength, int maxUnfilteredStringLength, int numCommonUniqueTokenIds) {
        int extendedLength = Math.min(maxUnfilteredStringLength * 11 / 10, (int)((float)baseUnfilteredLength * Math.max((float)numCommonUniqueTokenIds / 1.5f, 2.0f)));
        return Math.max(maxUnfilteredStringLength, extendedLength);
    }

    void setBucketOrd(long bucketOrd) {
        assert (bucketOrd >= 0L) : "Attempt to set bucketOrd to negative number " + bucketOrd;
        assert (this.bucketOrd == -1L || this.bucketOrd == bucketOrd) : "Attempt to change bucketOrd from " + this.bucketOrd + " to " + bucketOrd;
        this.bucketOrd = bucketOrd;
    }

    long getBucketOrd() {
        return this.bucketOrd;
    }

    public int missingCommonTokenWeight(List<TokenAndWeight> uniqueTokenIds) {
        assert (TokenListCategory.isSorted(uniqueTokenIds)) : "Unique token IDs is not sorted " + uniqueTokenIds;
        int presentWeight = 0;
        int commonIndex = 0;
        int testIndex = 0;
        while (commonIndex < this.commonUniqueTokenIds.size() && testIndex < uniqueTokenIds.size()) {
            TokenAndWeight commonTokenAndWeight = this.commonUniqueTokenIds.get(commonIndex);
            int cmp = commonTokenAndWeight.compareTo(uniqueTokenIds.get(testIndex));
            if (cmp < 0) {
                ++commonIndex;
                continue;
            }
            if (cmp == 0) {
                presentWeight += commonTokenAndWeight.getWeight();
                ++commonIndex;
            }
            ++testIndex;
        }
        return this.commonUniqueTokenWeight - presentWeight;
    }

    public boolean matchesSearchForCategory(TokenListCategory other) {
        return this.matchesSearchForCategory(other.baseWeight, other.maxUnfilteredStringLength, other.commonUniqueTokenIds, other.baseWeightedTokenIds);
    }

    public boolean matchesSearchForCategory(int otherBaseWeight, int otherUnfilteredStringLen, List<TokenAndWeight> otherUniqueTokenIds, List<TokenAndWeight> otherBaseTokenIds) {
        return this.baseWeight == 0 == (otherBaseWeight == 0) && this.maxMatchingStringLen() >= otherUnfilteredStringLen && this.isMissingCommonTokenWeightZero(otherUniqueTokenIds) && this.containsCommonInOrderTokensInOrder(otherBaseTokenIds);
    }

    public boolean isMissingCommonTokenWeightZero(List<TokenAndWeight> uniqueTokenIds) {
        assert (TokenListCategory.isSorted(uniqueTokenIds)) : "Unique token IDs is not sorted " + uniqueTokenIds;
        int uniqueTokenIdsSize = uniqueTokenIds.size();
        int testIndex = 0;
        for (TokenAndWeight commonTokenAndWeight : this.commonUniqueTokenIds) {
            TokenAndWeight testTokenAndWeight;
            if (testIndex >= uniqueTokenIdsSize) {
                return false;
            }
            while ((testTokenAndWeight = uniqueTokenIds.get(testIndex)).getTokenId() < commonTokenAndWeight.getTokenId()) {
                if (++testIndex < uniqueTokenIdsSize) continue;
                return false;
            }
            if (testTokenAndWeight.getTokenId() != commonTokenAndWeight.getTokenId()) {
                return false;
            }
            ++testIndex;
        }
        return true;
    }

    boolean containsCommonInOrderTokensInOrder(List<TokenAndWeight> tokenIds) {
        int testIndex = 0;
        for (int index = this.orderedCommonTokenBeginIndex; index < this.orderedCommonTokenEndIndex; ++index) {
            TokenAndWeight baseTokenAndWeight = this.baseWeightedTokenIds.get(index);
            if (!this.isTokenIdCommon(baseTokenAndWeight)) continue;
            do {
                if (testIndex < tokenIds.size()) continue;
                return false;
            } while (tokenIds.get(testIndex++).compareTo(baseTokenAndWeight) != 0);
        }
        return true;
    }

    public long ramBytesUsed() {
        return this.cachedSizeInBytes;
    }

    long ramBytesUsedSlow() {
        return SHALLOW_SIZE + RamUsageEstimator.sizeOfCollection(this.baseWeightedTokenIds) + RamUsageEstimator.sizeOfCollection(this.commonUniqueTokenIds);
    }

    private void cacheRamUsage() {
        this.cachedSizeInBytes = SHALLOW_SIZE + RamUsageEstimator.alignObjectSize((long)(SHALLOW_SIZE_OF_ARRAY_LIST + (long)RamUsageEstimator.NUM_BYTES_ARRAY_HEADER + (long)this.baseWeightedTokenIds.size() * (TokenAndWeight.SHALLOW_SIZE + (long)RamUsageEstimator.NUM_BYTES_OBJECT_REF))) + RamUsageEstimator.alignObjectSize((long)(SHALLOW_SIZE_OF_ARRAY_LIST + (long)RamUsageEstimator.NUM_BYTES_ARRAY_HEADER + (long)this.commonUniqueTokenIds.size() * (TokenAndWeight.SHALLOW_SIZE + (long)RamUsageEstimator.NUM_BYTES_OBJECT_REF)));
    }

    public int hashCode() {
        return Objects.hash(this.id, this.baseWeightedTokenIds, this.baseWeight, this.baseUnfilteredLength, this.maxUnfilteredStringLength, this.orderedCommonTokenBeginIndex, this.orderedCommonTokenEndIndex, this.commonUniqueTokenIds, this.commonUniqueTokenWeight, this.origUniqueTokenWeight, this.numMatches);
    }

    public boolean equals(Object other) {
        if (other == this) {
            return true;
        }
        if (other == null || this.getClass() != other.getClass()) {
            return false;
        }
        TokenListCategory that = (TokenListCategory)other;
        return this.id == that.id && Objects.equals(this.baseWeightedTokenIds, that.baseWeightedTokenIds) && this.baseWeight == that.baseWeight && this.baseUnfilteredLength == that.baseUnfilteredLength && this.maxUnfilteredStringLength == that.maxUnfilteredStringLength && this.orderedCommonTokenBeginIndex == that.orderedCommonTokenBeginIndex && this.orderedCommonTokenEndIndex == that.orderedCommonTokenEndIndex && Objects.equals(this.commonUniqueTokenIds, that.commonUniqueTokenIds) && this.commonUniqueTokenWeight == that.commonUniqueTokenWeight && this.origUniqueTokenWeight == that.origUniqueTokenWeight && this.numMatches == that.numMatches;
    }

    public String toString() {
        return "Category with base tokens " + this.baseWeightedTokenIds + " with [" + this.numMatches + "] matches";
    }

    static boolean isSorted(List<TokenAndWeight> list) {
        TokenAndWeight previousTokenAndWeight = null;
        for (TokenAndWeight tokenAndWeight : list) {
            if (previousTokenAndWeight != null && tokenAndWeight.compareTo(previousTokenAndWeight) < 0) {
                return false;
            }
            previousTokenAndWeight = tokenAndWeight;
        }
        return true;
    }

    public static class TokenAndWeight
    implements Comparable<TokenAndWeight>,
    Accountable {
        private static final long SHALLOW_SIZE = RamUsageEstimator.shallowSizeOfInstance(TokenAndWeight.class);
        private final int tokenId;
        private final int weight;

        public TokenAndWeight(int tokenId, int weight) {
            assert (tokenId >= 0) : "token ID cannot be negative, got " + tokenId;
            this.tokenId = tokenId;
            assert (weight >= 0) : "weight cannot be negative, got " + weight;
            this.weight = weight;
        }

        public int getTokenId() {
            return this.tokenId;
        }

        public int getWeight() {
            return this.weight;
        }

        public long ramBytesUsed() {
            return SHALLOW_SIZE;
        }

        public int hashCode() {
            return Objects.hash(this.tokenId, this.weight);
        }

        public boolean equals(Object other) {
            if (other == null || this.getClass() != other.getClass()) {
                return false;
            }
            TokenAndWeight that = (TokenAndWeight)other;
            return this.tokenId == that.tokenId && this.weight == that.weight;
        }

        @Override
        public int compareTo(TokenAndWeight other) {
            return this.tokenId - other.tokenId;
        }

        public String toString() {
            return "{" + this.tokenId + ", " + this.weight + "}";
        }
    }
}

