/*
 * Decompiled with CFR 0.152.
 */
package org.elasticsearch.xpack.ml.inference.allocation;

import java.util.Collections;
import java.util.Comparator;
import java.util.Locale;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.TreeMap;
import java.util.function.Function;
import java.util.stream.Collectors;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.apache.logging.log4j.message.ParameterizedMessage;
import org.elasticsearch.ElasticsearchStatusException;
import org.elasticsearch.ResourceAlreadyExistsException;
import org.elasticsearch.ResourceNotFoundException;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.action.support.master.AcknowledgedResponse;
import org.elasticsearch.cluster.ClusterChangedEvent;
import org.elasticsearch.cluster.ClusterState;
import org.elasticsearch.cluster.ClusterStateListener;
import org.elasticsearch.cluster.ClusterStateTaskConfig;
import org.elasticsearch.cluster.ClusterStateTaskExecutor;
import org.elasticsearch.cluster.ClusterStateUpdateTask;
import org.elasticsearch.cluster.metadata.Metadata;
import org.elasticsearch.cluster.metadata.NodesShutdownMetadata;
import org.elasticsearch.cluster.node.DiscoveryNode;
import org.elasticsearch.cluster.node.DiscoveryNodes;
import org.elasticsearch.cluster.service.ClusterService;
import org.elasticsearch.common.Strings;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.common.unit.ByteSizeValue;
import org.elasticsearch.common.util.set.Sets;
import org.elasticsearch.gateway.GatewayService;
import org.elasticsearch.rest.RestStatus;
import org.elasticsearch.xcontent.ToXContent;
import org.elasticsearch.xpack.core.ml.MlMetadata;
import org.elasticsearch.xpack.core.ml.action.StartTrainedModelDeploymentAction;
import org.elasticsearch.xpack.core.ml.action.UpdateTrainedModelAllocationStateAction;
import org.elasticsearch.xpack.core.ml.inference.allocation.AllocationState;
import org.elasticsearch.xpack.core.ml.inference.allocation.RoutingState;
import org.elasticsearch.xpack.core.ml.inference.allocation.TrainedModelAllocation;
import org.elasticsearch.xpack.ml.MachineLearning;
import org.elasticsearch.xpack.ml.inference.allocation.TrainedModelAllocationMetadata;
import org.elasticsearch.xpack.ml.job.NodeLoad;
import org.elasticsearch.xpack.ml.job.NodeLoadDetector;

public class TrainedModelAllocationClusterService
implements ClusterStateListener {
    private static final Logger logger = LogManager.getLogger(TrainedModelAllocationClusterService.class);
    private final ClusterService clusterService;
    private final NodeLoadDetector nodeLoadDetector;
    private volatile int maxMemoryPercentage;
    private volatile boolean useAuto;
    private volatile int maxOpenJobs;

    public TrainedModelAllocationClusterService(Settings settings, ClusterService clusterService, NodeLoadDetector nodeLoadDetector) {
        this.clusterService = clusterService;
        this.nodeLoadDetector = nodeLoadDetector;
        this.maxMemoryPercentage = (Integer)MachineLearning.MAX_MACHINE_MEMORY_PERCENT.get(settings);
        this.useAuto = (Boolean)MachineLearning.USE_AUTO_MACHINE_MEMORY_PERCENT.get(settings);
        this.maxOpenJobs = (Integer)MachineLearning.MAX_OPEN_JOBS_PER_NODE.get(settings);
        if (DiscoveryNode.isMasterNode((Settings)settings)) {
            clusterService.addListener((ClusterStateListener)this);
            clusterService.getClusterSettings().addSettingsUpdateConsumer(MachineLearning.MAX_MACHINE_MEMORY_PERCENT, this::setMaxMemoryPercentage);
            clusterService.getClusterSettings().addSettingsUpdateConsumer(MachineLearning.USE_AUTO_MACHINE_MEMORY_PERCENT, this::setUseAuto);
            clusterService.getClusterSettings().addSettingsUpdateConsumer(MachineLearning.MAX_OPEN_JOBS_PER_NODE, this::setMaxOpenJobs);
        }
    }

    private void setMaxMemoryPercentage(int maxMemoryPercentage) {
        this.maxMemoryPercentage = maxMemoryPercentage;
    }

    private void setUseAuto(boolean useAuto) {
        this.useAuto = useAuto;
    }

    private void setMaxOpenJobs(int maxOpenJobs) {
        this.maxOpenJobs = maxOpenJobs;
    }

    public void clusterChanged(ClusterChangedEvent event) {
        if (event.state().blocks().hasGlobalBlock(GatewayService.STATE_NOT_RECOVERED_BLOCK)) {
            return;
        }
        if (event.localNodeMaster() && TrainedModelAllocationClusterService.shouldAllocateModels(event)) {
            this.clusterService.submitStateUpdateTask("allocating models to nodes", (ClusterStateTaskConfig)new ClusterStateUpdateTask(){

                public ClusterState execute(ClusterState currentState) {
                    return TrainedModelAllocationClusterService.this.addRemoveAllocationNodes(currentState);
                }

                public void onFailure(Exception e) {
                    logger.warn("failed to allocate models", (Throwable)e);
                }

                public void clusterStateProcessed(ClusterState oldState, ClusterState newState) {
                    logger.trace(() -> new ParameterizedMessage("updated model allocations based on node changes in the cluster; new metadata [{}]", (Object)Strings.toString((ToXContent)TrainedModelAllocationMetadata.fromState(newState), (boolean)false, (boolean)true)));
                }
            }, ClusterStateTaskExecutor.unbatched());
        }
    }

    public void updateModelRoutingTable(final UpdateTrainedModelAllocationStateAction.Request request, final ActionListener<AcknowledgedResponse> listener) {
        this.clusterService.submitStateUpdateTask("updating model routing for node allocation", (ClusterStateTaskConfig)new ClusterStateUpdateTask(){

            public ClusterState execute(ClusterState currentState) {
                return TrainedModelAllocationClusterService.updateModelRoutingTable(currentState, request);
            }

            public void onFailure(Exception e) {
                listener.onFailure(e);
            }

            public void clusterStateProcessed(ClusterState oldState, ClusterState newState) {
                listener.onResponse((Object)AcknowledgedResponse.TRUE);
            }
        }, ClusterStateTaskExecutor.unbatched());
    }

    public void createNewModelAllocation(final StartTrainedModelDeploymentAction.TaskParams params, final ActionListener<TrainedModelAllocation> listener) {
        this.clusterService.submitStateUpdateTask("create model allocation", (ClusterStateTaskConfig)new ClusterStateUpdateTask(){

            public ClusterState execute(ClusterState currentState) {
                return TrainedModelAllocationClusterService.this.createModelAllocation(currentState, params);
            }

            public void onFailure(Exception e) {
                listener.onFailure(e);
            }

            public void clusterStateProcessed(ClusterState oldState, ClusterState newState) {
                listener.onResponse((Object)TrainedModelAllocationMetadata.fromState(newState).getModelAllocation(params.getModelId()));
            }
        }, ClusterStateTaskExecutor.unbatched());
    }

    public void setModelAllocationToStopping(final String modelId, final ActionListener<AcknowledgedResponse> listener) {
        this.clusterService.submitStateUpdateTask("set model allocation stopping", (ClusterStateTaskConfig)new ClusterStateUpdateTask(){

            public ClusterState execute(ClusterState currentState) {
                return TrainedModelAllocationClusterService.setToStopping(currentState, modelId, "client API call");
            }

            public void onFailure(Exception e) {
                listener.onFailure(e);
            }

            public void clusterStateProcessed(ClusterState oldState, ClusterState newState) {
                listener.onResponse((Object)AcknowledgedResponse.TRUE);
            }
        }, ClusterStateTaskExecutor.unbatched());
    }

    public void removeModelAllocation(final String modelId, final ActionListener<AcknowledgedResponse> listener) {
        this.clusterService.submitStateUpdateTask("delete model allocation", (ClusterStateTaskConfig)new ClusterStateUpdateTask(){

            public ClusterState execute(ClusterState currentState) {
                return TrainedModelAllocationClusterService.removeAllocation(currentState, modelId);
            }

            public void onFailure(Exception e) {
                listener.onFailure(e);
            }

            public void clusterStateProcessed(ClusterState oldState, ClusterState newState) {
                listener.onResponse((Object)AcknowledgedResponse.TRUE);
            }
        }, ClusterStateTaskExecutor.unbatched());
    }

    public void removeAllModelAllocations(final ActionListener<AcknowledgedResponse> listener) {
        this.clusterService.submitStateUpdateTask("delete all model allocations", (ClusterStateTaskConfig)new ClusterStateUpdateTask(){

            public ClusterState execute(ClusterState currentState) {
                return TrainedModelAllocationClusterService.removeAllAllocations(currentState);
            }

            public void onFailure(Exception e) {
                listener.onFailure(e);
            }

            public void clusterStateProcessed(ClusterState oldState, ClusterState newState) {
                listener.onResponse((Object)AcknowledgedResponse.TRUE);
            }
        }, ClusterStateTaskExecutor.unbatched());
    }

    private static ClusterState update(ClusterState currentState, TrainedModelAllocationMetadata.Builder modelAllocations) {
        if (modelAllocations.isChanged()) {
            return ClusterState.builder((ClusterState)currentState).metadata(Metadata.builder((Metadata)currentState.metadata()).putCustom("trained_model_allocation", (Metadata.Custom)modelAllocations.build())).build();
        }
        return currentState;
    }

    ClusterState createModelAllocation(ClusterState currentState, StartTrainedModelDeploymentAction.TaskParams params) {
        if (MlMetadata.getMlMetadata((ClusterState)currentState).isResetMode()) {
            throw new ElasticsearchStatusException("cannot create new allocation for model [{}] while feature reset is in progress.", RestStatus.CONFLICT, new Object[]{params.getModelId()});
        }
        TrainedModelAllocationMetadata.Builder builder = TrainedModelAllocationMetadata.builder(currentState);
        if (builder.hasModel(params.getModelId())) {
            throw new ResourceAlreadyExistsException("allocation for model with id [{}] already exist", new Object[]{params.getModelId()});
        }
        TrainedModelAllocation.Builder allocationBuilder = TrainedModelAllocation.Builder.empty((StartTrainedModelDeploymentAction.TaskParams)params);
        Set<String> shuttingDownNodes = TrainedModelAllocationClusterService.nodesShuttingDown(currentState);
        TreeMap<String, String> nodeToReason = new TreeMap<String, String>();
        for (DiscoveryNode node : currentState.getNodes().getAllNodes()) {
            if (!StartTrainedModelDeploymentAction.TaskParams.mayAllocateToNode((DiscoveryNode)node) || shuttingDownNodes.contains(node.getId())) continue;
            Optional<String> maybeError = this.nodeHasCapacity(currentState, params, node);
            if (maybeError.isPresent()) {
                nodeToReason.put(node.getName(), maybeError.get());
                continue;
            }
            allocationBuilder.addNewRoutingEntry(node.getId());
        }
        if (!nodeToReason.isEmpty()) {
            allocationBuilder.setReason(nodeToReason.entrySet().stream().map(entry -> String.format(Locale.ROOT, "Not allocating on node [%s]. Reason: %s", entry.getKey(), entry.getValue())).collect(Collectors.joining("|")));
        }
        builder.addNewAllocation(params.getModelId(), allocationBuilder);
        return TrainedModelAllocationClusterService.update(currentState, builder);
    }

    static ClusterState setToStopping(ClusterState clusterState, String modelId, String reason) {
        TrainedModelAllocationMetadata metadata = TrainedModelAllocationMetadata.fromState(clusterState);
        TrainedModelAllocation existingAllocation = metadata.getModelAllocation(modelId);
        if (existingAllocation == null) {
            throw new ResourceNotFoundException("allocation for model with id [{}] not found", new Object[]{modelId});
        }
        if (existingAllocation.getAllocationState().equals((Object)AllocationState.STOPPING)) {
            return clusterState;
        }
        TrainedModelAllocationMetadata.Builder builder = TrainedModelAllocationMetadata.builder(clusterState);
        builder.getAllocation(modelId).stopAllocation(reason);
        return TrainedModelAllocationClusterService.update(clusterState, builder);
    }

    static ClusterState updateModelRoutingTable(ClusterState currentState, UpdateTrainedModelAllocationStateAction.Request request) {
        String modelId = request.getModelId();
        String nodeId = request.getNodeId();
        TrainedModelAllocationMetadata metadata = TrainedModelAllocationMetadata.fromState(currentState);
        logger.trace(() -> new ParameterizedMessage("[{}] [{}] current metadata before update {}", new Object[]{modelId, nodeId, Strings.toString((ToXContent)metadata)}));
        TrainedModelAllocation existingAllocation = metadata.getModelAllocation(modelId);
        TrainedModelAllocationMetadata.Builder builder = TrainedModelAllocationMetadata.builder(currentState);
        if (request.getRoutingState().getState().equals((Object)RoutingState.STOPPED)) {
            if (existingAllocation == null || !existingAllocation.isRoutedToNode(nodeId)) {
                return currentState;
            }
            builder.getAllocation(modelId).removeRoutingEntry(nodeId).calculateAndSetAllocationState();
            return TrainedModelAllocationClusterService.update(currentState, builder);
        }
        if (existingAllocation == null) {
            throw new ResourceNotFoundException("allocation for model with id [{}] not found", new Object[]{modelId});
        }
        if (existingAllocation.getAllocationState().equals((Object)AllocationState.STOPPING)) {
            logger.debug(() -> new ParameterizedMessage("[{}] requested update from node [{}] to update route state to [{}]", new Object[]{modelId, nodeId, request.getRoutingState()}));
            return currentState;
        }
        if (!existingAllocation.isRoutedToNode(nodeId)) {
            throw new ResourceNotFoundException("allocation for model with id [{}]] is not routed to node [{}]", new Object[]{modelId, nodeId});
        }
        builder.getAllocation(modelId).updateExistingRoutingEntry(nodeId, request.getRoutingState()).calculateAndSetAllocationState();
        return TrainedModelAllocationClusterService.update(currentState, builder);
    }

    static ClusterState removeAllocation(ClusterState currentState, String modelId) {
        TrainedModelAllocationMetadata.Builder builder = TrainedModelAllocationMetadata.builder(currentState);
        if (!builder.hasModel(modelId)) {
            throw new ResourceNotFoundException("allocation for model with id [{}] not found", new Object[]{modelId});
        }
        return TrainedModelAllocationClusterService.update(currentState, builder.removeAllocation(modelId));
    }

    static ClusterState removeAllAllocations(ClusterState currentState) {
        if (TrainedModelAllocationMetadata.fromState(currentState).modelAllocations().isEmpty()) {
            return currentState;
        }
        return ClusterState.builder((ClusterState)currentState).metadata(Metadata.builder((Metadata)currentState.metadata()).putCustom("trained_model_allocation", (Metadata.Custom)TrainedModelAllocationMetadata.Builder.empty().build()).build()).build();
    }

    ClusterState addRemoveAllocationNodes(ClusterState currentState) {
        TrainedModelAllocationMetadata previousState = TrainedModelAllocationMetadata.fromState(currentState);
        TrainedModelAllocationMetadata.Builder builder = TrainedModelAllocationMetadata.builder(currentState);
        Set<String> shuttingDownNodes = TrainedModelAllocationClusterService.nodesShuttingDown(currentState);
        Map currentEligibleNodes = currentState.getNodes().getAllNodes().stream().filter(node -> !shuttingDownNodes.contains(node.getId()) && StartTrainedModelDeploymentAction.TaskParams.mayAllocateToNode((DiscoveryNode)node)).collect(Collectors.toMap(DiscoveryNode::getId, Function.identity()));
        previousState.modelAllocations().entrySet().stream().filter(entry -> !((TrainedModelAllocation)entry.getValue()).getAllocationState().equals((Object)AllocationState.STOPPING)).sorted(Comparator.comparing(e -> ((TrainedModelAllocation)e.getValue()).getNodeRoutingTable().size())).forEach(modelAllocationEntry -> {
            String modelId = (String)modelAllocationEntry.getKey();
            TreeMap<String, String> nodeToReason = new TreeMap<String, String>();
            for (DiscoveryNode node : currentEligibleNodes.values()) {
                Optional<String> failure;
                if (((TrainedModelAllocation)modelAllocationEntry.getValue()).isRoutedToNode(node.getId())) continue;
                Optional<String> optional = failure = builder.isChanged() ? this.nodeHasCapacity(currentState, builder, ((TrainedModelAllocation)modelAllocationEntry.getValue()).getTaskParams(), node) : this.nodeHasCapacity(currentState, ((TrainedModelAllocation)modelAllocationEntry.getValue()).getTaskParams(), node);
                if (failure.isPresent()) {
                    nodeToReason.put(node.getName(), failure.get());
                    continue;
                }
                builder.getAllocation(modelId).addNewRoutingEntry(node.getId());
            }
            if (!nodeToReason.isEmpty()) {
                builder.getAllocation(modelId).setReason(nodeToReason.entrySet().stream().map(entry -> String.format(Locale.ROOT, "Not allocating on node [%s]. Reason: %s", entry.getKey(), entry.getValue())).collect(Collectors.joining("|")));
            } else {
                builder.getAllocation(modelId).clearReason();
            }
            for (String nodeId : ((TrainedModelAllocation)modelAllocationEntry.getValue()).getNodeRoutingTable().keySet()) {
                if (currentEligibleNodes.containsKey(nodeId)) continue;
                builder.getAllocation(modelId).removeRoutingEntry(nodeId);
            }
            builder.getAllocation(modelId).calculateAndSetAllocationState();
        });
        return TrainedModelAllocationClusterService.update(currentState, builder);
    }

    static boolean shouldAllocateModels(ClusterChangedEvent event) {
        TrainedModelAllocationMetadata newMetadata = (TrainedModelAllocationMetadata)event.state().getMetadata().custom("trained_model_allocation");
        if (newMetadata == null) {
            return false;
        }
        boolean nodesShutdownChanged = event.changedCustomMetadataSet().contains("node_shutdown");
        if (event.nodesChanged() || nodesShutdownChanged) {
            Set exitingShutDownNodes;
            Set<String> shuttingDownNodes = TrainedModelAllocationClusterService.nodesShuttingDown(event.state());
            DiscoveryNodes.Delta nodesDelta = event.nodesDelta();
            Set removedNodes = nodesDelta.removedNodes().stream().map(DiscoveryNode::getId).collect(Collectors.toSet());
            Set addedNodes = nodesDelta.addedNodes().stream().map(DiscoveryNode::getId).collect(Collectors.toSet());
            if (nodesShutdownChanged) {
                Set<String> previousShuttingDownNodes = TrainedModelAllocationClusterService.nodesShuttingDown(event.previousState());
                Set returningShutDownNodes = Sets.difference(previousShuttingDownNodes, shuttingDownNodes);
                addedNodes.addAll(returningShutDownNodes);
                exitingShutDownNodes = Sets.difference(shuttingDownNodes, previousShuttingDownNodes);
                removedNodes.addAll(exitingShutDownNodes);
            } else {
                exitingShutDownNodes = Collections.emptySet();
            }
            for (TrainedModelAllocation trainedModelAllocation : newMetadata.modelAllocations().values()) {
                if (trainedModelAllocation.getAllocationState().equals((Object)AllocationState.STOPPING)) continue;
                for (String nodeId : exitingShutDownNodes) {
                    if (!trainedModelAllocation.isRoutedToNode(nodeId)) continue;
                    return true;
                }
                for (String nodeId : removedNodes) {
                    if (!trainedModelAllocation.isRoutedToNode(nodeId) || shuttingDownNodes.contains(nodeId)) continue;
                    return true;
                }
                for (String nodeId : addedNodes) {
                    if (!StartTrainedModelDeploymentAction.TaskParams.mayAllocateToNode((DiscoveryNode)event.state().nodes().get(nodeId)) || shuttingDownNodes.contains(nodeId)) continue;
                    return true;
                }
            }
        }
        return false;
    }

    Optional<String> nodeHasCapacity(ClusterState state, StartTrainedModelDeploymentAction.TaskParams params, DiscoveryNode node) {
        NodeLoad load = this.nodeLoadDetector.detectNodeLoad(state, node, this.maxOpenJobs, this.maxMemoryPercentage, this.useAuto);
        return this.handleNodeLoad(load, node.getId(), params);
    }

    Optional<String> nodeHasCapacity(ClusterState state, TrainedModelAllocationMetadata.Builder builder, StartTrainedModelDeploymentAction.TaskParams params, DiscoveryNode node) {
        NodeLoad load = this.nodeLoadDetector.detectNodeLoad(state, builder.build(), node, this.maxOpenJobs, this.maxMemoryPercentage, this.useAuto);
        return this.handleNodeLoad(load, node.getId(), params);
    }

    Optional<String> handleNodeLoad(NodeLoad load, String nodeId, StartTrainedModelDeploymentAction.TaskParams params) {
        if (!Strings.isNullOrEmpty((String)load.getError())) {
            logger.warn("[{}] failed to calculate current node load with error [{}]", (Object)params.getModelId(), (Object)nodeId);
            return Optional.of(load.getError());
        }
        if (load.remainingJobs() == 0) {
            return Optional.of(ParameterizedMessage.format((String)"This node is full. Number of opened jobs and allocated native inference processes [{}], {} [{}].", (Object[])new Object[]{load.getNumAssignedJobs(), MachineLearning.MAX_OPEN_JOBS_PER_NODE.getKey(), this.maxOpenJobs}));
        }
        if (load.getFreeMemory() < params.estimateMemoryUsageBytes()) {
            return Optional.of(ParameterizedMessage.format((String)"This node has insufficient available memory. Available memory for ML [{} ({})], memory required by existing jobs and models [{} ({})], estimated memory required for this model [{} ({})].", (Object[])new Object[]{load.getMaxMlMemory(), ByteSizeValue.ofBytes((long)load.getMaxMlMemory()).toString(), load.getAssignedJobMemory(), ByteSizeValue.ofBytes((long)load.getAssignedJobMemory()).toString(), params.estimateMemoryUsageBytes(), ByteSizeValue.ofBytes((long)params.estimateMemoryUsageBytes()).toString()}));
        }
        return Optional.empty();
    }

    static Set<String> nodesShuttingDown(ClusterState state) {
        return NodesShutdownMetadata.getShutdowns((ClusterState)state).map(NodesShutdownMetadata::getAllNodeMetadataMap).map(Map::keySet).orElse(Collections.emptySet());
    }
}

