/*
 * Decompiled with CFR 0.152.
 */
package networks.computation.training.strategies;

import java.util.List;
import java.util.logging.Logger;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import networks.computation.evaluation.results.Progress;
import networks.computation.evaluation.results.Result;
import networks.computation.evaluation.results.Results;
import networks.computation.training.NeuralModel;
import networks.computation.training.NeuralSample;
import networks.computation.training.optimizers.Optimizer;
import networks.computation.training.strategies.TrainingStrategy;
import networks.computation.training.strategies.trainers.MiniBatchTrainer;
import networks.computation.training.strategies.trainers.SequentialTrainer;
import networks.computation.training.strategies.trainers.StreamTrainer;
import settings.Settings;
import utils.generic.Pair;

public class StreamTrainingStrategy
extends TrainingStrategy {
    private static final Logger LOG = Logger.getLogger(StreamTrainingStrategy.class.getName());
    Stream<NeuralSample> samplesStream;
    StreamTrainer trainer;

    public StreamTrainingStrategy(Settings settings, NeuralModel model, Stream<NeuralSample> sampleStream) {
        super(settings, model);
        this.samplesStream = sampleStream;
        this.trainer = this.getTrainerFrom(settings);
    }

    private StreamTrainer getTrainerFrom(Settings settings) {
        if (settings.minibatchSize > 1) {
            return new MiniBatchTrainer.MinibatchStreamTrainer(new MiniBatchTrainer(settings, Optimizer.getFrom(settings, this.learningRate), this.currentModel, settings.minibatchSize));
        }
        return new SequentialTrainer.SequentialStreamTrainer(new SequentialTrainer(settings, Optimizer.getFrom(settings, this.learningRate), this.currentModel));
    }

    @Override
    public Pair<NeuralModel, Progress> train() {
        Stream<Result> resultStream = this.trainer.learnEpoch(this.currentModel, this.samplesStream);
        List<Result> resultList = resultStream.collect(Collectors.toList());
        Progress progress = new Progress();
        Results results = this.resultsFactory.createFrom(resultList);
        progress.addOnlineResults(results);
        return new Pair<NeuralModel, Progress>(this.currentModel, progress);
    }

    public NeuralModel getBestModel() {
        return this.currentModel;
    }
}

