/*
 * Decompiled with CFR 0.152.
 */
package org.elasticsearch.xpack.ml.inference.pytorch.process;

import java.io.IOException;
import java.io.OutputStream;
import java.nio.ByteBuffer;
import java.util.Locale;
import java.util.Objects;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.atomic.AtomicInteger;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.client.internal.Client;
import org.elasticsearch.client.internal.OriginSettingClient;
import org.elasticsearch.core.CheckedFunction;
import org.elasticsearch.xcontent.NamedXContentRegistry;
import org.elasticsearch.xpack.ml.inference.persistence.ChunkedTrainedModelRestorer;
import org.elasticsearch.xpack.ml.inference.persistence.TrainedModelDefinitionDoc;

public class PyTorchStateStreamer {
    private static final Logger logger = LogManager.getLogger(PyTorchStateStreamer.class);
    private static final int NUM_BYTES_IN_PRELUDE = 4;
    private final OriginSettingClient client;
    private final ExecutorService executorService;
    private final NamedXContentRegistry xContentRegistry;
    private volatile boolean isCancelled;
    private volatile int modelSize = -1;
    private final AtomicInteger modelBytesWritten = new AtomicInteger();

    public PyTorchStateStreamer(Client client, ExecutorService executorService, NamedXContentRegistry xContentRegistry) {
        this.client = new OriginSettingClient(Objects.requireNonNull(client), "ml");
        this.executorService = Objects.requireNonNull(executorService);
        this.xContentRegistry = Objects.requireNonNull(xContentRegistry);
    }

    public void cancel() {
        this.isCancelled = true;
    }

    public void writeStateToStream(String modelId, String index, OutputStream restoreStream, ActionListener<Boolean> listener) {
        ChunkedTrainedModelRestorer restorer = new ChunkedTrainedModelRestorer(modelId, (Client)this.client, this.executorService, this.xContentRegistry);
        restorer.setSearchIndex(index);
        restorer.setSearchSize(1);
        restorer.restoreModelDefinition((CheckedFunction<TrainedModelDefinitionDoc, Boolean, IOException>)((CheckedFunction)doc -> this.writeChunk((TrainedModelDefinitionDoc)doc, restoreStream)), success -> {
            logger.debug("model [{}] state restored in [{}] documents from index [{}]", (Object)modelId, (Object)restorer.getNumDocsWritten(), (Object)index);
            if (success.booleanValue()) {
                if (this.modelBytesWritten.get() != this.modelSize) {
                    logger.error("model [{}] restored state size [{}] does not equal the expected model size [{}]", (Object)modelId, (Object)this.modelBytesWritten, (Object)this.modelSize);
                }
            } else {
                logger.info("[{}] loading model state cancelled", (Object)modelId);
            }
            listener.onResponse(success);
        }, arg_0 -> listener.onFailure(arg_0));
    }

    private boolean writeChunk(TrainedModelDefinitionDoc doc, OutputStream outputStream) throws IOException {
        if (this.isCancelled) {
            return false;
        }
        if (this.modelSize == -1) {
            this.modelSize = this.writeModelSize(doc.getModelId(), doc.getTotalDefinitionLength(), outputStream);
        }
        outputStream.write(doc.getBinaryData().array(), doc.getBinaryData().arrayOffset(), doc.getBinaryData().length());
        this.modelBytesWritten.addAndGet(doc.getBinaryData().length());
        return true;
    }

    private int writeModelSize(String modelId, Long modelSizeBytes, OutputStream outputStream) throws IOException {
        if (modelSizeBytes == null) {
            String message = String.format(Locale.ROOT, "The definition doc for model [%s] has a null value for field [%s]", modelId, TrainedModelDefinitionDoc.TOTAL_DEFINITION_LENGTH.getPreferredName());
            logger.error(message);
            throw new IllegalStateException(message);
        }
        if (modelSizeBytes <= 0L) {
            String message = String.format(Locale.ROOT, "The definition doc for model [%s] has a negative value [%s] for field [%s]", modelId, modelSizeBytes, TrainedModelDefinitionDoc.TOTAL_DEFINITION_LENGTH.getPreferredName());
            logger.error(message);
            throw new IllegalStateException(message);
        }
        if (modelSizeBytes > Integer.MAX_VALUE) {
            String message = String.format(Locale.ROOT, "model [%s] has a size [%s] larger than the max size [%s]", modelId, modelSizeBytes, Integer.MAX_VALUE);
            logger.error(message);
            throw new IllegalStateException(message);
        }
        ByteBuffer lengthBuffer = ByteBuffer.allocate(4);
        lengthBuffer.putInt(modelSizeBytes.intValue());
        outputStream.write(lengthBuffer.array());
        return modelSizeBytes.intValue();
    }
}

