/*
 * Decompiled with CFR 0.152.
 */
package org.elasticsearch.xpack.ml;

import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.stream.Collectors;
import org.apache.lucene.util.Counter;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.action.ActionRequest;
import org.elasticsearch.action.ActionType;
import org.elasticsearch.action.admin.cluster.node.stats.NodeStats;
import org.elasticsearch.action.admin.cluster.node.stats.NodesStatsAction;
import org.elasticsearch.action.admin.cluster.node.stats.NodesStatsRequest;
import org.elasticsearch.action.admin.cluster.node.stats.NodesStatsResponse;
import org.elasticsearch.action.support.ActionFilters;
import org.elasticsearch.client.Client;
import org.elasticsearch.client.OriginSettingClient;
import org.elasticsearch.cluster.ClusterState;
import org.elasticsearch.cluster.metadata.IndexNameExpressionResolver;
import org.elasticsearch.cluster.node.DiscoveryNode;
import org.elasticsearch.cluster.service.ClusterService;
import org.elasticsearch.common.inject.Inject;
import org.elasticsearch.env.Environment;
import org.elasticsearch.ingest.IngestStats;
import org.elasticsearch.license.XPackLicenseState;
import org.elasticsearch.protocol.xpack.XPackUsageRequest;
import org.elasticsearch.tasks.Task;
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.transport.TransportService;
import org.elasticsearch.xpack.core.XPackFeatureSet;
import org.elasticsearch.xpack.core.XPackSettings;
import org.elasticsearch.xpack.core.action.XPackUsageFeatureAction;
import org.elasticsearch.xpack.core.action.XPackUsageFeatureResponse;
import org.elasticsearch.xpack.core.action.XPackUsageFeatureTransportAction;
import org.elasticsearch.xpack.core.action.util.PageParams;
import org.elasticsearch.xpack.core.action.util.QueryPage;
import org.elasticsearch.xpack.core.ml.MachineLearningFeatureSetUsage;
import org.elasticsearch.xpack.core.ml.MachineLearningField;
import org.elasticsearch.xpack.core.ml.action.GetDataFrameAnalyticsAction;
import org.elasticsearch.xpack.core.ml.action.GetDataFrameAnalyticsStatsAction;
import org.elasticsearch.xpack.core.ml.action.GetDatafeedsStatsAction;
import org.elasticsearch.xpack.core.ml.action.GetDeploymentStatsAction;
import org.elasticsearch.xpack.core.ml.action.GetJobsStatsAction;
import org.elasticsearch.xpack.core.ml.action.GetTrainedModelsAction;
import org.elasticsearch.xpack.core.ml.datafeed.DatafeedState;
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig;
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsState;
import org.elasticsearch.xpack.core.ml.dataframe.stats.common.MemoryUsage;
import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig;
import org.elasticsearch.xpack.core.ml.inference.allocation.AllocationStats;
import org.elasticsearch.xpack.core.ml.inference.allocation.TrainedModelAllocation;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig;
import org.elasticsearch.xpack.core.ml.job.config.Job;
import org.elasticsearch.xpack.core.ml.job.config.JobState;
import org.elasticsearch.xpack.core.ml.job.process.autodetect.state.ModelSizeStats;
import org.elasticsearch.xpack.core.ml.stats.ForecastStats;
import org.elasticsearch.xpack.core.ml.stats.StatsAccumulator;
import org.elasticsearch.xpack.ml.MachineLearning;
import org.elasticsearch.xpack.ml.inference.allocation.TrainedModelAllocationMetadata;
import org.elasticsearch.xpack.ml.job.JobManagerHolder;

public class MachineLearningUsageTransportAction
extends XPackUsageFeatureTransportAction {
    private final Client client;
    private final XPackLicenseState licenseState;
    private final JobManagerHolder jobManagerHolder;
    private final boolean enabled;

    @Inject
    public MachineLearningUsageTransportAction(TransportService transportService, ClusterService clusterService, ThreadPool threadPool, ActionFilters actionFilters, IndexNameExpressionResolver indexNameExpressionResolver, Environment environment, Client client, XPackLicenseState licenseState, JobManagerHolder jobManagerHolder) {
        super(XPackUsageFeatureAction.MACHINE_LEARNING.name(), transportService, clusterService, threadPool, actionFilters, indexNameExpressionResolver);
        this.client = new OriginSettingClient(client, "ml");
        this.licenseState = licenseState;
        this.jobManagerHolder = jobManagerHolder;
        this.enabled = (Boolean)XPackSettings.MACHINE_LEARNING_ENABLED.get(environment.settings());
    }

    protected void masterOperation(Task task, XPackUsageRequest request, ClusterState state, ActionListener<XPackUsageFeatureResponse> listener) {
        if (!this.enabled) {
            MachineLearningFeatureSetUsage usage = new MachineLearningFeatureSetUsage(MachineLearningField.ML_API_FEATURE.checkWithoutTracking(this.licenseState), this.enabled, Collections.emptyMap(), Collections.emptyMap(), Collections.emptyMap(), Collections.emptyMap(), 0);
            listener.onResponse((Object)new XPackUsageFeatureResponse((XPackFeatureSet.Usage)usage));
            return;
        }
        LinkedHashMap jobsUsage = new LinkedHashMap();
        LinkedHashMap datafeedsUsage = new LinkedHashMap();
        LinkedHashMap analyticsUsage = new LinkedHashMap();
        LinkedHashMap inferenceUsage = new LinkedHashMap();
        int nodeCount = MachineLearningUsageTransportAction.mlNodeCount(state);
        ActionListener trainedModelDeploymentsListener = ActionListener.wrap(response -> {
            this.addDeploymentStats((GetDeploymentStatsAction.Response)response, inferenceUsage);
            listener.onResponse((Object)new XPackUsageFeatureResponse((XPackFeatureSet.Usage)new MachineLearningFeatureSetUsage(MachineLearningField.ML_API_FEATURE.checkWithoutTracking(this.licenseState), this.enabled, jobsUsage, datafeedsUsage, analyticsUsage, inferenceUsage, nodeCount)));
        }, arg_0 -> listener.onFailure(arg_0));
        ActionListener trainedModelsListener = ActionListener.wrap(response -> {
            this.addTrainedModelStats((GetTrainedModelsAction.Response)response, inferenceUsage);
            this.client.execute((ActionType)GetDeploymentStatsAction.INSTANCE, (ActionRequest)new GetDeploymentStatsAction.Request("_all"), trainedModelDeploymentsListener);
        }, arg_0 -> listener.onFailure(arg_0));
        ActionListener nodesStatsListener = ActionListener.wrap(response -> {
            this.addInferenceIngestUsage((NodesStatsResponse)response, inferenceUsage);
            GetTrainedModelsAction.Request getModelsRequest = new GetTrainedModelsAction.Request("*", Collections.emptyList(), Collections.emptySet());
            getModelsRequest.setPageParams(new PageParams(0, 10000));
            this.client.execute((ActionType)GetTrainedModelsAction.INSTANCE, (ActionRequest)getModelsRequest, trainedModelsListener);
        }, arg_0 -> listener.onFailure(arg_0));
        ActionListener dataframeAnalyticsListener = ActionListener.wrap(response -> {
            this.addDataFrameAnalyticsUsage((GetDataFrameAnalyticsAction.Response)response, analyticsUsage);
            String[] ingestNodes = MachineLearningUsageTransportAction.ingestNodes(state);
            NodesStatsRequest nodesStatsRequest = new NodesStatsRequest(ingestNodes).clear().addMetric(NodesStatsRequest.Metric.INGEST.metricName());
            this.client.execute((ActionType)NodesStatsAction.INSTANCE, (ActionRequest)nodesStatsRequest, nodesStatsListener);
        }, arg_0 -> listener.onFailure(arg_0));
        ActionListener dataframeAnalyticsStatsListener = ActionListener.wrap(response -> {
            this.addDataFrameAnalyticsStatsUsage((GetDataFrameAnalyticsStatsAction.Response)response, analyticsUsage);
            GetDataFrameAnalyticsAction.Request getDfaRequest = new GetDataFrameAnalyticsAction.Request("_all");
            getDfaRequest.setPageParams(new PageParams(0, 10000));
            this.client.execute((ActionType)GetDataFrameAnalyticsAction.INSTANCE, (ActionRequest)getDfaRequest, dataframeAnalyticsListener);
        }, arg_0 -> listener.onFailure(arg_0));
        ActionListener datafeedStatsListener = ActionListener.wrap(response -> {
            this.addDatafeedsUsage((GetDatafeedsStatsAction.Response)response, datafeedsUsage);
            GetDataFrameAnalyticsStatsAction.Request dataframeAnalyticsStatsRequest = new GetDataFrameAnalyticsStatsAction.Request("_all");
            dataframeAnalyticsStatsRequest.setPageParams(new PageParams(0, 10000));
            this.client.execute((ActionType)GetDataFrameAnalyticsStatsAction.INSTANCE, (ActionRequest)dataframeAnalyticsStatsRequest, dataframeAnalyticsStatsListener);
        }, arg_0 -> listener.onFailure(arg_0));
        GetJobsStatsAction.Request jobStatsRequest = new GetJobsStatsAction.Request("_all");
        ActionListener jobStatsListener = ActionListener.wrap(response -> this.jobManagerHolder.getJobManager().expandJobs("_all", true, (ActionListener<QueryPage<Job>>)ActionListener.wrap(jobs -> {
            this.addJobsUsage((GetJobsStatsAction.Response)response, jobs.results(), jobsUsage);
            GetDatafeedsStatsAction.Request datafeedStatsRequest = new GetDatafeedsStatsAction.Request("_all");
            this.client.execute((ActionType)GetDatafeedsStatsAction.INSTANCE, (ActionRequest)datafeedStatsRequest, datafeedStatsListener);
        }, arg_0 -> ((ActionListener)listener).onFailure(arg_0))), arg_0 -> listener.onFailure(arg_0));
        this.client.execute((ActionType)GetJobsStatsAction.INSTANCE, (ActionRequest)jobStatsRequest, jobStatsListener);
    }

    private void addJobsUsage(GetJobsStatsAction.Response response, List<Job> jobs, Map<String, Object> jobsUsage) {
        StatsAccumulator allJobsDetectorsStats = new StatsAccumulator();
        StatsAccumulator allJobsModelSizeStats = new StatsAccumulator();
        ForecastStats allJobsForecastStats = new ForecastStats();
        HashMap<JobState, Counter> jobCountByState = new HashMap<JobState, Counter>();
        HashMap<JobState, StatsAccumulator> detectorStatsByState = new HashMap<JobState, StatsAccumulator>();
        HashMap<JobState, StatsAccumulator> modelSizeStatsByState = new HashMap<JobState, StatsAccumulator>();
        HashMap<JobState, ForecastStats> forecastStatsByState = new HashMap<JobState, ForecastStats>();
        HashMap<JobState, Map> createdByByState = new HashMap<JobState, Map>();
        List jobsStats = response.getResponse().results();
        Map<String, Job> jobMap = jobs.stream().collect(Collectors.toMap(Job::getId, item -> item));
        Map<String, Long> allJobsCreatedBy = jobs.stream().map(this::jobCreatedBy).collect(Collectors.groupingBy(item -> item, Collectors.counting()));
        for (GetJobsStatsAction.Response.JobStats jobStats : jobsStats) {
            Job job = jobMap.get(jobStats.getJobId());
            if (job == null) continue;
            int detectorsCount = job.getAnalysisConfig().getDetectors().size();
            ModelSizeStats modelSizeStats = jobStats.getModelSizeStats();
            double modelSize = modelSizeStats == null ? 0.0 : (double)jobStats.getModelSizeStats().getModelBytes();
            allJobsForecastStats.merge(jobStats.getForecastStats());
            allJobsDetectorsStats.add((double)detectorsCount);
            allJobsModelSizeStats.add(modelSize);
            JobState jobState = jobStats.getState();
            jobCountByState.computeIfAbsent(jobState, js -> Counter.newCounter()).addAndGet(1L);
            detectorStatsByState.computeIfAbsent(jobState, js -> new StatsAccumulator()).add((double)detectorsCount);
            modelSizeStatsByState.computeIfAbsent(jobState, js -> new StatsAccumulator()).add(modelSize);
            forecastStatsByState.merge(jobState, jobStats.getForecastStats(), ForecastStats::merge);
            createdByByState.computeIfAbsent(jobState, js -> new HashMap()).compute(this.jobCreatedBy(job), (k, v) -> v == null ? 1L : v + 1L);
        }
        jobsUsage.put("_all", this.createJobUsageEntry(jobs.size(), allJobsDetectorsStats, allJobsModelSizeStats, allJobsForecastStats, allJobsCreatedBy));
        for (JobState jobState : jobCountByState.keySet()) {
            jobsUsage.put(jobState.name().toLowerCase(Locale.ROOT), this.createJobUsageEntry(((Counter)jobCountByState.get(jobState)).get(), (StatsAccumulator)detectorStatsByState.get(jobState), (StatsAccumulator)modelSizeStatsByState.get(jobState), (ForecastStats)forecastStatsByState.get(jobState), (Map)createdByByState.get(jobState)));
        }
    }

    private String jobCreatedBy(Job job) {
        Map customSettings = job.getCustomSettings();
        if (customSettings == null || !customSettings.containsKey("created_by")) {
            return "unknown";
        }
        return customSettings.get("created_by").toString().replaceAll("\\W", "_");
    }

    private Map<String, Object> createJobUsageEntry(long count, StatsAccumulator detectorStats, StatsAccumulator modelSizeStats, ForecastStats forecastStats, Map<String, Long> createdBy) {
        HashMap<String, Object> usage = new HashMap<String, Object>();
        usage.put("count", count);
        usage.put("detectors", detectorStats.asMap());
        usage.put("model_size", modelSizeStats.asMap());
        usage.put("forecasts", forecastStats.asMap());
        usage.put("created_by", createdBy);
        return usage;
    }

    private void addDatafeedsUsage(GetDatafeedsStatsAction.Response response, Map<String, Object> datafeedsUsage) {
        HashMap<DatafeedState, Counter> datafeedCountByState = new HashMap<DatafeedState, Counter>();
        List datafeedsStats = response.getResponse().results();
        for (GetDatafeedsStatsAction.Response.DatafeedStats datafeedStats : datafeedsStats) {
            datafeedCountByState.computeIfAbsent(datafeedStats.getDatafeedState(), ds -> Counter.newCounter()).addAndGet(1L);
        }
        datafeedsUsage.put("_all", this.createCountUsageEntry(response.getResponse().count()));
        for (DatafeedState datafeedState : datafeedCountByState.keySet()) {
            datafeedsUsage.put(datafeedState.name().toLowerCase(Locale.ROOT), this.createCountUsageEntry(((Counter)datafeedCountByState.get(datafeedState)).get()));
        }
    }

    private Map<String, Object> createCountUsageEntry(long count) {
        HashMap<String, Object> usage = new HashMap<String, Object>();
        usage.put("count", count);
        return usage;
    }

    private void addDataFrameAnalyticsStatsUsage(GetDataFrameAnalyticsStatsAction.Response response, Map<String, Object> dataframeAnalyticsUsage) {
        HashMap<DataFrameAnalyticsState, Counter> dataFrameAnalyticsStateCounterMap = new HashMap<DataFrameAnalyticsState, Counter>();
        StatsAccumulator memoryUsagePeakBytesStats = new StatsAccumulator();
        for (GetDataFrameAnalyticsStatsAction.Response.Stats stats : response.getResponse().results()) {
            dataFrameAnalyticsStateCounterMap.computeIfAbsent(stats.getState(), ds -> Counter.newCounter()).addAndGet(1L);
            MemoryUsage memoryUsage = stats.getMemoryUsage();
            if (memoryUsage == null || memoryUsage.getPeakUsageBytes() <= 0L) continue;
            memoryUsagePeakBytesStats.add((double)memoryUsage.getPeakUsageBytes());
        }
        dataframeAnalyticsUsage.put("memory_usage", Collections.singletonMap(MemoryUsage.PEAK_USAGE_BYTES.getPreferredName(), memoryUsagePeakBytesStats.asMap()));
        dataframeAnalyticsUsage.put("_all", this.createCountUsageEntry(response.getResponse().count()));
        for (DataFrameAnalyticsState state : dataFrameAnalyticsStateCounterMap.keySet()) {
            dataframeAnalyticsUsage.put(state.name().toLowerCase(Locale.ROOT), this.createCountUsageEntry(((Counter)dataFrameAnalyticsStateCounterMap.get(state)).get()));
        }
    }

    private void addDataFrameAnalyticsUsage(GetDataFrameAnalyticsAction.Response response, Map<String, Object> dataframeAnalyticsUsage) {
        HashMap<String, Integer> perAnalysisTypeCounterMap = new HashMap<String, Integer>();
        for (DataFrameAnalyticsConfig config : response.getResources().results()) {
            int count = perAnalysisTypeCounterMap.computeIfAbsent(config.getAnalysis().getWriteableName(), k -> 0);
            perAnalysisTypeCounterMap.put(config.getAnalysis().getWriteableName(), ++count);
        }
        dataframeAnalyticsUsage.put("analysis_counts", perAnalysisTypeCounterMap);
    }

    private static void initializeStats(Map<String, Long> emptyStatsMap) {
        emptyStatsMap.put("sum", 0L);
        emptyStatsMap.put("min", 0L);
        emptyStatsMap.put("max", 0L);
    }

    private static void updateStats(Map<String, Long> statsMap, Long value) {
        statsMap.computeIfPresent("sum", (k, v) -> v + value);
        statsMap.computeIfPresent("min", (k, v) -> Math.min(v, value));
        statsMap.computeIfPresent("max", (k, v) -> Math.max(v, value));
    }

    private void addDeploymentStats(GetDeploymentStatsAction.Response response, Map<String, Object> inferenceUsage) {
        TrainedModelAllocationMetadata trainedModelAllocationMetadata = TrainedModelAllocationMetadata.fromState(this.clusterService.state());
        StatsAccumulator modelSizes = new StatsAccumulator();
        double avgTimeSum = 0.0;
        StatsAccumulator nodeDistribution = new StatsAccumulator();
        for (AllocationStats stats : response.getStats().results()) {
            TrainedModelAllocation allocation = trainedModelAllocationMetadata.getModelAllocation(stats.getModelId());
            if (allocation != null) {
                modelSizes.add((double)allocation.getTaskParams().getModelBytes());
            }
            for (AllocationStats.NodeStats nodeStats : stats.getNodeStats()) {
                long nodeInferenceCount = nodeStats.getInferenceCount().orElse(0L);
                avgTimeSum += nodeStats.getAvgInferenceTime().orElse(0.0) * (double)nodeInferenceCount;
                nodeDistribution.add((double)nodeInferenceCount);
            }
        }
        inferenceUsage.put("deployments", Map.of("count", response.getStats().count(), "time_ms", Map.of("avg", nodeDistribution.getTotal() == 0.0 ? 0.0 : avgTimeSum / nodeDistribution.getTotal()), "model_sizes_bytes", modelSizes.asMap(), "inference_counts", nodeDistribution.asMap()));
    }

    private void addTrainedModelStats(GetTrainedModelsAction.Response response, Map<String, Object> inferenceUsage) {
        List trainedModelConfigs = response.getResources().results();
        HashMap<String, Map> trainedModelsUsage = new HashMap<String, Map>();
        trainedModelsUsage.put("_all", this.createCountUsageEntry(trainedModelConfigs.size()));
        StatsAccumulator estimatedOperations = new StatsAccumulator();
        StatsAccumulator estimatedMemoryUsageBytes = new StatsAccumulator();
        int createdByAnalyticsCount = 0;
        LinkedHashMap<String, Counter> inferenceConfigCounts = new LinkedHashMap<String, Counter>();
        int prepackagedCount = 0;
        for (TrainedModelConfig trainedModelConfig : trainedModelConfigs) {
            if (trainedModelConfig.getTags().contains("prepackaged")) {
                ++prepackagedCount;
                continue;
            }
            InferenceConfig inferenceConfig = trainedModelConfig.getInferenceConfig();
            if (inferenceConfig != null) {
                inferenceConfigCounts.computeIfAbsent(inferenceConfig.getName(), s -> Counter.newCounter()).addAndGet(1L);
            }
            if (trainedModelConfig.getMetadata() != null && trainedModelConfig.getMetadata().containsKey("analytics_config")) {
                ++createdByAnalyticsCount;
            }
            estimatedOperations.add((double)trainedModelConfig.getEstimatedOperations());
            estimatedMemoryUsageBytes.add((double)trainedModelConfig.getModelSize());
        }
        HashMap<String, Integer> counts = new HashMap<String, Integer>();
        counts.put("total", trainedModelConfigs.size());
        inferenceConfigCounts.forEach((configName, count) -> counts.put((String)configName, (Integer)count.get()));
        counts.put("prepackaged", prepackagedCount);
        counts.put("other", trainedModelConfigs.size() - createdByAnalyticsCount - prepackagedCount);
        trainedModelsUsage.put("count", counts);
        trainedModelsUsage.put(TrainedModelConfig.ESTIMATED_OPERATIONS.getPreferredName(), estimatedOperations.asMap());
        trainedModelsUsage.put(TrainedModelConfig.MODEL_SIZE_BYTES.getPreferredName(), estimatedMemoryUsageBytes.asMap());
        inferenceUsage.put("trained_models", trainedModelsUsage);
    }

    private void addInferenceIngestUsage(NodesStatsResponse response, Map<String, Object> inferenceUsage) {
        HashSet pipelines = new HashSet();
        HashMap<String, Long> docCountStats = new HashMap<String, Long>(3);
        HashMap<String, Long> timeStats = new HashMap<String, Long>(3);
        HashMap<String, Long> failureStats = new HashMap<String, Long>(3);
        MachineLearningUsageTransportAction.initializeStats(docCountStats);
        MachineLearningUsageTransportAction.initializeStats(timeStats);
        MachineLearningUsageTransportAction.initializeStats(failureStats);
        response.getNodes().stream().map(NodeStats::getIngestStats).map(IngestStats::getProcessorStats).forEach(map -> map.forEach((pipelineId, processors) -> {
            boolean containsInference = false;
            for (IngestStats.ProcessorStat stats : processors) {
                if (!stats.getName().equals("inference")) continue;
                containsInference = true;
                long ingestCount = stats.getStats().getIngestCount();
                long ingestTime = stats.getStats().getIngestTimeInMillis();
                long failureCount = stats.getStats().getIngestFailedCount();
                MachineLearningUsageTransportAction.updateStats(docCountStats, ingestCount);
                MachineLearningUsageTransportAction.updateStats(timeStats, ingestTime);
                MachineLearningUsageTransportAction.updateStats(failureStats, failureCount);
            }
            if (containsInference) {
                pipelines.add(pipelineId);
            }
        }));
        HashMap<String, Map<String, Object>> ingestUsage = new HashMap<String, Map<String, Object>>(6);
        ingestUsage.put("pipelines", this.createCountUsageEntry(pipelines.size()));
        ingestUsage.put("num_docs_processed", docCountStats);
        ingestUsage.put("time_ms", timeStats);
        ingestUsage.put("num_failures", failureStats);
        inferenceUsage.put("ingest_processors", Collections.singletonMap("_all", ingestUsage));
    }

    private static int mlNodeCount(ClusterState clusterState) {
        int mlNodeCount = 0;
        for (DiscoveryNode node : clusterState.getNodes()) {
            if (!MachineLearning.isMlNode(node)) continue;
            ++mlNodeCount;
        }
        return mlNodeCount;
    }

    private static String[] ingestNodes(ClusterState clusterState) {
        return (String[])clusterState.nodes().getIngestNodes().keySet().toArray(String[]::new);
    }
}

