/*
 * Decompiled with CFR 0.152.
 */
package org.elasticsearch.xpack.core.ml.dataframe.evaluation.common;

import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.Comparator;
import java.util.List;
import java.util.Objects;
import org.elasticsearch.common.Strings;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.search.aggregations.metrics.Percentiles;
import org.elasticsearch.xcontent.ParseField;
import org.elasticsearch.xcontent.ToXContent;
import org.elasticsearch.xcontent.ToXContentObject;
import org.elasticsearch.xcontent.XContentBuilder;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetric;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetricResult;
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;

public abstract class AbstractAucRoc
implements EvaluationMetric {
    public static final ParseField NAME = new ParseField("auc_roc", new String[0]);

    protected AbstractAucRoc() {
    }

    @Override
    public String getName() {
        return NAME.getPreferredName();
    }

    protected static double[] percentilesArray(Percentiles percentiles) {
        double[] result = new double[99];
        percentiles.forEach(percentile -> {
            if (Double.isNaN(percentile.getValue())) {
                throw ExceptionsHelper.badRequestException("[{}] requires at all the percentiles values to be finite numbers", NAME.getPreferredName());
            }
            result[(int)percentile.getPercent() - 1] = percentile.getValue();
        });
        return result;
    }

    protected static List<AucRocPoint> buildAucRocCurve(double[] tpPercentiles, double[] fpPercentiles) {
        assert (tpPercentiles.length == fpPercentiles.length);
        assert (tpPercentiles.length == 99);
        List<AucRocPoint> points = new ArrayList<AucRocPoint>(tpPercentiles.length + fpPercentiles.length);
        RateThresholdCurve tpCurve = new RateThresholdCurve(tpPercentiles, true);
        RateThresholdCurve fpCurve = new RateThresholdCurve(fpPercentiles, false);
        points.addAll(tpCurve.scanPoints(fpCurve));
        points.addAll(fpCurve.scanPoints(tpCurve));
        Collections.sort(points);
        points = AbstractAucRoc.collapseEqualThresholdPoints(points);
        ArrayList<AucRocPoint> aucRocCurve = new ArrayList<AucRocPoint>(points.size() + 2);
        aucRocCurve.add(new AucRocPoint(0.0, 0.0, 1.0));
        aucRocCurve.addAll(points);
        aucRocCurve.add(new AucRocPoint(1.0, 1.0, 0.0));
        return aucRocCurve;
    }

    static List<AucRocPoint> collapseEqualThresholdPoints(List<AucRocPoint> points) {
        ArrayList<AucRocPoint> collapsed = new ArrayList<AucRocPoint>();
        ArrayList<AucRocPoint> equalThresholdPoints = new ArrayList<AucRocPoint>();
        for (AucRocPoint point : points) {
            if (!equalThresholdPoints.isEmpty() && ((AucRocPoint)equalThresholdPoints.get((int)0)).threshold != point.threshold) {
                collapsed.add(AbstractAucRoc.calculateAveragePoint(equalThresholdPoints));
                equalThresholdPoints = new ArrayList();
            }
            equalThresholdPoints.add(point);
        }
        if (!equalThresholdPoints.isEmpty()) {
            collapsed.add(AbstractAucRoc.calculateAveragePoint(equalThresholdPoints));
        }
        return collapsed;
    }

    private static AucRocPoint calculateAveragePoint(List<AucRocPoint> points) {
        if (points.isEmpty()) {
            throw new IllegalArgumentException("points must not be empty");
        }
        if (points.size() == 1) {
            return points.get(0);
        }
        double avgTpr = 0.0;
        double avgFpr = 0.0;
        double avgThreshold = 0.0;
        for (AucRocPoint sameThresholdPoint : points) {
            avgTpr += sameThresholdPoint.tpr;
            avgFpr += sameThresholdPoint.fpr;
            avgThreshold += sameThresholdPoint.threshold;
        }
        int n = points.size();
        return new AucRocPoint(avgTpr / (double)n, avgFpr / (double)n, avgThreshold / (double)n);
    }

    protected static double calculateAucScore(List<AucRocPoint> rocCurve) {
        double aucRoc = 0.0;
        for (int i = 1; i < rocCurve.size(); ++i) {
            AucRocPoint left = rocCurve.get(i - 1);
            AucRocPoint right = rocCurve.get(i);
            aucRoc += (right.fpr - left.fpr) * (right.tpr + left.tpr) / 2.0;
        }
        return aucRoc;
    }

    private static double interpolate(double x, double x1, double y1, double x2, double y2) {
        return y1 + (x - x1) * (y2 - y1) / (x2 - x1);
    }

    private static class RateThresholdCurve {
        private final double[] percentiles;
        private final boolean isTp;

        private RateThresholdCurve(double[] percentiles, boolean isTp) {
            this.percentiles = percentiles;
            this.isTp = isTp;
        }

        private double getRate(int index) {
            return 1.0 - 0.01 * (double)(index + 1);
        }

        private double getThreshold(int index) {
            return Math.max(0.0, this.percentiles[index] - Math.ulp(this.percentiles[index]));
        }

        private double interpolateRate(double threshold) {
            int binarySearchResult = Arrays.binarySearch(this.percentiles, threshold);
            if (binarySearchResult >= 0) {
                return this.getRate(binarySearchResult);
            }
            int right = binarySearchResult * -1 - 1;
            int left = right - 1;
            if (right >= this.percentiles.length) {
                return 0.0;
            }
            if (left < 0) {
                return 1.0;
            }
            double rightRate = this.getRate(right);
            double leftRate = this.getRate(left);
            return AbstractAucRoc.interpolate(threshold, this.percentiles[left], leftRate, this.percentiles[right], rightRate);
        }

        private List<AucRocPoint> scanPoints(RateThresholdCurve againstCurve) {
            ArrayList<AucRocPoint> points = new ArrayList<AucRocPoint>();
            for (int index = 0; index < this.percentiles.length; ++index) {
                double rate = this.getRate(index);
                double scannedThreshold = this.getThreshold(index);
                double againstRate = againstCurve.interpolateRate(scannedThreshold);
                AucRocPoint point = this.isTp ? new AucRocPoint(rate, againstRate, scannedThreshold) : new AucRocPoint(againstRate, rate, scannedThreshold);
                points.add(point);
            }
            return points;
        }
    }

    public static final class AucRocPoint
    implements Comparable<AucRocPoint>,
    ToXContentObject,
    Writeable {
        private static final String TPR = "tpr";
        private static final String FPR = "fpr";
        private static final String THRESHOLD = "threshold";
        final double tpr;
        final double fpr;
        final double threshold;

        public AucRocPoint(double tpr, double fpr, double threshold) {
            this.tpr = tpr;
            this.fpr = fpr;
            this.threshold = threshold;
        }

        private AucRocPoint(StreamInput in) throws IOException {
            this.tpr = in.readDouble();
            this.fpr = in.readDouble();
            this.threshold = in.readDouble();
        }

        @Override
        public int compareTo(AucRocPoint o) {
            return Comparator.comparingDouble(p -> p.threshold).reversed().thenComparing(p -> p.fpr).thenComparing(p -> p.tpr).compare(this, o);
        }

        public void writeTo(StreamOutput out) throws IOException {
            out.writeDouble(this.tpr);
            out.writeDouble(this.fpr);
            out.writeDouble(this.threshold);
        }

        public XContentBuilder toXContent(XContentBuilder builder, ToXContent.Params params) throws IOException {
            builder.startObject();
            builder.field(TPR, this.tpr);
            builder.field(FPR, this.fpr);
            builder.field(THRESHOLD, this.threshold);
            builder.endObject();
            return builder;
        }

        public boolean equals(Object o) {
            if (this == o) {
                return true;
            }
            if (o == null || this.getClass() != o.getClass()) {
                return false;
            }
            AucRocPoint that = (AucRocPoint)o;
            return this.tpr == that.tpr && this.fpr == that.fpr && this.threshold == that.threshold;
        }

        public int hashCode() {
            return Objects.hash(this.tpr, this.fpr, this.threshold);
        }

        public String toString() {
            return Strings.toString((ToXContent)this);
        }
    }

    public static class Result
    implements EvaluationMetricResult {
        public static final String NAME = "auc_roc_result";
        private static final String VALUE = "value";
        private static final String CURVE = "curve";
        private final double value;
        private final List<AucRocPoint> curve;

        public Result(double value, List<AucRocPoint> curve) {
            this.value = value;
            this.curve = Objects.requireNonNull(curve);
        }

        public Result(StreamInput in) throws IOException {
            this.value = in.readDouble();
            this.curve = in.readList(AucRocPoint::new);
        }

        public double getValue() {
            return this.value;
        }

        public List<AucRocPoint> getCurve() {
            return Collections.unmodifiableList(this.curve);
        }

        public String getWriteableName() {
            return NAME;
        }

        @Override
        public String getMetricName() {
            return NAME.getPreferredName();
        }

        public void writeTo(StreamOutput out) throws IOException {
            out.writeDouble(this.value);
            out.writeList(this.curve);
        }

        public XContentBuilder toXContent(XContentBuilder builder, ToXContent.Params params) throws IOException {
            builder.startObject();
            builder.field(VALUE, this.value);
            if (!this.curve.isEmpty()) {
                builder.field(CURVE, this.curve);
            }
            builder.endObject();
            return builder;
        }

        public boolean equals(Object o) {
            if (this == o) {
                return true;
            }
            if (o == null || this.getClass() != o.getClass()) {
                return false;
            }
            Result that = (Result)o;
            return this.value == that.value && Objects.equals(this.curve, that.curve);
        }

        public int hashCode() {
            return Objects.hash(this.value, this.curve);
        }
    }
}

