/*
 * Decompiled with CFR 0.152.
 */
package org.elasticsearch.xpack.ml.action;

import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.LinkedHashMap;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.function.Function;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import org.elasticsearch.ElasticsearchException;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.action.ActionRequest;
import org.elasticsearch.action.ActionType;
import org.elasticsearch.action.admin.cluster.node.stats.NodeStats;
import org.elasticsearch.action.admin.cluster.node.stats.NodesStatsAction;
import org.elasticsearch.action.admin.cluster.node.stats.NodesStatsRequest;
import org.elasticsearch.action.admin.cluster.node.stats.NodesStatsResponse;
import org.elasticsearch.action.search.SearchAction;
import org.elasticsearch.action.search.SearchRequest;
import org.elasticsearch.action.support.ActionFilters;
import org.elasticsearch.action.support.HandledTransportAction;
import org.elasticsearch.client.internal.Client;
import org.elasticsearch.cluster.ClusterState;
import org.elasticsearch.cluster.service.ClusterService;
import org.elasticsearch.common.document.DocumentField;
import org.elasticsearch.common.inject.Inject;
import org.elasticsearch.common.metrics.CounterMetric;
import org.elasticsearch.common.util.set.Sets;
import org.elasticsearch.core.Tuple;
import org.elasticsearch.index.query.BoolQueryBuilder;
import org.elasticsearch.index.query.QueryBuilder;
import org.elasticsearch.index.query.QueryBuilders;
import org.elasticsearch.ingest.IngestMetadata;
import org.elasticsearch.ingest.IngestService;
import org.elasticsearch.ingest.IngestStats;
import org.elasticsearch.ingest.Pipeline;
import org.elasticsearch.script.ScriptService;
import org.elasticsearch.search.SearchHit;
import org.elasticsearch.search.sort.SortOrder;
import org.elasticsearch.tasks.Task;
import org.elasticsearch.transport.TransportService;
import org.elasticsearch.xpack.core.ClientHelper;
import org.elasticsearch.xpack.core.ml.action.GetDeploymentStatsAction;
import org.elasticsearch.xpack.core.ml.action.GetTrainedModelsAction;
import org.elasticsearch.xpack.core.ml.action.GetTrainedModelsStatsAction;
import org.elasticsearch.xpack.core.ml.action.StartTrainedModelDeploymentAction;
import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig;
import org.elasticsearch.xpack.core.ml.inference.TrainedModelType;
import org.elasticsearch.xpack.core.ml.inference.allocation.AllocationStats;
import org.elasticsearch.xpack.core.ml.inference.persistence.InferenceIndexConstants;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceStats;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TrainedModelSizeStats;
import org.elasticsearch.xpack.ml.inference.ModelAliasMetadata;
import org.elasticsearch.xpack.ml.inference.ingest.InferenceProcessor;
import org.elasticsearch.xpack.ml.inference.persistence.TrainedModelDefinitionDoc;
import org.elasticsearch.xpack.ml.inference.persistence.TrainedModelProvider;

public class TransportGetTrainedModelsStatsAction
extends HandledTransportAction<GetTrainedModelsStatsAction.Request, GetTrainedModelsStatsAction.Response> {
    private final Client client;
    private final ClusterService clusterService;
    private final IngestService ingestService;
    private final TrainedModelProvider trainedModelProvider;

    @Inject
    public TransportGetTrainedModelsStatsAction(TransportService transportService, ActionFilters actionFilters, ClusterService clusterService, IngestService ingestService, TrainedModelProvider trainedModelProvider, Client client) {
        super("cluster:monitor/xpack/ml/inference/stats/get", transportService, actionFilters, GetTrainedModelsStatsAction.Request::new);
        this.client = client;
        this.clusterService = clusterService;
        this.ingestService = ingestService;
        this.trainedModelProvider = trainedModelProvider;
    }

    protected void doExecute(Task task, GetTrainedModelsStatsAction.Request request, ActionListener<GetTrainedModelsStatsAction.Response> listener) {
        ModelAliasMetadata currentMetadata = ModelAliasMetadata.fromState(this.clusterService.state());
        GetTrainedModelsStatsAction.Response.Builder responseBuilder = new GetTrainedModelsStatsAction.Response.Builder();
        ActionListener modelSizeStatsListener = ActionListener.wrap(modelSizeStatsByModelId -> {
            responseBuilder.setModelSizeStatsByModelId(modelSizeStatsByModelId);
            listener.onResponse((Object)responseBuilder.build());
        }, arg_0 -> listener.onFailure(arg_0));
        ActionListener deploymentStatsListener = ActionListener.wrap(deploymentStats -> {
            responseBuilder.setDeploymentStatsByModelId(deploymentStats.getStats().results().stream().collect(Collectors.toMap(AllocationStats::getModelId, Function.identity())));
            this.modelSizeStats(responseBuilder.getExpandedIdsWithAliases(), request.isAllowNoResources(), (ActionListener<Map<String, TrainedModelSizeStats>>)modelSizeStatsListener);
        }, arg_0 -> listener.onFailure(arg_0));
        ActionListener inferenceStatsListener = ActionListener.wrap(inferenceStats -> {
            responseBuilder.setInferenceStatsByModelId(inferenceStats.stream().collect(Collectors.toMap(InferenceStats::getModelId, Function.identity())));
            ClientHelper.executeAsyncWithOrigin((Client)this.client, (String)"ml", (ActionType)GetDeploymentStatsAction.INSTANCE, (ActionRequest)new GetDeploymentStatsAction.Request(request.getResourceId()), (ActionListener)deploymentStatsListener);
        }, arg_0 -> listener.onFailure(arg_0));
        ActionListener nodesStatsListener = ActionListener.wrap(nodesStatsResponse -> {
            Set<String> allPossiblePipelineReferences = responseBuilder.getExpandedIdsWithAliases().entrySet().stream().flatMap(entry -> Stream.concat(((Set)entry.getValue()).stream(), Stream.of((String)entry.getKey()))).collect(Collectors.toSet());
            Map<String, Set<String>> pipelineIdsByModelIdsOrAliases = TransportGetTrainedModelsStatsAction.pipelineIdsByModelIdsOrAliases(this.clusterService.state(), this.ingestService, allPossiblePipelineReferences);
            Map<String, IngestStats> modelIdIngestStats = TransportGetTrainedModelsStatsAction.inferenceIngestStatsByModelId(nodesStatsResponse, currentMetadata, pipelineIdsByModelIdsOrAliases);
            responseBuilder.setIngestStatsByModelId(modelIdIngestStats);
            this.trainedModelProvider.getInferenceStats(responseBuilder.getExpandedIdsWithAliases().keySet().toArray(new String[0]), (ActionListener<List<InferenceStats>>)inferenceStatsListener);
        }, arg_0 -> listener.onFailure(arg_0));
        ActionListener idsListener = ActionListener.wrap(tuple -> {
            responseBuilder.setExpandedIdsWithAliases((Map)tuple.v2()).setTotalModelCount(((Long)tuple.v1()).longValue());
            String[] ingestNodes = TransportGetTrainedModelsStatsAction.ingestNodes(this.clusterService.state());
            NodesStatsRequest nodesStatsRequest = new NodesStatsRequest(ingestNodes).clear().addMetric(NodesStatsRequest.Metric.INGEST.metricName());
            ClientHelper.executeAsyncWithOrigin((Client)this.client, (String)"ml", (ActionType)NodesStatsAction.INSTANCE, (ActionRequest)nodesStatsRequest, (ActionListener)nodesStatsListener);
        }, arg_0 -> listener.onFailure(arg_0));
        this.trainedModelProvider.expandIds(request.getResourceId(), request.isAllowNoResources(), request.getPageParams(), Collections.emptySet(), currentMetadata, (ActionListener<Tuple<Long, Map<String, Set<String>>>>)idsListener);
    }

    private void modelSizeStats(Map<String, Set<String>> expandedIdsWithAliases, boolean allowNoResources, ActionListener<Map<String, TrainedModelSizeStats>> listener) {
        ActionListener modelsListener = ActionListener.wrap(models -> {
            List<String> pytorchModelIds = models.stream().filter(m -> m.getModelType() == TrainedModelType.PYTORCH).map(TrainedModelConfig::getModelId).toList();
            this.definitionLengths(pytorchModelIds, (ActionListener<Map<String, Long>>)ActionListener.wrap(pytorchTotalDefinitionLengthsByModelId -> {
                HashMap<String, TrainedModelSizeStats> modelSizeStatsByModelId = new HashMap<String, TrainedModelSizeStats>();
                for (TrainedModelConfig model : models) {
                    if (model.getModelType() == TrainedModelType.PYTORCH) {
                        long totalDefinitionLength = pytorchTotalDefinitionLengthsByModelId.getOrDefault(model.getModelId(), 0L);
                        modelSizeStatsByModelId.put(model.getModelId(), new TrainedModelSizeStats(totalDefinitionLength, totalDefinitionLength > 0L ? StartTrainedModelDeploymentAction.estimateMemoryUsageBytes((long)totalDefinitionLength) : 0L));
                        continue;
                    }
                    modelSizeStatsByModelId.put(model.getModelId(), new TrainedModelSizeStats(model.getModelSize(), 0L));
                }
                listener.onResponse(modelSizeStatsByModelId);
            }, arg_0 -> ((ActionListener)listener).onFailure(arg_0)));
        }, arg_0 -> listener.onFailure(arg_0));
        this.trainedModelProvider.getTrainedModels(expandedIdsWithAliases, GetTrainedModelsAction.Includes.empty(), allowNoResources, (ActionListener<List<TrainedModelConfig>>)modelsListener);
    }

    private void definitionLengths(List<String> modelIds, ActionListener<Map<String, Long>> listener) {
        BoolQueryBuilder query = QueryBuilders.boolQuery().filter((QueryBuilder)QueryBuilders.termQuery((String)InferenceIndexConstants.DOC_TYPE.getPreferredName(), (String)"trained_model_definition_doc")).filter((QueryBuilder)QueryBuilders.termsQuery((String)TrainedModelConfig.MODEL_ID.getPreferredName(), modelIds)).filter((QueryBuilder)QueryBuilders.termQuery((String)TrainedModelDefinitionDoc.DOC_NUM.getPreferredName(), (int)0));
        SearchRequest searchRequest = (SearchRequest)this.client.prepareSearch(new String[]{".ml-inference-*"}).setQuery((QueryBuilder)QueryBuilders.constantScoreQuery((QueryBuilder)query)).setFetchSource(false).addDocValueField(TrainedModelConfig.MODEL_ID.getPreferredName()).addDocValueField(TrainedModelDefinitionDoc.TOTAL_DEFINITION_LENGTH.getPreferredName()).addSort("_index", SortOrder.DESC).request();
        ClientHelper.executeAsyncWithOrigin((Client)this.client, (String)"ml", (ActionType)SearchAction.INSTANCE, (ActionRequest)searchRequest, (ActionListener)ActionListener.wrap(searchResponse -> {
            HashMap<String, Long> totalDefinitionLengthByModelId = new HashMap<String, Long>();
            for (SearchHit hit : searchResponse.getHits().getHits()) {
                Object patt11895$temp;
                Object patt11577$temp;
                DocumentField modelIdField = hit.field(TrainedModelConfig.MODEL_ID.getPreferredName());
                if (modelIdField == null || !((patt11577$temp = modelIdField.getValue()) instanceof String)) continue;
                String modelId = (String)patt11577$temp;
                DocumentField totalDefinitionLengthField = hit.field(TrainedModelDefinitionDoc.TOTAL_DEFINITION_LENGTH.getPreferredName());
                if (totalDefinitionLengthField == null || !((patt11895$temp = totalDefinitionLengthField.getValue()) instanceof Long)) continue;
                Long totalDefinitionLength = (Long)patt11895$temp;
                totalDefinitionLengthByModelId.put(modelId, totalDefinitionLength);
            }
            listener.onResponse(totalDefinitionLengthByModelId);
        }, arg_0 -> listener.onFailure(arg_0)));
    }

    static Map<String, IngestStats> inferenceIngestStatsByModelId(NodesStatsResponse response, ModelAliasMetadata currentMetadata, Map<String, Set<String>> modelIdToPipelineId) {
        HashMap<String, IngestStats> ingestStatsMap = new HashMap<String, IngestStats>();
        Map<String, Set> trueModelIdToPipelines = modelIdToPipelineId.entrySet().stream().collect(Collectors.toMap(entry -> {
            String maybeModelId = currentMetadata.getModelId((String)entry.getKey());
            return maybeModelId == null ? (String)entry.getKey() : maybeModelId;
        }, Map.Entry::getValue, Sets::union));
        trueModelIdToPipelines.forEach((modelId, pipelineIds) -> {
            List<IngestStats> collectedStats = response.getNodes().stream().map(nodeStats -> TransportGetTrainedModelsStatsAction.ingestStatsForPipelineIds(nodeStats, pipelineIds)).collect(Collectors.toList());
            ingestStatsMap.put((String)modelId, TransportGetTrainedModelsStatsAction.mergeStats(collectedStats));
        });
        return ingestStatsMap;
    }

    static String[] ingestNodes(ClusterState clusterState) {
        return (String[])clusterState.nodes().getIngestNodes().keySet().toArray(String[]::new);
    }

    static Map<String, Set<String>> pipelineIdsByModelIdsOrAliases(ClusterState state, IngestService ingestService, Set<String> modelIds) {
        IngestMetadata ingestMetadata = (IngestMetadata)state.metadata().custom("ingest");
        HashMap<String, Set<String>> pipelineIdsByModelIds = new HashMap<String, Set<String>>();
        if (ingestMetadata == null) {
            return pipelineIdsByModelIds;
        }
        ingestMetadata.getPipelines().forEach((pipelineId, pipelineConfiguration) -> {
            try {
                Pipeline pipeline = Pipeline.create((String)pipelineId, (Map)pipelineConfiguration.getConfigAsMap(), (Map)ingestService.getProcessorFactories(), (ScriptService)ingestService.getScriptService());
                pipeline.getProcessors().forEach(processor -> {
                    InferenceProcessor inferenceProcessor;
                    if (processor instanceof InferenceProcessor && modelIds.contains((inferenceProcessor = (InferenceProcessor)((Object)((Object)processor))).getModelId())) {
                        pipelineIdsByModelIds.computeIfAbsent(inferenceProcessor.getModelId(), m -> new LinkedHashSet()).add(pipelineId);
                    }
                });
            }
            catch (Exception ex) {
                throw new ElasticsearchException("unexpected failure gathering pipeline information", (Throwable)ex, new Object[0]);
            }
        });
        return pipelineIdsByModelIds;
    }

    static IngestStats ingestStatsForPipelineIds(NodeStats nodeStats, Set<String> pipelineIds) {
        IngestStats fullNodeStats = nodeStats.getIngestStats();
        HashMap filteredProcessorStats = new HashMap(fullNodeStats.getProcessorStats());
        filteredProcessorStats.keySet().retainAll(pipelineIds);
        List<IngestStats.PipelineStat> filteredPipelineStats = fullNodeStats.getPipelineStats().stream().filter(pipelineStat -> pipelineIds.contains(pipelineStat.getPipelineId())).collect(Collectors.toList());
        CounterMetric ingestCount = new CounterMetric();
        CounterMetric ingestTimeInMillis = new CounterMetric();
        CounterMetric ingestCurrent = new CounterMetric();
        CounterMetric ingestFailedCount = new CounterMetric();
        filteredPipelineStats.forEach(pipelineStat -> {
            IngestStats.Stats stats = pipelineStat.getStats();
            ingestCount.inc(stats.getIngestCount());
            ingestTimeInMillis.inc(stats.getIngestTimeInMillis());
            ingestCurrent.inc(stats.getIngestCurrent());
            ingestFailedCount.inc(stats.getIngestFailedCount());
        });
        return new IngestStats(new IngestStats.Stats(ingestCount.count(), ingestTimeInMillis.count(), ingestCurrent.count(), ingestFailedCount.count()), filteredPipelineStats, filteredProcessorStats);
    }

    private static IngestStats mergeStats(List<IngestStats> ingestStatsList) {
        LinkedHashMap<String, IngestStatsAccumulator> pipelineStatsAcc = new LinkedHashMap<String, IngestStatsAccumulator>(ingestStatsList.size());
        LinkedHashMap<String, Map> processorStatsAcc = new LinkedHashMap<String, Map>(ingestStatsList.size());
        IngestStatsAccumulator totalStats = new IngestStatsAccumulator();
        ingestStatsList.forEach(ingestStats -> {
            ingestStats.getPipelineStats().forEach(pipelineStat -> pipelineStatsAcc.computeIfAbsent(pipelineStat.getPipelineId(), p -> new IngestStatsAccumulator()).inc(pipelineStat.getStats()));
            ingestStats.getProcessorStats().forEach((pipelineId, processorStat) -> {
                Map processorAcc = processorStatsAcc.computeIfAbsent((String)pipelineId, k -> new LinkedHashMap());
                processorStat.forEach(p -> processorAcc.computeIfAbsent(p.getName(), k -> new IngestStatsAccumulator(p.getType())).inc(p.getStats()));
            });
            totalStats.inc(ingestStats.getTotalStats());
        });
        ArrayList pipelineStatList = new ArrayList(pipelineStatsAcc.size());
        pipelineStatsAcc.forEach((pipelineId, accumulator) -> pipelineStatList.add(new IngestStats.PipelineStat(pipelineId, accumulator.build())));
        LinkedHashMap processorStatList = new LinkedHashMap(processorStatsAcc.size());
        processorStatsAcc.forEach((pipelineId, accumulatorMap) -> {
            ArrayList processorStats = new ArrayList(accumulatorMap.size());
            accumulatorMap.forEach((processorName, acc) -> processorStats.add(new IngestStats.ProcessorStat(processorName, acc.type, acc.build())));
            processorStatList.put(pipelineId, processorStats);
        });
        return new IngestStats(totalStats.build(), pipelineStatList, processorStatList);
    }

    private static class IngestStatsAccumulator {
        CounterMetric ingestCount = new CounterMetric();
        CounterMetric ingestTimeInMillis = new CounterMetric();
        CounterMetric ingestCurrent = new CounterMetric();
        CounterMetric ingestFailedCount = new CounterMetric();
        String type;

        IngestStatsAccumulator() {
        }

        IngestStatsAccumulator(String type) {
            this.type = type;
        }

        IngestStatsAccumulator inc(IngestStats.Stats s) {
            this.ingestCount.inc(s.getIngestCount());
            this.ingestTimeInMillis.inc(s.getIngestTimeInMillis());
            this.ingestCurrent.inc(s.getIngestCurrent());
            this.ingestFailedCount.inc(s.getIngestFailedCount());
            return this;
        }

        IngestStats.Stats build() {
            return new IngestStats.Stats(this.ingestCount.count(), this.ingestTimeInMillis.count(), this.ingestCurrent.count(), this.ingestFailedCount.count());
        }
    }
}

