/*
 * Decompiled with CFR 0.152.
 */
package org.elasticsearch.xpack.ml.autoscaling;

import java.util.ArrayList;
import java.util.Collections;
import java.util.Comparator;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.PriorityQueue;
import java.util.stream.Collectors;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.action.admin.cluster.node.stats.NodeStats;
import org.elasticsearch.action.admin.cluster.node.stats.NodesStatsRequestBuilder;
import org.elasticsearch.client.internal.Client;
import org.elasticsearch.cluster.ClusterState;
import org.elasticsearch.cluster.node.DiscoveryNode;
import org.elasticsearch.cluster.node.DiscoveryNodeRole;
import org.elasticsearch.common.settings.ClusterSettings;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.core.Strings;
import org.elasticsearch.core.TimeValue;
import org.elasticsearch.core.Tuple;
import org.elasticsearch.monitor.os.OsStats;
import org.elasticsearch.persistent.PersistentTasksCustomMetadata;
import org.elasticsearch.xpack.core.ml.MlTasks;
import org.elasticsearch.xpack.core.ml.action.OpenJobAction;
import org.elasticsearch.xpack.core.ml.autoscaling.MlAutoscalingStats;
import org.elasticsearch.xpack.core.ml.inference.assignment.AssignmentState;
import org.elasticsearch.xpack.core.ml.inference.assignment.Priority;
import org.elasticsearch.xpack.core.ml.inference.assignment.TrainedModelAssignment;
import org.elasticsearch.xpack.ml.MachineLearning;
import org.elasticsearch.xpack.ml.autoscaling.MlAutoscalingContext;
import org.elasticsearch.xpack.ml.job.JobNodeSelector;
import org.elasticsearch.xpack.ml.process.MlMemoryTracker;
import org.elasticsearch.xpack.ml.utils.MlProcessors;
import org.elasticsearch.xpack.ml.utils.NativeMemoryCalculator;

public final class MlAutoscalingResourceTracker {
    private static final Logger logger = LogManager.getLogger(MlAutoscalingResourceTracker.class);

    private MlAutoscalingResourceTracker() {
    }

    public static void getMlAutoscalingStats(ClusterState clusterState, ClusterSettings clusterSettings, Client client, TimeValue timeout, MlMemoryTracker mlMemoryTracker, Settings settings, ActionListener<MlAutoscalingStats> listener) {
        String[] mlNodes = (String[])clusterState.nodes().stream().filter(node -> node.getRoles().contains(DiscoveryNodeRole.ML_ROLE)).map(DiscoveryNode::getId).toArray(String[]::new);
        long modelMemoryAvailableFirstNode = mlNodes.length > 0 ? NativeMemoryCalculator.allowedBytesForMl(clusterState.nodes().get(mlNodes[0]), settings).orElse(0L) : 0L;
        int processorsAvailableFirstNode = mlNodes.length > 0 ? MlProcessors.get(clusterState.nodes().get(mlNodes[0]), (Integer)clusterSettings.get(MachineLearning.ALLOCATED_PROCESSORS_SCALE)).roundDown() : 0;
        int maxOpenJobsPerNode = (Integer)MachineLearning.MAX_OPEN_JOBS_PER_NODE.get(settings);
        MlAutoscalingResourceTracker.getMlNodeStats(mlNodes, client, timeout, (ActionListener<Map<String, OsStats>>)ActionListener.wrap(osStatsPerNode -> MlAutoscalingResourceTracker.getMemoryAndProcessors(new MlAutoscalingContext(clusterState), mlMemoryTracker, osStatsPerNode, modelMemoryAvailableFirstNode, processorsAvailableFirstNode, maxOpenJobsPerNode, listener), arg_0 -> listener.onFailure(arg_0)));
    }

    static void getMlNodeStats(String[] mlNodes, Client client, TimeValue timeout, ActionListener<Map<String, OsStats>> listener) {
        if (mlNodes.length == 0) {
            listener.onResponse(Collections.emptyMap());
            return;
        }
        ((NodesStatsRequestBuilder)client.admin().cluster().prepareNodesStats(mlNodes).clear().setOs(true).setTimeout(timeout)).execute(ActionListener.wrap(nodesStatsResponse -> listener.onResponse(nodesStatsResponse.getNodes().stream().collect(Collectors.toMap(nodeStats -> nodeStats.getNode().getId(), NodeStats::getOs))), arg_0 -> listener.onFailure(arg_0)));
    }

    static void getMemoryAndProcessors(MlAutoscalingContext autoscalingContext, MlMemoryTracker mlMemoryTracker, Map<String, OsStats> osStatsPerNode, long perNodeAvailableModelMemoryInBytes, int perNodeAvailableProcessors, int maxOpenJobsPerNode, ActionListener<MlAutoscalingStats> listener) {
        Long jobMemory;
        String jobId;
        HashMap<String, List<MlJobRequirements>> perNodeModelMemoryInBytes = new HashMap<String, List<MlJobRequirements>>();
        long perNodeMemoryInBytes = osStatsPerNode.values().stream().map(s -> s.getMem().getAdjustedTotal().getBytes()).distinct().count() != 1L ? 0L : osStatsPerNode.values().iterator().next().getMem().getAdjustedTotal().getBytes();
        long modelMemoryBytesSum = 0L;
        long extraSingleNodeModelMemoryInBytes = 0L;
        long extraModelMemoryInBytes = 0L;
        int extraSingleNodeProcessors = 0;
        int extraProcessors = 0;
        int processorsSum = 0;
        logger.debug("getting ml resources, found [{}] ad jobs, [{}] dfa jobs and [{}] inference deployments", (Object)autoscalingContext.anomalyDetectionTasks.size(), (Object)autoscalingContext.dataframeAnalyticsTasks.size(), (Object)autoscalingContext.modelAssignments.size());
        int minNodes = autoscalingContext.anomalyDetectionTasks.isEmpty() && autoscalingContext.dataframeAnalyticsTasks.isEmpty() && autoscalingContext.modelAssignments.isEmpty() ? 0 : 1;
        for (PersistentTasksCustomMetadata.PersistentTask<?> persistentTask : autoscalingContext.anomalyDetectionTasks) {
            jobId = ((OpenJobAction.JobParams)persistentTask.getParams()).getJobId();
            jobMemory = mlMemoryTracker.getAnomalyDetectorJobMemoryRequirement(jobId);
            if (jobMemory == null) {
                logger.debug("could not find memory requirement for job [{}], skipping", (Object)jobId);
                continue;
            }
            if (JobNodeSelector.AWAITING_LAZY_ASSIGNMENT.equals((Object)persistentTask.getAssignment())) {
                logger.debug("job [{}] lacks assignment , memory required [{}]", (Object)jobId, (Object)jobMemory);
                extraSingleNodeModelMemoryInBytes = Math.max(extraSingleNodeModelMemoryInBytes, jobMemory);
                extraModelMemoryInBytes += jobMemory.longValue();
                continue;
            }
            logger.debug("job [{}] assigned to [{}], memory required [{}]", (Object)jobId, (Object)persistentTask.getAssignment(), (Object)jobMemory);
            modelMemoryBytesSum += jobMemory.longValue();
            perNodeModelMemoryInBytes.computeIfAbsent(persistentTask.getExecutorNode(), k -> new ArrayList()).add(MlJobRequirements.of(jobMemory, 0));
        }
        for (PersistentTasksCustomMetadata.PersistentTask<?> persistentTask : autoscalingContext.dataframeAnalyticsTasks) {
            jobId = MlTasks.dataFrameAnalyticsId((String)persistentTask.getId());
            jobMemory = mlMemoryTracker.getDataFrameAnalyticsJobMemoryRequirement(jobId);
            if (jobMemory == null) {
                logger.debug("could not find memory requirement for job [{}], skipping", (Object)jobId);
                continue;
            }
            if (JobNodeSelector.AWAITING_LAZY_ASSIGNMENT.equals((Object)persistentTask.getAssignment())) {
                logger.debug("dfa job [{}] lacks assignment , memory required [{}]", (Object)jobId, (Object)jobMemory);
                extraSingleNodeModelMemoryInBytes = Math.max(extraSingleNodeModelMemoryInBytes, jobMemory);
                extraModelMemoryInBytes += jobMemory.longValue();
                continue;
            }
            logger.debug("dfa job [{}] assigned to [{}], memory required [{}]", (Object)jobId, (Object)persistentTask.getAssignment(), (Object)jobMemory);
            modelMemoryBytesSum += jobMemory.longValue();
            perNodeModelMemoryInBytes.computeIfAbsent(persistentTask.getExecutorNode(), k -> new ArrayList()).add(MlJobRequirements.of(jobMemory, 0));
        }
        for (Map.Entry entry : autoscalingContext.modelAssignments.entrySet()) {
            int numberOfAllocations = ((TrainedModelAssignment)entry.getValue()).getTaskParams().getNumberOfAllocations();
            int numberOfThreadsPerAllocation = ((TrainedModelAssignment)entry.getValue()).getTaskParams().getThreadsPerAllocation();
            long estimatedMemoryUsage = ((TrainedModelAssignment)entry.getValue()).getTaskParams().estimateMemoryUsageBytes();
            if (AssignmentState.STARTING.equals((Object)((TrainedModelAssignment)entry.getValue()).getAssignmentState()) && ((TrainedModelAssignment)entry.getValue()).getNodeRoutingTable().isEmpty()) {
                logger.debug(() -> Strings.format((String)"trained model [%s] lacks assignment , memory required [%d]", (Object[])new Object[]{modelAssignment.getKey(), estimatedMemoryUsage}));
                extraSingleNodeModelMemoryInBytes = Math.max(extraSingleNodeModelMemoryInBytes, estimatedMemoryUsage);
                extraModelMemoryInBytes += estimatedMemoryUsage;
                if (Priority.LOW.equals((Object)((TrainedModelAssignment)entry.getValue()).getTaskParams().getPriority())) continue;
                extraSingleNodeProcessors = Math.max(extraSingleNodeProcessors, numberOfThreadsPerAllocation);
                extraProcessors += numberOfAllocations * numberOfThreadsPerAllocation;
                continue;
            }
            logger.debug(() -> Strings.format((String)"trained model [%s] assigned to [%s], memory required [%d]", (Object[])new Object[]{modelAssignment.getKey(), org.elasticsearch.common.Strings.arrayToCommaDelimitedString((Object[])((TrainedModelAssignment)modelAssignment.getValue()).getStartedNodes()), estimatedMemoryUsage}));
            modelMemoryBytesSum += estimatedMemoryUsage;
            processorsSum += numberOfAllocations * numberOfThreadsPerAllocation;
            minNodes = Math.min(3, Math.max(minNodes, numberOfAllocations));
            for (String node : ((TrainedModelAssignment)entry.getValue()).getNodeRoutingTable().keySet()) {
                perNodeModelMemoryInBytes.computeIfAbsent(node, k -> new ArrayList()).add(MlJobRequirements.of(estimatedMemoryUsage, Priority.LOW.equals((Object)((TrainedModelAssignment)entry.getValue()).getTaskParams().getPriority()) ? 0 : numberOfThreadsPerAllocation));
            }
        }
        long removeNodeMemoryInBytes = 0L;
        if (perNodeMemoryInBytes > 0L && perNodeAvailableModelMemoryInBytes > 0L && extraModelMemoryInBytes == 0L && extraProcessors == 0 && modelMemoryBytesSum < perNodeMemoryInBytes * (long)(osStatsPerNode.size() - 1) && (perNodeModelMemoryInBytes.size() < osStatsPerNode.size() || MlAutoscalingResourceTracker.checkIfOneNodeCouldBeRemoved(perNodeModelMemoryInBytes, perNodeAvailableModelMemoryInBytes, perNodeAvailableProcessors, maxOpenJobsPerNode))) {
            removeNodeMemoryInBytes = perNodeMemoryInBytes;
        }
        listener.onResponse((Object)new MlAutoscalingStats(osStatsPerNode.size(), perNodeMemoryInBytes, modelMemoryBytesSum, processorsSum, minNodes, extraSingleNodeModelMemoryInBytes, extraSingleNodeProcessors, extraModelMemoryInBytes, extraProcessors, removeNodeMemoryInBytes, MachineLearning.NATIVE_EXECUTABLE_CODE_OVERHEAD.getBytes()));
    }

    static boolean checkIfOneNodeCouldBeRemoved(Map<String, List<MlJobRequirements>> perNodeJobRequirements, long perNodeMemoryInBytes, int perNodeProcessors, int maxOpenJobsPerNode) {
        if (perNodeJobRequirements.size() <= 1) {
            return false;
        }
        Map<String, MlJobRequirements> perNodeMlJobRequirementSum = perNodeJobRequirements.entrySet().stream().map(entry -> Tuple.tuple((Object)((String)entry.getKey()), (Object)((List)entry.getValue()).stream().reduce(MlJobRequirements.of(0L, 0, 0), (subtotal, element) -> MlJobRequirements.of(subtotal.memory + element.memory, subtotal.processors + element.processors, subtotal.jobs + element.jobs)))).collect(Collectors.toMap(Tuple::v1, Tuple::v2));
        Optional<Map.Entry> leastLoadedNodeAndMemoryUsage = perNodeMlJobRequirementSum.entrySet().stream().min(Comparator.comparingLong(entry -> ((MlJobRequirements)entry.getValue()).memory));
        if (!leastLoadedNodeAndMemoryUsage.isPresent()) {
            return false;
        }
        assert (((MlJobRequirements)leastLoadedNodeAndMemoryUsage.get().getValue()).memory >= 0L);
        String candidateNode = (String)leastLoadedNodeAndMemoryUsage.get().getKey();
        List<MlJobRequirements> candidateJobRequirements = perNodeJobRequirements.get(candidateNode);
        perNodeMlJobRequirementSum.remove(candidateNode);
        return MlAutoscalingResourceTracker.checkIfJobsCanBeMovedInLeastEfficientWay(candidateJobRequirements, perNodeMlJobRequirementSum, perNodeMemoryInBytes, perNodeProcessors, maxOpenJobsPerNode) == 0L;
    }

    static long checkIfJobsCanBeMovedInLeastEfficientWay(List<MlJobRequirements> candidateJobRequirements, Map<String, MlJobRequirements> perNodeMlJobRequirementsSum, long perNodeMemoryInBytes, int perNodeProcessors, int maxOpenJobsPerNode) {
        if (candidateJobRequirements.size() == 0) {
            return 0L;
        }
        List<MlJobRequirements> candidateNodeMemoryListSorted = candidateJobRequirements.stream().sorted(Comparator.comparingLong(MlJobRequirements::memory)).toList();
        long candidateNodeMemorySum = candidateJobRequirements.stream().mapToLong(MlJobRequirements::memory).sum();
        if (perNodeMlJobRequirementsSum.size() == 0) {
            return candidateNodeMemorySum;
        }
        PriorityQueue nodesWithSpareCapacitySortedByMemory = perNodeMlJobRequirementsSum.values().stream().filter(e -> e.jobs < maxOpenJobsPerNode).collect(Collectors.toCollection(() -> new PriorityQueue(perNodeMlJobRequirementsSum.size(), (c1, c2) -> {
            if (c1.memory == c2.memory) {
                return Integer.compare(c1.processors, c2.processors);
            }
            return Long.compare(c1.memory, c2.memory);
        })));
        for (MlJobRequirements jobRequirement : candidateNodeMemoryListSorted) {
            assert (jobRequirement.jobs == 1);
            if (jobRequirement.processors == 0) {
                MlJobRequirements nodeWithSpareCapacity = (MlJobRequirements)nodesWithSpareCapacitySortedByMemory.poll();
                long memoryAfterAddingJobMemory = nodeWithSpareCapacity.memory + jobRequirement.memory;
                if (memoryAfterAddingJobMemory > perNodeMemoryInBytes) break;
                if (nodeWithSpareCapacity.jobs + jobRequirement.jobs < maxOpenJobsPerNode) {
                    nodesWithSpareCapacitySortedByMemory.add(MlJobRequirements.of(memoryAfterAddingJobMemory, nodeWithSpareCapacity.processors, nodeWithSpareCapacity.jobs + jobRequirement.jobs));
                }
                candidateNodeMemorySum -= jobRequirement.memory;
            } else {
                ArrayList<MlJobRequirements> stash = new ArrayList<MlJobRequirements>();
                boolean foundNodeThatCanTakeTheJob = false;
                while (!nodesWithSpareCapacitySortedByMemory.isEmpty()) {
                    MlJobRequirements nodeWithSpareCapacity = (MlJobRequirements)nodesWithSpareCapacitySortedByMemory.poll();
                    long memoryAfterAddingJobMemory = nodeWithSpareCapacity.memory + jobRequirement.memory;
                    if (memoryAfterAddingJobMemory > perNodeMemoryInBytes) break;
                    if (nodeWithSpareCapacity.processors + jobRequirement.processors <= perNodeProcessors) {
                        if (nodeWithSpareCapacity.jobs + jobRequirement.jobs < maxOpenJobsPerNode) {
                            nodesWithSpareCapacitySortedByMemory.add(MlJobRequirements.of(memoryAfterAddingJobMemory, nodeWithSpareCapacity.processors + jobRequirement.processors, nodeWithSpareCapacity.jobs + jobRequirement.jobs));
                        }
                        candidateNodeMemorySum -= jobRequirement.memory;
                        foundNodeThatCanTakeTheJob = true;
                        break;
                    }
                    stash.add(nodeWithSpareCapacity);
                }
                if (!foundNodeThatCanTakeTheJob) break;
                nodesWithSpareCapacitySortedByMemory.addAll(stash);
            }
            if (!nodesWithSpareCapacitySortedByMemory.isEmpty()) continue;
            break;
        }
        return candidateNodeMemorySum;
    }

    record MlJobRequirements(long memory, int processors, int jobs) {
        static MlJobRequirements of(long memory, int processors, int jobs) {
            return new MlJobRequirements(memory, processors, jobs);
        }

        static MlJobRequirements of(long memory, int processors) {
            return new MlJobRequirements(memory, processors, 1);
        }
    }
}

