/*
 * Decompiled with CFR 0.152.
 */
package org.elasticsearch.xpack.ml.inference.assignment.planning;

import java.util.ArrayList;
import java.util.Comparator;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.TreeMap;
import java.util.function.Function;
import java.util.stream.Collectors;
import org.elasticsearch.core.Strings;
import org.elasticsearch.logging.LogManager;
import org.elasticsearch.logging.Logger;
import org.elasticsearch.xpack.ml.inference.assignment.planning.AssignmentPlan;
import org.elasticsearch.xpack.ml.inference.assignment.planning.AssignmentPlanner;
import org.elasticsearch.xpack.ml.inference.assignment.planning.LinearProgrammingPlanSolver;
import org.elasticsearch.xpack.ml.inference.assignment.planning.PreserveAllAllocations;

public class ZoneAwareAssignmentPlanner {
    private static final Logger logger = LogManager.getLogger(ZoneAwareAssignmentPlanner.class);
    private final Map<List<String>, List<AssignmentPlan.Node>> nodesByZone;
    private final List<AssignmentPlan.Deployment> deployments;

    public ZoneAwareAssignmentPlanner(Map<List<String>, List<AssignmentPlan.Node>> nodesByZone, List<AssignmentPlan.Deployment> deployments) {
        this.nodesByZone = ZoneAwareAssignmentPlanner.sortByZone(Objects.requireNonNull(nodesByZone));
        this.deployments = Objects.requireNonNull(deployments);
    }

    private static Map<List<String>, List<AssignmentPlan.Node>> sortByZone(Map<List<String>, List<AssignmentPlan.Node>> nodesByZone) {
        TreeMap<List<String>, List<AssignmentPlan.Node>> sortedByZone = new TreeMap<List<String>, List<AssignmentPlan.Node>>(Comparator.comparing(zoneAttributes -> String.join((CharSequence)"", zoneAttributes)));
        sortedByZone.putAll(nodesByZone);
        return sortedByZone;
    }

    public AssignmentPlan computePlan() {
        if (this.nodesByZone.size() == 1) {
            return new AssignmentPlanner(this.nodesByZone.values().iterator().next(), this.deployments).computePlan(true);
        }
        AssignmentPlan plan = this.computePlan(false);
        if (!plan.arePreviouslyAssignedModelsAssigned()) {
            plan = this.computePlan(true);
        }
        return plan;
    }

    private AssignmentPlan computePlan(boolean tryAssigningPreviouslyAssignedModels) {
        logger.debug(() -> Strings.format((String)"computing plan%s trying to assign previously assigned models", (Object[])new Object[]{tryAssigningPreviouslyAssignedModels ? "" : " without"}));
        int remainingZones = this.nodesByZone.size();
        Map<String, Integer> modelIdToRemainingAllocations = this.deployments.stream().collect(Collectors.toMap(AssignmentPlan.Deployment::id, AssignmentPlan.Deployment::allocations));
        ArrayList<AssignmentPlan> plans = new ArrayList<AssignmentPlan>();
        for (Map.Entry<List<String>, List<AssignmentPlan.Node>> zoneToNodes : this.nodesByZone.entrySet()) {
            logger.debug(() -> Strings.format((String)"computing plan for availability zone %s", (Object[])new Object[]{zoneToNodes.getKey()}));
            AssignmentPlan plan = this.computeZonePlan(zoneToNodes.getValue(), modelIdToRemainingAllocations, remainingZones, tryAssigningPreviouslyAssignedModels);
            plan.models().forEach(m -> modelIdToRemainingAllocations.computeIfPresent(m.id(), (modelId, remainingAllocations) -> remainingAllocations - plan.totalAllocations((AssignmentPlan.Deployment)m)));
            plans.add(plan);
            --remainingZones;
        }
        AssignmentPlan plan = this.computePlanAcrossAllNodes(plans);
        logger.debug(() -> "Zone aware plan =\n" + plan.prettyPrint());
        return plan;
    }

    private AssignmentPlan computeZonePlan(List<AssignmentPlan.Node> nodes, Map<String, Integer> modelIdToRemainingAllocations, int remainingZones, boolean tryAssigningPreviouslyAssignedModels) {
        Map<String, Integer> modelIdToTargetAllocations = modelIdToRemainingAllocations.entrySet().stream().filter(e -> (Integer)e.getValue() > 0).collect(Collectors.toMap(e -> (String)e.getKey(), e -> ((Integer)e.getValue() - 1) / remainingZones + 1));
        List<AssignmentPlan.Deployment> modifiedDeployments = this.deployments.stream().filter(m -> modelIdToTargetAllocations.getOrDefault(m.id(), 0) > 0).map(m -> new AssignmentPlan.Deployment(m.id(), m.memoryBytes(), (Integer)modelIdToTargetAllocations.get(m.id()), m.threadsPerAllocation(), m.currentAllocationsByNodeId(), tryAssigningPreviouslyAssignedModels && ((Integer)modelIdToRemainingAllocations.get(m.id())).intValue() == m.allocations() ? m.maxAssignedAllocations() : 0, m.perDeploymentMemoryBytes(), m.perAllocationMemoryBytes())).toList();
        return new AssignmentPlanner(nodes, modifiedDeployments).computePlan(tryAssigningPreviouslyAssignedModels);
    }

    private AssignmentPlan computePlanAcrossAllNodes(List<AssignmentPlan> plans) {
        logger.debug(() -> "computing plan across all nodes");
        ArrayList<AssignmentPlan.Node> allNodes = new ArrayList<AssignmentPlan.Node>();
        this.nodesByZone.values().forEach(allNodes::addAll);
        Map<String, Map<String, Integer>> allocationsByNodeIdByModelId = this.mergeAllocationsByNodeIdByModelId(plans);
        List<AssignmentPlan.Deployment> modelsAccountingPlans = this.deployments.stream().map(m -> new AssignmentPlan.Deployment(m.id(), m.memoryBytes(), m.allocations(), m.threadsPerAllocation(), (Map)allocationsByNodeIdByModelId.get(m.id()), m.maxAssignedAllocations(), m.perDeploymentMemoryBytes(), m.perAllocationMemoryBytes())).toList();
        PreserveAllAllocations preserveAllAllocations = new PreserveAllAllocations(allNodes, modelsAccountingPlans);
        List<AssignmentPlan.Node> planNodes = preserveAllAllocations.nodesPreservingAllocations();
        List<AssignmentPlan.Deployment> planDeployments = preserveAllAllocations.modelsPreservingAllocations();
        AssignmentPlan plan = new LinearProgrammingPlanSolver(planNodes, planDeployments).solvePlan(false);
        plan = preserveAllAllocations.mergePreservedAllocations(plan);
        return this.swapOriginalModelsInPlan(plan, allNodes, modelsAccountingPlans);
    }

    private AssignmentPlan swapOriginalModelsInPlan(AssignmentPlan plan, List<AssignmentPlan.Node> allNodes, List<AssignmentPlan.Deployment> planDeployments) {
        Map originalModelById = this.deployments.stream().collect(Collectors.toMap(AssignmentPlan.Deployment::id, Function.identity()));
        Map originalNodeById = allNodes.stream().collect(Collectors.toMap(AssignmentPlan.Node::id, Function.identity()));
        AssignmentPlan.Builder planBuilder = AssignmentPlan.builder(allNodes, this.deployments);
        for (AssignmentPlan.Deployment m : planDeployments) {
            AssignmentPlan.Deployment originalDeployment = (AssignmentPlan.Deployment)originalModelById.get(m.id());
            Map nodeAssignments = plan.assignments(m).orElse(Map.of());
            for (Map.Entry assignment : nodeAssignments.entrySet()) {
                AssignmentPlan.Node originalNode = (AssignmentPlan.Node)originalNodeById.get(((AssignmentPlan.Node)assignment.getKey()).id());
                planBuilder.assignModelToNode(originalDeployment, originalNode, (Integer)assignment.getValue());
                planBuilder.accountMemory(originalDeployment, originalNode);
            }
        }
        return planBuilder.build();
    }

    private Map<String, Map<String, Integer>> mergeAllocationsByNodeIdByModelId(List<AssignmentPlan> plans) {
        HashMap<String, Map<String, Integer>> allocationsByNodeIdByModelId = new HashMap<String, Map<String, Integer>>();
        this.deployments.forEach(m -> allocationsByNodeIdByModelId.put(m.id(), new HashMap()));
        for (AssignmentPlan plan : plans) {
            for (AssignmentPlan.Deployment m2 : plan.models()) {
                Map nodeIdToAllocations = (Map)allocationsByNodeIdByModelId.get(m2.id());
                Optional<Map<AssignmentPlan.Node, Integer>> assignments = plan.assignments(m2);
                if (!assignments.isPresent()) continue;
                for (Map.Entry<AssignmentPlan.Node, Integer> nodeAssignments : assignments.get().entrySet()) {
                    nodeIdToAllocations.compute(nodeAssignments.getKey().id(), (nodeId, existingAllocations) -> existingAllocations == null ? (Integer)nodeAssignments.getValue() : existingAllocations + (Integer)nodeAssignments.getValue());
                }
            }
        }
        return allocationsByNodeIdByModelId;
    }
}

