/*
 * Decompiled with CFR 0.152.
 */
package org.elasticsearch.xpack.ml.inference.pytorch.process;

import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import java.util.Objects;
import org.elasticsearch.xpack.core.ml.action.StartTrainedModelDeploymentAction;
import org.elasticsearch.xpack.core.ml.inference.assignment.Priority;
import org.elasticsearch.xpack.ml.process.NativeController;
import org.elasticsearch.xpack.ml.process.ProcessPipes;

public class PyTorchBuilder {
    public static final String PROCESS_NAME = "pytorch_inference";
    private static final String PROCESS_PATH = "./pytorch_inference";
    private static final String LICENSE_KEY_VALIDATED_ARG = "--validElasticLicenseKeyConfirmed=";
    private static final String NUM_THREADS_PER_ALLOCATION_ARG = "--numThreadsPerAllocation=";
    private static final String NUM_ALLOCATIONS_ARG = "--numAllocations=";
    private static final String CACHE_MEMORY_LIMIT_BYTES_ARG = "--cacheMemorylimitBytes=";
    private static final String LOW_PRIORITY_ARG = "--lowPriority";
    private final NativeController nativeController;
    private final ProcessPipes processPipes;
    private final StartTrainedModelDeploymentAction.TaskParams taskParams;

    public PyTorchBuilder(NativeController nativeController, ProcessPipes processPipes, StartTrainedModelDeploymentAction.TaskParams taskParams) {
        this.nativeController = Objects.requireNonNull(nativeController);
        this.processPipes = Objects.requireNonNull(processPipes);
        this.taskParams = Objects.requireNonNull(taskParams);
    }

    public void build() throws IOException, InterruptedException {
        List<String> command = this.buildCommand();
        this.processPipes.addArgs(command);
        this.nativeController.startProcess(command);
    }

    private List<String> buildCommand() {
        ArrayList<String> command = new ArrayList<String>();
        command.add(PROCESS_PATH);
        command.add("--validElasticLicenseKeyConfirmed=true");
        command.add(NUM_THREADS_PER_ALLOCATION_ARG + this.taskParams.getThreadsPerAllocation());
        command.add(NUM_ALLOCATIONS_ARG + this.taskParams.getNumberOfAllocations());
        if (this.taskParams.getCacheSizeBytes() > 0L) {
            command.add(CACHE_MEMORY_LIMIT_BYTES_ARG + this.taskParams.getCacheSizeBytes());
        }
        if (this.taskParams.getPriority() == Priority.LOW) {
            command.add(LOW_PRIORITY_ARG);
        }
        return command;
    }
}

