/*
 * Decompiled with CFR 0.152.
 */
package cz.cvut.fel.ida.neural.networks.computation.training.strategies;

import cz.cvut.fel.ida.algebra.values.ScalarValue;
import cz.cvut.fel.ida.algebra.weights.Weight;
import cz.cvut.fel.ida.learning.results.Progress;
import cz.cvut.fel.ida.learning.results.Result;
import cz.cvut.fel.ida.learning.results.Results;
import cz.cvut.fel.ida.neural.networks.computation.training.NeuralModel;
import cz.cvut.fel.ida.neural.networks.computation.training.NeuralSample;
import cz.cvut.fel.ida.neural.networks.computation.training.strategies.IterativeTrainingStrategy;
import cz.cvut.fel.ida.neural.networks.computation.training.strategies.StreamTrainingStrategy;
import cz.cvut.fel.ida.neural.networks.computation.training.strategies.debugging.NeuralDebugging;
import cz.cvut.fel.ida.setup.Settings;
import cz.cvut.fel.ida.utils.exporting.Exportable;
import cz.cvut.fel.ida.utils.exporting.Exporter;
import cz.cvut.fel.ida.utils.generic.Pair;
import cz.cvut.fel.ida.utils.generic.Timing;
import cz.cvut.fel.ida.utils.generic.Utilities;
import java.io.IOException;
import java.util.List;
import java.util.Map;
import java.util.function.Consumer;
import java.util.logging.Logger;
import java.util.stream.Stream;

public abstract class TrainingStrategy
implements Exportable {
    private static final Logger LOG = Logger.getLogger(TrainingStrategy.class.getName());
    transient Settings settings;
    private transient NeuralModel initModel;
    transient NeuralModel currentModel;
    ScalarValue learningRate;
    transient Results.Factory resultsFactory;
    Timing timing;
    transient Exporter exporter;
    protected int restart;
    transient Process progressPlotter;
    static int counter = 0;
    transient Consumer<Map<Integer, Weight>> trainingDebugCallback;

    public TrainingStrategy(Settings settings, NeuralModel model) {
        this.settings = settings;
        this.learningRate = new ScalarValue(settings.initLearningRate);
        this.currentModel = model;
        this.storeParametersState(model);
        this.resultsFactory = Results.Factory.getFrom(settings);
        this.timing = new Timing();
        this.trainingDebugCallback = model.templateDebugCallback;
    }

    protected void setupExporter() {
        this.exporter = Exporter.getExporter(this.settings.exportDir, "progress/training" + counter++ + "restart" + this.restart, this.settings.exportType.name());
        this.exporter.delimitStart();
        if (this.settings.plotProgress > 0) {
            LOG.fine("Will try to call the Python progress plotter...");
            ProcessBuilder processBuilder = new ProcessBuilder(this.settings.pythonPath, this.settings.progressPlotterPath, this.settings.exportDir, "" + this.settings.plotProgress);
            try {
                this.progressPlotter = processBuilder.start();
            }
            catch (IOException e) {
                e.printStackTrace();
            }
        }
        LOG.finer("End of exporter setup");
    }

    private void storeParametersState(NeuralModel inputModel) {
        this.initModel = inputModel.cloneWeights();
    }

    private void loadParametersState() {
        this.currentModel.loadWeightValues(this.initModel);
    }

    public abstract Pair<NeuralModel, Progress> train();

    public abstract void setupDebugger(NeuralDebugging var1);

    public static TrainingStrategy getFrom(Settings settings, NeuralModel model, Stream<NeuralSample> sampleStream) {
        if (settings.neuralStreaming) {
            return new StreamTrainingStrategy(settings, model, sampleStream);
        }
        List<NeuralSample> collect = Utilities.terminateSampleStream(sampleStream);
        return new IterativeTrainingStrategy(settings, model, collect);
    }

    protected void endTrainingStrategy() {
        if (this.settings.undoWeightTrainingChanges) {
            this.loadParametersState();
        }
        if (this.progressPlotter != null) {
            this.progressPlotter.destroy();
        }
    }

    protected class TrainVal {
        List<Result> training;
        List<Result> validation;

        public TrainVal(List<Result> train, List<Result> val) {
            this.training = train;
            this.validation = val;
        }
    }
}

