/*
 * Decompiled with CFR 0.152.
 */
package org.elasticsearch.xpack.ml.packageloader.action;

import java.io.IOException;
import java.io.InputStream;
import java.net.URI;
import java.net.URISyntaxException;
import java.util.Objects;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.elasticsearch.ElasticsearchStatusException;
import org.elasticsearch.action.ActionRequest;
import org.elasticsearch.action.ActionResponse;
import org.elasticsearch.action.ActionType;
import org.elasticsearch.client.internal.Client;
import org.elasticsearch.common.bytes.BytesArray;
import org.elasticsearch.common.bytes.BytesReference;
import org.elasticsearch.core.Strings;
import org.elasticsearch.rest.RestStatus;
import org.elasticsearch.tasks.TaskCancelledException;
import org.elasticsearch.xpack.core.ml.action.PutTrainedModelDefinitionPartAction;
import org.elasticsearch.xpack.core.ml.action.PutTrainedModelVocabularyAction;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ModelPackageConfig;
import org.elasticsearch.xpack.ml.packageloader.action.ModelDownloadTask;
import org.elasticsearch.xpack.ml.packageloader.action.ModelLoaderUtils;

class ModelImporter {
    private static final int DEFAULT_CHUNK_SIZE = 0x100000;
    private static final Logger logger = LogManager.getLogger(ModelImporter.class);
    private final Client client;
    private final String modelId;
    private final ModelPackageConfig config;
    private final ModelDownloadTask task;

    ModelImporter(Client client, String modelId, ModelPackageConfig packageConfig, ModelDownloadTask task) {
        this.client = client;
        this.modelId = Objects.requireNonNull(modelId);
        this.config = Objects.requireNonNull(packageConfig);
        this.task = Objects.requireNonNull(task);
    }

    public void doImport() throws URISyntaxException, IOException, ElasticsearchStatusException {
        String message;
        long size = this.config.getSize();
        if (!org.elasticsearch.common.Strings.isNullOrEmpty((String)this.config.getVocabularyFile())) {
            this.uploadVocabulary();
            logger.debug(() -> Strings.format((String)"[%s] imported model vocabulary [%s]", (Object[])new Object[]{this.modelId, this.config.getVocabularyFile()}));
        }
        URI uri = ModelLoaderUtils.resolvePackageLocation(this.config.getModelRepository(), this.config.getPackagedModelId() + ModelLoaderUtils.MODEL_FILE_EXTENSION);
        InputStream modelInputStream = ModelLoaderUtils.getInputStreamFromModelRepository(uri);
        ModelLoaderUtils.InputStreamChunker chunkIterator = new ModelLoaderUtils.InputStreamChunker(modelInputStream, 0x100000);
        int totalParts = (int)((size + 0x100000L - 1L) / 0x100000L);
        for (int part = 0; part < totalParts - 1; ++part) {
            this.task.setProgress(totalParts, part);
            BytesArray definition = chunkIterator.next();
            PutTrainedModelDefinitionPartAction.Request modelPartRequest = new PutTrainedModelDefinitionPartAction.Request(this.modelId, (BytesReference)definition, part, size, totalParts, true);
            this.executeRequestIfNotCancelled((ActionType)PutTrainedModelDefinitionPartAction.INSTANCE, (ActionRequest)modelPartRequest);
        }
        BytesArray definition = chunkIterator.next();
        if (!this.config.getSha256().equals(chunkIterator.getSha256())) {
            message = Strings.format((String)"Model sha256 checksums do not match, expected [%s] but got [%s]", (Object[])new Object[]{this.config.getSha256(), chunkIterator.getSha256()});
            throw new ElasticsearchStatusException(message, RestStatus.INTERNAL_SERVER_ERROR, new Object[0]);
        }
        if (this.config.getSize() != (long)chunkIterator.getTotalBytesRead()) {
            message = Strings.format((String)"Model size does not match, expected [%d] but got [%d]", (Object[])new Object[]{this.config.getSize(), chunkIterator.getTotalBytesRead()});
            throw new ElasticsearchStatusException(message, RestStatus.INTERNAL_SERVER_ERROR, new Object[0]);
        }
        PutTrainedModelDefinitionPartAction.Request finalModelPartRequest = new PutTrainedModelDefinitionPartAction.Request(this.modelId, (BytesReference)definition, totalParts - 1, size, totalParts, true);
        this.executeRequestIfNotCancelled((ActionType)PutTrainedModelDefinitionPartAction.INSTANCE, (ActionRequest)finalModelPartRequest);
        logger.debug(Strings.format((String)"finished importing model [%s] using [%d] parts", (Object[])new Object[]{this.modelId, totalParts}));
    }

    private void uploadVocabulary() throws URISyntaxException {
        ModelLoaderUtils.VocabularyParts vocabularyParts = ModelLoaderUtils.loadVocabulary(ModelLoaderUtils.resolvePackageLocation(this.config.getModelRepository(), this.config.getVocabularyFile()));
        PutTrainedModelVocabularyAction.Request request = new PutTrainedModelVocabularyAction.Request(this.modelId, vocabularyParts.vocab(), vocabularyParts.merges(), vocabularyParts.scores(), true);
        this.executeRequestIfNotCancelled((ActionType)PutTrainedModelVocabularyAction.INSTANCE, (ActionRequest)request);
    }

    private <Request extends ActionRequest, Response extends ActionResponse> void executeRequestIfNotCancelled(ActionType<Response> action, Request request) {
        if (this.task.isCancelled()) {
            throw new TaskCancelledException(Strings.format((String)"task cancelled with reason [%s]", (Object[])new Object[]{this.task.getReasonCancelled()}));
        }
        this.client.execute(action, request).actionGet();
    }
}

