/*
 * Decompiled with CFR 0.152.
 */
package networks.computation.training.strategies;

import java.util.Collections;
import java.util.List;
import java.util.logging.Level;
import java.util.logging.Logger;
import learning.crossvalidation.splitting.Splitter;
import networks.computation.evaluation.results.ClassificationResults;
import networks.computation.evaluation.results.Progress;
import networks.computation.evaluation.results.Result;
import networks.computation.evaluation.results.Results;
import networks.computation.evaluation.values.ScalarValue;
import networks.computation.evaluation.values.Value;
import networks.computation.evaluation.values.distributions.ValueInitializer;
import networks.computation.iteration.actions.Accumulating;
import networks.computation.iteration.visitors.states.neurons.SaturationChecker;
import networks.computation.training.NeuralModel;
import networks.computation.training.NeuralSample;
import networks.computation.training.debugging.TrainingDebugger;
import networks.computation.training.optimizers.Optimizer;
import networks.computation.training.strategies.Hyperparameters.LearnRateDecayStrategy;
import networks.computation.training.strategies.Hyperparameters.RestartingStrategy;
import networks.computation.training.strategies.TrainingStrategy;
import networks.computation.training.strategies.trainers.AsyncParallelTrainer;
import networks.computation.training.strategies.trainers.ListTrainer;
import networks.computation.training.strategies.trainers.MiniBatchTrainer;
import networks.computation.training.strategies.trainers.SequentialTrainer;
import settings.Settings;
import utils.Utilities;
import utils.generic.Pair;

public class IterativeTrainingStrategy
extends TrainingStrategy {
    private static final Logger LOG = Logger.getLogger(IterativeTrainingStrategy.class.getName());
    transient NeuralModel bestModel;
    transient List<NeuralSample> trainingSet;
    transient List<NeuralSample> validationSet;
    transient Progress progress;
    RestartingStrategy restartingStrategy;
    LearnRateDecayStrategy learnRateDecayStrategy;
    ListTrainer trainer;
    ValueInitializer valueInitializer;
    transient TrainingDebugger trainingDebugger;

    public IterativeTrainingStrategy(Settings settings, NeuralModel model, List<NeuralSample> sampleList) {
        super(settings, model);
        this.trainer = this.getTrainerFrom(settings);
        this.bestModel = this.currentModel;
        this.valueInitializer = ValueInitializer.getInitializer(settings);
        Pair<List<NeuralSample>, List<NeuralSample>> trainVal = this.trainingValidationSplit(sampleList);
        this.trainingSet = (List)trainVal.r;
        this.validationSet = (List)trainVal.s;
        this.learnRateDecayStrategy = LearnRateDecayStrategy.getFrom(settings, this.learningRate);
        this.restartingStrategy = RestartingStrategy.getFrom(settings);
        this.trainingDebugger = new TrainingDebugger(settings);
    }

    private Pair<List<NeuralSample>, List<NeuralSample>> trainingValidationSplit(List<NeuralSample> sampleList) {
        LOG.info("Preparing the train-validation dataset split with percentage: " + this.settings.trainValidationPercentage);
        Splitter<NeuralSample> sampleSplitter = Splitter.getSplitter(this.settings);
        Pair<List<NeuralSample>, List<NeuralSample>> partition = sampleSplitter.partition(sampleList, this.settings.trainValidationPercentage);
        return new Pair<List<NeuralSample>, List<NeuralSample>>((List<NeuralSample>)partition.r, (List<NeuralSample>)partition.s);
    }

    private ListTrainer getTrainerFrom(Settings settings) {
        if (settings.asyncParallelTraining) {
            return new AsyncParallelTrainer(settings, Optimizer.getFrom(settings, this.learningRate), this.currentModel).new AsyncParallelTrainer.AsyncListTrainer();
        }
        if (settings.minibatchSize > 1) {
            return new MiniBatchTrainer(settings, Optimizer.getFrom(settings, this.learningRate), this.currentModel, settings.minibatchSize).new MiniBatchTrainer.MinibatchListTrainer();
        }
        return new SequentialTrainer(settings, Optimizer.getFrom(settings, this.learningRate), this.currentModel).new SequentialTrainer.SequentialListTrainer();
    }

    @Override
    public Pair<NeuralModel, Progress> train() {
        this.timing.tic();
        LOG.finer("Starting with iterative mode neural training.");
        this.initTraining();
        int epochae = 0;
        for (int restart = 0; restart < this.settings.restartCount; ++restart) {
            this.initRestart();
            while (this.restartingStrategy.continueRestart(this.progress) && epochae++ < this.settings.maxCumEpochCount) {
                this.initEpoch(epochae);
                List<Result> onlineEvaluations = this.trainer.learnEpoch(this.currentModel, this.trainingSet);
                this.endEpoch(epochae, onlineEvaluations);
            }
            this.endRestart();
        }
        this.timing.toc();
        this.timing.finish();
        return this.finish();
    }

    protected void initTraining() {
        LOG.info("Initializing training (shuffling examples etc.)");
        if (this.settings.shuffleBeforeTraining) {
            Collections.shuffle(this.trainingSet, this.settings.random);
        }
        this.progress = new Progress();
    }

    protected void initRestart() {
        LOG.info("Initializing new restart (resetting weights).");
        this.trainer.restart(this.settings);
        this.currentModel.resetWeights(this.valueInitializer);
        this.progress.nextRestart();
        this.recalculateResults();
    }

    protected void initEpoch(int epochNumber) {
        if (this.settings.shuffleEachEpoch) {
            Collections.shuffle(this.trainingSet, this.settings.random);
        }
        if (this.settings.islearnRateDecay) {
            this.learnRateDecayStrategy.decay();
        }
        if (this.settings.dropoutMode == Settings.DropoutMode.LIFTED_DROPCONNECT && this.settings.dropoutRate > 0.0) {
            this.currentModel.dropoutWeights();
        }
    }

    protected void endEpoch(int count, List<Result> onlineEvaluations) {
        Results onlineResults = this.resultsFactory.createFrom(onlineEvaluations);
        this.progress.addOnlineResults(onlineResults);
        LOG.info("epoch: " + count + " : online results : " + onlineResults.toString(this.settings));
        Utilities.logMemory();
        if (count % this.settings.resultsRecalculationEpochae == 0) {
            this.recalculateResults();
            if (this.settings.debugTemplateTraining) {
                this.currentModel.getTemplate().updateWeightsFrom(this.currentModel);
                this.trainingDebugger.debug(this.currentModel.getTemplate());
            }
        }
    }

    protected void endRestart() {
        this.recalculateResults();
        this.restartingStrategy.nextRestart();
        if (LOG.isLoggable(Level.FINER)) {
            this.logSampleOutputs();
        }
    }

    protected Pair<NeuralModel, Progress> finish() {
        this.evaluateModel(this.bestModel);
        super.endTrainingStrategy();
        return new Pair<NeuralModel, Progress>(this.bestModel, this.progress);
    }

    private TrainingStrategy.TrainVal evaluateModel(NeuralModel neuralModel) {
        this.currentModel.loadWeightValues(neuralModel);
        return this.evaluateModel();
    }

    private TrainingStrategy.TrainVal evaluateModel() {
        List<Result> trainingResults = this.trainer.evaluate(this.trainingSet);
        List<Result> validationResults = this.trainer.evaluate(this.validationSet);
        return new TrainingStrategy.TrainVal(trainingResults, validationResults);
    }

    private void recalculateResults() {
        TrainingStrategy.TrainVal trueEvaluations = this.evaluateModel();
        Results trainingResults = this.resultsFactory.createFrom(trueEvaluations.training);
        Results validationResults = this.resultsFactory.createFrom(trueEvaluations.validation);
        if (this.settings.calculateBestThreshold && validationResults instanceof ClassificationResults) {
            ((ClassificationResults)validationResults).getBestAccuracy(validationResults.evaluations, ((ClassificationResults)trainingResults).bestThreshold);
        }
        this.progress.addTrueResults(trainingResults, validationResults);
        if (LOG.isLoggable(Level.FINE)) {
            String msg = "true results :- train: " + trainingResults.toString(this.settings);
            if (!validationResults.isEmpty()) {
                msg = msg + ", val: " + validationResults.toString(this.settings);
            }
            LOG.fine(msg);
        }
        if (this.settings.debugSampleOutputs) {
            this.logSampleOutputs();
        }
        if (this.settings.checkNeuronSaturation) {
            this.saturationCheck(this.trainingSet);
            this.saturationCheck(this.validationSet);
        }
        Progress.TrainVal trainVal = new Progress.TrainVal(trainingResults, validationResults);
        this.saveIfBest(trainVal);
    }

    private void saveIfBest(Progress.TrainVal trainVal) {
        if (this.progress.bestResults == null || trainVal.betterThan(this.progress.bestResults)) {
            this.bestModel = this.currentModel.cloneValues();
            if (this.settings.calculateBestThreshold && trainVal.training instanceof ClassificationResults) {
                this.bestModel.threshold = ((ClassificationResults)trainVal.training).bestThreshold;
            }
            this.progress.bestResults = trainVal;
        }
    }

    private void logSampleOutputs() {
        LOG.finer("Training outputs");
        this.progress.getLastTrueResults().training.printOutputs();
        if (!this.progress.getLastTrueResults().validation.isEmpty()) {
            LOG.finer("Validation outputs:");
            this.progress.getLastTrueResults().validation.printOutputs();
        }
    }

    private void saturationCheck(List<NeuralSample> samples) {
        ScalarValue percentage = new ScalarValue(this.settings.saturationPercentage);
        Accumulating accumulating = new Accumulating(this.settings, new SaturationChecker());
        List<Pair<Value, Value>> pairs = accumulating.accumulateStats(samples);
        int saturatedNetworks = 0;
        for (Pair<Value, Value> pair : pairs) {
            Value saturated = (Value)pair.r;
            Value all = (Value)pair.s;
            if (!saturated.greaterThan(all.times((Value)percentage))) continue;
            ++saturatedNetworks;
        }
        if (saturatedNetworks > 0) {
            LOG.warning("There are saturated networks: #" + saturatedNetworks + " / " + samples.size());
        }
    }
}

