/*
 * Decompiled with CFR 0.152.
 */
package org.elasticsearch.xpack.ml.aggs.categorization;

import java.util.ArrayList;
import java.util.Comparator;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.PriorityQueue;
import java.util.stream.Collectors;
import org.apache.lucene.util.Accountable;
import org.apache.lucene.util.RamUsageEstimator;
import org.elasticsearch.core.Tuple;
import org.elasticsearch.search.aggregations.AggregationExecutionException;
import org.elasticsearch.xpack.ml.aggs.categorization.CategorizationTokenTree;
import org.elasticsearch.xpack.ml.aggs.categorization.TextCategorization;

abstract class TreeNode
implements Accountable {
    private long count;

    TreeNode(long count) {
        this.count = count;
    }

    abstract void mergeWith(TreeNode var1);

    abstract boolean isLeaf();

    final void incCount(long count) {
        this.count += count;
    }

    final long getCount() {
        return this.count;
    }

    abstract TextCategorization addText(int[] var1, long var2, CategorizationTokenTree var4);

    abstract TextCategorization getCategorization(int[] var1);

    abstract List<TextCategorization> getAllChildrenTextCategorizations();

    abstract void collapseTinyChildren();

    private static class NativeIntLongPair {
        private final int tokenId;
        private final long count;

        static NativeIntLongPair of(int tokenId, long count) {
            return new NativeIntLongPair(tokenId, count);
        }

        NativeIntLongPair(int tokenId, long count) {
            this.tokenId = tokenId;
            this.count = count;
        }

        public long count() {
            return this.count;
        }
    }

    static class InnerTreeNode
    extends TreeNode {
        private final Map<Integer, TreeNode> children = new HashMap<Integer, TreeNode>();
        private final int childrenTokenPos;
        private final int maxChildren;
        private final PriorityQueue<NativeIntLongPair> smallestChild;

        InnerTreeNode(long count, int childrenTokenPos, int maxChildren) {
            super(count);
            this.childrenTokenPos = childrenTokenPos;
            this.maxChildren = maxChildren;
            this.smallestChild = new PriorityQueue<NativeIntLongPair>(maxChildren, Comparator.comparing(NativeIntLongPair::count));
        }

        @Override
        boolean isLeaf() {
            return false;
        }

        @Override
        public TextCategorization getCategorization(int[] tokenIds) {
            Optional<TreeNode> maybeChild = this.getChild(tokenIds[this.childrenTokenPos]);
            if (!maybeChild.isPresent()) {
                maybeChild = this.getChild(-1);
            }
            return maybeChild.map(node -> node.getCategorization(tokenIds)).orElse(null);
        }

        public long ramBytesUsed() {
            return (long)(8 + RamUsageEstimator.NUM_BYTES_OBJECT_REF + 4 + 4 + RamUsageEstimator.NUM_BYTES_OBJECT_REF) + RamUsageEstimator.sizeOfMap(this.children, (long)RamUsageEstimator.NUM_BYTES_OBJECT_REF) + (long)this.smallestChild.size() * (long)(RamUsageEstimator.NUM_BYTES_OBJECT_REF + 4 + 8);
        }

        @Override
        public TextCategorization addText(int[] tokenIds, long docCount, CategorizationTokenTree treeNodeFactory) {
            int currentToken = tokenIds[this.childrenTokenPos];
            TreeNode child = this.getChild(currentToken).map(node -> {
                node.incCount(docCount);
                if (!this.smallestChild.isEmpty() && this.smallestChild.peek().tokenId == currentToken) {
                    this.smallestChild.add(this.smallestChild.poll());
                }
                return node;
            }).orElseGet(() -> {
                TreeNode newNode = treeNodeFactory.newNode(docCount, this.childrenTokenPos + 1, tokenIds);
                treeNodeFactory.incSize(newNode.ramBytesUsed() + RamUsageEstimator.HASHTABLE_RAM_BYTES_PER_ENTRY + (long)RamUsageEstimator.NUM_BYTES_OBJECT_REF);
                return this.addChild(currentToken, newNode);
            });
            return child.addText(tokenIds, docCount, treeNodeFactory);
        }

        @Override
        void collapseTinyChildren() {
            if (this.isLeaf()) {
                return;
            }
            if (this.children.size() <= 1) {
                return;
            }
            Optional<TreeNode> maybeWildChild = this.getChild(-1);
            if (!maybeWildChild.isPresent() && this.smallestChild.size() > 0 && (double)this.smallestChild.peek().count / (double)this.getCount() <= 1.0 / (double)this.maxChildren) {
                TreeNode tinyChild = this.children.remove(this.smallestChild.poll().tokenId);
                maybeWildChild = Optional.of(this.addChild(-1, tinyChild));
            }
            if (maybeWildChild.isPresent()) {
                NativeIntLongPair tinyNode;
                TreeNode wildChild = maybeWildChild.get();
                while ((tinyNode = this.smallestChild.poll()) != null) {
                    if ((double)tinyNode.count / (double)this.getCount() > 1.0 / (double)this.maxChildren) {
                        this.smallestChild.add(tinyNode);
                        break;
                    }
                    wildChild.mergeWith(this.children.remove(tinyNode.tokenId));
                }
            }
            this.children.values().forEach(TreeNode::collapseTinyChildren);
        }

        @Override
        void mergeWith(TreeNode treeNode) {
            NativeIntLongPair siblingChild;
            if (treeNode == null) {
                return;
            }
            this.incCount(treeNode.count);
            if (treeNode.isLeaf()) {
                throw new UnsupportedOperationException("cannot merge non-leaf node with leaf node in categorization tree \n[" + this + "]\n[" + treeNode + "]");
            }
            InnerTreeNode innerTreeNode = (InnerTreeNode)treeNode;
            TreeNode siblingWildChild = innerTreeNode.children.remove(-1);
            this.addChild(-1, siblingWildChild);
            while ((siblingChild = innerTreeNode.smallestChild.poll()) != null) {
                TreeNode nephewNode = innerTreeNode.children.remove(siblingChild.tokenId);
                this.addChild(siblingChild.tokenId, nephewNode);
            }
        }

        private TreeNode addChild(int tokenId, TreeNode node) {
            if (node == null) {
                return null;
            }
            Optional<TreeNode> existingChild = this.getChild(tokenId).map(existingNode -> {
                existingNode.mergeWith(node);
                if (!this.smallestChild.isEmpty() && this.smallestChild.peek().tokenId == tokenId) {
                    this.smallestChild.poll();
                    this.smallestChild.add(NativeIntLongPair.of(tokenId, existingNode.getCount()));
                }
                return existingNode;
            });
            if (existingChild.isPresent()) {
                return existingChild.get();
            }
            if (this.children.size() == this.maxChildren) {
                return this.getChild(-1).map(wildChild -> {
                    TreeNode toReturn;
                    TreeNode toMerge;
                    if (!this.smallestChild.isEmpty() && node.getCount() > this.smallestChild.peek().count) {
                        toMerge = this.children.remove(this.smallestChild.poll().tokenId);
                        this.addChildAndUpdateSmallest(tokenId, node);
                        toReturn = node;
                    } else {
                        toMerge = node;
                        toReturn = wildChild;
                    }
                    wildChild.mergeWith(toMerge);
                    return toReturn;
                }).orElseThrow(() -> new AggregationExecutionException("Missing wild_card child even though maximum children reached"));
            }
            if (this.children.size() == this.maxChildren - 1) {
                if (this.children.containsKey(-1)) {
                    this.addChildAndUpdateSmallest(tokenId, node);
                } else if (tokenId == -1) {
                    this.addChildAndUpdateSmallest(tokenId, node);
                } else if (!this.smallestChild.isEmpty() && node.count > this.smallestChild.peek().count) {
                    this.addChildAndUpdateSmallest(-1, this.children.remove(this.smallestChild.poll().tokenId));
                    this.addChildAndUpdateSmallest(tokenId, node);
                } else {
                    this.addChildAndUpdateSmallest(-1, node);
                }
            } else {
                this.addChildAndUpdateSmallest(tokenId, node);
            }
            return node;
        }

        private void addChildAndUpdateSmallest(int tokenId, TreeNode node) {
            this.children.put(tokenId, node);
            if (tokenId != -1) {
                this.smallestChild.add(NativeIntLongPair.of(tokenId, node.count));
            }
        }

        private Optional<TreeNode> getChild(int tokenId) {
            return Optional.ofNullable(this.children.get(tokenId));
        }

        @Override
        public List<TextCategorization> getAllChildrenTextCategorizations() {
            return this.children.values().stream().flatMap(c -> c.getAllChildrenTextCategorizations().stream()).collect(Collectors.toList());
        }

        boolean hasChild(int tokenId) {
            return this.children.containsKey(tokenId);
        }

        public boolean equals(Object o) {
            if (this == o) {
                return true;
            }
            if (o == null || this.getClass() != o.getClass()) {
                return false;
            }
            InnerTreeNode treeNode = (InnerTreeNode)o;
            return this.childrenTokenPos == treeNode.childrenTokenPos && this.getCount() == treeNode.getCount() && Objects.equals(this.children, treeNode.children) && Objects.equals(this.smallestChild, treeNode.smallestChild);
        }

        public int hashCode() {
            return Objects.hash(this.children, this.childrenTokenPos, this.smallestChild, this.getCount());
        }
    }

    static class LeafTreeNode
    extends TreeNode {
        private final List<TextCategorization> textCategorizations = new ArrayList<TextCategorization>();
        private final int similarityThreshold;

        LeafTreeNode(long count, int similarityThreshold) {
            super(count);
            this.similarityThreshold = similarityThreshold;
            if (similarityThreshold < 1 || similarityThreshold > 100) {
                throw new IllegalArgumentException("similarityThreshold must be between 1 and 100");
            }
        }

        @Override
        public boolean isLeaf() {
            return true;
        }

        @Override
        void mergeWith(TreeNode treeNode) {
            if (treeNode == null) {
                return;
            }
            if (!treeNode.isLeaf()) {
                throw new UnsupportedOperationException("cannot merge leaf node with non-leaf node in categorization tree \n[" + this + "]\n[" + treeNode + "]");
            }
            this.incCount(treeNode.getCount());
            LeafTreeNode otherLeaf = (LeafTreeNode)treeNode;
            for (TextCategorization group : otherLeaf.textCategorizations) {
                if (this.getAndUpdateTextCategorization(group.getCategorization(), group.getCount()).isPresent()) continue;
                this.putNewTextCategorization(group);
            }
        }

        public long ramBytesUsed() {
            return (long)(8 + RamUsageEstimator.NUM_BYTES_OBJECT_REF + 4) + RamUsageEstimator.sizeOfCollection(this.textCategorizations);
        }

        @Override
        public TextCategorization addText(int[] tokenIds, long docCount, CategorizationTokenTree treeNodeFactory) {
            return this.getAndUpdateTextCategorization(tokenIds, docCount).orElseGet(() -> {
                TextCategorization categorization = this.putNewTextCategorization(treeNodeFactory.newCategorization(docCount, tokenIds));
                treeNodeFactory.incSize(categorization.ramBytesUsed() + (long)RamUsageEstimator.NUM_BYTES_OBJECT_REF);
                return categorization;
            });
        }

        @Override
        List<TextCategorization> getAllChildrenTextCategorizations() {
            return this.textCategorizations;
        }

        @Override
        void collapseTinyChildren() {
        }

        private Optional<TextCategorization> getAndUpdateTextCategorization(int[] tokenIds, long docCount) {
            return this.getBestCategorization(tokenIds).map(bestGroupAndSimilarity -> {
                if ((Double)bestGroupAndSimilarity.v2() * 100.0 >= (double)this.similarityThreshold) {
                    ((TextCategorization)bestGroupAndSimilarity.v1()).addTokens(tokenIds, docCount);
                    return (TextCategorization)bestGroupAndSimilarity.v1();
                }
                return null;
            });
        }

        TextCategorization putNewTextCategorization(TextCategorization categorization) {
            this.textCategorizations.add(categorization);
            return categorization;
        }

        private Optional<Tuple<TextCategorization, Double>> getBestCategorization(int[] tokenIds) {
            if (this.textCategorizations.isEmpty()) {
                return Optional.empty();
            }
            if (this.textCategorizations.size() == 1) {
                return Optional.of(new Tuple((Object)this.textCategorizations.get(0), (Object)this.textCategorizations.get(0).calculateSimilarity(tokenIds).getSimilarity()));
            }
            TextCategorization.Similarity maxSimilarity = null;
            TextCategorization bestGroup = null;
            for (TextCategorization textCategorization : this.textCategorizations) {
                TextCategorization.Similarity groupSimilarity = textCategorization.calculateSimilarity(tokenIds);
                if (maxSimilarity != null && groupSimilarity.compareTo(maxSimilarity) <= 0) continue;
                maxSimilarity = groupSimilarity;
                bestGroup = textCategorization;
            }
            return Optional.of(new Tuple(bestGroup, (Object)maxSimilarity.getSimilarity()));
        }

        @Override
        public TextCategorization getCategorization(int[] tokenIds) {
            return this.getBestCategorization(tokenIds).map(Tuple::v1).orElse(null);
        }

        public boolean equals(Object o) {
            if (this == o) {
                return true;
            }
            if (o == null || this.getClass() != o.getClass()) {
                return false;
            }
            LeafTreeNode that = (LeafTreeNode)o;
            return that.similarityThreshold == this.similarityThreshold && Objects.equals(this.textCategorizations, that.textCategorizations);
        }

        public int hashCode() {
            return Objects.hash(this.textCategorizations, this.similarityThreshold);
        }
    }
}

