/*
 * Decompiled with CFR 0.152.
 */
package org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble;

import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.Objects;
import org.apache.lucene.util.RamUsageEstimator;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.xcontent.ConstructingObjectParser;
import org.elasticsearch.xcontent.ParseField;
import org.elasticsearch.xcontent.ToXContent;
import org.elasticsearch.xcontent.XContentBuilder;
import org.elasticsearch.xcontent.XContentParser;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TargetType;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble.LenientlyParsedOutputAggregator;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble.StrictlyParsedOutputAggregator;
import org.elasticsearch.xpack.core.ml.inference.utils.Statistics;
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;

public class WeightedMode
implements StrictlyParsedOutputAggregator,
LenientlyParsedOutputAggregator {
    private static final long SHALLOW_SIZE = RamUsageEstimator.shallowSizeOfInstance(WeightedMode.class);
    public static final ParseField NAME = new ParseField("weighted_mode", new String[0]);
    public static final ParseField WEIGHTS = new ParseField("weights", new String[0]);
    public static final ParseField NUM_CLASSES = new ParseField("num_classes", new String[0]);
    private static final ConstructingObjectParser<WeightedMode, Void> LENIENT_PARSER = WeightedMode.createParser(true);
    private static final ConstructingObjectParser<WeightedMode, Void> STRICT_PARSER = WeightedMode.createParser(false);
    private final double[] weights;
    private final int numClasses;

    private static ConstructingObjectParser<WeightedMode, Void> createParser(boolean lenient) {
        ConstructingObjectParser parser = new ConstructingObjectParser(NAME.getPreferredName(), lenient, a -> new WeightedMode((Integer)a[0], (List)a[1]));
        parser.declareInt(ConstructingObjectParser.constructorArg(), NUM_CLASSES);
        parser.declareDoubleArray(ConstructingObjectParser.optionalConstructorArg(), WEIGHTS);
        return parser;
    }

    public static WeightedMode fromXContentStrict(XContentParser parser) {
        return (WeightedMode)STRICT_PARSER.apply(parser, null);
    }

    public static WeightedMode fromXContentLenient(XContentParser parser) {
        return (WeightedMode)LENIENT_PARSER.apply(parser, null);
    }

    WeightedMode(int numClasses) {
        this(numClasses, null);
    }

    private WeightedMode(Integer numClasses, List<Double> weights) {
        this(weights == null ? null : weights.stream().mapToDouble(Double::valueOf).toArray(), numClasses);
    }

    public WeightedMode(double[] weights, Integer numClasses) {
        this.weights = weights;
        this.numClasses = ExceptionsHelper.requireNonNull(numClasses, NUM_CLASSES);
        if (this.numClasses <= 1) {
            throw new IllegalArgumentException("[" + NUM_CLASSES.getPreferredName() + "] must be greater than 1.");
        }
    }

    public WeightedMode(StreamInput in) throws IOException {
        this.weights = (double[])(in.readBoolean() ? in.readDoubleArray() : null);
        this.numClasses = in.readVInt();
    }

    @Override
    public Integer expectedValueSize() {
        return this.weights == null ? null : Integer.valueOf(this.weights.length);
    }

    @Override
    public double[] processValues(double[][] values) {
        Objects.requireNonNull(values, "values must not be null");
        if (this.weights != null && values.length != this.weights.length) {
            throw new IllegalArgumentException("values must be the same length as weights.");
        }
        if (values[0].length > 1) {
            double[] sumOnAxis1 = new double[values[0].length];
            for (int j = 0; j < values.length; ++j) {
                double[] value = values[j];
                double weight = this.weights == null ? 1.0 : this.weights[j];
                for (int i = 0; i < value.length; ++i) {
                    if (i >= sumOnAxis1.length) {
                        throw new IllegalArgumentException("value entries must have the same dimensions");
                    }
                    int n = i;
                    sumOnAxis1[n] = sumOnAxis1[n] + value[i] * weight;
                }
            }
            return Statistics.softMax(sumOnAxis1);
        }
        ArrayList<Integer> freqArray = new ArrayList<Integer>();
        int maxVal = 0;
        for (double[] value : values) {
            if (value.length != 1) {
                throw new IllegalArgumentException("value entries must have the same dimensions");
            }
            if (Double.isNaN(value[0]) || Double.isInfinite(value[0]) || value[0] < 0.0 || value[0] != Math.rint(value[0])) {
                throw new IllegalArgumentException("values must be whole, non-infinite, and positive");
            }
            int integerValue = Double.valueOf(value[0]).intValue();
            freqArray.add(integerValue);
            if (integerValue <= maxVal) continue;
            maxVal = integerValue;
        }
        if (maxVal >= this.numClasses) {
            throw new IllegalArgumentException("values contain entries larger than expected max of [" + (this.numClasses - 1) + "]");
        }
        double[] frequencies = Collections.nCopies(this.numClasses, Double.NEGATIVE_INFINITY).stream().mapToDouble(Double::doubleValue).toArray();
        for (int i = 0; i < freqArray.size(); ++i) {
            double frequency;
            double weight = this.weights == null ? 1.0 : this.weights[i];
            int value = (Integer)freqArray.get(i);
            frequencies[value] = frequency = frequencies[value] == Double.NEGATIVE_INFINITY ? weight : frequencies[value] + weight;
        }
        return Statistics.softMax(frequencies);
    }

    @Override
    public double aggregate(double[] values) {
        Objects.requireNonNull(values, "values must not be null");
        int bestValue = 0;
        double bestFreq = Double.NEGATIVE_INFINITY;
        for (int i = 0; i < values.length; ++i) {
            if (!(values[i] > bestFreq)) continue;
            bestFreq = values[i];
            bestValue = i;
        }
        return bestValue;
    }

    @Override
    public String getName() {
        return NAME.getPreferredName();
    }

    @Override
    public boolean compatibleWith(TargetType targetType) {
        return targetType.equals((Object)TargetType.CLASSIFICATION);
    }

    public String getWriteableName() {
        return NAME.getPreferredName();
    }

    public void writeTo(StreamOutput out) throws IOException {
        out.writeBoolean(this.weights != null);
        if (this.weights != null) {
            out.writeDoubleArray(this.weights);
        }
        out.writeVInt(this.numClasses);
    }

    public XContentBuilder toXContent(XContentBuilder builder, ToXContent.Params params) throws IOException {
        builder.startObject();
        if (this.weights != null) {
            builder.field(WEIGHTS.getPreferredName(), (Object)this.weights);
        }
        builder.field(NUM_CLASSES.getPreferredName(), this.numClasses);
        builder.endObject();
        return builder;
    }

    public boolean equals(Object o) {
        if (this == o) {
            return true;
        }
        if (o == null || this.getClass() != o.getClass()) {
            return false;
        }
        WeightedMode that = (WeightedMode)o;
        return Arrays.equals(this.weights, that.weights) && this.numClasses == that.numClasses;
    }

    public int hashCode() {
        return Objects.hash(Arrays.hashCode(this.weights), this.numClasses);
    }

    public long ramBytesUsed() {
        long weightSize = this.weights == null ? 0L : RamUsageEstimator.sizeOf((double[])this.weights);
        return SHALLOW_SIZE + weightSize;
    }
}

