/*
 * Decompiled with CFR 0.152.
 */
package networks.computation.training.strategies;

import java.util.List;
import java.util.logging.Logger;
import java.util.stream.Stream;
import networks.computation.evaluation.results.Progress;
import networks.computation.evaluation.results.Result;
import networks.computation.evaluation.results.Results;
import networks.computation.evaluation.values.ScalarValue;
import networks.computation.training.NeuralModel;
import networks.computation.training.NeuralSample;
import networks.computation.training.strategies.IterativeTrainingStrategy;
import networks.computation.training.strategies.StreamTrainingStrategy;
import settings.Settings;
import utils.Timing;
import utils.Utilities;
import utils.exporting.Exportable;
import utils.generic.Pair;

public abstract class TrainingStrategy
implements Exportable {
    private static final Logger LOG = Logger.getLogger(TrainingStrategy.class.getName());
    transient Settings settings;
    private transient NeuralModel initModel;
    transient NeuralModel currentModel;
    ScalarValue learningRate;
    transient Results.Factory resultsFactory;
    Timing timing;

    public TrainingStrategy(Settings settings, NeuralModel model) {
        this.settings = settings;
        this.learningRate = new ScalarValue(settings.initLearningRate);
        this.currentModel = model;
        this.storeParametersState(model);
        this.resultsFactory = Results.Factory.getFrom(settings);
        this.timing = new Timing();
    }

    private void storeParametersState(NeuralModel inputModel) {
        this.initModel = inputModel.cloneValues();
    }

    private void loadParametersState() {
        this.currentModel.loadWeightValues(this.initModel);
    }

    public abstract Pair<NeuralModel, Progress> train();

    public static TrainingStrategy getFrom(Settings settings, NeuralModel model, Stream<NeuralSample> sampleStream) {
        if (settings.neuralStreaming) {
            return new StreamTrainingStrategy(settings, model, sampleStream);
        }
        List<NeuralSample> collect = Utilities.terminateSampleStream(sampleStream);
        return new IterativeTrainingStrategy(settings, model, collect);
    }

    protected void endTrainingStrategy() {
        if (this.settings.undoWeightTrainingChanges) {
            this.loadParametersState();
        }
    }

    protected class TrainVal {
        List<Result> training;
        List<Result> validation;

        public TrainVal(List<Result> train, List<Result> val) {
            this.training = train;
            this.validation = val;
        }
    }
}

