/*
 * Decompiled with CFR 0.152.
 */
package org.ojalgo.ann;

import java.util.List;
import java.util.function.Supplier;
import org.ojalgo.ann.ArtificialNeuralNetwork;
import org.ojalgo.ann.TrainingConfiguration;
import org.ojalgo.data.DataBatch;
import org.ojalgo.matrix.store.MatrixStore;
import org.ojalgo.matrix.store.PhysicalStore;
import org.ojalgo.structure.Access1D;
import org.ojalgo.structure.Structure2D;

abstract class WrappedANN
implements Supplier<ArtificialNeuralNetwork> {
    private final int myBatchSize;
    private PhysicalStore<Double> myInput;
    private final ArtificialNeuralNetwork myNetwork;
    private final PhysicalStore<Double>[] myOutputs;

    WrappedANN(ArtificialNeuralNetwork network, int batchSize) {
        this.myNetwork = network;
        this.myBatchSize = batchSize;
        this.myOutputs = new PhysicalStore[network.depth()];
        for (int l = 0; l < this.myOutputs.length; ++l) {
            this.myOutputs[l] = network.newStore(batchSize, network.countOutputNodes(l));
        }
    }

    public boolean equals(Object obj) {
        if (this == obj) {
            return true;
        }
        if (!(obj instanceof WrappedANN)) {
            return false;
        }
        WrappedANN other = (WrappedANN)obj;
        return !(this.myNetwork == null ? other.myNetwork != null : !this.myNetwork.equals(other.myNetwork));
    }

    @Override
    public ArtificialNeuralNetwork get() {
        return this.myNetwork;
    }

    public int hashCode() {
        int prime = 31;
        int result = 1;
        result = 31 * result + (this.myNetwork == null ? 0 : this.myNetwork.hashCode());
        return result;
    }

    public DataBatch newInputBatch() {
        return this.myNetwork.newBatch(this.myBatchSize, this.myNetwork.countInputNodes());
    }

    private void setInput(Access1D<Double> input) {
        if (input instanceof PhysicalStore && ((PhysicalStore)input).getRowDim() == this.myBatchSize) {
            this.myInput = (PhysicalStore)input;
        } else {
            if (this.myInput == null || this.myInput.getRowDim() != this.myBatchSize) {
                this.myInput = this.myNetwork.newStore(this.myBatchSize, this.myNetwork.countInputNodes());
            }
            this.myInput.fillMatching(input);
        }
    }

    void adjust(int layer, PhysicalStore<Double> input, PhysicalStore<Double> output, PhysicalStore<Double> upstreamGradient, PhysicalStore<Double> downstreamGradient) {
        this.myNetwork.adjust(layer, input, output, upstreamGradient, downstreamGradient);
    }

    int depth() {
        return this.myNetwork.depth();
    }

    ArtificialNeuralNetwork.Activator getActivator(int layer) {
        return this.myNetwork.getActivator(layer);
    }

    int getBatchSize() {
        return this.myBatchSize;
    }

    double getBias(int layer, int output) {
        return this.myNetwork.getBias(layer, output);
    }

    PhysicalStore<Double> getInput() {
        return this.myInput;
    }

    PhysicalStore<Double> getInput(int layer) {
        return layer <= 0 ? this.myInput : this.myOutputs[layer - 1];
    }

    PhysicalStore<Double> getOutput() {
        return this.myOutputs[this.myOutputs.length - 1];
    }

    PhysicalStore<Double> getOutput(int layer) {
        return this.myOutputs[layer];
    }

    ArtificialNeuralNetwork.Activator getOutputActivator() {
        return this.myNetwork.getOutputActivator();
    }

    double getWeight(int layer, int input, int output) {
        return this.myNetwork.getWeight(layer, input, output);
    }

    List<MatrixStore<Double>> getWeights() {
        return this.myNetwork.getWeights();
    }

    MatrixStore<Double> invoke(Access1D<Double> input, TrainingConfiguration configuration) {
        this.setInput(input);
        this.myNetwork.setConfiguration(configuration);
        PhysicalStore<Double> retVal = this.myInput;
        int limit = this.depth();
        for (int l = 0; l < limit; ++l) {
            retVal = this.myNetwork.invoke(l, retVal, this.myOutputs[l]);
        }
        return retVal;
    }

    DataBatch newOutputBatch() {
        return this.myNetwork.newBatch(this.myBatchSize, this.myNetwork.countOutputNodes());
    }

    void randomise() {
        this.myNetwork.randomise();
    }

    void setActivator(int layer, ArtificialNeuralNetwork.Activator activator) {
        this.myNetwork.setActivator(layer, activator);
    }

    void setBias(int layer, int output, double bias) {
        this.myNetwork.setBias(layer, output, bias);
    }

    void setWeight(int layer, int input, int output, double weight) {
        this.myNetwork.setWeight(layer, input, output, weight);
    }

    Structure2D[] structure() {
        return this.myNetwork.structure();
    }
}

