/*
 * Decompiled with CFR 0.152.
 */
package cz.cvut.fel.ida.neural.networks.computation.training.strategies;

import cz.cvut.fel.ida.algebra.values.ScalarValue;
import cz.cvut.fel.ida.algebra.values.Value;
import cz.cvut.fel.ida.algebra.values.inits.ValueInitializer;
import cz.cvut.fel.ida.algebra.weights.Weight;
import cz.cvut.fel.ida.learning.LearningSample;
import cz.cvut.fel.ida.learning.crossvalidation.splitting.Splitter;
import cz.cvut.fel.ida.learning.results.DetailedClassificationResults;
import cz.cvut.fel.ida.learning.results.Progress;
import cz.cvut.fel.ida.learning.results.Result;
import cz.cvut.fel.ida.learning.results.Results;
import cz.cvut.fel.ida.neural.networks.computation.iteration.actions.Accumulating;
import cz.cvut.fel.ida.neural.networks.computation.iteration.visitors.states.neurons.SaturationChecker;
import cz.cvut.fel.ida.neural.networks.computation.training.NeuralModel;
import cz.cvut.fel.ida.neural.networks.computation.training.NeuralSample;
import cz.cvut.fel.ida.neural.networks.computation.training.optimizers.Optimizer;
import cz.cvut.fel.ida.neural.networks.computation.training.strategies.Hyperparameters.LearnRateDecayStrategy;
import cz.cvut.fel.ida.neural.networks.computation.training.strategies.Hyperparameters.RestartingStrategy;
import cz.cvut.fel.ida.neural.networks.computation.training.strategies.TrainingStrategy;
import cz.cvut.fel.ida.neural.networks.computation.training.strategies.debugging.NeuralDebugging;
import cz.cvut.fel.ida.neural.networks.computation.training.strategies.trainers.AsyncParallelTrainer;
import cz.cvut.fel.ida.neural.networks.computation.training.strategies.trainers.ListTrainer;
import cz.cvut.fel.ida.neural.networks.computation.training.strategies.trainers.MiniBatchTrainer;
import cz.cvut.fel.ida.neural.networks.computation.training.strategies.trainers.SequentialTrainer;
import cz.cvut.fel.ida.setup.Settings;
import cz.cvut.fel.ida.utils.exporting.Exportable;
import cz.cvut.fel.ida.utils.exporting.TextExporter;
import cz.cvut.fel.ida.utils.generic.Pair;
import java.nio.file.Paths;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.logging.Level;
import java.util.logging.Logger;
import java.util.stream.Collectors;

public class IterativeTrainingStrategy
extends TrainingStrategy {
    private static final Logger LOG = Logger.getLogger(IterativeTrainingStrategy.class.getName());
    transient NeuralModel bestModel;
    transient List<NeuralSample> trainingSet;
    transient List<NeuralSample> validationSet;
    Progress progress;
    RestartingStrategy restartingStrategy;
    LearnRateDecayStrategy learnRateDecayStrategy;
    transient ListTrainer trainer;
    ValueInitializer valueInitializer;
    private final int resultsRecalculationEpochae;

    public IterativeTrainingStrategy(Settings settings, NeuralModel model, List<NeuralSample> sampleList) {
        super(settings, model);
        this.trainer = this.getTrainerFrom(settings);
        this.bestModel = this.currentModel;
        this.valueInitializer = ValueInitializer.getInitializer(settings);
        Pair<List<NeuralSample>, List<NeuralSample>> trainVal = this.trainingValidationSplit(sampleList);
        this.trainingSet = (List)trainVal.r;
        this.validationSet = (List)trainVal.s;
        this.learnRateDecayStrategy = LearnRateDecayStrategy.getFrom(settings, this.learningRate);
        this.restartingStrategy = RestartingStrategy.getFrom(settings);
        this.resultsRecalculationEpochae = settings.resultsRecalculationEpochae;
    }

    private Pair<List<NeuralSample>, List<NeuralSample>> trainingValidationSplit(List<NeuralSample> sampleList) {
        List extraValidation = sampleList.stream().filter(s -> s.type == LearningSample.Split.VALIDATION).collect(Collectors.toList());
        if (!extraValidation.isEmpty()) {
            LOG.fine("Splitting back the train-validation dataset according to the given input splits");
            ArrayList<NeuralSample> trainOnly = new ArrayList<NeuralSample>(sampleList);
            trainOnly.removeAll(extraValidation);
            LOG.fine("Train-set size=" + trainOnly.size() + ", Validation-set size=" + extraValidation.size());
            return new Pair<List<NeuralSample>, List<NeuralSample>>(trainOnly, extraValidation);
        }
        LOG.info("Preparing the train-validation dataset split with percentage: " + this.settings.trainValidationPercentage);
        Splitter<NeuralSample> sampleSplitter = Splitter.getSplitter(this.settings);
        Pair<List<NeuralSample>, List<NeuralSample>> partition = sampleSplitter.partition(sampleList, this.settings.trainValidationPercentage);
        LOG.info("Train-set size=" + ((List)partition.r).size() + ", Validation-set size=" + ((List)partition.s).size());
        return new Pair<Object, Object>(partition.r, partition.s);
    }

    private ListTrainer getTrainerFrom(Settings settings) {
        if (settings.asyncParallelTraining) {
            return new AsyncParallelTrainer(settings, Optimizer.getFrom(settings, this.learningRate), this.currentModel).new AsyncParallelTrainer.AsyncListTrainer();
        }
        if (settings.minibatchSize > 1) {
            return new MiniBatchTrainer(settings, Optimizer.getFrom(settings, this.learningRate), this.currentModel, settings.minibatchSize).new MiniBatchTrainer.MinibatchListTrainer();
        }
        return new SequentialTrainer(settings, Optimizer.getFrom(settings, this.learningRate), this.currentModel).new SequentialTrainer.SequentialListTrainer();
    }

    @Override
    public Pair<NeuralModel, Progress> train() {
        this.timing.tic();
        LOG.finer("Starting with iterative mode neural training.");
        this.initTraining();
        int epochae = 0;
        for (int restart = 0; restart < this.settings.restartCount; ++restart) {
            this.initRestart();
            while (this.restartingStrategy.continueRestart(this.progress) && epochae++ < this.settings.maxCumEpochCount) {
                this.initEpoch(epochae);
                List<Result> onlineEvaluations = this.trainer.learnEpoch(this.currentModel, this.trainingSet);
                this.endEpoch(epochae, onlineEvaluations);
            }
            this.endRestart();
        }
        this.timing.toc();
        this.timing.finish();
        return this.finish();
    }

    @Override
    public void setupDebugger(NeuralDebugging neuralDebugger) {
        this.trainer.setupDebugger(neuralDebugger);
    }

    protected void initTraining() {
        LOG.info("Initializing training (shuffling examples etc.)");
        if (this.settings.shuffleBeforeTraining) {
            Collections.shuffle(this.trainingSet, this.settings.random);
        }
        this.progress = new Progress();
    }

    protected void initRestart() {
        LOG.info("Initializing new restart (resetting weights).");
        ++this.restart;
        this.setupExporter();
        this.trainer.restart(this.settings);
        this.currentModel.resetWeights(this.valueInitializer);
        this.progress.nextRestart();
        this.recalculateResults();
        LOG.finer("New restart has been initialized");
    }

    protected void initEpoch(int epochNumber) {
        if (this.settings.shuffleEachEpoch) {
            Collections.shuffle(this.trainingSet, this.settings.random);
        }
        if (this.settings.islearnRateDecay) {
            this.learnRateDecayStrategy.decay(epochNumber);
        }
        if (this.settings.dropoutMode == Settings.DropoutMode.LIFTED_DROPCONNECT && this.settings.dropoutRate > 0.0) {
            this.currentModel.dropoutWeights();
        }
    }

    protected void endEpoch(int count, List<Result> onlineEvaluations) {
        Results onlineResults = this.resultsFactory.createFrom(onlineEvaluations);
        this.progress.addOnlineResults(onlineResults);
        this.exportProgress(onlineResults);
        LOG.info("epoch: " + count + " : online results : " + onlineResults.toString(this.settings));
        if (count % this.settings.resultsRecalculationEpochae == 0) {
            this.recalculateResults();
            if (this.settings.debugTemplateTraining && this.trainingDebugCallback != null) {
                Map<Integer, Weight> integerWeightMap = this.currentModel.mapWeightsToIds();
                this.trainingDebugCallback.accept(integerWeightMap);
            }
        }
    }

    protected void endRestart() {
        LOG.info("Finished restart, recalculating last true results.");
        this.recalculateResults();
        this.exporter.delimitEnd();
        this.exporter.finish();
        this.restartingStrategy.nextRestart();
    }

    protected Pair<NeuralModel, Progress> finish() {
        LOG.info("Finished training, loading best model so far.");
        this.evaluateModel(this.bestModel);
        this.logSampleOutputs();
        super.endTrainingStrategy();
        return new Pair<NeuralModel, Progress>(this.bestModel, this.progress);
    }

    private void evaluateModel(NeuralModel neuralModel) {
        this.currentModel.loadWeightValues(neuralModel);
        this.recalculateResults();
    }

    private TrainingStrategy.TrainVal evaluateModel() {
        List<Result> trainingResults = this.trainer.evaluate(this.trainingSet);
        List<Result> validationResults = this.trainer.evaluate(this.validationSet);
        return new TrainingStrategy.TrainVal(trainingResults, validationResults);
    }

    private void recalculateResults() {
        TrainingStrategy.TrainVal trueEvaluations = this.evaluateModel();
        Results trainingResults = this.resultsFactory.createFrom(trueEvaluations.training);
        Results validationResults = this.resultsFactory.createFrom(trueEvaluations.validation);
        if (this.settings.calculateBestThreshold && validationResults instanceof DetailedClassificationResults) {
            Value threshold = ((DetailedClassificationResults)trainingResults).computeBestAccuracyThreshold(trainingResults.evaluations);
            ((DetailedClassificationResults)validationResults).computeBestAccuracy(validationResults.evaluations, threshold);
        }
        this.progress.addTrueResults(trainingResults, validationResults);
        if (LOG.isLoggable(Level.FINE)) {
            String msg = "true results :- train: " + trainingResults.toString(this.settings);
            if (!validationResults.isEmpty()) {
                msg = msg + ", val: " + validationResults.toString(this.settings);
            }
            LOG.fine(msg);
        }
        if (this.settings.debugSampleOutputs) {
            this.logSampleOutputs();
        }
        if (this.settings.checkNeuronSaturation) {
            this.saturationCheck(this.trainingSet);
            this.saturationCheck(this.validationSet);
        }
        Progress.TrainVal trainVal = new Progress.TrainVal(trainingResults, validationResults);
        this.exportProgress(trainVal);
        this.saveIfBest(trainVal);
    }

    private void saveIfBest(Progress.TrainVal trainVal) {
        if (this.progress.bestResults == null || trainVal.betterThan(this.progress.bestResults, this.settings.preferBestTrainingNotvalidation, this.settings.modelSelection)) {
            LOG.fine("Improvement of best " + (this.settings.preferBestTrainingNotvalidation || this.settings.trainValidationPercentage == 1.0 ? "training " : "validation ") + this.settings.modelSelection.name() + " stored so far...");
            this.bestModel = this.currentModel.cloneWeights();
            if (this.settings.calculateBestThreshold && trainVal.training instanceof DetailedClassificationResults) {
                this.bestModel.threshold = ((DetailedClassificationResults)trainVal.training).bestThreshold;
            }
            this.progress.bestResults = trainVal;
        }
    }

    private void logSampleOutputs() {
        LOG.finer("Training outputs");
        LOG.finer(this.progress.getLastTrueResults().training.printOutputs(false).toString());
        TextExporter.exportString(this.progress.getLastTrueResults().training.printOutputs(true).toString(), Paths.get(this.settings.exportDir, "outputs/train" + counter + ".txt"));
        if (!this.progress.getLastTrueResults().validation.isEmpty()) {
            LOG.finer("Validation outputs:");
            LOG.finer(this.progress.getLastTrueResults().validation.printOutputs(false).toString());
            TextExporter.exportString(this.progress.getLastTrueResults().validation.printOutputs(true).toString(), Paths.get(this.settings.exportDir, "outputs/val" + counter + ".txt"));
        }
    }

    private void saturationCheck(List<NeuralSample> samples) {
        ScalarValue percentage = new ScalarValue(this.settings.saturationPercentage);
        Accumulating accumulating = new Accumulating(this.settings, new SaturationChecker());
        List<Pair<Value, Value>> pairs = accumulating.accumulateStats(samples);
        int saturatedNetworks = 0;
        for (Pair<Value, Value> pair : pairs) {
            Value saturated = (Value)pair.r;
            Value all = (Value)pair.s;
            if (!saturated.greaterThan(all.times((Value)percentage))) continue;
            ++saturatedNetworks;
        }
        if (saturatedNetworks > 0) {
            LOG.warning("There are saturated networks: #" + saturatedNetworks + " / " + samples.size());
        }
    }

    private void exportProgress(Exportable results) {
        results.export(this.exporter);
        this.exporter.delimitNext();
    }
}

