/*
 * Decompiled with CFR 0.152.
 */
package org.ojalgo.ann;

import java.util.Arrays;
import java.util.Iterator;
import org.ojalgo.ann.ArtificialNeuralNetwork;
import org.ojalgo.ann.TrainingConfiguration;
import org.ojalgo.ann.WrappedANN;
import org.ojalgo.data.DataBatch;
import org.ojalgo.matrix.store.MatrixStore;
import org.ojalgo.matrix.store.PhysicalStore;
import org.ojalgo.structure.Access1D;
import org.ojalgo.structure.Structure2D;

public final class NetworkTrainer
extends WrappedANN {
    private final TrainingConfiguration myConfiguration = new TrainingConfiguration();
    private final PhysicalStore<Double>[] myGradients;

    NetworkTrainer(ArtificialNeuralNetwork network, int batchSize) {
        super(network, batchSize);
        int depth = network.depth();
        this.myGradients = new PhysicalStore[depth];
        for (int l = 0; l < depth; ++l) {
            this.myGradients[l] = network.newStore(network.countOutputNodes(l), batchSize);
        }
    }

    @Deprecated
    public NetworkTrainer activator(int layer, ArtificialNeuralNetwork.Activator activator) {
        this.setActivator(layer, activator);
        return this;
    }

    @Deprecated
    public NetworkTrainer activators(ArtificialNeuralNetwork.Activator activator) {
        int limit = this.depth();
        for (int i = 0; i < limit; ++i) {
            this.activator(i, activator);
        }
        return this;
    }

    @Deprecated
    public NetworkTrainer activators(ArtificialNeuralNetwork.Activator ... activators) {
        int limit = activators.length;
        for (int i = 0; i < limit; ++i) {
            this.activator(i, activators[i]);
        }
        return this;
    }

    public NetworkTrainer bias(int layer, int output, double bias) {
        this.setBias(layer, output, bias);
        return this;
    }

    public NetworkTrainer dropouts() {
        this.myConfiguration.dropouts = true;
        return this;
    }

    @Override
    public boolean equals(Object obj) {
        if (this == obj) {
            return true;
        }
        if (!super.equals(obj) || !(obj instanceof NetworkTrainer)) {
            return false;
        }
        NetworkTrainer other = (NetworkTrainer)obj;
        return this.myConfiguration.equals(other.myConfiguration) && Arrays.equals(this.myGradients, other.myGradients);
    }

    public NetworkTrainer error(ArtificialNeuralNetwork.Error error) {
        if (this.getOutputActivator() == ArtificialNeuralNetwork.Activator.SOFTMAX ? error != ArtificialNeuralNetwork.Error.CROSS_ENTROPY : error != ArtificialNeuralNetwork.Error.HALF_SQUARED_DIFFERENCE) {
            throw new IllegalArgumentException();
        }
        this.myConfiguration.error = error;
        return this;
    }

    @Override
    public int hashCode() {
        int prime = 31;
        int result = super.hashCode();
        result = 31 * result + this.myConfiguration.hashCode();
        result = 31 * result + Arrays.hashCode(this.myGradients);
        return result;
    }

    public NetworkTrainer lasso(double factor) {
        this.myConfiguration.regularisationL1 = true;
        this.myConfiguration.regularisationL1Factor = factor;
        return this;
    }

    @Override
    public DataBatch newOutputBatch() {
        return super.newOutputBatch();
    }

    public NetworkTrainer rate(double rate) {
        this.myConfiguration.learningRate = rate;
        return this;
    }

    public NetworkTrainer ridge(double factor) {
        this.myConfiguration.regularisationL2 = true;
        this.myConfiguration.regularisationL2Factor = factor;
        return this;
    }

    @Override
    public Structure2D[] structure() {
        return super.structure();
    }

    public String toString() {
        StringBuilder builder = new StringBuilder();
        builder.append("NetworkBuilder [structure()=");
        builder.append(Arrays.toString(this.structure()));
        builder.append(", Error=");
        builder.append(this.myConfiguration.error);
        builder.append(", LearningRate=");
        builder.append(this.myConfiguration.learningRate);
        builder.append("]");
        return builder.toString();
    }

    public void train(Access1D<Double> givenInput, Access1D<Double> targetOutput) {
        MatrixStore<Double> current = this.invoke(givenInput, this.myConfiguration);
        this.myGradients[this.myGradients.length - 1].regionByTransposing().fillMatching(targetOutput, this.myConfiguration.error.getDerivative(), current);
        for (int l = this.depth() - 1; l >= 0; --l) {
            PhysicalStore<Double> input = this.getInput(l);
            PhysicalStore<Double> output = this.getOutput(l);
            PhysicalStore<Double> upstreamGradient = l == 0 ? null : this.myGradients[l - 1];
            PhysicalStore<Double> downstreamGradient = this.myGradients[l];
            this.adjust(l, input, output, upstreamGradient, downstreamGradient);
        }
    }

    @Deprecated
    public void train(Iterable<? extends Access1D<Double>> givenInputs, Iterable<? extends Access1D<Double>> targetOutputs) {
        Iterator<? extends Access1D<Double>> iterI = givenInputs.iterator();
        Iterator<? extends Access1D<Double>> iterO = targetOutputs.iterator();
        while (iterI.hasNext() && iterO.hasNext()) {
            this.train(iterI.next(), iterO.next());
        }
    }

    public NetworkTrainer weight(int layer, int input, int output, double weight) {
        this.setWeight(layer, input, output, weight);
        return this;
    }

    double error(Access1D<?> target, Access1D<?> current) {
        return this.myConfiguration.error.invoke(target, current);
    }
}

