/*
 * Decompiled with CFR 0.152.
 */
package networks.computation.iteration.actions;

import java.util.logging.Logger;
import networks.computation.evaluation.results.Result;
import networks.computation.evaluation.values.ScalarValue;
import networks.computation.evaluation.values.Value;
import networks.computation.iteration.TopDown;
import networks.computation.iteration.modes.BFS;
import networks.computation.iteration.modes.DFSrecursion;
import networks.computation.iteration.modes.DFSstack;
import networks.computation.iteration.modes.Topologic;
import networks.computation.iteration.visitors.neurons.CrossDown;
import networks.computation.iteration.visitors.neurons.NeuronVisitor;
import networks.computation.iteration.visitors.neurons.StandardNeuronVisitors;
import networks.computation.iteration.visitors.states.neurons.Backproper;
import networks.computation.iteration.visitors.weights.WeightUpdater;
import networks.computation.training.NeuralModel;
import networks.computation.training.NeuralSample;
import networks.structure.components.NeuralNetwork;
import networks.structure.components.neurons.Neurons;
import networks.structure.components.neurons.QueryNeuron;
import networks.structure.components.neurons.types.AtomNeurons;
import networks.structure.components.types.TopologicNetwork;
import networks.structure.metadata.states.State;
import settings.Settings;

public class Backpropagation {
    private static final Logger LOG = Logger.getLogger(Backpropagation.class.getName());
    private final Settings settings;
    WeightUpdater weightUpdater;
    Backproper backproper;

    public Backpropagation(Settings settings, NeuralModel model, int index) {
        this.settings = settings;
        this.backproper = Backproper.getFrom(settings, index);
        this.weightUpdater = new WeightUpdater(model.weights);
    }

    public Backpropagation(Settings settings, NeuralModel neuralModel) {
        this(settings, neuralModel, -1);
    }

    public TopDown getTopDownPropagator(NeuralNetwork<State.Structure> network, Neurons outputNeuron) {
        if (network instanceof TopologicNetwork && !network.containsInputMasking) {
            NeuronVisitor.Weighted down = new StandardNeuronVisitors.Down(network, this.backproper, this.weightUpdater);
            if (network.containsCrossProducts) {
                down = new CrossDown(down);
            }
            Topologic topologic = new Topologic((TopologicNetwork)network);
            topologic.getClass();
            return topologic.new Topologic.TDownVisitor(outputNeuron, down);
        }
        if (this.settings.iterationMode == Settings.IterationMode.DFS_RECURSIVE) {
            DFSrecursion dFSrecursion = new DFSrecursion();
            dFSrecursion.getClass();
            return dFSrecursion.new DFSrecursion.TDownVisitor(network, outputNeuron, this.backproper, this.weightUpdater);
        }
        if (this.settings.iterationMode == Settings.IterationMode.DFS_STACK) {
            DFSstack dFSstack = new DFSstack();
            dFSstack.getClass();
            return dFSstack.new DFSstack.TDownVisitor(network, outputNeuron, this.backproper, this.weightUpdater);
        }
        BFS bFS = new BFS();
        bFS.getClass();
        return bFS.new BFS.TDownVisitor(network, outputNeuron, this.backproper, this.weightUpdater);
    }

    public WeightUpdater backpropagate(NeuralSample neuralSample, Result evaluatedResult) {
        NeuralNetwork neuralNetwork = (NeuralNetwork)((QueryNeuron)neuralSample.query).evidence;
        AtomNeurons outputNeuron = ((QueryNeuron)neuralSample.query).neuron;
        Value errorGradient = evaluatedResult.errorGradient();
        errorGradient = errorGradient.times((Value)new ScalarValue(this.settings.initLearningRate));
        this.weightUpdater.clearUpdates();
        outputNeuron.getComputationView(this.backproper.stateIndex).storeGradient(errorGradient);
        TopDown topDownPropagator = this.getTopDownPropagator(neuralNetwork, outputNeuron);
        topDownPropagator.topdown();
        return this.weightUpdater;
    }
}

