/*
 * Decompiled with CFR 0.152.
 */
package networks.computation.iteration.visitors.neurons;

import java.util.ArrayList;
import java.util.Iterator;
import networks.computation.evaluation.functions.CrossProduct;
import networks.computation.evaluation.values.Value;
import networks.computation.iteration.visitors.neurons.NeuronVisitor;
import networks.computation.iteration.visitors.neurons.StandardNeuronVisitors;
import networks.computation.iteration.visitors.states.StateVisiting;
import networks.computation.iteration.visitors.weights.WeightUpdater;
import networks.structure.components.NeuralNetwork;
import networks.structure.components.neurons.BaseNeuron;
import networks.structure.components.neurons.Neurons;
import networks.structure.components.neurons.WeightedNeuron;
import networks.structure.components.weights.Weight;
import networks.structure.metadata.states.AggregationState;
import networks.structure.metadata.states.State;
import utils.generic.Pair;

public class CrossDown
extends NeuronVisitor.Weighted {
    NeuronVisitor.Weighted classicDown;

    public CrossDown(NeuralNetwork<State.Structure> network, StateVisiting.Computation computationVisitor, WeightUpdater weightUpdater) {
        super(network, computationVisitor, weightUpdater);
        this.classicDown = new StandardNeuronVisitors.Down(network, computationVisitor, weightUpdater);
    }

    public CrossDown(NeuronVisitor.Weighted down) {
        super(down.network, down.stateVisitor, down.weightUpdater);
        this.classicDown = down;
    }

    @Override
    public <T extends Neurons, S extends State.Neural> void visit(BaseNeuron<T, S> neuron) {
        if (neuron.getAggregation() instanceof CrossProduct) {
            State.Neural.Computation state = neuron.getComputationView(this.stateVisitor.stateIndex);
            Value gradient = this.stateVisitor.visit(state);
            int[][] mapping = ((AggregationState.CrossProducState)state.getAggregationState()).mapping;
            Iterator<T> inputs = this.network.getInputs(neuron);
            ArrayList<Value> inputGradients = new ArrayList<Value>();
            while (inputs.hasNext()) {
                Neurons next = (Neurons)inputs.next();
                inputGradients.add(next.getComputationView(this.stateVisitor.stateIndex).getGradient());
            }
            for (int i = 0; i < mapping.length; ++i) {
                double grad = gradient.get(i);
                int[] map = mapping[i];
                for (int j = 0; j < inputGradients.size(); ++j) {
                    Value inGrad = (Value)inputGradients.get(j);
                    if (inGrad == null) continue;
                    inGrad.increment(map[j], grad);
                }
            }
        } else {
            ((NeuronVisitor)this.classicDown).visit(neuron);
        }
    }

    @Override
    public <T extends Neurons, S extends State.Neural> void visit(WeightedNeuron<T, S> neuron) {
        if (neuron.getAggregation() instanceof CrossProduct) {
            State.Neural.Computation state = neuron.getComputationView(this.stateVisitor.stateIndex);
            Value gradient = this.stateVisitor.visit(state);
            int[][] mapping = ((AggregationState.CrossProducState)state.getAggregationState()).mapping;
            Pair<Iterator<T>, Iterator<Weight>> inputs = this.network.getInputs(neuron);
            this.weightUpdater.visit(neuron.offset, gradient);
            Iterator inputNeurons = (Iterator)inputs.r;
            Iterator inWeights = (Iterator)inputs.s;
            ArrayList<Value> inputGradients = new ArrayList<Value>();
            ArrayList<Value> inputOutputValues = new ArrayList<Value>();
            ArrayList inputWeights = new ArrayList();
            while (inputNeurons.hasNext()) {
                Neurons next = (Neurons)inputNeurons.next();
                inputWeights.add(inWeights.next());
                inputGradients.add(next.getComputationView(this.stateVisitor.stateIndex).getGradient());
                inputOutputValues.add(next.getComputationView(this.stateVisitor.stateIndex).getValue());
            }
            for (int i = 0; i < mapping.length; ++i) {
                double grad = gradient.get(i);
                int[] map = mapping[i];
                for (int j = 0; j < inputGradients.size(); ++j) {
                    Value inGrad;
                    Weight weight = (Weight)inputWeights.get(j);
                    if (weight.isLearnable.booleanValue()) {
                        weight.value.increment(map[j], grad * ((Value)inputOutputValues.get(j)).get(map[j]));
                    }
                    if ((inGrad = (Value)inputGradients.get(j)) == null) continue;
                    inGrad.increment(map[j], grad * weight.value.get(0));
                }
            }
        } else {
            this.classicDown.visit(neuron);
        }
    }
}

