/*
 * Decompiled with CFR 0.152.
 */
package networks.computation.iteration.visitors.neurons;

import java.util.Iterator;
import java.util.logging.Logger;
import networks.computation.evaluation.values.Value;
import networks.computation.iteration.visitors.neurons.NeuronVisitor;
import networks.computation.iteration.visitors.states.StateVisiting;
import networks.computation.iteration.visitors.weights.WeightUpdater;
import networks.structure.components.NeuralNetwork;
import networks.structure.components.neurons.BaseNeuron;
import networks.structure.components.neurons.Neurons;
import networks.structure.components.neurons.WeightedNeuron;
import networks.structure.components.weights.Weight;
import networks.structure.metadata.states.State;
import utils.generic.Pair;

public class StandardNeuronVisitors {
    private static final Logger LOG = Logger.getLogger(StandardNeuronVisitors.class.getName());

    public static class Down
    extends NeuronVisitor.Weighted {
        public Down(NeuralNetwork<State.Structure> network, StateVisiting.Computation topDown, WeightUpdater weightUpdater) {
            super(network, topDown, weightUpdater);
        }

        @Override
        public <T extends Neurons, S extends State.Neural> void visit(BaseNeuron<T, S> neuron) {
            State.Neural.Computation state = neuron.getComputationView(this.stateVisitor.stateIndex);
            Value gradient = this.stateVisitor.visit(state);
            Iterator<T> inputs = this.network.getInputs(neuron);
            while (inputs.hasNext()) {
                Neurons input = (Neurons)inputs.next();
                input.getComputationView(this.stateVisitor.stateIndex).storeGradient(gradient);
            }
        }

        @Override
        public <T extends Neurons, S extends State.Neural> void visit(WeightedNeuron<T, S> neuron) {
            State.Neural.Computation state = neuron.getComputationView(this.stateVisitor.stateIndex);
            Value gradient = this.stateVisitor.visit(state);
            Pair<Iterator<T>, Iterator<Weight>> inputs = this.network.getInputs(neuron);
            this.weightUpdater.visit(neuron.offset, gradient);
            Iterator inputNeurons = (Iterator)inputs.r;
            Iterator inputWeights = (Iterator)inputs.s;
            Value transpGradient = gradient.transposedView();
            while (inputNeurons.hasNext()) {
                Neurons input = (Neurons)inputNeurons.next();
                Weight weight = (Weight)inputWeights.next();
                State.Neural.Computation inputComputationView = input.getComputationView(this.stateVisitor.stateIndex);
                Value inputValue = inputComputationView.getValue().transposedView();
                this.weightUpdater.visit(weight, gradient.times(inputValue));
                inputComputationView.storeGradient(transpGradient.times(weight.value));
            }
        }
    }

    public static class Up
    extends NeuronVisitor.Weighted {
        public Up(NeuralNetwork<State.Structure> network, StateVisiting.Computation computationVisitor) {
            super(network, computationVisitor, null);
        }

        @Override
        public <T extends Neurons, S extends State.Neural> void visit(BaseNeuron<T, S> neuron) {
            State.Neural.Computation state = neuron.getComputationView(this.stateVisitor.stateIndex);
            Iterator<T> inputs = this.network.getInputs(neuron);
            while (inputs.hasNext()) {
                Neurons input = (Neurons)inputs.next();
                state.storeValue(input.getComputationView(this.stateVisitor.stateIndex).getValue());
            }
            Value value = this.stateVisitor.visit(state);
        }

        @Override
        public <T extends Neurons, S extends State.Neural> void visit(WeightedNeuron<T, S> neuron) {
            State.Neural.Computation state = neuron.getComputationView(this.stateVisitor.stateIndex);
            Pair<Iterator<T>, Iterator<Weight>> inputs = this.network.getInputs(neuron);
            Iterator inputNeurons = (Iterator)inputs.r;
            Iterator inputWeights = (Iterator)inputs.s;
            state.storeValue(neuron.offset.value);
            while (inputNeurons.hasNext()) {
                Neurons input = (Neurons)inputNeurons.next();
                Weight weight = (Weight)inputWeights.next();
                state.storeValue(weight.value.times(input.getComputationView(this.stateVisitor.stateIndex).getValue()));
            }
            Value value = this.stateVisitor.visit(state);
        }
    }
}

