/*
 * Decompiled with CFR 0.152.
 */
package org.elasticsearch.xpack.ml.inference.persistence;

import java.io.IOException;
import java.io.InputStream;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.TimeUnit;
import java.util.function.Consumer;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.elasticsearch.ElasticsearchException;
import org.elasticsearch.ResourceNotFoundException;
import org.elasticsearch.action.search.SearchPhaseExecutionException;
import org.elasticsearch.action.search.SearchRequest;
import org.elasticsearch.action.search.SearchRequestBuilder;
import org.elasticsearch.action.search.SearchResponse;
import org.elasticsearch.client.internal.Client;
import org.elasticsearch.client.internal.OriginSettingClient;
import org.elasticsearch.common.breaker.CircuitBreakingException;
import org.elasticsearch.common.bytes.BytesReference;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.xcontent.LoggingDeprecationHandler;
import org.elasticsearch.core.CheckedFunction;
import org.elasticsearch.core.Nullable;
import org.elasticsearch.core.Strings;
import org.elasticsearch.core.TimeValue;
import org.elasticsearch.index.query.QueryBuilder;
import org.elasticsearch.index.query.QueryBuilders;
import org.elasticsearch.search.SearchHit;
import org.elasticsearch.search.sort.FieldSortBuilder;
import org.elasticsearch.search.sort.SortBuilder;
import org.elasticsearch.search.sort.SortBuilders;
import org.elasticsearch.search.sort.SortOrder;
import org.elasticsearch.tasks.TaskId;
import org.elasticsearch.xcontent.DeprecationHandler;
import org.elasticsearch.xcontent.NamedXContentRegistry;
import org.elasticsearch.xcontent.XContentFactory;
import org.elasticsearch.xcontent.XContentParser;
import org.elasticsearch.xcontent.XContentType;
import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig;
import org.elasticsearch.xpack.core.ml.inference.persistence.InferenceIndexConstants;
import org.elasticsearch.xpack.core.ml.job.messages.Messages;
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
import org.elasticsearch.xpack.ml.inference.persistence.TrainedModelDefinitionDoc;

public class ChunkedTrainedModelRestorer {
    private static final Logger logger = LogManager.getLogger(ChunkedTrainedModelRestorer.class);
    private static final int MAX_NUM_DEFINITION_DOCS = 20;
    private static final int SEARCH_RETRY_LIMIT = 5;
    private static final TimeValue SEARCH_FAILURE_RETRY_WAIT_TIME = new TimeValue(5L, TimeUnit.SECONDS);
    private final Client client;
    private final NamedXContentRegistry xContentRegistry;
    private final ExecutorService executorService;
    private final String modelId;
    private String index = ".ml-inference-*";
    private int searchSize = 10;
    private int numDocsWritten = 0;

    public ChunkedTrainedModelRestorer(String modelId, Client client, ExecutorService executorService, NamedXContentRegistry xContentRegistry) {
        this.client = new OriginSettingClient(client, "ml");
        this.executorService = executorService;
        this.xContentRegistry = xContentRegistry;
        this.modelId = modelId;
    }

    public void setSearchSize(int searchSize) {
        if (searchSize > 20) {
            throw new IllegalArgumentException("search size [" + searchSize + "] cannot be bigger than [20]");
        }
        if (searchSize <= 0) {
            throw new IllegalArgumentException("search size [" + searchSize + "] must be greater than 0");
        }
        this.searchSize = searchSize;
    }

    public void setSearchIndex(String indexNameOrPattern) {
        this.index = indexNameOrPattern;
    }

    public int getNumDocsWritten() {
        return this.numDocsWritten;
    }

    public void restoreModelDefinition(CheckedFunction<TrainedModelDefinitionDoc, Boolean, IOException> modelConsumer, Consumer<Boolean> successConsumer, Consumer<Exception> errorConsumer) {
        logger.debug("[{}] restoring model", (Object)this.modelId);
        SearchRequest searchRequest = ChunkedTrainedModelRestorer.buildSearch(this.client, this.modelId, this.index, this.searchSize, null);
        this.executorService.execute(() -> this.doSearch(searchRequest, modelConsumer, successConsumer, errorConsumer));
    }

    private void doSearch(SearchRequest searchRequest, CheckedFunction<TrainedModelDefinitionDoc, Boolean, IOException> modelConsumer, Consumer<Boolean> successConsumer, Consumer<Exception> errorConsumer) {
        try {
            boolean endOfSearch;
            assert (Thread.currentThread().getName().contains("ml_native_inference_comms") || Thread.currentThread().getName().contains("ml_utility")) : Strings.format((String)"Must execute from [%s] or [%s] but thread is [%s]", (Object[])new Object[]{"ml_native_inference_comms", "ml_utility", Thread.currentThread().getName()});
            SearchResponse searchResponse = ChunkedTrainedModelRestorer.retryingSearch(this.client, this.modelId, searchRequest, 5, SEARCH_FAILURE_RETRY_WAIT_TIME);
            if (searchResponse.getHits().getHits().length == 0) {
                errorConsumer.accept((Exception)((Object)new ResourceNotFoundException(Messages.getMessage((String)"Could not find trained model definition [{0}]", (Object[])new Object[]{this.modelId}), new Object[0])));
                return;
            }
            int lastNum = this.numDocsWritten - 1;
            for (SearchHit hit : searchResponse.getHits().getHits()) {
                logger.debug(() -> Strings.format((String)"[%s] Restoring model definition doc with id [%s]", (Object[])new Object[]{this.modelId, hit.getId()}));
                try {
                    TrainedModelDefinitionDoc doc = ChunkedTrainedModelRestorer.parseModelDefinitionDocLenientlyFromSource(hit.getSourceRef(), this.modelId, this.xContentRegistry);
                    lastNum = doc.getDocNum();
                    boolean continueSearching = (Boolean)modelConsumer.apply((Object)doc);
                    if (continueSearching) continue;
                    successConsumer.accept(Boolean.FALSE);
                    return;
                }
                catch (IOException e) {
                    logger.error(() -> "[" + this.modelId + "] error writing model definition", (Throwable)e);
                    errorConsumer.accept(e);
                    return;
                }
            }
            this.numDocsWritten += searchResponse.getHits().getHits().length;
            boolean bl = endOfSearch = searchResponse.getHits().getHits().length < this.searchSize || searchResponse.getHits().getTotalHits().value == (long)this.numDocsWritten;
            if (endOfSearch) {
                successConsumer.accept(Boolean.TRUE);
            } else {
                SearchHit lastHit = searchResponse.getHits().getAt(searchResponse.getHits().getHits().length - 1);
                SearchRequestBuilder searchRequestBuilder = ChunkedTrainedModelRestorer.buildSearchBuilder(this.client, this.modelId, this.index, this.searchSize);
                searchRequestBuilder.searchAfter(new Object[]{lastHit.getIndex(), lastNum});
                this.executorService.execute(() -> this.doSearch((SearchRequest)searchRequestBuilder.request(), modelConsumer, successConsumer, errorConsumer));
            }
        }
        catch (Exception e) {
            if (ExceptionsHelper.unwrapCause((Throwable)e) instanceof ResourceNotFoundException) {
                errorConsumer.accept((Exception)((Object)new ResourceNotFoundException(Messages.getMessage((String)"Could not find trained model definition [{0}]", (Object[])new Object[]{this.modelId}), new Object[0])));
            }
            errorConsumer.accept(e);
        }
    }

    static SearchResponse retryingSearch(Client client, String modelId, SearchRequest searchRequest, int retries, TimeValue sleep) throws InterruptedException {
        int failureCount = 0;
        while (true) {
            try {
                return (SearchResponse)client.search(searchRequest).actionGet();
            }
            catch (Exception e) {
                if (!(ExceptionsHelper.unwrapCause((Throwable)e) instanceof SearchPhaseExecutionException) && !(ExceptionsHelper.unwrapCause((Throwable)e) instanceof CircuitBreakingException)) {
                    throw e;
                }
                if (failureCount >= retries) {
                    logger.warn(Strings.format((String)"[%s] searching for model part failed %s times, returning failure", (Object[])new Object[]{modelId, retries}));
                    throw new ElasticsearchException(Strings.format((String)"loading model [%s] failed after [%s] retries. The deployment is now in a failed state, the error may be transient please stop the deployment and restart", (Object[])new Object[]{modelId, retries}), (Throwable)e, new Object[0]);
                }
                logger.debug(Strings.format((String)"[%s] searching for model part failed %s times, retrying", (Object[])new Object[]{modelId, ++failureCount}));
                TimeUnit.SECONDS.sleep(sleep.getSeconds());
                continue;
            }
            break;
        }
    }

    private static SearchRequestBuilder buildSearchBuilder(Client client, String modelId, String index, int searchSize) {
        return client.prepareSearch(new String[]{index}).setQuery((QueryBuilder)QueryBuilders.constantScoreQuery((QueryBuilder)QueryBuilders.boolQuery().filter((QueryBuilder)QueryBuilders.termQuery((String)TrainedModelConfig.MODEL_ID.getPreferredName(), (String)modelId)).filter((QueryBuilder)QueryBuilders.termQuery((String)InferenceIndexConstants.DOC_TYPE.getPreferredName(), (String)"trained_model_definition_doc")))).setSize(searchSize).setTrackTotalHits(true).addSort("_index", SortOrder.DESC).addSort((SortBuilder)((FieldSortBuilder)SortBuilders.fieldSort((String)TrainedModelDefinitionDoc.DOC_NUM.getPreferredName()).order(SortOrder.ASC)).unmappedType("long"));
    }

    public static SearchRequest buildSearch(Client client, String modelId, String index, int searchSize, @Nullable TaskId parentTaskId) {
        SearchRequest searchRequest = (SearchRequest)ChunkedTrainedModelRestorer.buildSearchBuilder(client, modelId, index, searchSize).request();
        if (parentTaskId != null) {
            searchRequest.setParentTask(parentTaskId);
        }
        return searchRequest;
    }

    /*
     * Enabled aggressive exception aggregation
     */
    public static TrainedModelDefinitionDoc parseModelDefinitionDocLenientlyFromSource(BytesReference source, String modelId, NamedXContentRegistry xContentRegistry) throws IOException {
        try (StreamInput stream = source.streamInput();){
            TrainedModelDefinitionDoc trainedModelDefinitionDoc;
            block14: {
                XContentParser parser = XContentFactory.xContent((XContentType)XContentType.JSON).createParser(xContentRegistry, (DeprecationHandler)LoggingDeprecationHandler.INSTANCE, (InputStream)stream);
                try {
                    trainedModelDefinitionDoc = TrainedModelDefinitionDoc.fromXContent(parser, true).build();
                    if (parser == null) break block14;
                }
                catch (Throwable throwable) {
                    if (parser != null) {
                        try {
                            parser.close();
                        }
                        catch (Throwable throwable2) {
                            throwable.addSuppressed(throwable2);
                        }
                    }
                    throw throwable;
                }
                parser.close();
            }
            return trainedModelDefinitionDoc;
        }
        catch (IOException e) {
            logger.error(() -> "[" + modelId + "] failed to parse model definition", (Throwable)e);
            throw e;
        }
    }
}

