/*
 * Decompiled with CFR 0.152.
 */
package org.elasticsearch.xpack.ml.inference.allocation;

import java.util.Objects;
import java.util.function.Predicate;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.elasticsearch.ExceptionsHelper;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.action.ActionRequest;
import org.elasticsearch.action.ActionType;
import org.elasticsearch.action.support.master.AcknowledgedResponse;
import org.elasticsearch.client.internal.Client;
import org.elasticsearch.client.internal.OriginSettingClient;
import org.elasticsearch.cluster.ClusterState;
import org.elasticsearch.cluster.ClusterStateObserver;
import org.elasticsearch.cluster.MasterNodeChangePredicate;
import org.elasticsearch.cluster.NotMasterException;
import org.elasticsearch.cluster.coordination.FailedToCommitClusterStateException;
import org.elasticsearch.cluster.node.DiscoveryNode;
import org.elasticsearch.cluster.service.ClusterService;
import org.elasticsearch.core.Nullable;
import org.elasticsearch.core.TimeValue;
import org.elasticsearch.node.NodeClosedException;
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.transport.ConnectTransportException;
import org.elasticsearch.xpack.core.ml.action.CreateTrainedModelAllocationAction;
import org.elasticsearch.xpack.core.ml.action.DeleteTrainedModelAllocationAction;
import org.elasticsearch.xpack.core.ml.action.StartTrainedModelDeploymentAction;
import org.elasticsearch.xpack.core.ml.action.UpdateTrainedModelAllocationStateAction;
import org.elasticsearch.xpack.core.ml.inference.allocation.TrainedModelAllocation;
import org.elasticsearch.xpack.ml.inference.allocation.TrainedModelAllocationMetadata;

public class TrainedModelAllocationService {
    private static final Logger logger = LogManager.getLogger(TrainedModelAllocationService.class);
    private final Client client;
    private final ClusterService clusterService;
    private final ThreadPool threadPool;
    private static final Class<?>[] MASTER_CHANNEL_EXCEPTIONS = new Class[]{NotMasterException.class, ConnectTransportException.class, FailedToCommitClusterStateException.class};

    public TrainedModelAllocationService(Client client, ClusterService clusterService, ThreadPool threadPool) {
        this.client = new OriginSettingClient(client, "ml");
        this.clusterService = Objects.requireNonNull(clusterService);
        this.threadPool = Objects.requireNonNull(threadPool);
    }

    public void updateModelAllocationState(UpdateTrainedModelAllocationStateAction.Request request, ActionListener<AcknowledgedResponse> listener) {
        ClusterState currentState = this.clusterService.state();
        ClusterStateObserver observer = new ClusterStateObserver(currentState, this.clusterService, null, logger, this.threadPool.getThreadContext());
        Predicate changePredicate = MasterNodeChangePredicate.build((ClusterState)currentState);
        DiscoveryNode masterNode = currentState.nodes().getMasterNode();
        if (masterNode == null) {
            logger.warn("[{}] no master known for allocation state update [{}]", (Object)request.getModelId(), (Object)request.getRoutingState().getState());
            this.waitForNewMasterAndRetry(observer, (ActionType<AcknowledgedResponse>)UpdateTrainedModelAllocationStateAction.INSTANCE, (ActionRequest)request, listener, changePredicate);
            return;
        }
        this.client.execute((ActionType)UpdateTrainedModelAllocationStateAction.INSTANCE, (ActionRequest)request, ActionListener.wrap(arg_0 -> listener.onResponse(arg_0), failure -> {
            if (TrainedModelAllocationService.isMasterChannelException(failure)) {
                logger.info("[{}] master channel exception will retry on new master node for allocation state update [{}]", (Object)request.getModelId(), (Object)request.getRoutingState().getState());
                this.waitForNewMasterAndRetry(observer, (ActionType<AcknowledgedResponse>)UpdateTrainedModelAllocationStateAction.INSTANCE, (ActionRequest)request, listener, changePredicate);
                return;
            }
            listener.onFailure(failure);
        }));
    }

    public void createNewModelAllocation(StartTrainedModelDeploymentAction.TaskParams taskParams, ActionListener<CreateTrainedModelAllocationAction.Response> listener) {
        this.client.execute((ActionType)CreateTrainedModelAllocationAction.INSTANCE, (ActionRequest)new CreateTrainedModelAllocationAction.Request(taskParams), listener);
    }

    public void deleteModelAllocation(String modelId, ActionListener<AcknowledgedResponse> listener) {
        this.client.execute((ActionType)DeleteTrainedModelAllocationAction.INSTANCE, (ActionRequest)new DeleteTrainedModelAllocationAction.Request(modelId), listener);
    }

    public void waitForAllocationCondition(final String modelId, Predicate<ClusterState> predicate, @Nullable TimeValue timeout, final WaitForAllocationListener listener) {
        ClusterStateObserver observer = new ClusterStateObserver(this.clusterService, timeout, logger, this.threadPool.getThreadContext());
        ClusterState clusterState = observer.setAndGetObservedState();
        if (predicate.test(clusterState)) {
            listener.onResponse(TrainedModelAllocationMetadata.allocationForModelId(clusterState, modelId).orElse(null));
        } else {
            observer.waitForNextChange(new ClusterStateObserver.Listener(){

                public void onNewClusterState(ClusterState state) {
                    listener.onResponse(TrainedModelAllocationMetadata.allocationForModelId(state, modelId).orElse(null));
                }

                public void onClusterServiceClose() {
                    listener.onFailure((Exception)new NodeClosedException(TrainedModelAllocationService.this.clusterService.localNode()));
                }

                public void onTimeout(TimeValue timeout) {
                    listener.onTimeout(timeout);
                }
            }, predicate);
        }
    }

    protected void waitForNewMasterAndRetry(ClusterStateObserver observer, final ActionType<AcknowledgedResponse> action, final ActionRequest request, final ActionListener<AcknowledgedResponse> listener, Predicate<ClusterState> changePredicate) {
        observer.waitForNextChange(new ClusterStateObserver.Listener(){

            public void onNewClusterState(ClusterState state) {
                TrainedModelAllocationService.this.client.execute(action, request, listener);
            }

            public void onClusterServiceClose() {
                logger.warn("node closed while execution action [{}] for request [{}]", (Object)action.name(), (Object)request);
                listener.onFailure((Exception)new NodeClosedException(TrainedModelAllocationService.this.clusterService.localNode()));
            }

            public void onTimeout(TimeValue timeout) {
                assert (false);
            }
        }, changePredicate);
    }

    private static boolean isMasterChannelException(Exception exp) {
        return ExceptionsHelper.unwrap((Throwable)exp, (Class[])MASTER_CHANNEL_EXCEPTIONS) != null;
    }

    public static interface WaitForAllocationListener
    extends ActionListener<TrainedModelAllocation> {
        default public void onTimeout(TimeValue timeout) {
            this.onFailure(new IllegalStateException("Timed out when waiting for trained model allocation after " + timeout));
        }
    }
}

