/*
 * Decompiled with CFR 0.152.
 */
package org.elasticsearch.xpack.ml.inference.assignment.planning;

import java.util.ArrayList;
import java.util.Collection;
import java.util.Comparator;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Random;
import java.util.Set;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.elasticsearch.common.util.Maps;
import org.elasticsearch.core.SuppressForbidden;
import org.elasticsearch.core.Tuple;
import org.elasticsearch.xpack.ml.inference.assignment.planning.AssignmentPlan;

class RandomizedAssignmentRounding {
    private static final Logger logger = LogManager.getLogger(RandomizedAssignmentRounding.class);
    private static final double EPS = 1.0E-6;
    private final Random random;
    private final int rounds;
    private final Collection<AssignmentPlan.Node> nodes;
    private final Collection<AssignmentPlan.Deployment> deployments;
    private final AssignmentHolder assignmentHolder;

    RandomizedAssignmentRounding(Random random, int rounds, Collection<AssignmentPlan.Node> nodes, Collection<AssignmentPlan.Deployment> deployments) {
        if (rounds <= 0) {
            throw new IllegalArgumentException("rounds must be > 0");
        }
        this.random = Objects.requireNonNull(random);
        this.rounds = rounds;
        this.nodes = Objects.requireNonNull(nodes);
        this.deployments = Objects.requireNonNull(deployments);
        this.assignmentHolder = new AssignmentHolder();
    }

    AssignmentPlan computePlan(Map<Tuple<AssignmentPlan.Deployment, AssignmentPlan.Node>, Double> allocationVars, Map<Tuple<AssignmentPlan.Deployment, AssignmentPlan.Node>, Double> assignmentVars) {
        AssignmentPlan bestPlan = this.assignmentHolder.toPlan();
        this.assignmentHolder.initializeAssignments(allocationVars, assignmentVars);
        this.assignmentHolder.assignUnderSubscribedNodes();
        List<Tuple<AssignmentPlan.Deployment, AssignmentPlan.Node>> softAssignmentQueue = this.assignmentHolder.createSoftAssignmentQueue();
        if (!softAssignmentQueue.isEmpty()) {
            logger.debug(() -> "Random assignment rounding across [" + this.rounds + "] rounds");
            for (int i = 0; i < this.rounds; ++i) {
                AssignmentHolder randomizedAssignments = new AssignmentHolder(this.assignmentHolder);
                randomizedAssignments.doRandomizedRounding(softAssignmentQueue);
                AssignmentPlan randomizedPlan = randomizedAssignments.toPlan();
                if (randomizedPlan.compareTo(bestPlan) <= 0) continue;
                bestPlan = randomizedPlan;
            }
        } else {
            AssignmentPlan plan = this.assignmentHolder.toPlan();
            if (plan.compareTo(bestPlan) > 0) {
                bestPlan = plan;
            }
        }
        return bestPlan;
    }

    @SuppressForbidden(reason="Math#abs(int) is safe here as we protect against MIN_VALUE")
    private static int distance(int x, int y) {
        int distance = x - y;
        return distance == Integer.MIN_VALUE ? Integer.MAX_VALUE : Math.abs(distance);
    }

    private static boolean isInteger(double value) {
        return Double.isFinite(value) && Math.abs(value - Math.rint(value)) < 1.0E-6;
    }

    private class AssignmentHolder {
        private final Map<Tuple<AssignmentPlan.Deployment, AssignmentPlan.Node>, Double> assignments = new HashMap<Tuple<AssignmentPlan.Deployment, AssignmentPlan.Node>, Double>();
        private final Map<Tuple<AssignmentPlan.Deployment, AssignmentPlan.Node>, Double> allocations = new HashMap<Tuple<AssignmentPlan.Deployment, AssignmentPlan.Node>, Double>();
        private final ResourceTracker resourceTracker;

        private AssignmentHolder() {
            this.resourceTracker = new ResourceTracker(RandomizedAssignmentRounding.this.nodes, RandomizedAssignmentRounding.this.deployments);
        }

        private AssignmentHolder(AssignmentHolder holder) {
            this.assignments.putAll(holder.assignments);
            this.allocations.putAll(holder.allocations);
            this.resourceTracker = new ResourceTracker(holder.resourceTracker);
        }

        private void initializeAssignments(Map<Tuple<AssignmentPlan.Deployment, AssignmentPlan.Node>, Double> allocationVars, Map<Tuple<AssignmentPlan.Deployment, AssignmentPlan.Node>, Double> assignmentVars) {
            for (AssignmentPlan.Node n : RandomizedAssignmentRounding.this.nodes) {
                for (AssignmentPlan.Deployment m : RandomizedAssignmentRounding.this.deployments) {
                    Tuple index = Tuple.tuple((Object)m, (Object)n);
                    double assignment = assignmentVars.get(index);
                    double allocations = allocationVars.get(index);
                    if (assignment == 1.0 && RandomizedAssignmentRounding.isInteger(allocations)) {
                        this.resourceTracker.assign(m, n, (int)Math.rint(allocations));
                    }
                    this.assignments.put((Tuple<AssignmentPlan.Deployment, AssignmentPlan.Node>)index, assignment);
                    this.allocations.put((Tuple<AssignmentPlan.Deployment, AssignmentPlan.Node>)index, allocations);
                }
            }
        }

        private void assignUnderSubscribedNodes() {
            this.assignUnderSubscribedNodes(RandomizedAssignmentRounding.this.nodes);
        }

        private void assignUnderSubscribedNodes(Collection<AssignmentPlan.Node> nodeSelection) {
            for (AssignmentPlan.Node n : nodeSelection.stream().sorted(Comparator.comparingDouble(this::decreasingQualityNodeOrder)).toList()) {
                Tuple assignment;
                ArrayList<AssignmentPlan.Deployment> assignedDeployments = new ArrayList<AssignmentPlan.Deployment>();
                long totalModelMemory = 0L;
                int maxTotalThreads = 0;
                for (AssignmentPlan.Deployment m : RandomizedAssignmentRounding.this.deployments) {
                    assignment = Tuple.tuple((Object)m, (Object)n);
                    if (!(this.assignments.get(assignment) > 0.0)) continue;
                    int roundedAllocations = (int)Math.ceil(this.allocations.get(assignment));
                    totalModelMemory += m.estimateMemoryUsageBytes(roundedAllocations);
                    maxTotalThreads += roundedAllocations * m.threadsPerAllocation();
                    assignedDeployments.add(m);
                }
                if (totalModelMemory > n.availableMemoryBytes() || maxTotalThreads > n.cores()) continue;
                for (AssignmentPlan.Deployment m : assignedDeployments) {
                    assignment = Tuple.tuple((Object)m, (Object)n);
                    if (!(this.assignments.get(assignment) > 0.0) || !(this.assignments.get(assignment) < 1.0)) continue;
                    this.assignModelToNode(m, n, this.allocationsToAssign((Tuple<AssignmentPlan.Deployment, AssignmentPlan.Node>)assignment));
                }
                this.assignExcessCores(n);
            }
        }

        private int allocationsToAssign(Tuple<AssignmentPlan.Deployment, AssignmentPlan.Node> assignment) {
            if (RandomizedAssignmentRounding.isInteger(this.allocations.get(assignment))) {
                return (int)Math.rint(this.allocations.get(assignment));
            }
            return (int)Math.ceil(this.allocations.get(assignment));
        }

        private void assignModelToNode(AssignmentPlan.Deployment m, AssignmentPlan.Node n, int allocations) {
            Tuple assignment = Tuple.tuple((Object)m, (Object)n);
            int assignedAllocations = Math.min(allocations, this.resourceTracker.remainingModelAllocations.get(m));
            this.assignments.put((Tuple<AssignmentPlan.Deployment, AssignmentPlan.Node>)assignment, 1.0);
            this.allocations.put((Tuple<AssignmentPlan.Deployment, AssignmentPlan.Node>)assignment, Double.valueOf(assignedAllocations));
            this.resourceTracker.assign(m, n, assignedAllocations);
        }

        private double decreasingQualityNodeOrder(AssignmentPlan.Node n) {
            double quality = 0.0;
            for (AssignmentPlan.Deployment m : RandomizedAssignmentRounding.this.deployments) {
                Tuple index = Tuple.tuple((Object)m, (Object)n);
                if (!(this.allocations.get(index) > 0.0)) continue;
                quality += (double)(1 + (m.currentAllocationsByNodeId().containsKey(n.id()) ? 1 : 0)) * this.allocations.get(index) * (double)m.threadsPerAllocation();
            }
            return quality;
        }

        private void assignExcessCores(AssignmentPlan.Node n) {
            if (this.resourceTracker.remainingNodeCores.get(n) == 0) {
                return;
            }
            if (this.hasSoftAssignments(n)) {
                return;
            }
            for (AssignmentPlan.Deployment m2 : RandomizedAssignmentRounding.this.deployments.stream().filter(m -> this.assignments.get(Tuple.tuple((Object)m, (Object)n)) == 1.0 && this.resourceTracker.remainingModelAllocations.get(m) > 0).sorted(Comparator.comparingDouble(AssignmentHolder::remainingModelOrder)).toList()) {
                if (this.resourceTracker.remainingNodeCores.get(n) <= 0) break;
                int extraAllocations = m2.findExcessAllocations(Math.min(this.resourceTracker.remainingNodeCores.get(n) / m2.threadsPerAllocation(), this.resourceTracker.remainingModelAllocations.get(m2)), this.resourceTracker.remainingNodeMemory.get(n));
                this.allocations.compute((Tuple<AssignmentPlan.Deployment, AssignmentPlan.Node>)Tuple.tuple((Object)m2, (Object)n), (k, v) -> v + (double)extraAllocations);
                this.resourceTracker.assign(m2, n, extraAllocations);
            }
            this.zeroSoftAssignmentsOfSatisfiedModels();
        }

        private static double remainingModelOrder(AssignmentPlan.Deployment m) {
            return (long)(m.currentAllocationsByNodeId().isEmpty() ? 1 : 2) * -m.minimumMemoryRequiredBytes();
        }

        private boolean hasSoftAssignments(AssignmentPlan.Node n) {
            return RandomizedAssignmentRounding.this.deployments.stream().anyMatch(m -> this.isSoftAssignment((AssignmentPlan.Deployment)m, n));
        }

        private boolean isSoftAssignment(AssignmentPlan.Deployment m, AssignmentPlan.Node n) {
            Tuple index = Tuple.tuple((Object)m, (Object)n);
            return this.assignments.get(index) > 0.0 && this.assignments.get(index) < 1.0 || !RandomizedAssignmentRounding.isInteger(this.allocations.get(index));
        }

        private void zeroSoftAssignmentsOfSatisfiedModels() {
            for (AssignmentPlan.Deployment m : RandomizedAssignmentRounding.this.deployments) {
                if (this.resourceTracker.remainingModelAllocations.get(m) > 0) continue;
                for (AssignmentPlan.Node n : RandomizedAssignmentRounding.this.nodes) {
                    if (!this.isSoftAssignment(m, n)) continue;
                    this.unassign((Tuple<AssignmentPlan.Deployment, AssignmentPlan.Node>)Tuple.tuple((Object)m, (Object)n));
                }
            }
        }

        private void unassign(Tuple<AssignmentPlan.Deployment, AssignmentPlan.Node> assignment) {
            this.assignments.put(assignment, 0.0);
            this.allocations.put(assignment, 0.0);
        }

        private List<Tuple<AssignmentPlan.Deployment, AssignmentPlan.Node>> createSoftAssignmentQueue() {
            ArrayList<Tuple<AssignmentPlan.Deployment, AssignmentPlan.Node>> queue = new ArrayList<Tuple<AssignmentPlan.Deployment, AssignmentPlan.Node>>();
            RandomizedAssignmentRounding.this.deployments.forEach(m -> RandomizedAssignmentRounding.this.nodes.forEach(n -> {
                if (this.isSoftAssignment((AssignmentPlan.Deployment)m, (AssignmentPlan.Node)n)) {
                    queue.add(Tuple.tuple((Object)m, (Object)n));
                }
            }));
            queue.sort(Comparator.comparingDouble(this::assignmentDistanceFromZeroOrOneOrder).thenComparingDouble(this::assignmentMostRemainingThreadsOrder));
            return queue;
        }

        private double assignmentDistanceFromZeroOrOneOrder(Tuple<AssignmentPlan.Deployment, AssignmentPlan.Node> assignment) {
            return Math.min(this.assignments.get(assignment), 1.0 - this.assignments.get(assignment));
        }

        private double assignmentMostRemainingThreadsOrder(Tuple<AssignmentPlan.Deployment, AssignmentPlan.Node> assignment) {
            return -this.allocations.get(assignment).doubleValue() * (double)((AssignmentPlan.Deployment)assignment.v1()).threadsPerAllocation();
        }

        private void doRandomizedRounding(List<Tuple<AssignmentPlan.Deployment, AssignmentPlan.Node>> softAssignmentQueue) {
            for (Tuple<AssignmentPlan.Deployment, AssignmentPlan.Node> assignment : softAssignmentQueue) {
                int roundedAllocations;
                if (!this.isSoftAssignment((AssignmentPlan.Deployment)assignment.v1(), (AssignmentPlan.Node)assignment.v2())) continue;
                AssignmentPlan.Deployment m = (AssignmentPlan.Deployment)assignment.v1();
                AssignmentPlan.Node n = (AssignmentPlan.Node)assignment.v2();
                double roundUpProbability = this.allocations.get(assignment) - Math.floor(this.allocations.get(assignment));
                int n2 = roundedAllocations = RandomizedAssignmentRounding.this.random.nextDouble() < roundUpProbability ? (int)Math.ceil(this.allocations.get(assignment)) : (int)Math.floor(this.allocations.get(assignment));
                if (m.estimateMemoryUsageBytes(roundedAllocations) > this.resourceTracker.remainingNodeMemory.get(n) || m.threadsPerAllocation() > this.resourceTracker.remainingNodeCores.get(n) || roundedAllocations == 0 || RandomizedAssignmentRounding.this.random.nextDouble() > this.assignments.get(assignment)) {
                    this.unassign(assignment);
                    this.assignUnderSubscribedNodes(Set.of(n));
                    continue;
                }
                roundedAllocations = m.findOptimalAllocations(Math.min(roundedAllocations, this.resourceTracker.remainingNodeCores.get(n) / m.threadsPerAllocation()), this.resourceTracker.remainingNodeMemory.get(n));
                this.assignModelToNode(m, n, roundedAllocations);
                this.unassignOversizedModels(n);
                this.assignExcessCores(n);
            }
        }

        private void unassignOversizedModels(AssignmentPlan.Node n) {
            for (AssignmentPlan.Deployment m : RandomizedAssignmentRounding.this.deployments) {
                Tuple assignment = Tuple.tuple((Object)m, (Object)n);
                int roundedAllocations = (int)Math.ceil(this.allocations.get(assignment));
                if (!(this.assignments.get(assignment) < 1.0) || m.minimumMemoryRequiredBytes() <= this.resourceTracker.remainingNodeMemory.get(n)) continue;
                this.unassign((Tuple<AssignmentPlan.Deployment, AssignmentPlan.Node>)assignment);
            }
        }

        private AssignmentPlan toPlan() {
            AssignmentPlan.Builder builder = AssignmentPlan.builder(RandomizedAssignmentRounding.this.nodes, RandomizedAssignmentRounding.this.deployments);
            for (Map.Entry<Tuple<AssignmentPlan.Deployment, AssignmentPlan.Node>, Integer> assignment : this.tryAssigningRemainingCores().entrySet()) {
                if (!builder.canAssign((AssignmentPlan.Deployment)assignment.getKey().v1(), (AssignmentPlan.Node)assignment.getKey().v2(), assignment.getValue())) continue;
                builder.assignModelToNode((AssignmentPlan.Deployment)assignment.getKey().v1(), (AssignmentPlan.Node)assignment.getKey().v2(), assignment.getValue());
            }
            return builder.build();
        }

        private Map<Tuple<AssignmentPlan.Deployment, AssignmentPlan.Node>, Integer> tryAssigningRemainingCores() {
            HashMap<Tuple<AssignmentPlan.Deployment, AssignmentPlan.Node>, Integer> resultAllocations = new HashMap<Tuple<AssignmentPlan.Deployment, AssignmentPlan.Node>, Integer>();
            ResourceTracker resourceTracker = new ResourceTracker(RandomizedAssignmentRounding.this.nodes, RandomizedAssignmentRounding.this.deployments);
            for (AssignmentPlan.Deployment m2 : RandomizedAssignmentRounding.this.deployments) {
                for (AssignmentPlan.Node n2 : RandomizedAssignmentRounding.this.nodes) {
                    Tuple assignment = Tuple.tuple((Object)m2, (Object)n2);
                    int allocations = (int)Math.floor(this.allocations.getOrDefault(assignment, 0.0));
                    resultAllocations.put((Tuple<AssignmentPlan.Deployment, AssignmentPlan.Node>)assignment, allocations);
                    if (allocations <= 0) continue;
                    resourceTracker.assign(m2, n2, allocations);
                }
            }
            block2: for (AssignmentPlan.Deployment m2 : RandomizedAssignmentRounding.this.deployments.stream().filter(m -> resourceTracker.remainingModelAllocations.get(m) > 0).sorted(Comparator.comparingDouble(AssignmentHolder::remainingModelOrder)).toList()) {
                for (AssignmentPlan.Node n2 : RandomizedAssignmentRounding.this.nodes.stream().filter(n -> resourceTracker.remainingNodeMemory.get(n) >= m2.minimumMemoryRequiredBytes() && resourceTracker.remainingNodeCores.get(n) >= m2.threadsPerAllocation() && (Integer)resultAllocations.get(Tuple.tuple((Object)m2, (Object)n)) == 0).sorted(Comparator.comparingDouble(n -> AssignmentHolder.remainingNodeOrder(n, m2, resourceTracker.remainingNodeCores.get(n), resourceTracker.remainingNodeMemory.get(n), resourceTracker.remainingModelAllocations.get(m2)))).toList()) {
                    int assigningAllocations = Math.min(resourceTracker.remainingNodeCores.get(n2) / m2.threadsPerAllocation(), Math.min(resourceTracker.remainingModelAllocations.get(m2), m2.findOptimalAllocations(resourceTracker.remainingNodeCores.get(n2) / m2.threadsPerAllocation(), resourceTracker.remainingModelAllocations.get(m2).intValue())));
                    resourceTracker.assign(m2, n2, assigningAllocations);
                    resultAllocations.put((Tuple<AssignmentPlan.Deployment, AssignmentPlan.Node>)Tuple.tuple((Object)m2, (Object)n2), assigningAllocations);
                    if (resourceTracker.remainingModelAllocations.get(m2) != 0) continue;
                    continue block2;
                }
            }
            return resultAllocations;
        }

        private static double remainingNodeOrder(AssignmentPlan.Node n, AssignmentPlan.Deployment m, int remainingNodeCores, long remainingNodeMemory, int remainingModelAllocations) {
            return (double)(!m.currentAllocationsByNodeId().containsKey(n.id()) ? 1 : 0) + (remainingNodeCores <= remainingModelAllocations * m.threadsPerAllocation() ? 0.0 : 0.5) + 0.01 * (double)RandomizedAssignmentRounding.distance(remainingNodeCores, remainingModelAllocations * m.threadsPerAllocation()) + 0.01 * (double)remainingNodeMemory;
        }
    }

    private static class ResourceTracker {
        final Set<Tuple<AssignmentPlan.Deployment, AssignmentPlan.Node>> assignments = new HashSet<Tuple<AssignmentPlan.Deployment, AssignmentPlan.Node>>();
        final Map<AssignmentPlan.Node, Long> remainingNodeMemory;
        final Map<AssignmentPlan.Node, Integer> remainingNodeCores;
        final Map<AssignmentPlan.Deployment, Integer> remainingModelAllocations;

        ResourceTracker(Collection<AssignmentPlan.Node> nodes, Collection<AssignmentPlan.Deployment> deployments) {
            this.remainingNodeMemory = Maps.newHashMapWithExpectedSize((int)nodes.size());
            this.remainingNodeCores = Maps.newHashMapWithExpectedSize((int)nodes.size());
            this.remainingModelAllocations = Maps.newHashMapWithExpectedSize((int)deployments.size());
            nodes.forEach(n -> {
                this.remainingNodeMemory.put((AssignmentPlan.Node)n, n.availableMemoryBytes());
                this.remainingNodeCores.put((AssignmentPlan.Node)n, n.cores());
            });
            for (AssignmentPlan.Deployment m : deployments) {
                for (AssignmentPlan.Node n2 : nodes) {
                    if (!m.currentAllocationsByNodeId().containsKey(n2.id())) continue;
                    this.assignments.add((Tuple<AssignmentPlan.Deployment, AssignmentPlan.Node>)Tuple.tuple((Object)m, (Object)n2));
                }
                this.remainingModelAllocations.put(m, m.allocations());
            }
        }

        ResourceTracker(ResourceTracker copy) {
            this.assignments.addAll(copy.assignments);
            this.remainingNodeMemory = new HashMap<AssignmentPlan.Node, Long>(copy.remainingNodeMemory);
            this.remainingNodeCores = new HashMap<AssignmentPlan.Node, Integer>(copy.remainingNodeCores);
            this.remainingModelAllocations = new HashMap<AssignmentPlan.Deployment, Integer>(copy.remainingModelAllocations);
        }

        void assign(AssignmentPlan.Deployment m, AssignmentPlan.Node n, int allocations) {
            if (!this.assignments.contains(Tuple.tuple((Object)m, (Object)n))) {
                this.assignments.add((Tuple<AssignmentPlan.Deployment, AssignmentPlan.Node>)Tuple.tuple((Object)m, (Object)n));
                this.remainingNodeMemory.compute(n, (k, v) -> v - m.estimateMemoryUsageBytes(allocations));
            }
            this.remainingNodeCores.compute(n, (k, v) -> v - allocations * m.threadsPerAllocation());
            this.remainingModelAllocations.compute(m, (k, v) -> v - allocations);
        }
    }
}

