/*
 * Decompiled with CFR 0.152.
 */
package org.elasticsearch.xpack.ml.action;

import java.time.Instant;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.Comparator;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.TreeMap;
import java.util.function.Function;
import java.util.stream.Collectors;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.action.FailedNodeException;
import org.elasticsearch.action.TaskOperationFailure;
import org.elasticsearch.action.support.ActionFilters;
import org.elasticsearch.action.support.tasks.BaseTasksRequest;
import org.elasticsearch.action.support.tasks.TransportTasksAction;
import org.elasticsearch.cluster.ClusterState;
import org.elasticsearch.cluster.node.DiscoveryNode;
import org.elasticsearch.cluster.node.DiscoveryNodes;
import org.elasticsearch.cluster.service.ClusterService;
import org.elasticsearch.common.Strings;
import org.elasticsearch.common.inject.Inject;
import org.elasticsearch.tasks.Task;
import org.elasticsearch.transport.TransportService;
import org.elasticsearch.xpack.core.action.util.ExpandedIdsMatcher;
import org.elasticsearch.xpack.core.ml.action.GetDeploymentStatsAction;
import org.elasticsearch.xpack.core.ml.action.StartTrainedModelDeploymentAction;
import org.elasticsearch.xpack.core.ml.inference.allocation.AllocationState;
import org.elasticsearch.xpack.core.ml.inference.allocation.AllocationStats;
import org.elasticsearch.xpack.core.ml.inference.allocation.AllocationStatus;
import org.elasticsearch.xpack.core.ml.inference.allocation.RoutingState;
import org.elasticsearch.xpack.core.ml.inference.allocation.RoutingStateAndReason;
import org.elasticsearch.xpack.core.ml.inference.allocation.TrainedModelAllocation;
import org.elasticsearch.xpack.ml.action.TransportStartTrainedModelDeploymentAction;
import org.elasticsearch.xpack.ml.inference.allocation.TrainedModelAllocationMetadata;
import org.elasticsearch.xpack.ml.inference.deployment.ModelStats;
import org.elasticsearch.xpack.ml.inference.deployment.TrainedModelDeploymentTask;

public class TransportGetDeploymentStatsAction
extends TransportTasksAction<TrainedModelDeploymentTask, GetDeploymentStatsAction.Request, GetDeploymentStatsAction.Response, AllocationStats> {
    @Inject
    public TransportGetDeploymentStatsAction(TransportService transportService, ActionFilters actionFilters, ClusterService clusterService) {
        super("cluster:internal/xpack/ml/trained_models/deployments/stats/get", clusterService, transportService, actionFilters, GetDeploymentStatsAction.Request::new, GetDeploymentStatsAction.Response::new, AllocationStats::new, "management");
    }

    protected GetDeploymentStatsAction.Response newResponse(GetDeploymentStatsAction.Request request, List<AllocationStats> taskResponse, List<TaskOperationFailure> taskOperationFailures, List<FailedNodeException> failedNodeExceptions) {
        TreeMap mergedNodeStatsByModel = taskResponse.stream().collect(Collectors.toMap(AllocationStats::getModelId, Function.identity(), (l, r) -> {
            l.getNodeStats().addAll(r.getNodeStats());
            return l;
        }, TreeMap::new));
        ArrayList bunchedAndSorted = new ArrayList(mergedNodeStatsByModel.values());
        return new GetDeploymentStatsAction.Response(taskOperationFailures, failedNodeExceptions, bunchedAndSorted, (long)bunchedAndSorted.size());
    }

    protected void doExecute(Task task, GetDeploymentStatsAction.Request request, ActionListener<GetDeploymentStatsAction.Response> listener) {
        ClusterState clusterState = this.clusterService.state();
        TrainedModelAllocationMetadata allocation = TrainedModelAllocationMetadata.fromState(clusterState);
        String[] tokenizedRequestIds = Strings.tokenizeToStringArray((String)request.getDeploymentId(), (String)",");
        ExpandedIdsMatcher.SimpleIdsMatcher idsMatcher = new ExpandedIdsMatcher.SimpleIdsMatcher(tokenizedRequestIds);
        ArrayList<String> matchedDeploymentIds = new ArrayList<String>();
        HashSet<String> taskNodes = new HashSet<String>();
        HashMap<TrainedModelAllocation, Map<String, RoutingStateAndReason>> allocationNonStartedRoutes = new HashMap<TrainedModelAllocation, Map<String, RoutingStateAndReason>>();
        for (Map.Entry<String, TrainedModelAllocation> allocationEntry : allocation.modelAllocations().entrySet()) {
            String modelId = allocationEntry.getKey();
            if (!idsMatcher.idMatches(modelId)) continue;
            matchedDeploymentIds.add(modelId);
            taskNodes.addAll(Arrays.asList(allocationEntry.getValue().getStartedNodes()));
            Map<String, RoutingStateAndReason> routings = allocationEntry.getValue().getNodeRoutingTable().entrySet().stream().filter(routingEntry -> !RoutingState.STARTED.equals((Object)((RoutingStateAndReason)routingEntry.getValue()).getState())).collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue));
            allocationNonStartedRoutes.put(allocationEntry.getValue(), routings);
        }
        if (matchedDeploymentIds.isEmpty()) {
            listener.onResponse((Object)new GetDeploymentStatsAction.Response(Collections.emptyList(), Collections.emptyList(), Collections.emptyList(), 0L));
            return;
        }
        request.setNodes((String[])taskNodes.toArray(String[]::new));
        request.setExpandedIds(matchedDeploymentIds);
        ActionListener addFailedListener = listener.delegateFailure((l, response) -> {
            GetDeploymentStatsAction.Response updatedResponse = TransportGetDeploymentStatsAction.addFailedRoutes(response, allocationNonStartedRoutes, clusterState.nodes());
            ClusterState latestState = this.clusterService.state();
            Set<String> nodesShuttingDown = TransportStartTrainedModelDeploymentAction.nodesShuttingDown(latestState);
            List nodes = latestState.getNodes().getAllNodes().stream().filter(d -> !nodesShuttingDown.contains(d.getId())).filter(StartTrainedModelDeploymentAction.TaskParams::mayAllocateToNode).collect(Collectors.toList());
            for (AllocationStats stats : updatedResponse.getStats().results()) {
                TrainedModelAllocation trainedModelAllocation = allocation.getModelAllocation(stats.getModelId());
                if (trainedModelAllocation == null) continue;
                stats.setState(trainedModelAllocation.getAllocationState()).setReason((String)trainedModelAllocation.getReason().orElse(null));
                if (!trainedModelAllocation.getAllocationState().isAnyOf(new AllocationState[]{AllocationState.STARTED, AllocationState.STARTING})) continue;
                stats.setAllocationStatus((AllocationStatus)trainedModelAllocation.calculateAllocationStatus(nodes).orElse(null));
            }
            l.onResponse((Object)updatedResponse);
        });
        super.doExecute(task, (BaseTasksRequest)request, addFailedListener);
    }

    static GetDeploymentStatsAction.Response addFailedRoutes(GetDeploymentStatsAction.Response tasksResponse, Map<TrainedModelAllocation, Map<String, RoutingStateAndReason>> allocationNonStartedRoutes, DiscoveryNodes nodes) {
        Map modelToAllocationWithNonStartedRoutes = allocationNonStartedRoutes.keySet().stream().collect(Collectors.toMap(TrainedModelAllocation::getModelId, Function.identity()));
        ArrayList<AllocationStats> updatedAllocationStats = new ArrayList<AllocationStats>();
        for (AllocationStats allocationStats : tasksResponse.getStats().results()) {
            if (modelToAllocationWithNonStartedRoutes.containsKey(allocationStats.getModelId())) {
                Map<String, RoutingStateAndReason> nodeToRoutingStates = allocationNonStartedRoutes.get(modelToAllocationWithNonStartedRoutes.get(allocationStats.getModelId()));
                ArrayList<AllocationStats.NodeStats> updatedNodeStats = new ArrayList<AllocationStats.NodeStats>();
                HashSet<String> visitedNodes = new HashSet<String>();
                for (AllocationStats.NodeStats nodeStats : allocationStats.getNodeStats()) {
                    if (nodeToRoutingStates.containsKey(nodeStats.getNode().getId())) {
                        RoutingStateAndReason stateAndReason = nodeToRoutingStates.get(nodeStats.getNode().getId());
                        updatedNodeStats.add(AllocationStats.NodeStats.forNotStartedState((DiscoveryNode)nodeStats.getNode(), (RoutingState)stateAndReason.getState(), (String)stateAndReason.getReason()));
                    } else {
                        updatedNodeStats.add(nodeStats);
                    }
                    visitedNodes.add(nodeStats.getNode().getId());
                }
                for (Map.Entry entry : nodeToRoutingStates.entrySet()) {
                    if (visitedNodes.contains(entry.getKey())) continue;
                    updatedNodeStats.add(AllocationStats.NodeStats.forNotStartedState((DiscoveryNode)nodes.get((String)entry.getKey()), (RoutingState)((RoutingStateAndReason)entry.getValue()).getState(), (String)((RoutingStateAndReason)entry.getValue()).getReason()));
                }
                updatedNodeStats.sort(Comparator.comparing(n -> n.getNode().getId()));
                updatedAllocationStats.add(new AllocationStats(allocationStats.getModelId(), allocationStats.getInferenceThreads(), allocationStats.getModelThreads(), allocationStats.getQueueCapacity(), allocationStats.getStartTime(), updatedNodeStats));
                continue;
            }
            updatedAllocationStats.add(allocationStats);
        }
        for (Map.Entry entry : allocationNonStartedRoutes.entrySet()) {
            TrainedModelAllocation allocation = (TrainedModelAllocation)entry.getKey();
            String modelId = allocation.getTaskParams().getModelId();
            if (tasksResponse.getStats().results().stream().anyMatch(e -> modelId.equals(e.getModelId()))) continue;
            ArrayList<AllocationStats.NodeStats> nodeStats = new ArrayList<AllocationStats.NodeStats>();
            for (Map.Entry entry2 : ((Map)entry.getValue()).entrySet()) {
                nodeStats.add(AllocationStats.NodeStats.forNotStartedState((DiscoveryNode)nodes.get((String)entry2.getKey()), (RoutingState)((RoutingStateAndReason)entry2.getValue()).getState(), (String)((RoutingStateAndReason)entry2.getValue()).getReason()));
            }
            nodeStats.sort(Comparator.comparing(n -> n.getNode().getId()));
            updatedAllocationStats.add(new AllocationStats(modelId, null, null, null, allocation.getStartTime(), nodeStats));
        }
        updatedAllocationStats.sort(Comparator.comparing(AllocationStats::getModelId));
        return new GetDeploymentStatsAction.Response(tasksResponse.getTaskFailures(), tasksResponse.getNodeFailures(), updatedAllocationStats, (long)updatedAllocationStats.size());
    }

    protected void taskOperation(GetDeploymentStatsAction.Request request, TrainedModelDeploymentTask task, ActionListener<AllocationStats> listener) {
        Optional<ModelStats> stats = task.modelStats();
        ArrayList<AllocationStats.NodeStats> nodeStats = new ArrayList<AllocationStats.NodeStats>();
        if (stats.isPresent()) {
            nodeStats.add(AllocationStats.NodeStats.forStartedState((DiscoveryNode)this.clusterService.localNode(), (long)stats.get().timingStats().getCount(), (Double)stats.get().timingStats().getAverage(), (int)stats.get().pendingCount(), (int)stats.get().errorCount(), (int)stats.get().rejectedExecutionCount(), (int)stats.get().timeoutCount(), (Instant)stats.get().lastUsed(), (Instant)stats.get().startTime(), (Integer)stats.get().inferenceThreads(), (Integer)stats.get().modelThreads()));
        } else {
            nodeStats.add(AllocationStats.NodeStats.forNotStartedState((DiscoveryNode)this.clusterService.localNode(), (RoutingState)RoutingState.STOPPED, (String)""));
        }
        listener.onResponse((Object)new AllocationStats(task.getModelId(), Integer.valueOf(task.getParams().getInferenceThreads()), Integer.valueOf(task.getParams().getModelThreads()), Integer.valueOf(task.getParams().getQueueCapacity()), TrainedModelAllocationMetadata.fromState(this.clusterService.state()).getModelAllocation(task.getModelId()).getStartTime(), nodeStats));
    }
}

