/*
 * Decompiled with CFR 0.152.
 */
package org.elasticsearch.xpack.ml.job;

import java.lang.invoke.CallSite;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Map;
import java.util.Optional;
import java.util.OptionalLong;
import java.util.stream.Collectors;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.elasticsearch.cluster.ClusterState;
import org.elasticsearch.cluster.node.DiscoveryNode;
import org.elasticsearch.common.Strings;
import org.elasticsearch.persistent.PersistentTasksCustomMetadata;
import org.elasticsearch.xpack.core.ml.MlTasks;
import org.elasticsearch.xpack.core.ml.inference.assignment.RoutingInfo;
import org.elasticsearch.xpack.core.ml.inference.assignment.RoutingState;
import org.elasticsearch.xpack.core.ml.inference.assignment.TrainedModelAssignment;
import org.elasticsearch.xpack.core.ml.inference.assignment.TrainedModelAssignmentMetadata;
import org.elasticsearch.xpack.core.ml.utils.MemoryTrackedTaskState;
import org.elasticsearch.xpack.core.ml.utils.MlTaskParams;
import org.elasticsearch.xpack.ml.MachineLearning;
import org.elasticsearch.xpack.ml.job.NodeLoad;
import org.elasticsearch.xpack.ml.process.MlMemoryTracker;
import org.elasticsearch.xpack.ml.utils.NativeMemoryCalculator;

public class NodeLoadDetector {
    private static final Logger logger = LogManager.getLogger(NodeLoadDetector.class);
    private final MlMemoryTracker mlMemoryTracker;

    public static OptionalLong getNodeSize(DiscoveryNode node) {
        String memoryString = (String)node.getAttributes().get("ml.machine_memory");
        try {
            return OptionalLong.of(Long.parseLong(memoryString));
        }
        catch (NumberFormatException e) {
            assert (e == null) : "ml.machine_memory should parse because we set it internally: invalid value was " + memoryString;
            return OptionalLong.empty();
        }
    }

    public NodeLoadDetector(MlMemoryTracker memoryTracker) {
        this.mlMemoryTracker = memoryTracker;
    }

    public MlMemoryTracker getMlMemoryTracker() {
        return this.mlMemoryTracker;
    }

    public NodeLoad detectNodeLoad(ClusterState clusterState, DiscoveryNode node, int dynamicMaxOpenJobs, int maxMachineMemoryPercent, boolean useAutoMachineMemoryCalculation) {
        return this.detectNodeLoad(clusterState, TrainedModelAssignmentMetadata.fromState((ClusterState)clusterState), node, dynamicMaxOpenJobs, maxMachineMemoryPercent, useAutoMachineMemoryCalculation);
    }

    public NodeLoad detectNodeLoad(ClusterState clusterState, TrainedModelAssignmentMetadata assignmentMetadata, DiscoveryNode node, int maxNumberOfOpenJobs, int maxMachineMemoryPercent, boolean useAutoMachineMemoryCalculation) {
        return this.detectNodeLoad((PersistentTasksCustomMetadata)clusterState.getMetadata().custom("persistent_tasks"), assignmentMetadata, node, maxNumberOfOpenJobs, maxMachineMemoryPercent, useAutoMachineMemoryCalculation);
    }

    public NodeLoad detectNodeLoad(PersistentTasksCustomMetadata persistentTasks, TrainedModelAssignmentMetadata assignmentMetadata, DiscoveryNode node, int maxNumberOfOpenJobs, int maxMachineMemoryPercent, boolean useAutoMachineMemoryCalculation) {
        Map nodeAttributes = node.getAttributes();
        ArrayList<CallSite> errors = new ArrayList<CallSite>();
        OptionalLong maxMlMemory = NativeMemoryCalculator.allowedBytesForMl(node, maxMachineMemoryPercent, useAutoMachineMemoryCalculation);
        if (maxMlMemory.isEmpty()) {
            errors.add((CallSite)((Object)("ml.machine_memory attribute [" + (String)nodeAttributes.get("ml.machine_memory") + "] is not a long")));
        }
        NodeLoad.Builder nodeLoad = NodeLoad.builder(node.getId()).setMaxMemory(maxMlMemory.orElse(-1L)).setMaxJobs(maxNumberOfOpenJobs).setUseMemory(true);
        if (!errors.isEmpty()) {
            String errorMsg = Strings.collectionToCommaDelimitedString(errors);
            logger.warn("error detecting load for node [{}]: {}", (Object)node.getId(), (Object)errorMsg);
            return nodeLoad.setError(errorMsg).build();
        }
        this.updateLoadGivenTasks(nodeLoad, persistentTasks);
        NodeLoadDetector.updateLoadGivenModelAssignments(nodeLoad, assignmentMetadata);
        if (nodeLoad.getNumAssignedJobs() > 0) {
            nodeLoad.incAssignedNativeCodeOverheadMemory(MachineLearning.NATIVE_EXECUTABLE_CODE_OVERHEAD.getBytes());
        }
        return nodeLoad.build();
    }

    private void updateLoadGivenTasks(NodeLoad.Builder nodeLoad, PersistentTasksCustomMetadata persistentTasks) {
        if (persistentTasks != null) {
            Collection<PersistentTasksCustomMetadata.PersistentTask<?>> memoryTrackedTasks = NodeLoadDetector.findAllMemoryTrackedTasks(persistentTasks, nodeLoad.getNodeId());
            for (PersistentTasksCustomMetadata.PersistentTask<?> task : memoryTrackedTasks) {
                MemoryTrackedTaskState state = MlTasks.getMemoryTrackedTaskState(task);
                assert (state != null) : "null MemoryTrackedTaskState for memory tracked task with params " + String.valueOf(task.getParams());
                if (state == null || !state.consumesMemory()) continue;
                MlTaskParams taskParams = (MlTaskParams)task.getParams();
                nodeLoad.addTask(task.getTaskName(), taskParams.getMlId(), state.isAllocating(), this.mlMemoryTracker);
            }
        }
    }

    private static void updateLoadGivenModelAssignments(NodeLoad.Builder nodeLoad, TrainedModelAssignmentMetadata trainedModelAssignmentMetadata) {
        if (trainedModelAssignmentMetadata != null && !trainedModelAssignmentMetadata.allAssignments().isEmpty()) {
            for (TrainedModelAssignment assignment : trainedModelAssignmentMetadata.allAssignments().values()) {
                if (!Optional.ofNullable((RoutingInfo)assignment.getNodeRoutingTable().get(nodeLoad.getNodeId())).map(RoutingInfo::getState).orElse(RoutingState.STOPPED).consumesMemory()) continue;
                nodeLoad.incNumAssignedNativeInferenceModels();
                nodeLoad.incAssignedNativeInferenceMemory(assignment.getTaskParams().estimateMemoryUsageBytes());
            }
        }
    }

    private static Collection<PersistentTasksCustomMetadata.PersistentTask<?>> findAllMemoryTrackedTasks(PersistentTasksCustomMetadata persistentTasks, String nodeId) {
        return persistentTasks.tasks().stream().filter(NodeLoadDetector::isMemoryTrackedTask).filter(task -> nodeId.equals(task.getExecutorNode())).collect(Collectors.toList());
    }

    private static boolean isMemoryTrackedTask(PersistentTasksCustomMetadata.PersistentTask<?> task) {
        return "xpack/ml/job".equals(task.getTaskName()) || "xpack/ml/job/snapshot/upgrade".equals(task.getTaskName()) || "xpack/ml/data_frame/analytics".equals(task.getTaskName());
    }
}

