/*
 * Decompiled with CFR 0.152.
 */
package org.elasticsearch.xpack.ml.aggs.changepoint;

import java.util.Arrays;
import java.util.HashSet;
import java.util.Random;
import java.util.Set;
import java.util.function.IntToDoubleFunction;
import java.util.stream.IntStream;
import org.apache.commons.math3.distribution.UniformRealDistribution;
import org.apache.commons.math3.random.RandomGeneratorFactory;
import org.apache.commons.math3.special.Beta;
import org.apache.commons.math3.stat.inference.KolmogorovSmirnovTest;
import org.apache.commons.math3.stat.regression.SimpleRegression;
import org.elasticsearch.common.util.set.Sets;
import org.elasticsearch.logging.LogManager;
import org.elasticsearch.logging.Logger;
import org.elasticsearch.xpack.ml.aggs.MlAggsHelper;
import org.elasticsearch.xpack.ml.aggs.changepoint.ChangeType;
import org.elasticsearch.xpack.ml.aggs.changepoint.LeastSquaresOnlineRegression;

public class ChangeDetector {
    private static final int MAXIMUM_SAMPLE_SIZE_FOR_KS_TEST = 500;
    private static final int MAXIMUM_CANDIDATE_CHANGE_POINTS = 1000;
    private static final KolmogorovSmirnovTest KOLMOGOROV_SMIRNOV_TEST = new KolmogorovSmirnovTest();
    private static final Logger logger = LogManager.getLogger(ChangeDetector.class);
    private final MlAggsHelper.DoubleBucketValues bucketValues;
    private final double[] values;

    ChangeDetector(MlAggsHelper.DoubleBucketValues bucketValues) {
        this.bucketValues = bucketValues;
        this.values = bucketValues.getValues();
    }

    ChangeType detect(double minBucketsPValue) {
        double pValueThreshold = minBucketsPValue * Math.exp(-0.04 * (double)(this.values.length - 22));
        return this.testForChange(pValueThreshold).changeType(this.bucketValues, this.slope(this.values));
    }

    private TestStats testForChange(double pValueThreshold) {
        int[] candidateChangePoints = this.computeCandidateChangePoints(this.values);
        logger.trace("candidatePoints: [{}]", new Object[]{Arrays.toString(candidateChangePoints)});
        double[] valuesWeights = this.outlierWeights(this.values);
        logger.trace("values: [{}]", new Object[]{Arrays.toString(this.values)});
        logger.trace("valuesWeights: [{}]", new Object[]{Arrays.toString(valuesWeights)});
        RunningStats dataRunningStats = RunningStats.from(this.values, i -> valuesWeights[i]);
        DataStats dataStats = new DataStats(dataRunningStats.count(), dataRunningStats.mean(), dataRunningStats.variance(), candidateChangePoints.length);
        logger.trace("dataStats: [{}]", new Object[]{dataStats});
        TestStats stationary = new TestStats(Type.STATIONARY, 1.0, dataStats.var(), 1.0, dataStats);
        if (dataStats.varianceZeroToWorkingPrecision()) {
            return stationary;
        }
        TestStats trendVsStationary = this.testTrendVs(stationary, this.values, valuesWeights);
        logger.trace("trend vs stationary: [{}]", new Object[]{trendVsStationary});
        TestStats best = stationary;
        HashSet discoveredChangePoints = Sets.newHashSetWithExpectedSize((int)4);
        if (trendVsStationary.accept(pValueThreshold)) {
            TestStats trendChangeVsTrend = this.testTrendChangeVs(trendVsStationary, this.values, valuesWeights, candidateChangePoints);
            discoveredChangePoints.add(trendChangeVsTrend.changePoint());
            logger.trace("trend change vs trend: [{}]", new Object[]{trendChangeVsTrend});
            best = trendChangeVsTrend.accept(pValueThreshold) ? this.testVsStepChange(trendChangeVsTrend, this.values, valuesWeights, candidateChangePoints, pValueThreshold) : trendVsStationary;
        } else {
            TestStats stepChangeVsStationary = this.testStepChangeVs(stationary, this.values, valuesWeights, candidateChangePoints);
            discoveredChangePoints.add(stepChangeVsStationary.changePoint());
            logger.trace("step change vs stationary: [{}]", new Object[]{stepChangeVsStationary});
            if (stepChangeVsStationary.accept(pValueThreshold)) {
                TestStats trendChangeVsStepChange = this.testTrendChangeVs(stepChangeVsStationary, this.values, valuesWeights, candidateChangePoints);
                discoveredChangePoints.add(stepChangeVsStationary.changePoint());
                logger.trace("trend change vs step change: [{}]", new Object[]{trendChangeVsStepChange});
                best = trendChangeVsStepChange.accept(pValueThreshold) ? trendChangeVsStepChange : stepChangeVsStationary;
            } else {
                TestStats trendChangeVsStationary = this.testTrendChangeVs(stationary, this.values, valuesWeights, candidateChangePoints);
                discoveredChangePoints.add(stepChangeVsStationary.changePoint());
                logger.trace("trend change vs stationary: [{}]", new Object[]{trendChangeVsStationary});
                if (trendChangeVsStationary.accept(pValueThreshold)) {
                    best = trendChangeVsStationary;
                }
            }
        }
        logger.trace("best: [{}]", new Object[]{best.pValueVsStationary()});
        if (best.pValueVsStationary() > 1.0E-5) {
            TestStats distChange = this.testDistributionChange(dataStats, this.values, valuesWeights, candidateChangePoints, discoveredChangePoints);
            logger.trace("distribution change: [{}]", new Object[]{distChange});
            if (distChange.pValue() < Math.min(pValueThreshold, 0.1 * best.pValueVsStationary())) {
                best = distChange;
            }
        }
        return best;
    }

    private int[] computeCandidateChangePoints(double[] values) {
        int minValues = Math.max((int)(0.1 * (double)values.length + 0.5), 10);
        if (values.length - 2 * minValues <= 1000) {
            return IntStream.range(minValues, values.length - minValues).toArray();
        }
        int step = (int)Math.ceil((double)(values.length - 2 * minValues) / 1000.0);
        return IntStream.range(minValues, values.length - minValues).filter(i -> i % step == 0).toArray();
    }

    private double[] outlierWeights(double[] values) {
        int i = (int)Math.ceil(0.025 * (double)values.length);
        double[] weights = Arrays.copyOf(values, values.length);
        Arrays.sort(weights);
        double a = weights[i];
        double b = weights[values.length - i - 1];
        for (int j = 0; j < values.length; ++j) {
            weights[j] = values[j] <= b && values[j] >= a ? 1.0 : 0.01;
        }
        return weights;
    }

    private double slope(double[] values) {
        SimpleRegression regression = new SimpleRegression();
        for (int i = 0; i < values.length; ++i) {
            regression.addData((double)i, values[i]);
        }
        return regression.getSlope();
    }

    private static double independentTrialsPValue(double pValue, int nTrials) {
        return pValue > 1.0E-10 ? 1.0 - Math.pow(1.0 - pValue, nTrials) : (double)nTrials * pValue;
    }

    private TestStats testTrendVs(TestStats H0, double[] values, double[] weights) {
        LeastSquaresOnlineRegression allLeastSquares = new LeastSquaresOnlineRegression(2);
        for (int i = 0; i < values.length; ++i) {
            allLeastSquares.add(i, values[i], weights[i]);
        }
        double vTrend = H0.dataStats().var() * (1.0 - allLeastSquares.rSquared());
        double pValue = ChangeDetector.fTestNestedPValue(H0.dataStats().nValues(), H0.var(), H0.nParams(), vTrend, 3.0);
        return new TestStats(Type.NON_STATIONARY, pValue, vTrend, 3.0, H0.dataStats());
    }

    private TestStats testStepChangeVs(TestStats H0, double[] values, double[] weights, int[] candidateChangePoints) {
        double vStep = Double.MAX_VALUE;
        int changePoint = -1;
        RunningStats lowerRange = new RunningStats();
        RunningStats upperRange = new RunningStats();
        upperRange.addValues(values, i -> weights[i], candidateChangePoints[0], values.length);
        lowerRange.addValues(values, i -> weights[i], 0, candidateChangePoints[0]);
        double mean = H0.dataStats().mean();
        int last = candidateChangePoints[0];
        for (int cp : candidateChangePoints) {
            lowerRange.addValues(values, i -> weights[i], last, cp);
            upperRange.removeValues(values, i -> weights[i], last, cp);
            last = cp;
            double nl = lowerRange.count();
            double nu = upperRange.count();
            double ml = lowerRange.mean();
            double mu = upperRange.mean();
            double vl = lowerRange.variance();
            double vu = upperRange.variance();
            double v = (nl * vl + nu * vu) / (nl + nu);
            if (!(v < vStep)) continue;
            vStep = v;
            changePoint = cp;
        }
        double pValue = ChangeDetector.independentTrialsPValue(ChangeDetector.fTestNestedPValue(H0.dataStats().nValues(), H0.var(), H0.nParams(), vStep, 2.0), candidateChangePoints.length);
        return new TestStats(Type.STEP_CHANGE, pValue, vStep, 2.0, changePoint, H0.dataStats());
    }

    private TestStats testTrendChangeVs(TestStats H0, double[] values, double[] weights, int[] candidateChangePoints) {
        int i2;
        double vChange = Double.MAX_VALUE;
        int changePoint = -1;
        RunningStats lowerRange = new RunningStats();
        RunningStats upperRange = new RunningStats();
        lowerRange.addValues(values, i -> weights[i], 0, candidateChangePoints[0]);
        upperRange.addValues(values, i -> weights[i], candidateChangePoints[0], values.length);
        LeastSquaresOnlineRegression lowerLeastSquares = new LeastSquaresOnlineRegression(2);
        LeastSquaresOnlineRegression upperLeastSquares = new LeastSquaresOnlineRegression(2);
        int first = candidateChangePoints[0];
        int last = candidateChangePoints[0];
        for (i2 = 0; i2 < candidateChangePoints[0]; ++i2) {
            lowerLeastSquares.add(i2, values[i2], weights[i2]);
        }
        for (i2 = candidateChangePoints[0]; i2 < values.length; ++i2) {
            upperLeastSquares.add(i2 - first, values[i2], weights[i2]);
        }
        for (int cp : candidateChangePoints) {
            double vu;
            for (int i3 = last; i3 < cp; ++i3) {
                lowerRange.addValue(values[i3], weights[i3]);
                upperRange.removeValue(values[i3], weights[i3]);
                lowerLeastSquares.add(i3, values[i3], weights[i3]);
                upperLeastSquares.remove(i3 - first, values[i3], weights[i3]);
            }
            last = cp;
            double nl = lowerRange.count();
            double nu = upperRange.count();
            double rl = lowerLeastSquares.rSquared();
            double ru = upperLeastSquares.rSquared();
            double vl = lowerRange.variance() * (1.0 - rl);
            double v = (nl * vl + nu * (vu = upperRange.variance() * (1.0 - ru))) / (nl + nu);
            if (!(v < vChange)) continue;
            vChange = v;
            changePoint = cp;
        }
        double pValue = ChangeDetector.independentTrialsPValue(ChangeDetector.fTestNestedPValue(H0.dataStats().nValues(), H0.var(), H0.nParams(), vChange, 6.0), candidateChangePoints.length);
        return new TestStats(Type.TREND_CHANGE, pValue, vChange, 6.0, changePoint, H0.dataStats());
    }

    private TestStats testVsStepChange(TestStats trendChange, double[] values, double[] weights, int[] candidateChangePoints, double pValueThreshold) {
        DataStats dataStats = trendChange.dataStats();
        TestStats stationary = new TestStats(Type.STATIONARY, 1.0, dataStats.var(), 1.0, dataStats);
        TestStats stepChange = this.testStepChangeVs(stationary, values, weights, candidateChangePoints);
        double n = dataStats.nValues();
        double pValue = ChangeDetector.fTestNestedPValue(n, stepChange.var(), 2.0, trendChange.var(), 6.0);
        return pValue < pValueThreshold ? trendChange : stepChange;
    }

    private static double fTestNestedPValue(double n, double vNull, double pNull, double vAlt, double pAlt) {
        if (vAlt == vNull) {
            return 1.0;
        }
        if (vAlt == 0.0) {
            return 0.0;
        }
        double F = (vNull - vAlt) / (pAlt - pNull) * (n - pAlt) / vAlt;
        double sf = ChangeDetector.fDistribSf(pAlt - pNull, n - pAlt, F);
        return Math.min(2.0 * sf, 1.0);
    }

    private static int lowerBound(int[] x, int start, int end, int xs) {
        int retVal = Arrays.binarySearch(x, start, end, xs);
        if (retVal < 0) {
            retVal = -1 - retVal;
        }
        return retVal;
    }

    private SampleData sample(double[] values, double[] weights, Set<Integer> changePoints) {
        int i;
        Integer[] adjChangePoints = changePoints.toArray(new Integer[changePoints.size()]);
        if (values.length <= 500) {
            return new SampleData(values, weights, adjChangePoints);
        }
        Random rng = new Random(126832678L);
        UniformRealDistribution uniform = new UniformRealDistribution(RandomGeneratorFactory.createRandomGenerator((Random)rng), 0.0, 0.99999);
        int[] choice = IntStream.range(0, values.length).toArray();
        for (int i2 = 0; i2 < 500; ++i2) {
            int index = i2 + (int)Math.floor(uniform.sample() * (double)(values.length - i2));
            int tmp = choice[i2];
            choice[i2] = choice[index];
            choice[index] = tmp;
        }
        double[] sample = new double[500];
        double[] sampleWeights = new double[500];
        Arrays.sort(choice, 0, 500);
        for (i = 0; i < 500; ++i) {
            sample[i] = values[choice[i]];
            sampleWeights[i] = weights[choice[i]];
        }
        for (i = 0; i < adjChangePoints.length; ++i) {
            adjChangePoints[i] = ChangeDetector.lowerBound(choice, 0, 500, adjChangePoints[i]);
        }
        return new SampleData(sample, sampleWeights, adjChangePoints);
    }

    private TestStats testDistributionChange(DataStats stats, double[] values, double[] weights, int[] candidateChangePoints, Set<Integer> discoveredChangePoints) {
        double maxDiff = 0.0;
        int changePoint = -1;
        RunningStats lowerRange = new RunningStats();
        RunningStats upperRange = new RunningStats();
        upperRange.addValues(values, i -> weights[i], candidateChangePoints[0], values.length);
        lowerRange.addValues(values, i -> weights[i], 0, candidateChangePoints[0]);
        int last = candidateChangePoints[0];
        for (int cp : candidateChangePoints) {
            lowerRange.addValues(values, i -> weights[i], last, cp);
            upperRange.removeValues(values, i -> weights[i], last, cp);
            last = cp;
            double scale = Math.min(cp, values.length - cp);
            double meanDiff = Math.abs(lowerRange.mean() - upperRange.mean());
            double stdDiff = Math.abs(lowerRange.std() - upperRange.std());
            double diff = scale * (meanDiff + stdDiff);
            if (!(diff >= maxDiff)) continue;
            maxDiff = diff;
            changePoint = cp;
        }
        discoveredChangePoints.add(changePoint);
        SampleData sampleData = this.sample(values, weights, discoveredChangePoints);
        double[] sampleValues = sampleData.values();
        double pValue = 1.0;
        Integer[] integerArray = sampleData.changePoints();
        int n = integerArray.length;
        for (int j = 0; j < n; ++j) {
            double[] y;
            double[] x;
            double statistic;
            double ksTestPValue;
            int cp = integerArray[j];
            if (cp == -1 || !((ksTestPValue = KOLMOGOROV_SMIRNOV_TEST.exactP(statistic = KOLMOGOROV_SMIRNOV_TEST.kolmogorovSmirnovStatistic(x = Arrays.copyOfRange(sampleValues, 0, cp), y = Arrays.copyOfRange(sampleValues, cp, sampleValues.length)), x.length, y.length, false)) < pValue)) continue;
            changePoint = cp;
            pValue = ksTestPValue;
        }
        pValue = ChangeDetector.independentTrialsPValue(pValue, (sampleValues.length + 49) / 50);
        logger.trace("distribution change p-value: [{}]", new Object[]{pValue});
        return new TestStats(Type.DISTRIBUTION_CHANGE, pValue, changePoint, stats);
    }

    private static double fDistribSf(double numeratorDegreesOfFreedom, double denominatorDegreesOfFreedom, double x) {
        if (x <= 0.0) {
            return 1.0;
        }
        if (Double.isInfinite(x) || Double.isNaN(x)) {
            return 0.0;
        }
        return Beta.regularizedBeta((double)(denominatorDegreesOfFreedom / (denominatorDegreesOfFreedom + numeratorDegreesOfFreedom * x)), (double)(0.5 * denominatorDegreesOfFreedom), (double)(0.5 * numeratorDegreesOfFreedom));
    }

    private record TestStats(Type type, double pValue, double var, double nParams, int changePoint, DataStats dataStats) {
        TestStats(Type type, double pValue, int changePoint, DataStats dataStats) {
            this(type, pValue, 0.0, 0.0, changePoint, dataStats);
        }

        TestStats(Type type, double pValue, double var, double nParams, DataStats dataStats) {
            this(type, pValue, var, nParams, -1, dataStats);
        }

        boolean accept(double pValueThreshold) {
            return this.pValue < pValueThreshold && this.rSquared() >= 0.5;
        }

        double rSquared() {
            return 1.0 - this.var / this.dataStats.var();
        }

        double pValueVsStationary() {
            return ChangeDetector.independentTrialsPValue(ChangeDetector.fTestNestedPValue(this.dataStats.nValues(), this.dataStats.var(), 1.0, this.var, this.nParams), this.dataStats.nCandidateChangePoints());
        }

        ChangeType changeType(MlAggsHelper.DoubleBucketValues bucketValues, double slope) {
            switch (this.type) {
                case STATIONARY: {
                    return new ChangeType.Stationary();
                }
                case NON_STATIONARY: {
                    return new ChangeType.NonStationary(this.pValueVsStationary(), this.rSquared(), slope < 0.0 ? "decreasing" : "increasing");
                }
                case STEP_CHANGE: {
                    return new ChangeType.StepChange(this.pValueVsStationary(), bucketValues.getBucketIndex(this.changePoint));
                }
                case TREND_CHANGE: {
                    return new ChangeType.TrendChange(this.pValueVsStationary(), this.rSquared(), bucketValues.getBucketIndex(this.changePoint));
                }
                case DISTRIBUTION_CHANGE: {
                    return new ChangeType.DistributionChange(this.pValue, bucketValues.getBucketIndex(this.changePoint));
                }
            }
            throw new RuntimeException("Unknown change type [" + this.type + "].");
        }

        @Override
        public String toString() {
            return "TestStats{type=" + this.type + ", dataStats=" + this.dataStats + ", var=" + this.var + ", rSquared=" + this.rSquared() + ", pValue=" + this.pValue + ", nParams=" + this.nParams + ", changePoint=" + this.changePoint + "}";
        }
    }

    private static class RunningStats {
        double sumOfSqrs;
        double sum;
        double count;

        static RunningStats from(double[] values, IntToDoubleFunction weightFunction) {
            return new RunningStats().addValues(values, weightFunction, 0, values.length);
        }

        RunningStats() {
        }

        double count() {
            return this.count;
        }

        double mean() {
            return this.sum / this.count;
        }

        double variance() {
            return Math.max((this.sumOfSqrs - this.sum * this.sum / this.count) / this.count, 0.0);
        }

        double std() {
            return Math.sqrt(this.variance());
        }

        RunningStats addValues(double[] value, IntToDoubleFunction weightFunction, int start, int end) {
            for (int i = start; i < value.length && i < end; ++i) {
                this.addValue(value[i], weightFunction.applyAsDouble(i));
            }
            return this;
        }

        RunningStats addValue(double value, double weight) {
            this.sumOfSqrs += value * value * weight;
            this.count += weight;
            this.sum += value * weight;
            return this;
        }

        RunningStats removeValue(double value, double weight) {
            this.sumOfSqrs = Math.max(this.sumOfSqrs - value * value * weight, 0.0);
            this.count = Math.max(this.count - weight, 0.0);
            this.sum -= value * weight;
            return this;
        }

        RunningStats removeValues(double[] value, IntToDoubleFunction weightFunction, int start, int end) {
            for (int i = start; i < value.length && i < end; ++i) {
                this.removeValue(value[i], weightFunction.applyAsDouble(i));
            }
            return this;
        }
    }

    private record DataStats(double nValues, double mean, double var, int nCandidateChangePoints) {
        boolean varianceZeroToWorkingPrecision() {
            return this.var < Math.sqrt(Math.ulp(2.0 * this.nValues * this.mean));
        }

        @Override
        public String toString() {
            return "DataStats{nValues=" + this.nValues + ", mean=" + this.mean + ", var=" + this.var + ", nCandidates=" + this.nCandidateChangePoints + "}";
        }
    }

    private static enum Type {
        STATIONARY,
        NON_STATIONARY,
        STEP_CHANGE,
        TREND_CHANGE,
        DISTRIBUTION_CHANGE;

    }

    private record SampleData(double[] values, double[] weights, Integer[] changePoints) {
    }
}

