/*
 * Decompiled with CFR 0.152.
 */
package org.elasticsearch.xpack.ml.aggs.correlation;

import java.io.IOException;
import java.util.Objects;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
import org.elasticsearch.common.xcontent.ParseField;
import org.elasticsearch.common.xcontent.ToXContent;
import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.search.aggregations.AggregationExecutionException;
import org.elasticsearch.search.aggregations.PipelineAggregationBuilder;
import org.elasticsearch.search.aggregations.pipeline.MovingFunctions;
import org.elasticsearch.xpack.ml.aggs.correlation.CorrelationFunction;
import org.elasticsearch.xpack.ml.aggs.correlation.CountCorrelationIndicator;

public class CountCorrelationFunction
implements CorrelationFunction {
    public static final ParseField NAME = new ParseField("count_correlation", new String[0]);
    public static final ParseField INDICATOR = new ParseField("indicator", new String[0]);
    private static final ConstructingObjectParser<CountCorrelationFunction, Void> PARSER = new ConstructingObjectParser("count_correlation_function", false, a -> new CountCorrelationFunction((CountCorrelationIndicator)a[0]));
    private final CountCorrelationIndicator indicator;

    public CountCorrelationFunction(CountCorrelationIndicator indicator) {
        this.indicator = indicator;
    }

    public CountCorrelationFunction(StreamInput in) throws IOException {
        this.indicator = new CountCorrelationIndicator(in);
    }

    public static CountCorrelationFunction fromXContent(XContentParser parser) {
        return (CountCorrelationFunction)PARSER.apply(parser, null);
    }

    public XContentBuilder toXContent(XContentBuilder builder, ToXContent.Params params) throws IOException {
        builder.startObject();
        builder.field(INDICATOR.getPreferredName(), (ToXContent)this.indicator);
        builder.endObject();
        return builder;
    }

    public String getWriteableName() {
        return NAME.getPreferredName();
    }

    public void writeTo(StreamOutput out) throws IOException {
        this.indicator.writeTo(out);
    }

    public String getName() {
        return NAME.getPreferredName();
    }

    public int hashCode() {
        return NAME.getPreferredName().hashCode();
    }

    public boolean equals(Object obj) {
        if (this == obj) {
            return true;
        }
        if (obj == null || this.getClass() != obj.getClass()) {
            return false;
        }
        CountCorrelationFunction other = (CountCorrelationFunction)obj;
        return Objects.equals(this.indicator, other.indicator);
    }

    @Override
    public double execute(CountCorrelationIndicator y) {
        double xVar;
        double xMean;
        if (this.indicator.getExpectations().length != y.getExpectations().length) {
            throw new AggregationExecutionException("value lengths do not match; indicator.expectations [" + this.indicator.getExpectations().length + "] and number of buckets [" + y.getExpectations().length + "]. Unable to calculate correlation");
        }
        if (this.indicator.getFractions() == null) {
            xMean = MovingFunctions.unweightedAvg((double[])this.indicator.getExpectations());
            if (Double.isNaN(xMean)) {
                return Double.NaN;
            }
            double stdDev = MovingFunctions.stdDev((double[])this.indicator.getExpectations(), (double)xMean);
            if (Double.isNaN(stdDev)) {
                return Double.NaN;
            }
            xVar = Math.pow(stdDev, 2.0);
        } else {
            double mean = 0.0;
            for (int i = 0; i < this.indicator.getExpectations().length; ++i) {
                mean += this.indicator.getExpectations()[i] * this.indicator.getFractions()[i];
            }
            if (Double.isNaN(mean)) {
                return Double.NaN;
            }
            xMean = mean;
            double var = 0.0;
            for (int i = 0; i < this.indicator.getExpectations().length; ++i) {
                var += Math.pow(this.indicator.getExpectations()[i] - xMean, 2.0) * this.indicator.getFractions()[i];
            }
            xVar = var;
        }
        double weight = MovingFunctions.sum((double[])y.getExpectations()) / (double)this.indicator.getDocCount();
        if (weight > 1.0) {
            throw new AggregationExecutionException("doc_count of indicator must be larger than the total count of the correlating values indicator count [" + this.indicator.getDocCount() + "] correlating value total count [" + MovingFunctions.sum((double[])y.getExpectations()) + "]");
        }
        double yMean = weight;
        double yVar = (1.0 - weight) * yMean * yMean + weight * (1.0 - yMean) * (1.0 - yMean);
        double xyCov = 0.0;
        if (this.indicator.getFractions() == null) {
            double fraction = 1.0 / (double)this.indicator.getExpectations().length;
            for (int i = 0; i < this.indicator.getExpectations().length; ++i) {
                double xVal = this.indicator.getExpectations()[i];
                double nX = y.getExpectations()[i];
                xyCov = xyCov - ((double)this.indicator.getDocCount() * fraction - nX) * (xVal - xMean) * yMean + nX * (xVal - xMean) * (1.0 - yMean);
            }
        } else {
            for (int i = 0; i < this.indicator.getExpectations().length; ++i) {
                double fraction = this.indicator.getFractions()[i];
                double xVal = this.indicator.getExpectations()[i];
                double nX = y.getExpectations()[i];
                xyCov = xyCov - ((double)this.indicator.getDocCount() * fraction - nX) * (xVal - xMean) * yMean + nX * (xVal - xMean) * (1.0 - yMean);
            }
        }
        return xVar * yVar == 0.0 ? Double.NaN : (xyCov /= (double)this.indicator.getDocCount()) / Math.sqrt(xVar * yVar);
    }

    @Override
    public void validate(PipelineAggregationBuilder.ValidationContext context, String bucketPath) {
        if (!bucketPath.endsWith("_count")) {
            context.addBucketPathValidationError("count correlation requires that bucket_path points to bucket [_count]");
        }
    }

    static {
        PARSER.declareObject(ConstructingObjectParser.constructorArg(), (p, c) -> CountCorrelationIndicator.fromXContent(p), INDICATOR);
    }
}

