/*
 * Decompiled with CFR 0.152.
 */
package networks.computation.training.strategies.trainers;

import java.util.logging.Logger;
import networks.computation.evaluation.results.Result;
import networks.computation.iteration.actions.Backpropagation;
import networks.computation.iteration.actions.Evaluation;
import networks.computation.iteration.actions.IndependentNeuronProcessing;
import networks.computation.iteration.visitors.weights.WeightUpdater;
import networks.computation.training.NeuralModel;
import networks.computation.training.NeuralSample;
import networks.computation.training.optimizers.Optimizer;
import networks.structure.building.debugging.NeuralDebugger;
import networks.structure.components.NeuralNetwork;
import networks.structure.components.neurons.QueryNeuron;
import settings.Settings;

public class Trainer {
    private static final Logger LOG = Logger.getLogger(Trainer.class.getName());
    protected Settings settings;
    int index;
    Optimizer optimizer;
    NeuralDebugger neuralDebugger;

    public Trainer(Settings settings, Optimizer optimizer) {
        this.settings = settings;
        this.optimizer = optimizer;
        this.neuralDebugger = new NeuralDebugger(settings);
    }

    public Trainer() {
    }

    protected Result learnFromSample(NeuralModel neuralModel, NeuralSample neuralSample, IndependentNeuronProcessing dropouter, IndependentNeuronProcessing invalidation, Evaluation evaluation, Backpropagation backpropagation) {
        if (this.settings.dropoutMode == Settings.DropoutMode.DROPOUT && this.settings.dropoutRate > 0.0) {
            this.dropoutSample(dropouter, neuralSample);
        }
        this.invalidateSample(invalidation, neuralSample);
        Result result = this.evaluateSample(evaluation, neuralSample);
        WeightUpdater weightUpdater = this.backpropSample(backpropagation, result, neuralSample);
        this.updateWeights(neuralModel, weightUpdater);
        if (this.settings.debugSampleTraining) {
            this.neuralDebugger.debug(neuralSample);
        }
        return result;
    }

    void dropoutSample(IndependentNeuronProcessing dropouter, NeuralSample neuralSample) {
        dropouter.process((NeuralNetwork)((QueryNeuron)neuralSample.query).evidence, ((QueryNeuron)neuralSample.query).neuron);
    }

    public void invalidateSample(IndependentNeuronProcessing invalidation, NeuralSample neuralSample) {
        ((NeuralNetwork)((QueryNeuron)neuralSample.query).evidence).initializeStatesCache(this.index);
        invalidation.process((NeuralNetwork)((QueryNeuron)neuralSample.query).evidence, ((QueryNeuron)neuralSample.query).neuron);
    }

    Result evaluateSample(Evaluation evaluation, NeuralSample neuralSample) {
        return evaluation.evaluate(neuralSample);
    }

    WeightUpdater backpropSample(Backpropagation backpropagation, Result evaluatedResult, NeuralSample neuralSample) {
        return backpropagation.backpropagate(neuralSample, evaluatedResult);
    }

    synchronized void updateWeights(NeuralModel model, WeightUpdater weightUpdater) {
        this.optimizer.performGradientStep(model, weightUpdater);
    }

    public void restart() {
        this.optimizer.restart(this.settings);
    }
}

