/*
 * Decompiled with CFR 0.152.
 */
package org.elasticsearch.xpack.ml.action;

import java.util.ArrayList;
import java.util.Map;
import java.util.Optional;
import java.util.concurrent.Executor;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.action.ActionRequest;
import org.elasticsearch.action.ActionType;
import org.elasticsearch.action.FailedNodeException;
import org.elasticsearch.action.admin.cluster.node.stats.NodeStats;
import org.elasticsearch.action.admin.cluster.node.stats.NodesStatsRequestBuilder;
import org.elasticsearch.action.admin.cluster.node.stats.NodesStatsResponse;
import org.elasticsearch.action.support.ActionFilters;
import org.elasticsearch.action.support.master.TransportMasterNodeAction;
import org.elasticsearch.action.support.nodes.BaseNodeResponse;
import org.elasticsearch.client.internal.Client;
import org.elasticsearch.client.internal.OriginSettingClient;
import org.elasticsearch.client.internal.ParentTaskAssigningClient;
import org.elasticsearch.cluster.ClusterState;
import org.elasticsearch.cluster.block.ClusterBlockException;
import org.elasticsearch.cluster.block.ClusterBlockLevel;
import org.elasticsearch.cluster.metadata.IndexNameExpressionResolver;
import org.elasticsearch.cluster.node.DiscoveryNode;
import org.elasticsearch.cluster.node.DiscoveryNodeRole;
import org.elasticsearch.cluster.service.ClusterService;
import org.elasticsearch.common.inject.Inject;
import org.elasticsearch.common.settings.ClusterSettings;
import org.elasticsearch.common.unit.ByteSizeValue;
import org.elasticsearch.common.util.concurrent.EsExecutors;
import org.elasticsearch.monitor.os.OsStats;
import org.elasticsearch.persistent.PersistentTasksCustomMetadata;
import org.elasticsearch.tasks.Task;
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.transport.TransportService;
import org.elasticsearch.xpack.core.ml.MachineLearningField;
import org.elasticsearch.xpack.core.ml.action.MlMemoryAction;
import org.elasticsearch.xpack.core.ml.action.TrainedModelCacheInfoAction;
import org.elasticsearch.xpack.ml.MachineLearning;
import org.elasticsearch.xpack.ml.job.NodeLoad;
import org.elasticsearch.xpack.ml.job.NodeLoadDetector;
import org.elasticsearch.xpack.ml.process.MlMemoryTracker;

public class TransportMlMemoryAction
extends TransportMasterNodeAction<MlMemoryAction.Request, MlMemoryAction.Response> {
    private final Client client;
    private final MlMemoryTracker memoryTracker;

    @Inject
    public TransportMlMemoryAction(TransportService transportService, ClusterService clusterService, ThreadPool threadPool, ActionFilters actionFilters, IndexNameExpressionResolver indexNameExpressionResolver, Client client, MlMemoryTracker memoryTracker) {
        super("cluster:monitor/xpack/ml/memory/stats/get", transportService, clusterService, threadPool, actionFilters, MlMemoryAction.Request::new, indexNameExpressionResolver, MlMemoryAction.Response::new, (Executor)EsExecutors.DIRECT_EXECUTOR_SERVICE);
        this.client = new OriginSettingClient(client, "ml");
        this.memoryTracker = memoryTracker;
    }

    protected void masterOperation(Task task, MlMemoryAction.Request request, ClusterState state, ActionListener<MlMemoryAction.Response> listener) throws Exception {
        ClusterSettings clusterSettings = this.clusterService.getClusterSettings();
        String[] nodeIds = state.nodes().resolveNodes(new String[]{request.getNodeId()});
        ParentTaskAssigningClient parentTaskClient = new ParentTaskAssigningClient(this.client, task.getParentTaskId());
        ActionListener nodeStatsListener = ActionListener.wrap(nodesStatsResponse -> {
            TrainedModelCacheInfoAction.Request trainedModelCacheInfoRequest = (TrainedModelCacheInfoAction.Request)new TrainedModelCacheInfoAction.Request((DiscoveryNode[])nodesStatsResponse.getNodes().stream().map(BaseNodeResponse::getNode).toArray(DiscoveryNode[]::new)).timeout(request.timeout());
            parentTaskClient.execute((ActionType)TrainedModelCacheInfoAction.INSTANCE, (ActionRequest)trainedModelCacheInfoRequest, ActionListener.wrap(trainedModelCacheInfoResponse -> this.handleResponses(state, clusterSettings, (NodesStatsResponse)nodesStatsResponse, (TrainedModelCacheInfoAction.Response)trainedModelCacheInfoResponse, listener), arg_0 -> ((ActionListener)listener).onFailure(arg_0)));
        }, arg_0 -> listener.onFailure(arg_0));
        ActionListener memoryTrackerRefreshListener = ActionListener.wrap(r -> ((NodesStatsRequestBuilder)parentTaskClient.admin().cluster().prepareNodesStats(nodeIds).clear().setOs(true).setJvm(true).setTimeout(request.timeout())).execute(nodeStatsListener), arg_0 -> listener.onFailure(arg_0));
        if (this.memoryTracker.isEverRefreshed()) {
            memoryTrackerRefreshListener.onResponse(null);
        } else {
            this.memoryTracker.refresh((PersistentTasksCustomMetadata)state.getMetadata().custom("persistent_tasks"), (ActionListener<Void>)memoryTrackerRefreshListener);
        }
    }

    void handleResponses(ClusterState state, ClusterSettings clusterSettings, NodesStatsResponse nodesStatsResponse, TrainedModelCacheInfoAction.Response trainedModelCacheInfoResponse, ActionListener<MlMemoryAction.Response> listener) {
        ArrayList<MlMemoryAction.Response.MlMemoryStats> nodeResponses = new ArrayList<MlMemoryAction.Response.MlMemoryStats>(nodesStatsResponse.getNodes().size());
        int maxOpenJobsPerNode = (Integer)clusterSettings.get(MachineLearning.MAX_OPEN_JOBS_PER_NODE);
        int maxMachineMemoryPercent = (Integer)clusterSettings.get(MachineLearning.MAX_MACHINE_MEMORY_PERCENT);
        boolean useAutoMachineMemoryPercent = (Boolean)clusterSettings.get(MachineLearningField.USE_AUTO_MACHINE_MEMORY_PERCENT);
        NodeLoadDetector nodeLoadDetector = new NodeLoadDetector(this.memoryTracker);
        Map cacheInfoByNode = trainedModelCacheInfoResponse.getNodesMap();
        ArrayList<FailedNodeException> failures = new ArrayList<FailedNodeException>(nodesStatsResponse.failures());
        for (NodeStats nodeStats : nodesStatsResponse.getNodes()) {
            ByteSizeValue jvmInference;
            ByteSizeValue jvmInferenceMax;
            ByteSizeValue mlNativeInference;
            ByteSizeValue mlDataFrameAnalytics;
            ByteSizeValue mlAnomalyDetectors;
            ByteSizeValue mlNativeCodeOverhead;
            ByteSizeValue mlMax;
            DiscoveryNode node = nodeStats.getNode();
            String nodeId = node.getId();
            Optional<FailedNodeException> trainedModelCacheInfoFailure = trainedModelCacheInfoResponse.failures().stream().filter(e -> nodeId.equals(e.nodeId())).findFirst();
            if (trainedModelCacheInfoFailure.isPresent()) {
                failures.add(trainedModelCacheInfoFailure.get());
                continue;
            }
            OsStats.Mem mem = nodeStats.getOs().getMem();
            if (node.getRoles().contains(DiscoveryNodeRole.ML_ROLE)) {
                NodeLoad nodeLoad = nodeLoadDetector.detectNodeLoad(state, node, maxOpenJobsPerNode, maxMachineMemoryPercent, useAutoMachineMemoryPercent);
                mlMax = ByteSizeValue.ofBytes((long)nodeLoad.getMaxMlMemory());
                mlNativeCodeOverhead = ByteSizeValue.ofBytes((long)nodeLoad.getAssignedNativeCodeOverheadMemory());
                mlAnomalyDetectors = ByteSizeValue.ofBytes((long)nodeLoad.getAssignedAnomalyDetectorMemory());
                mlDataFrameAnalytics = ByteSizeValue.ofBytes((long)nodeLoad.getAssignedDataFrameAnalyticsMemory());
                mlNativeInference = ByteSizeValue.ofBytes((long)nodeLoad.getAssignedNativeInferenceMemory());
            } else {
                mlMax = ByteSizeValue.ZERO;
                mlNativeCodeOverhead = ByteSizeValue.ZERO;
                mlAnomalyDetectors = ByteSizeValue.ZERO;
                mlDataFrameAnalytics = ByteSizeValue.ZERO;
                mlNativeInference = ByteSizeValue.ZERO;
            }
            ByteSizeValue jvmHeapMax = nodeStats.getJvm().getMem().getHeapMax();
            TrainedModelCacheInfoAction.Response.CacheInfo cacheInfoForNode = (TrainedModelCacheInfoAction.Response.CacheInfo)cacheInfoByNode.get(nodeId);
            if (cacheInfoForNode != null) {
                jvmInferenceMax = cacheInfoForNode.getJvmInferenceMax();
                jvmInference = cacheInfoForNode.getJvmInference();
            } else {
                jvmInferenceMax = ByteSizeValue.ZERO;
                jvmInference = ByteSizeValue.ZERO;
            }
            nodeResponses.add(new MlMemoryAction.Response.MlMemoryStats(node, mem.getTotal(), mem.getAdjustedTotal(), mlMax, mlNativeCodeOverhead, mlAnomalyDetectors, mlDataFrameAnalytics, mlNativeInference, jvmHeapMax, jvmInferenceMax, jvmInference));
        }
        listener.onResponse((Object)new MlMemoryAction.Response(state.getClusterName(), nodeResponses, failures));
    }

    protected ClusterBlockException checkBlock(MlMemoryAction.Request request, ClusterState state) {
        return state.blocks().globalBlockedException(ClusterBlockLevel.METADATA_READ);
    }
}

