/*
 * Decompiled with CFR 0.152.
 */
package org.elasticsearch.xpack.ml.inference.assignment.planning;

import java.security.AccessController;
import java.util.ArrayList;
import java.util.Comparator;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Random;
import java.util.function.Function;
import java.util.stream.Collectors;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.elasticsearch.common.unit.ByteSizeValue;
import org.elasticsearch.core.SuppressForbidden;
import org.elasticsearch.core.Tuple;
import org.elasticsearch.xpack.ml.inference.assignment.planning.AssignmentPlan;
import org.elasticsearch.xpack.ml.inference.assignment.planning.RandomizedAssignmentRounding;
import org.ojalgo.optimisation.Expression;
import org.ojalgo.optimisation.ExpressionsBasedModel;
import org.ojalgo.optimisation.Optimisation;
import org.ojalgo.optimisation.Variable;
import org.ojalgo.structure.Access1D;
import org.ojalgo.type.CalendarDateDuration;
import org.ojalgo.type.CalendarDateUnit;

class LinearProgrammingPlanSolver {
    private static final Logger logger = LogManager.getLogger(LinearProgrammingPlanSolver.class);
    private static final long RANDOMIZATION_SEED = 738921734L;
    private static final double L1 = 0.9;
    private static final double INITIAL_W = 0.2;
    private static final int RANDOMIZED_ROUNDING_ROUNDS = 20;
    private static final int MEMORY_COMPLEXITY_SPARSE_THRESHOLD = 4000000;
    private static final int MEMORY_COMPLEXITY_LIMIT = 10000000;
    private final Random random = new Random(738921734L);
    private final List<AssignmentPlan.Node> nodes;
    private final List<AssignmentPlan.Deployment> deployments;
    private final Map<AssignmentPlan.Node, Double> normalizedMemoryPerNode;
    private final Map<AssignmentPlan.Node, Integer> coresPerNode;
    private final Map<AssignmentPlan.Deployment, Double> normalizedMemoryPerModel;
    private final Map<AssignmentPlan.Deployment, Double> normalizedMemoryPerAllocation;
    private final Map<AssignmentPlan.Deployment, Double> normalizedMinimumDeploymentMemoryRequired;
    private final int maxNodeCores;
    private final long maxModelMemoryBytes;

    LinearProgrammingPlanSolver(List<AssignmentPlan.Node> nodes, List<AssignmentPlan.Deployment> deployments) {
        this.nodes = nodes;
        this.maxNodeCores = this.nodes.stream().map(AssignmentPlan.Node::cores).max(Integer::compareTo).orElse(0);
        long maxNodeMemory = nodes.stream().map(AssignmentPlan.Node::availableMemoryBytes).max(Long::compareTo).orElse(0L);
        this.deployments = deployments.stream().filter(m -> !m.currentAllocationsByNodeId().isEmpty() || m.memoryBytes() <= maxNodeMemory).filter(m -> m.threadsPerAllocation() <= this.maxNodeCores).toList();
        this.maxModelMemoryBytes = this.deployments.stream().map(m -> m.minimumMemoryRequiredBytes()).max(Long::compareTo).orElse(1L);
        this.normalizedMemoryPerNode = this.nodes.stream().collect(Collectors.toMap(Function.identity(), n -> (double)n.availableMemoryBytes() / (double)this.maxModelMemoryBytes));
        this.coresPerNode = this.nodes.stream().collect(Collectors.toMap(Function.identity(), AssignmentPlan.Node::cores));
        this.normalizedMemoryPerModel = this.deployments.stream().collect(Collectors.toMap(Function.identity(), m -> (double)m.estimateMemoryUsageBytes(0) / (double)this.maxModelMemoryBytes));
        this.normalizedMemoryPerAllocation = this.deployments.stream().collect(Collectors.toMap(Function.identity(), m -> (double)m.perAllocationMemoryBytes() / (double)this.maxModelMemoryBytes));
        this.normalizedMinimumDeploymentMemoryRequired = this.deployments.stream().collect(Collectors.toMap(Function.identity(), m -> (double)m.minimumMemoryRequiredBytes() / (double)this.maxModelMemoryBytes));
    }

    AssignmentPlan solvePlan(boolean useBinPackingOnly) {
        if (this.deployments.isEmpty() || this.maxNodeCores == 0) {
            return AssignmentPlan.builder(this.nodes, this.deployments).build();
        }
        Tuple<Map<Tuple<AssignmentPlan.Deployment, AssignmentPlan.Node>, Double>, AssignmentPlan> weightsAndBinPackingPlan = this.calculateWeightsAndBinPackingPlan();
        if (useBinPackingOnly) {
            return (AssignmentPlan)weightsAndBinPackingPlan.v2();
        }
        HashMap<Tuple<AssignmentPlan.Deployment, AssignmentPlan.Node>, Double> allocationValues = new HashMap<Tuple<AssignmentPlan.Deployment, AssignmentPlan.Node>, Double>();
        HashMap<Tuple<AssignmentPlan.Deployment, AssignmentPlan.Node>, Double> assignmentValues = new HashMap<Tuple<AssignmentPlan.Deployment, AssignmentPlan.Node>, Double>();
        if (!this.solveLinearProgram((Map)weightsAndBinPackingPlan.v1(), allocationValues, assignmentValues)) {
            return (AssignmentPlan)weightsAndBinPackingPlan.v2();
        }
        RandomizedAssignmentRounding randomizedRounding = new RandomizedAssignmentRounding(this.random, 20, this.nodes, this.deployments);
        AssignmentPlan assignmentPlan = randomizedRounding.computePlan(allocationValues, assignmentValues);
        AssignmentPlan binPackingPlan = (AssignmentPlan)weightsAndBinPackingPlan.v2();
        if (binPackingPlan.compareTo(assignmentPlan) > 0) {
            assignmentPlan = binPackingPlan;
            logger.debug(() -> "Best plan is from bin packing");
        } else {
            logger.debug(() -> "Best plan is from LP solver");
        }
        return assignmentPlan;
    }

    private double weightForAllocationVar(AssignmentPlan.Deployment m, AssignmentPlan.Node n, Map<Tuple<AssignmentPlan.Deployment, AssignmentPlan.Node>, Double> weights) {
        return 1.0 + weights.get(Tuple.tuple((Object)m, (Object)n)) - (double)(m.minimumMemoryRequiredBytes() > n.availableMemoryBytes() ? 10 : 0) - 0.9 * this.normalizedMemoryPerModel.get(m) / (double)this.maxNodeCores;
    }

    private Tuple<Map<Tuple<AssignmentPlan.Deployment, AssignmentPlan.Node>, Double>, AssignmentPlan> calculateWeightsAndBinPackingPlan() {
        logger.debug(() -> "Calculating weights and bin packing plan");
        double w = 0.2;
        double dw = w / (double)this.nodes.size() / (double)this.deployments.size();
        HashMap<Tuple, Double> weights = new HashMap<Tuple, Double>();
        AssignmentPlan.Builder assignmentPlan = AssignmentPlan.builder(this.nodes, this.deployments);
        for (AssignmentPlan.Deployment m : this.deployments.stream().sorted(Comparator.comparingDouble(this::descendingSizeAnyFitsModelOrder)).toList()) {
            double lastW;
            block1: do {
                lastW = w;
                List<AssignmentPlan.Node> orderedNodes = this.nodes.stream().sorted(Comparator.comparingDouble(n -> this.descendingSizeAnyFitsNodeOrder((AssignmentPlan.Node)n, m, assignmentPlan))).toList();
                for (AssignmentPlan.Node n2 : orderedNodes) {
                    int allocations = m.findOptimalAllocations(Math.min(assignmentPlan.getRemainingCores(n2) / m.threadsPerAllocation(), assignmentPlan.getRemainingAllocations(m)), assignmentPlan.getRemainingMemory(n2));
                    if (allocations <= 0 || !assignmentPlan.canAssign(m, n2, allocations)) continue;
                    assignmentPlan.assignModelToNode(m, n2, allocations);
                    weights.put(Tuple.tuple((Object)m, (Object)n2), w);
                    w -= dw;
                    continue block1;
                }
            } while (lastW != w && assignmentPlan.getRemainingAllocations(m) > 0);
        }
        double finalW = w;
        for (AssignmentPlan.Deployment m : this.deployments) {
            for (AssignmentPlan.Node n3 : this.nodes) {
                weights.computeIfAbsent(Tuple.tuple((Object)m, (Object)n3), key -> this.random.nextDouble(LinearProgrammingPlanSolver.minWeight(m, n3, finalW), LinearProgrammingPlanSolver.maxWeight(m, n3, finalW)));
            }
        }
        logger.trace(() -> "Weights = " + weights);
        AssignmentPlan binPackingPlan = assignmentPlan.build();
        logger.debug(() -> "Bin packing plan =\n" + binPackingPlan.prettyPrint());
        return Tuple.tuple(weights, (Object)binPackingPlan);
    }

    private double descendingSizeAnyFitsModelOrder(AssignmentPlan.Deployment m) {
        return (double)(m.currentAllocationsByNodeId().isEmpty() ? 1 : 2) * -this.normalizedMinimumDeploymentMemoryRequired.get(m).doubleValue() * (double)m.threadsPerAllocation();
    }

    private double descendingSizeAnyFitsNodeOrder(AssignmentPlan.Node n, AssignmentPlan.Deployment m, AssignmentPlan.Builder assignmentPlan) {
        return (double)((m.currentAllocationsByNodeId().containsKey(n.id()) ? 0 : 1) + (assignmentPlan.getRemainingCores(n) >= assignmentPlan.getRemainingThreads(m) ? 0 : 1)) + 0.01 * (double)LinearProgrammingPlanSolver.distance(assignmentPlan.getRemainingCores(n), assignmentPlan.getRemainingThreads(m)) - 0.01 * this.normalizedMemoryPerNode.get(n);
    }

    @SuppressForbidden(reason="Math#abs(int) is safe here as we protect against MIN_VALUE")
    private static int distance(int x, int y) {
        int distance = x - y;
        return distance == Integer.MIN_VALUE ? Integer.MAX_VALUE : Math.abs(distance);
    }

    private static double minWeight(AssignmentPlan.Deployment m, AssignmentPlan.Node n, double w) {
        return m.currentAllocationsByNodeId().containsKey(n.id()) ? w / 2.0 : 0.0;
    }

    private static double maxWeight(AssignmentPlan.Deployment m, AssignmentPlan.Node n, double w) {
        return m.currentAllocationsByNodeId().containsKey(n.id()) ? w : w / 2.0;
    }

    private boolean solveLinearProgram(Map<Tuple<AssignmentPlan.Deployment, AssignmentPlan.Node>, Double> weights, Map<Tuple<AssignmentPlan.Deployment, AssignmentPlan.Node>, Double> allocationValues, Map<Tuple<AssignmentPlan.Deployment, AssignmentPlan.Node>, Double> assignmentValues) {
        if (this.memoryComplexity() > 10000000) {
            logger.debug(() -> "Problem size to big to solve with linear programming; falling back to bin packing solution");
            return false;
        }
        Optimisation.Options options = new Optimisation.Options().abort(new CalendarDateDuration(10.0, CalendarDateUnit.SECOND));
        if (this.memoryComplexity() > 4000000) {
            logger.debug(() -> "Problem size is large enough to switch to sparse solver");
            options.sparse = true;
        }
        ExpressionsBasedModel model = new ExpressionsBasedModel(options);
        HashMap<Tuple<AssignmentPlan.Deployment, AssignmentPlan.Node>, Variable> allocationVars = new HashMap<Tuple<AssignmentPlan.Deployment, AssignmentPlan.Node>, Variable>();
        for (AssignmentPlan.Deployment deployment : this.deployments) {
            for (AssignmentPlan.Node n : this.nodes) {
                Variable allocationVar = (Variable)((Variable)model.addVariable("allocations_of_model_" + deployment.id() + "_on_node_" + n.id()).integer(false).lower(0.0)).weight(this.weightForAllocationVar(deployment, n, weights));
                allocationVars.put((Tuple<AssignmentPlan.Deployment, AssignmentPlan.Node>)Tuple.tuple((Object)deployment, (Object)n), allocationVar);
            }
        }
        for (AssignmentPlan.Deployment deployment : this.deployments) {
            ((Expression)((Expression)model.addExpression("allocations_of_model_" + deployment.id() + "_not_more_than_required").lower((long)deployment.getCurrentAssignedAllocations())).upper((long)deployment.allocations())).setLinearFactorsSimple(this.varsForModel(deployment, allocationVars));
        }
        double[] threadsPerAllocationPerModel = this.deployments.stream().mapToDouble(m -> m.threadsPerAllocation()).toArray();
        for (AssignmentPlan.Node n : this.nodes) {
            ((Expression)model.addExpression("threads_on_node_" + n.id() + "_not_more_than_cores").upper((Comparable)this.coresPerNode.get(n))).setLinearFactors(this.varsForNode(n, allocationVars), Access1D.wrap((double[])threadsPerAllocationPerModel));
        }
        for (AssignmentPlan.Node n : this.nodes) {
            ArrayList allocations = new ArrayList();
            ArrayList modelMemories = new ArrayList();
            this.deployments.stream().filter(m -> !m.currentAllocationsByNodeId().containsKey(n.id())).forEach(m -> {
                allocations.add((Variable)allocationVars.get(Tuple.tuple((Object)m, (Object)n)));
                modelMemories.add((this.normalizedMemoryPerModel.get(m) / (double)this.coresPerNode.get(n).intValue() + this.normalizedMemoryPerAllocation.get(m)) * (double)m.threadsPerAllocation());
            });
            ((Expression)model.addExpression("used_memory_on_node_" + n.id() + "_not_more_than_available").upper((Comparable)this.normalizedMemoryPerNode.get(n))).setLinearFactors(allocations, Access1D.wrap(modelMemories));
        }
        Optimisation.Result result = LinearProgrammingPlanSolver.privilegedModelMaximise(model);
        if (!result.getState().isFeasible()) {
            logger.debug("Linear programming solution state [{}] is not feasible", (Object)result.getState());
            return false;
        }
        for (AssignmentPlan.Deployment m3 : this.deployments) {
            for (AssignmentPlan.Node n : this.nodes) {
                Tuple assignment = Tuple.tuple((Object)m3, (Object)n);
                allocationValues.put((Tuple<AssignmentPlan.Deployment, AssignmentPlan.Node>)assignment, ((Variable)allocationVars.get(assignment)).getValue().doubleValue());
                assignmentValues.put((Tuple<AssignmentPlan.Deployment, AssignmentPlan.Node>)assignment, ((Variable)allocationVars.get(assignment)).getValue().doubleValue() * (double)m3.threadsPerAllocation() / (double)this.coresPerNode.get(n).intValue());
            }
        }
        logger.debug(() -> "LP solver result =\n" + this.prettyPrintSolverResult(assignmentValues, allocationValues));
        return true;
    }

    private static Optimisation.Result privilegedModelMaximise(ExpressionsBasedModel model) {
        return AccessController.doPrivileged(() -> model.maximise());
    }

    private int memoryComplexity() {
        return (this.nodes.size() + this.deployments.size()) * this.nodes.size() * this.deployments.size();
    }

    private List<Variable> varsForModel(AssignmentPlan.Deployment m, Map<Tuple<AssignmentPlan.Deployment, AssignmentPlan.Node>, Variable> vars) {
        return this.nodes.stream().map(n -> (Variable)vars.get(Tuple.tuple((Object)m, (Object)n))).toList();
    }

    private List<Variable> varsForNode(AssignmentPlan.Node n, Map<Tuple<AssignmentPlan.Deployment, AssignmentPlan.Node>, Variable> vars) {
        return this.deployments.stream().map(m -> (Variable)vars.get(Tuple.tuple((Object)m, (Object)n))).toList();
    }

    private String prettyPrintSolverResult(Map<Tuple<AssignmentPlan.Deployment, AssignmentPlan.Node>, Double> assignmentValues, Map<Tuple<AssignmentPlan.Deployment, AssignmentPlan.Node>, Double> threadValues) {
        StringBuilder msg = new StringBuilder();
        for (int i = 0; i < this.nodes.size(); ++i) {
            AssignmentPlan.Node n = this.nodes.get(i);
            msg.append(n + " ->");
            for (AssignmentPlan.Deployment m : this.deployments) {
                if (!(threadValues.get(Tuple.tuple((Object)m, (Object)n)) > 0.0)) continue;
                msg.append(" ");
                msg.append(m.id());
                msg.append(" (mem = ");
                msg.append(ByteSizeValue.ofBytes((long)m.memoryBytes()));
                msg.append(") (allocations = ");
                msg.append(threadValues.get(Tuple.tuple((Object)m, (Object)n)));
                msg.append("/");
                msg.append(m.allocations());
                msg.append(") (y = ");
                msg.append(assignmentValues.get(Tuple.tuple((Object)m, (Object)n)));
                msg.append(")");
            }
            if (i >= this.nodes.size() - 1) continue;
            msg.append('\n');
        }
        return msg.toString();
    }
}

