/*
 * Decompiled with CFR 0.152.
 */
package org.elasticsearch.xpack.ml.action;

import java.util.List;
import java.util.Objects;
import java.util.Optional;
import java.util.concurrent.Executor;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.action.ActionListenerResponseHandler;
import org.elasticsearch.action.support.ActionFilters;
import org.elasticsearch.action.support.ContextPreservingActionListener;
import org.elasticsearch.action.support.HandledTransportAction;
import org.elasticsearch.client.internal.Client;
import org.elasticsearch.client.internal.ParentTaskAssigningClient;
import org.elasticsearch.client.internal.node.NodeClient;
import org.elasticsearch.cluster.ClusterState;
import org.elasticsearch.cluster.node.DiscoveryNode;
import org.elasticsearch.cluster.node.DiscoveryNodes;
import org.elasticsearch.cluster.service.ClusterService;
import org.elasticsearch.common.inject.Inject;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.common.unit.ByteSizeValue;
import org.elasticsearch.common.util.concurrent.EsExecutors;
import org.elasticsearch.common.util.concurrent.ThreadContext;
import org.elasticsearch.core.Tuple;
import org.elasticsearch.license.LicenseUtils;
import org.elasticsearch.license.XPackLicenseState;
import org.elasticsearch.tasks.Task;
import org.elasticsearch.tasks.TaskId;
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.transport.TransportRequest;
import org.elasticsearch.transport.TransportResponseHandler;
import org.elasticsearch.transport.TransportService;
import org.elasticsearch.xpack.core.ClientHelper;
import org.elasticsearch.xpack.core.XPackSettings;
import org.elasticsearch.xpack.core.ml.MachineLearningField;
import org.elasticsearch.xpack.core.ml.action.ExplainDataFrameAnalyticsAction;
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig;
import org.elasticsearch.xpack.core.ml.dataframe.explain.FieldSelection;
import org.elasticsearch.xpack.core.ml.dataframe.explain.MemoryEstimation;
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
import org.elasticsearch.xpack.core.security.SecurityContext;
import org.elasticsearch.xpack.ml.MachineLearning;
import org.elasticsearch.xpack.ml.dataframe.extractor.DataFrameDataExtractorFactory;
import org.elasticsearch.xpack.ml.dataframe.extractor.ExtractedFieldsDetector;
import org.elasticsearch.xpack.ml.dataframe.extractor.ExtractedFieldsDetectorFactory;
import org.elasticsearch.xpack.ml.dataframe.process.MemoryUsageEstimationProcessManager;
import org.elasticsearch.xpack.ml.dataframe.process.results.MemoryUsageEstimationResult;
import org.elasticsearch.xpack.ml.extractor.ExtractedFields;
import org.elasticsearch.xpack.ml.utils.SecondaryAuthorizationUtils;

public class TransportExplainDataFrameAnalyticsAction
extends HandledTransportAction<ExplainDataFrameAnalyticsAction.Request, ExplainDataFrameAnalyticsAction.Response> {
    private static final Logger logger = LogManager.getLogger(TransportExplainDataFrameAnalyticsAction.class);
    private final XPackLicenseState licenseState;
    private final TransportService transportService;
    private final ClusterService clusterService;
    private final NodeClient client;
    private final MemoryUsageEstimationProcessManager processManager;
    private final SecurityContext securityContext;
    private final ThreadPool threadPool;
    private final Settings settings;

    @Inject
    public TransportExplainDataFrameAnalyticsAction(TransportService transportService, ActionFilters actionFilters, ClusterService clusterService, NodeClient client, XPackLicenseState licenseState, MemoryUsageEstimationProcessManager processManager, Settings settings, ThreadPool threadPool) {
        super("cluster:admin/xpack/ml/data_frame/analytics/explain", transportService, actionFilters, ExplainDataFrameAnalyticsAction.Request::new, (Executor)EsExecutors.DIRECT_EXECUTOR_SERVICE);
        this.transportService = transportService;
        this.clusterService = Objects.requireNonNull(clusterService);
        this.client = Objects.requireNonNull(client);
        this.licenseState = licenseState;
        this.processManager = Objects.requireNonNull(processManager);
        this.threadPool = threadPool;
        this.settings = settings;
        this.securityContext = (Boolean)XPackSettings.SECURITY_ENABLED.get(settings) != false ? new SecurityContext(settings, threadPool.getThreadContext()) : null;
    }

    protected void doExecute(Task task, ExplainDataFrameAnalyticsAction.Request request, ActionListener<ExplainDataFrameAnalyticsAction.Response> listener) {
        if (!MachineLearningField.ML_API_FEATURE.check(this.licenseState)) {
            listener.onFailure((Exception)LicenseUtils.newComplianceException((String)"ml"));
            return;
        }
        DiscoveryNode localNode = this.clusterService.localNode();
        boolean isMlNode = MachineLearning.isMlNode(localNode);
        if (isMlNode || localNode.isMasterNode() || localNode.canContainData() || localNode.isIngestNode()) {
            if (!isMlNode) {
                logger.debug("estimating data frame analytics memory on non-ML node");
            }
            this.explain(task, request, listener);
        } else {
            this.redirectToSuitableNode(request, listener);
        }
    }

    private void explain(Task task, ExplainDataFrameAnalyticsAction.Request request, ActionListener<ExplainDataFrameAnalyticsAction.Response> listener) {
        TaskId parentTaskId = new TaskId(this.clusterService.localNode().getId(), task.getId());
        ExtractedFieldsDetectorFactory extractedFieldsDetectorFactory = new ExtractedFieldsDetectorFactory((Client)new ParentTaskAssigningClient((Client)this.client, parentTaskId));
        if (((Boolean)XPackSettings.SECURITY_ENABLED.get(this.settings)).booleanValue()) {
            SecondaryAuthorizationUtils.useSecondaryAuthIfAvailable(this.securityContext, () -> {
                DataFrameAnalyticsConfig config = new DataFrameAnalyticsConfig.Builder(request.getConfig()).setHeaders(ClientHelper.getPersistableSafeSecurityHeaders((ThreadContext)this.threadPool.getThreadContext(), (ClusterState)this.clusterService.state())).build();
                extractedFieldsDetectorFactory.createFromSource(config, (ActionListener<ExtractedFieldsDetector>)ActionListener.wrap(extractedFieldsDetector -> this.explain(parentTaskId, config, (ExtractedFieldsDetector)extractedFieldsDetector, listener), arg_0 -> ((ActionListener)listener).onFailure(arg_0)));
            });
        } else {
            ContextPreservingActionListener responseHeaderPreservingListener = ContextPreservingActionListener.wrapPreservingContext(listener, (ThreadContext)this.threadPool.getThreadContext());
            extractedFieldsDetectorFactory.createFromSource(request.getConfig(), (ActionListener<ExtractedFieldsDetector>)ActionListener.wrap(extractedFieldsDetector -> this.explain(parentTaskId, request.getConfig(), (ExtractedFieldsDetector)extractedFieldsDetector, (ActionListener<ExplainDataFrameAnalyticsAction.Response>)responseHeaderPreservingListener), arg_0 -> ((ContextPreservingActionListener)responseHeaderPreservingListener).onFailure(arg_0)));
        }
    }

    private void explain(TaskId parentTaskId, DataFrameAnalyticsConfig config, ExtractedFieldsDetector extractedFieldsDetector, ActionListener<ExplainDataFrameAnalyticsAction.Response> listener) {
        Tuple<ExtractedFields, List<FieldSelection>> fieldExtraction = extractedFieldsDetector.detect();
        if (((ExtractedFields)fieldExtraction.v1()).getAllFields().isEmpty()) {
            listener.onResponse((Object)new ExplainDataFrameAnalyticsAction.Response((List)fieldExtraction.v2(), new MemoryEstimation(ByteSizeValue.ZERO, ByteSizeValue.ZERO)));
            return;
        }
        ActionListener memoryEstimationListener = ActionListener.wrap(memoryEstimation -> listener.onResponse((Object)new ExplainDataFrameAnalyticsAction.Response((List)fieldExtraction.v2(), memoryEstimation)), arg_0 -> listener.onFailure(arg_0));
        this.estimateMemoryUsage(parentTaskId, config, (ExtractedFields)fieldExtraction.v1(), (ActionListener<MemoryEstimation>)memoryEstimationListener);
    }

    private void estimateMemoryUsage(TaskId parentTaskId, DataFrameAnalyticsConfig config, ExtractedFields extractedFields, ActionListener<MemoryEstimation> listener) {
        String estimateMemoryTaskId = "memory_usage_estimation_" + parentTaskId.getId();
        DataFrameDataExtractorFactory extractorFactory = DataFrameDataExtractorFactory.createForSourceIndices((Client)new ParentTaskAssigningClient((Client)this.client, parentTaskId), estimateMemoryTaskId, config, extractedFields);
        this.processManager.runJobAsync(estimateMemoryTaskId, config, extractorFactory, (ActionListener<MemoryUsageEstimationResult>)ActionListener.wrap(result -> listener.onResponse((Object)new MemoryEstimation(result.getExpectedMemoryWithoutDisk(), result.getExpectedMemoryWithDisk())), arg_0 -> listener.onFailure(arg_0)));
    }

    private void redirectToSuitableNode(ExplainDataFrameAnalyticsAction.Request request, ActionListener<ExplainDataFrameAnalyticsAction.Response> listener) {
        Optional<DiscoveryNode> node = TransportExplainDataFrameAnalyticsAction.findSuitableNode(this.clusterService.state());
        if (node.isPresent()) {
            this.transportService.sendRequest(node.get(), this.actionName, (TransportRequest)request, (TransportResponseHandler)new ActionListenerResponseHandler(listener, ExplainDataFrameAnalyticsAction.Response::new, TransportResponseHandler.TRANSPORT_WORKER));
        } else {
            listener.onFailure((Exception)((Object)ExceptionsHelper.badRequestException((String)"No ML, data or ingest node to run on", (Object[])new Object[0])));
        }
    }

    private static Optional<DiscoveryNode> findSuitableNode(ClusterState clusterState) {
        DiscoveryNodes nodes = clusterState.getNodes();
        for (DiscoveryNode node : nodes) {
            if (!MachineLearning.isMlNode(node)) continue;
            return Optional.of(node);
        }
        DiscoveryNode currentMaster = null;
        for (DiscoveryNode node : nodes) {
            if (!node.isMasterNode()) continue;
            if (node.getId().equals(nodes.getMasterNodeId())) {
                currentMaster = node;
                continue;
            }
            return Optional.of(node);
        }
        return Optional.ofNullable(currentMaster);
    }
}

