/*
 * Decompiled with CFR 0.152.
 */
package org.elasticsearch.xpack.ml.inference.pytorch.process;

import java.time.Instant;
import java.util.Iterator;
import java.util.LongSummaryStatistics;
import java.util.Objects;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentMap;
import java.util.function.Consumer;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.apache.logging.log4j.message.Message;
import org.apache.logging.log4j.message.ParameterizedMessage;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.xpack.ml.inference.pytorch.process.NativePyTorchProcess;
import org.elasticsearch.xpack.ml.inference.pytorch.results.PyTorchInferenceResult;
import org.elasticsearch.xpack.ml.inference.pytorch.results.PyTorchResult;
import org.elasticsearch.xpack.ml.inference.pytorch.results.ThreadSettings;

public class PyTorchResultProcessor {
    private static final Logger logger = LogManager.getLogger(PyTorchResultProcessor.class);
    private final ConcurrentMap<String, PendingResult> pendingResults = new ConcurrentHashMap<String, PendingResult>();
    private final String deploymentId;
    private volatile boolean isStopping;
    private final LongSummaryStatistics timingStats;
    private final Consumer<ThreadSettings> threadSettingsConsumer;
    private int errorCount;
    private Instant lastUsed;

    public PyTorchResultProcessor(String deploymentId, Consumer<ThreadSettings> threadSettingsConsumer) {
        this.deploymentId = Objects.requireNonNull(deploymentId);
        this.timingStats = new LongSummaryStatistics();
        this.threadSettingsConsumer = Objects.requireNonNull(threadSettingsConsumer);
    }

    public void registerRequest(String requestId, ActionListener<PyTorchInferenceResult> listener) {
        this.pendingResults.computeIfAbsent(requestId, k -> new PendingResult(listener));
    }

    public void ignoreResponseWithoutNotifying(String requestId) {
        this.pendingResults.remove(requestId);
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    public void process(NativePyTorchProcess process) {
        try {
            Iterator<PyTorchResult> iterator = process.readResults();
            while (iterator.hasNext()) {
                ThreadSettings threadSettings;
                PyTorchResult result = iterator.next();
                PyTorchInferenceResult inferenceResult = result.inferenceResult();
                if (inferenceResult != null) {
                    this.processInferenceResult(inferenceResult);
                }
                if ((threadSettings = result.threadSettings()) == null) continue;
                this.threadSettingsConsumer.accept(threadSettings);
            }
        }
        catch (Exception e) {
            if (!this.isStopping) {
                logger.error((Message)new ParameterizedMessage("[{}] Error processing results", (Object)this.deploymentId), (Throwable)e);
            }
            this.pendingResults.forEach((id, pendingResult) -> pendingResult.listener.onResponse((Object)new PyTorchInferenceResult((String)id, null, null, (String)(this.isStopping ? "inference canceled as process is stopping" : "inference native process died unexpectedly with failure [" + e.getMessage() + "]"))));
            this.pendingResults.clear();
        }
        finally {
            this.pendingResults.forEach((id, pendingResult) -> pendingResult.listener.onResponse((Object)new PyTorchInferenceResult((String)id, null, null, "inference canceled as process is stopping")));
            this.pendingResults.clear();
        }
        logger.debug(() -> new ParameterizedMessage("[{}] Results processing finished", (Object)this.deploymentId));
    }

    private void processInferenceResult(PyTorchInferenceResult inferenceResult) {
        logger.trace(() -> new ParameterizedMessage("[{}] Parsed result with id [{}]", (Object)this.deploymentId, (Object)inferenceResult.getRequestId()));
        this.processResult(inferenceResult);
        PendingResult pendingResult = (PendingResult)this.pendingResults.remove(inferenceResult.getRequestId());
        if (pendingResult == null) {
            logger.debug(() -> new ParameterizedMessage("[{}] no pending result for [{}]", (Object)this.deploymentId, (Object)inferenceResult.getRequestId()));
        } else {
            pendingResult.listener.onResponse((Object)inferenceResult);
        }
    }

    public synchronized ResultStats getResultStats() {
        return new ResultStats(new LongSummaryStatistics(this.timingStats.getCount(), this.timingStats.getMin(), this.timingStats.getMax(), this.timingStats.getSum()), this.errorCount, this.pendingResults.size(), this.lastUsed);
    }

    private synchronized void processResult(PyTorchInferenceResult result) {
        if (!result.isError()) {
            this.timingStats.accept(result.getTimeMs());
            this.lastUsed = Instant.now();
        } else {
            ++this.errorCount;
        }
    }

    public void stop() {
        this.isStopping = true;
    }

    public static class PendingResult {
        public final ActionListener<PyTorchInferenceResult> listener;

        public PendingResult(ActionListener<PyTorchInferenceResult> listener) {
            this.listener = Objects.requireNonNull(listener);
        }
    }

    public record ResultStats(LongSummaryStatistics timingStats, int errorCount, int numberOfPendingResults, Instant lastUsed) {
    }
}

