/*
 * Decompiled with CFR 0.152.
 */
package org.elasticsearch.xpack.ml.inference.assignment.planning;

import java.util.ArrayList;
import java.util.Collection;
import java.util.Comparator;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;
import java.util.stream.Collectors;
import org.elasticsearch.common.unit.ByteSizeValue;
import org.elasticsearch.common.util.Maps;
import org.elasticsearch.core.Tuple;
import org.elasticsearch.xpack.core.ml.action.StartTrainedModelDeploymentAction;
import org.elasticsearch.xpack.core.ml.inference.assignment.AdaptiveAllocationsSettings;
import org.elasticsearch.xpack.core.ml.inference.assignment.Priority;

public class AssignmentPlan
implements Comparable<AssignmentPlan> {
    private final Map<Deployment, Map<Node, Integer>> assignments;
    private final Map<String, Long> remainingNodeMemory;
    private final Map<String, Integer> remainingNodeCores;
    private final Map<Deployment, Integer> remainingModelAllocations;

    private AssignmentPlan(Map<Deployment, Map<Node, Integer>> assignments, Map<Node, Long> remainingNodeMemory, Map<Node, Integer> remainingNodeCores, Map<Deployment, Integer> remainingModelAllocations) {
        this.assignments = Objects.requireNonNull(assignments);
        this.remainingNodeMemory = remainingNodeMemory.entrySet().stream().collect(Collectors.toMap(e -> ((Node)e.getKey()).id(), e -> (Long)e.getValue()));
        this.remainingNodeCores = remainingNodeCores.entrySet().stream().collect(Collectors.toMap(e -> ((Node)e.getKey()).id(), e -> (Integer)e.getValue()));
        this.remainingModelAllocations = Objects.requireNonNull(remainingModelAllocations);
    }

    public Set<Deployment> models() {
        return this.assignments.keySet();
    }

    public Optional<Map<Node, Integer>> assignments(Deployment deployment) {
        Map<Node, Integer> modelAssignments = this.assignments.get(deployment);
        return modelAssignments == null || modelAssignments.isEmpty() ? Optional.empty() : Optional.of(modelAssignments);
    }

    @Override
    public int compareTo(AssignmentPlan o) {
        return Comparator.comparing(AssignmentPlan::computeQuality).compare(this, o);
    }

    public boolean satisfiesCurrentAssignments() {
        return this.models().stream().allMatch(this::isSatisfyingCurrentAssignmentsForModel);
    }

    private boolean isSatisfyingCurrentAssignmentsForModel(Deployment m) {
        if (m.currentAllocationsByNodeId().isEmpty()) {
            return true;
        }
        Map<Node, Integer> nodeAssignments = this.assignments.get(m);
        int currentAllocations = nodeAssignments.values().stream().mapToInt(Integer::intValue).sum();
        return currentAllocations >= m.getCurrentAssignedAllocations();
    }

    public boolean satisfiesAllocations(Deployment m) {
        return this.remainingModelAllocations.getOrDefault(m, 0) == 0;
    }

    public boolean satisfiesAllModels() {
        return this.models().stream().allMatch(this::satisfiesAllocations);
    }

    public boolean arePreviouslyAssignedModelsAssigned() {
        return this.models().stream().filter(Deployment::hasEverBeenAllocated).map(this::totalAllocations).allMatch(totalAllocations -> totalAllocations > 0);
    }

    public long countPreviouslyAssignedModelsThatAreStillAssigned() {
        return this.models().stream().filter(Deployment::hasEverBeenAllocated).map(this::totalAllocations).filter(totalAllocations -> totalAllocations > 0).count();
    }

    public int getRemainingNodeCores(String nodeId) {
        return this.remainingNodeCores.getOrDefault(nodeId, 0);
    }

    public long getRemainingNodeMemory(String nodeId) {
        return this.remainingNodeMemory.getOrDefault(nodeId, 0L);
    }

    public int totalAllocations(Deployment m) {
        if (!this.assignments.containsKey(m)) {
            return 0;
        }
        return this.assignments.get(m).values().stream().mapToInt(Integer::intValue).sum();
    }

    private Quality computeQuality() {
        boolean isSatisfyingPreviousAssignments = true;
        double weighedAllocationsScore = 0.0;
        double memoryScore = 0.0;
        for (Map.Entry<Deployment, Map<Node, Integer>> entry : this.assignments.entrySet()) {
            Deployment m = entry.getKey();
            boolean bl = isSatisfyingPreviousAssignments = isSatisfyingPreviousAssignments && this.isSatisfyingCurrentAssignmentsForModel(m);
            Map<Node, Integer> modelAssignments = entry.getValue();
            if (modelAssignments == null) continue;
            for (Map.Entry<Node, Integer> nodeAllocations : modelAssignments.entrySet()) {
                Node n = nodeAllocations.getKey();
                weighedAllocationsScore += (1.0 + 0.1 * (double)(m.currentAllocationsByNodeId().containsKey(n.id()) ? 1 : 0)) * (double)modelAssignments.get(n).intValue();
                memoryScore -= (double)(nodeAllocations.getValue() > 0 ? m.memoryBytes() : 0L);
            }
        }
        return new Quality(isSatisfyingPreviousAssignments, weighedAllocationsScore, memoryScore);
    }

    public String prettyPrint() {
        if (this.assignments.isEmpty()) {
            return "Empty plan";
        }
        HashMap nodeToModel = new HashMap();
        for (Deployment m : this.assignments.keySet()) {
            for (Node n : this.assignments.get(m).keySet()) {
                List<Tuple> allocationsPerModel = nodeToModel.containsKey(n) ? (List)nodeToModel.get(n) : new ArrayList();
                allocationsPerModel.add(Tuple.tuple((Object)m, (Object)this.assignments.get(m).get(n)));
                nodeToModel.put(n, allocationsPerModel);
            }
        }
        StringBuilder msg = new StringBuilder();
        List<Node> nodes = nodeToModel.keySet().stream().sorted(Comparator.comparing(Node::id)).toList();
        for (int i = 0; i < nodes.size(); ++i) {
            Node n;
            n = nodes.get(i);
            msg.append(n);
            msg.append(" ->");
            for (Tuple modelAllocations : ((List)nodeToModel.get(n)).stream().sorted(Comparator.comparing(x -> ((Deployment)x.v1()).id())).toList()) {
                if ((Integer)modelAllocations.v2() <= 0) continue;
                msg.append(" ");
                msg.append(((Deployment)modelAllocations.v1()).id());
                msg.append(" (mem = ");
                msg.append(ByteSizeValue.ofBytes((long)((Deployment)modelAllocations.v1()).memoryBytes()));
                msg.append(")");
                msg.append(" (allocations = ");
                msg.append(modelAllocations.v2());
                msg.append("/");
                msg.append(((Deployment)modelAllocations.v1()).allocations());
                msg.append(")");
                msg.append(" (threads_per_allocation = ");
                msg.append(((Deployment)modelAllocations.v1()).threadsPerAllocation());
                msg.append(")");
            }
            if (i >= nodes.size() - 1) continue;
            msg.append('\n');
        }
        return msg.toString();
    }

    public static Builder builder(Collection<Node> nodes, Collection<Deployment> deployments) {
        return new Builder(nodes, deployments);
    }

    public record Deployment(String id, long memoryBytes, int allocations, int threadsPerAllocation, Map<String, Integer> currentAllocationsByNodeId, int maxAssignedAllocations, AdaptiveAllocationsSettings adaptiveAllocationsSettings, Priority priority, long perDeploymentMemoryBytes, long perAllocationMemoryBytes) {
        public Deployment(String id, long modelBytes, int allocations, int threadsPerAllocation, Map<String, Integer> currentAllocationsByNodeId, int maxAssignedAllocations, AdaptiveAllocationsSettings adaptiveAllocationsSettings, long perDeploymentMemoryBytes, long perAllocationMemoryBytes) {
            this(id, modelBytes, allocations, threadsPerAllocation, currentAllocationsByNodeId, maxAssignedAllocations, adaptiveAllocationsSettings, Priority.NORMAL, perDeploymentMemoryBytes, perAllocationMemoryBytes);
        }

        public AdaptiveAllocationsSettings getAdaptiveAllocationsSettings() {
            return this.adaptiveAllocationsSettings;
        }

        int getCurrentAssignedAllocations() {
            return this.currentAllocationsByNodeId.values().stream().mapToInt(Integer::intValue).sum();
        }

        boolean hasEverBeenAllocated() {
            return this.maxAssignedAllocations > 0;
        }

        public long estimateMemoryUsageBytes(int allocations) {
            return StartTrainedModelDeploymentAction.estimateMemoryUsageBytes((String)this.id, (long)this.memoryBytes, (long)this.perDeploymentMemoryBytes, (long)this.perAllocationMemoryBytes, (int)allocations);
        }

        long estimateAdditionalMemoryUsageBytes(int allocationsOld, int allocationsNew) {
            return StartTrainedModelDeploymentAction.estimateMemoryUsageBytes((String)this.id, (long)this.memoryBytes, (long)this.perDeploymentMemoryBytes, (long)this.perAllocationMemoryBytes, (int)allocationsNew) - StartTrainedModelDeploymentAction.estimateMemoryUsageBytes((String)this.id, (long)this.memoryBytes, (long)this.perDeploymentMemoryBytes, (long)this.perAllocationMemoryBytes, (int)allocationsOld);
        }

        long minimumMemoryRequiredBytes() {
            return StartTrainedModelDeploymentAction.estimateMemoryUsageBytes((String)this.id, (long)this.memoryBytes, (long)this.perDeploymentMemoryBytes, (long)this.perAllocationMemoryBytes, (int)1);
        }

        int findOptimalAllocations(int maxAllocations, long availableMemoryBytes) {
            if (this.perDeploymentMemoryBytes > 0L && this.perAllocationMemoryBytes > 0L) {
                return (int)Math.max(Math.min((long)maxAllocations, Math.floorDiv(availableMemoryBytes - this.estimateMemoryUsageBytes(0), this.perAllocationMemoryBytes)), 0L);
            }
            return maxAllocations;
        }

        int findExcessAllocations(int maxAllocations, long availableMemoryBytes) {
            if (this.perDeploymentMemoryBytes > 0L && this.perAllocationMemoryBytes > 0L) {
                return (int)Math.min((long)maxAllocations, Math.floorDiv(availableMemoryBytes, this.perAllocationMemoryBytes));
            }
            return maxAllocations;
        }

        @Override
        public String toString() {
            return this.id + " (mem = " + String.valueOf(ByteSizeValue.ofBytes((long)this.memoryBytes)) + ") (allocations = " + this.allocations + ") (threads_per_allocation = " + this.threadsPerAllocation + ") (current_allocations = " + String.valueOf(this.currentAllocationsByNodeId) + ") (max_assigned_allocations = " + this.maxAssignedAllocations + ") (memory_usage = " + String.valueOf(ByteSizeValue.ofBytes((long)this.estimateMemoryUsageBytes(this.allocations))) + ")";
        }
    }

    public record Node(String id, long availableMemoryBytes, int cores) {
        @Override
        public String toString() {
            return this.id + " (mem = " + String.valueOf(ByteSizeValue.ofBytes((long)this.availableMemoryBytes)) + ") (cores = " + this.cores + ")";
        }
    }

    private record Quality(boolean satisfiesPreviousAssignments, double allocationsScore, double memoryScore) implements Comparable<Quality>
    {
        private int previousAssignmentScore() {
            return this.satisfiesPreviousAssignments ? 1 : 0;
        }

        @Override
        public int compareTo(Quality o) {
            return Comparator.comparingInt(Quality::previousAssignmentScore).thenComparingDouble(Quality::allocationsScore).thenComparingDouble(Quality::memoryScore).compare(this, o);
        }
    }

    public static class Builder {
        private final Map<Deployment, Map<Node, Integer>> assignments;
        private final Map<Node, Long> remainingNodeMemory;
        private final Map<Node, Integer> remainingNodeCores;
        private final Map<Deployment, Integer> remainingModelAllocations;

        private Builder(Collection<Node> nodes, Collection<Deployment> deployments) {
            if (new HashSet<Node>(nodes).size() != nodes.size()) {
                throw new IllegalArgumentException("there should be no duplicate nodes");
            }
            if (new HashSet<Deployment>(deployments).size() != deployments.size()) {
                throw new IllegalArgumentException("there should be no duplicate models");
            }
            this.assignments = Maps.newHashMapWithExpectedSize((int)(nodes.size() * deployments.size()));
            this.remainingNodeMemory = Maps.newHashMapWithExpectedSize((int)nodes.size());
            this.remainingNodeCores = Maps.newHashMapWithExpectedSize((int)nodes.size());
            this.remainingModelAllocations = Maps.newHashMapWithExpectedSize((int)deployments.size());
            nodes.forEach(n -> {
                this.remainingNodeMemory.put((Node)n, n.availableMemoryBytes());
                this.remainingNodeCores.put((Node)n, n.cores());
            });
            for (Deployment m : deployments) {
                HashMap<Node, Integer> nodeAssignments = new HashMap<Node, Integer>();
                for (Node n2 : nodes) {
                    nodeAssignments.put(n2, 0);
                }
                this.assignments.put(m, nodeAssignments);
                this.remainingModelAllocations.put(m, m.allocations());
            }
        }

        int getRemainingCores(Node n) {
            return this.remainingNodeCores.get(n);
        }

        long getRemainingMemory(Node n) {
            return this.remainingNodeMemory.get(n);
        }

        int getRemainingThreads(Deployment m) {
            return this.remainingModelAllocations.get(m) * m.threadsPerAllocation();
        }

        int getRemainingAllocations(Deployment m) {
            return this.remainingModelAllocations.get(m);
        }

        boolean canAssign(Deployment deployment, Node node, int allocations) {
            long requiredMemory = this.getDeploymentMemoryRequirement(deployment, node, allocations);
            return this.canAssign(deployment, node, allocations, requiredMemory);
        }

        boolean canAssign(Deployment deployment, Node node, int allocations, long requiredMemory) {
            return requiredMemory <= this.remainingNodeMemory.get(node) && (deployment.priority == Priority.LOW || allocations * deployment.threadsPerAllocation() <= this.remainingNodeCores.get(node));
        }

        public long getDeploymentMemoryRequirement(Deployment deployment, Node node, int newAllocations) {
            int assignedAllocations = this.getAssignedAllocations(deployment, node);
            if (assignedAllocations > 0) {
                return deployment.estimateAdditionalMemoryUsageBytes(assignedAllocations, assignedAllocations + newAllocations);
            }
            return deployment.estimateMemoryUsageBytes(newAllocations);
        }

        public Builder assignModelToNode(Deployment deployment, Node node, int allocations) {
            return this.assignModelToNode(deployment, node, allocations, this.getDeploymentMemoryRequirement(deployment, node, allocations));
        }

        public Builder assignModelToNode(Deployment deployment, Node node, int allocations, long requiredMemory) {
            if (allocations <= 0) {
                return this;
            }
            if (requiredMemory > this.remainingNodeMemory.get(node)) {
                throw new IllegalArgumentException("not enough memory on node [" + node.id() + "] to assign [" + allocations + "] allocations to deployment [" + deployment.id() + "]");
            }
            if (deployment.priority == Priority.NORMAL && allocations * deployment.threadsPerAllocation() > this.remainingNodeCores.get(node)) {
                throw new IllegalArgumentException("not enough cores on node [" + node.id() + "] to assign [" + allocations + "] allocations to deployment [" + deployment.id() + "]; required threads per allocation [" + deployment.threadsPerAllocation() + "]");
            }
            this.assignments.get(deployment).compute(node, (n, remAllocations) -> remAllocations + allocations);
            this.accountMemory(deployment, node, requiredMemory);
            if (deployment.priority == Priority.NORMAL) {
                this.remainingNodeCores.compute(node, (n, remCores) -> remCores - allocations * deployment.threadsPerAllocation());
            }
            this.remainingModelAllocations.compute(deployment, (m, remModelThreads) -> remModelThreads - allocations);
            return this;
        }

        private int getAssignedAllocations(Deployment deployment, Node node) {
            int currentAllocations = Builder.getCurrentAllocations(deployment, node);
            int assignmentAllocations = this.assignments.get(deployment).get(node);
            return currentAllocations + assignmentAllocations;
        }

        private static int getCurrentAllocations(Deployment m, Node n) {
            return m.currentAllocationsByNodeId.containsKey(n.id()) ? m.currentAllocationsByNodeId.get(n.id()) : 0;
        }

        public void accountMemory(Deployment m, Node n) {
            long requiredMemory = this.getDeploymentMemoryRequirement(m, n, Builder.getCurrentAllocations(m, n));
            this.accountMemory(m, n, requiredMemory);
        }

        public void accountMemory(Deployment m, Node n, long requiredMemory) {
            this.remainingNodeMemory.computeIfPresent(n, (k, v) -> v - requiredMemory);
            if (this.remainingNodeMemory.containsKey(n) && this.remainingNodeMemory.get(n) < 0L) {
                throw new IllegalArgumentException("not enough memory on node [" + n.id() + "] to assign model [" + m.id() + "]");
            }
        }

        public AssignmentPlan build() {
            HashMap<Deployment, Map<Node, Integer>> finalAssignments = new HashMap<Deployment, Map<Node, Integer>>();
            for (Deployment m : this.assignments.keySet()) {
                HashMap<Node, Integer> allocationsPerNode = new HashMap<Node, Integer>();
                for (Map.Entry<Node, Integer> entry : this.assignments.get(m).entrySet()) {
                    if (entry.getValue() <= 0) continue;
                    allocationsPerNode.put(entry.getKey(), entry.getValue());
                }
                finalAssignments.put(m, allocationsPerNode);
            }
            return new AssignmentPlan(finalAssignments, this.remainingNodeMemory, this.remainingNodeCores, this.remainingModelAllocations);
        }
    }
}

