/*
 * Decompiled with CFR 0.152.
 */
package org.elasticsearch.xpack.ml.action;

import java.time.Instant;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.Comparator;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.TreeMap;
import java.util.function.Function;
import java.util.stream.Collectors;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.action.FailedNodeException;
import org.elasticsearch.action.TaskOperationFailure;
import org.elasticsearch.action.support.ActionFilters;
import org.elasticsearch.action.support.tasks.BaseTasksRequest;
import org.elasticsearch.action.support.tasks.TransportTasksAction;
import org.elasticsearch.cluster.ClusterState;
import org.elasticsearch.cluster.node.DiscoveryNode;
import org.elasticsearch.cluster.node.DiscoveryNodes;
import org.elasticsearch.cluster.service.ClusterService;
import org.elasticsearch.common.Strings;
import org.elasticsearch.common.inject.Inject;
import org.elasticsearch.common.unit.ByteSizeValue;
import org.elasticsearch.tasks.CancellableTask;
import org.elasticsearch.tasks.Task;
import org.elasticsearch.transport.TransportService;
import org.elasticsearch.xpack.core.action.util.ExpandedIdsMatcher;
import org.elasticsearch.xpack.core.ml.action.GetDeploymentStatsAction;
import org.elasticsearch.xpack.core.ml.inference.assignment.AllocationStatus;
import org.elasticsearch.xpack.core.ml.inference.assignment.AssignmentState;
import org.elasticsearch.xpack.core.ml.inference.assignment.AssignmentStats;
import org.elasticsearch.xpack.core.ml.inference.assignment.RoutingInfo;
import org.elasticsearch.xpack.core.ml.inference.assignment.RoutingState;
import org.elasticsearch.xpack.core.ml.inference.assignment.TrainedModelAssignment;
import org.elasticsearch.xpack.ml.inference.assignment.TrainedModelAssignmentMetadata;
import org.elasticsearch.xpack.ml.inference.deployment.ModelStats;
import org.elasticsearch.xpack.ml.inference.deployment.TrainedModelDeploymentTask;

public class TransportGetDeploymentStatsAction
extends TransportTasksAction<TrainedModelDeploymentTask, GetDeploymentStatsAction.Request, GetDeploymentStatsAction.Response, AssignmentStats> {
    @Inject
    public TransportGetDeploymentStatsAction(TransportService transportService, ActionFilters actionFilters, ClusterService clusterService) {
        super("cluster:internal/xpack/ml/trained_models/deployments/stats/get", clusterService, transportService, actionFilters, GetDeploymentStatsAction.Request::new, GetDeploymentStatsAction.Response::new, AssignmentStats::new, "management");
    }

    protected GetDeploymentStatsAction.Response newResponse(GetDeploymentStatsAction.Request request, List<AssignmentStats> taskResponse, List<TaskOperationFailure> taskOperationFailures, List<FailedNodeException> failedNodeExceptions) {
        TreeMap mergedNodeStatsByDeployment = taskResponse.stream().collect(Collectors.toMap(AssignmentStats::getDeploymentId, Function.identity(), (l, r) -> {
            l.getNodeStats().addAll(r.getNodeStats());
            return l;
        }, TreeMap::new));
        ArrayList bunchedAndSorted = new ArrayList(mergedNodeStatsByDeployment.values());
        return new GetDeploymentStatsAction.Response(taskOperationFailures, failedNodeExceptions, bunchedAndSorted, (long)bunchedAndSorted.size());
    }

    protected void doExecute(Task task, GetDeploymentStatsAction.Request request, ActionListener<GetDeploymentStatsAction.Response> listener) {
        ClusterState clusterState = this.clusterService.state();
        TrainedModelAssignmentMetadata assignment = TrainedModelAssignmentMetadata.fromState(clusterState);
        String[] tokenizedRequestIds = Strings.tokenizeToStringArray((String)request.getDeploymentId(), (String)",");
        ExpandedIdsMatcher.SimpleIdsMatcher idsMatcher = new ExpandedIdsMatcher.SimpleIdsMatcher(tokenizedRequestIds);
        ArrayList<String> matchedIds = new ArrayList<String>();
        HashSet<String> taskNodes = new HashSet<String>();
        HashMap<TrainedModelAssignment, Map<String, RoutingInfo>> assignmentNonStartedRoutes = new HashMap<TrainedModelAssignment, Map<String, RoutingInfo>>();
        for (Map.Entry<String, TrainedModelAssignment> assignmentEntry : assignment.allAssignments().entrySet()) {
            String deploymentId = assignmentEntry.getKey();
            if (!idsMatcher.idMatches(deploymentId)) continue;
            matchedIds.add(deploymentId);
            taskNodes.addAll(Arrays.asList(assignmentEntry.getValue().getStartedNodes()));
            Map<String, RoutingInfo> routings = assignmentEntry.getValue().getNodeRoutingTable().entrySet().stream().filter(routingEntry -> !RoutingState.STARTED.equals((Object)((RoutingInfo)routingEntry.getValue()).getState())).collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue));
            assignmentNonStartedRoutes.put(assignmentEntry.getValue(), routings);
        }
        if (matchedIds.isEmpty()) {
            listener.onResponse((Object)new GetDeploymentStatsAction.Response(Collections.emptyList(), Collections.emptyList(), Collections.emptyList(), 0L));
            return;
        }
        request.setNodes((String[])taskNodes.toArray(String[]::new));
        request.setExpandedIds(matchedIds);
        ActionListener addFailedListener = listener.safeMap(response -> {
            GetDeploymentStatsAction.Response updatedResponse = TransportGetDeploymentStatsAction.addFailedRoutes(response, assignmentNonStartedRoutes, clusterState.nodes());
            for (AssignmentStats stats : updatedResponse.getStats().results()) {
                TrainedModelAssignment trainedModelAssignment = assignment.getDeploymentAssignment(stats.getDeploymentId());
                if (trainedModelAssignment == null) continue;
                stats.setState(trainedModelAssignment.getAssignmentState()).setReason((String)trainedModelAssignment.getReason().orElse(null));
                if (!trainedModelAssignment.getNodeRoutingTable().isEmpty() && trainedModelAssignment.getNodeRoutingTable().values().stream().allMatch(ri -> ri.getState().equals((Object)RoutingState.FAILED))) {
                    stats.setState(AssignmentState.FAILED);
                    if (stats.getReason() == null) {
                        stats.setReason("All node routes are failed; see node route reason for details");
                    }
                }
                if (!trainedModelAssignment.getAssignmentState().isAnyOf(new AssignmentState[]{AssignmentState.STARTED, AssignmentState.STARTING})) continue;
                stats.setAllocationStatus((AllocationStatus)trainedModelAssignment.calculateAllocationStatus().orElse(null));
            }
            return updatedResponse;
        });
        super.doExecute(task, (BaseTasksRequest)request, addFailedListener);
    }

    static GetDeploymentStatsAction.Response addFailedRoutes(GetDeploymentStatsAction.Response tasksResponse, Map<TrainedModelAssignment, Map<String, RoutingInfo>> assignmentNonStartedRoutes, DiscoveryNodes nodes) {
        Map deploymentToAssignmentWithNonStartedRoutes = assignmentNonStartedRoutes.keySet().stream().collect(Collectors.toMap(TrainedModelAssignment::getDeploymentId, Function.identity()));
        ArrayList<AssignmentStats> updatedAssignmentStats = new ArrayList<AssignmentStats>();
        for (AssignmentStats assignmentStats : tasksResponse.getStats().results()) {
            if (deploymentToAssignmentWithNonStartedRoutes.containsKey(assignmentStats.getDeploymentId())) {
                Map<String, RoutingInfo> nodeToRoutingStates = assignmentNonStartedRoutes.get(deploymentToAssignmentWithNonStartedRoutes.get(assignmentStats.getDeploymentId()));
                ArrayList<AssignmentStats.NodeStats> updatedNodeStats = new ArrayList<AssignmentStats.NodeStats>();
                HashSet<String> visitedNodes = new HashSet<String>();
                for (AssignmentStats.NodeStats nodeStats : assignmentStats.getNodeStats()) {
                    if (nodeToRoutingStates.containsKey(nodeStats.getNode().getId())) {
                        RoutingInfo routingInfo = nodeToRoutingStates.get(nodeStats.getNode().getId());
                        updatedNodeStats.add(AssignmentStats.NodeStats.forNotStartedState((DiscoveryNode)nodeStats.getNode(), (RoutingState)routingInfo.getState(), (String)routingInfo.getReason()));
                    } else {
                        updatedNodeStats.add(nodeStats);
                    }
                    visitedNodes.add(nodeStats.getNode().getId());
                }
                for (Map.Entry entry : nodeToRoutingStates.entrySet()) {
                    if (visitedNodes.contains(entry.getKey())) continue;
                    updatedNodeStats.add(AssignmentStats.NodeStats.forNotStartedState((DiscoveryNode)nodes.get((String)entry.getKey()), (RoutingState)((RoutingInfo)entry.getValue()).getState(), (String)((RoutingInfo)entry.getValue()).getReason()));
                }
                updatedNodeStats.sort(Comparator.comparing(n -> n.getNode().getId()));
                updatedAssignmentStats.add(new AssignmentStats(assignmentStats.getDeploymentId(), assignmentStats.getModelId(), assignmentStats.getThreadsPerAllocation(), assignmentStats.getNumberOfAllocations(), assignmentStats.getQueueCapacity(), assignmentStats.getCacheSize(), assignmentStats.getStartTime(), updatedNodeStats, assignmentStats.getPriority()));
                continue;
            }
            updatedAssignmentStats.add(assignmentStats);
        }
        for (Map.Entry entry : assignmentNonStartedRoutes.entrySet()) {
            TrainedModelAssignment assignment = (TrainedModelAssignment)entry.getKey();
            String deploymentId = assignment.getDeploymentId();
            if (tasksResponse.getStats().results().stream().anyMatch(e -> deploymentId.equals(e.getDeploymentId()))) continue;
            ArrayList<AssignmentStats.NodeStats> nodeStats = new ArrayList<AssignmentStats.NodeStats>();
            for (Map.Entry entry2 : ((Map)entry.getValue()).entrySet()) {
                nodeStats.add(AssignmentStats.NodeStats.forNotStartedState((DiscoveryNode)nodes.get((String)entry2.getKey()), (RoutingState)((RoutingInfo)entry2.getValue()).getState(), (String)((RoutingInfo)entry2.getValue()).getReason()));
            }
            nodeStats.sort(Comparator.comparing(n -> n.getNode().getId()));
            updatedAssignmentStats.add(new AssignmentStats(deploymentId, assignment.getModelId(), Integer.valueOf(assignment.getTaskParams().getThreadsPerAllocation()), Integer.valueOf(assignment.getTaskParams().getNumberOfAllocations()), Integer.valueOf(assignment.getTaskParams().getQueueCapacity()), (ByteSizeValue)assignment.getTaskParams().getCacheSize().orElse(null), assignment.getStartTime(), nodeStats, assignment.getTaskParams().getPriority()));
        }
        updatedAssignmentStats.sort(Comparator.comparing(AssignmentStats::getDeploymentId));
        return new GetDeploymentStatsAction.Response(tasksResponse.getTaskFailures(), tasksResponse.getNodeFailures(), updatedAssignmentStats, (long)updatedAssignmentStats.size());
    }

    protected void taskOperation(CancellableTask actionTask, GetDeploymentStatsAction.Request request, TrainedModelDeploymentTask task, ActionListener<AssignmentStats> listener) {
        Optional<ModelStats> stats = task.modelStats();
        ArrayList<AssignmentStats.NodeStats> nodeStats = new ArrayList<AssignmentStats.NodeStats>();
        if (stats.isPresent()) {
            ModelStats presentValue = stats.get();
            nodeStats.add(AssignmentStats.NodeStats.forStartedState((DiscoveryNode)this.clusterService.localNode(), (long)presentValue.inferenceCount(), (Double)presentValue.averageInferenceTime(), (Double)presentValue.averageInferenceTimeNoCacheHits(), (int)presentValue.pendingCount(), (int)presentValue.errorCount(), (long)presentValue.cacheHitCount(), (int)presentValue.rejectedExecutionCount(), (int)presentValue.timeoutCount(), (Instant)presentValue.lastUsed(), (Instant)presentValue.startTime(), (Integer)presentValue.threadsPerAllocation(), (Integer)presentValue.numberOfAllocations(), (long)presentValue.peakThroughput(), (long)presentValue.throughputLastPeriod(), (Double)presentValue.avgInferenceTimeLastPeriod(), (long)presentValue.cacheHitCountLastPeriod()));
        } else {
            nodeStats.add(AssignmentStats.NodeStats.forNotStartedState((DiscoveryNode)this.clusterService.localNode(), (RoutingState)RoutingState.STOPPED, (String)""));
        }
        TrainedModelAssignment assignment = TrainedModelAssignmentMetadata.fromState(this.clusterService.state()).getDeploymentAssignment(task.getDeploymentId());
        listener.onResponse((Object)new AssignmentStats(task.getDeploymentId(), task.getParams().getModelId(), Integer.valueOf(task.getParams().getThreadsPerAllocation()), Integer.valueOf(assignment == null ? task.getParams().getNumberOfAllocations() : assignment.getTaskParams().getNumberOfAllocations()), Integer.valueOf(task.getParams().getQueueCapacity()), (ByteSizeValue)task.getParams().getCacheSize().orElse(null), TrainedModelAssignmentMetadata.fromState(this.clusterService.state()).getDeploymentAssignment(task.getDeploymentId()).getStartTime(), nodeStats, task.getParams().getPriority()));
    }
}

