/*
 * Decompiled with CFR 0.152.
 */
package networks.computation.training.strategies.trainers;

import java.util.ArrayList;
import java.util.Collection;
import java.util.Iterator;
import java.util.List;
import java.util.logging.Logger;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import java.util.stream.Stream;
import java.util.stream.StreamSupport;
import networks.computation.evaluation.results.Result;
import networks.computation.training.NeuralModel;
import networks.computation.training.NeuralSample;
import networks.computation.training.optimizers.Optimizer;
import networks.computation.training.strategies.trainers.ListTrainer;
import networks.computation.training.strategies.trainers.SequentialTrainer;
import networks.computation.training.strategies.trainers.StreamTrainer;
import networks.computation.training.strategies.trainers.Trainer;
import settings.Settings;
import utils.Utilities;

public class MiniBatchTrainer
extends Trainer {
    private static final Logger LOG = Logger.getLogger(MiniBatchTrainer.class.getName());
    int minibatchSize;
    List<SequentialTrainer> trainers;

    private MiniBatchTrainer() {
    }

    public MiniBatchTrainer(Settings settings, Optimizer optimizer, NeuralModel neuralModel, int minibatchSize) {
        super(settings, optimizer);
        this.minibatchSize = minibatchSize;
        this.trainers = new ArrayList<SequentialTrainer>(minibatchSize);
        for (int i = 0; i < minibatchSize; ++i) {
            this.trainers.add(new SequentialTrainer(settings, optimizer, neuralModel, i));
        }
    }

    private List<Result> minibatchParallelLearn(NeuralModel neuralModel, List<NeuralSample> sampleList) {
        ArrayList results = new ArrayList(sampleList.size());
        if (sampleList.size() > this.minibatchSize) {
            LOG.severe("Minibatch size mismatch");
        }
        List<Result> resultList = IntStream.range(0, this.minibatchSize).parallel().mapToObj(i -> new Task(this.trainers.get(i), (NeuralSample)sampleList.get(i))).map(task -> task.runLearning(neuralModel)).collect(Collectors.toList());
        return resultList;
    }

    private List<Result> minibatchParallelEvaluate(List<NeuralSample> minibatch) {
        ArrayList results = new ArrayList(minibatch.size());
        if (minibatch.size() > this.minibatchSize) {
            LOG.severe("Minibatch size mismatch");
        }
        List<Result> resultList = IntStream.range(0, this.minibatchSize).parallel().mapToObj(i -> new Task(this.trainers.get(i), (NeuralSample)minibatch.get(i))).map(task -> task.runEvaluation()).collect(Collectors.toList());
        return resultList;
    }

    public class MiniBatchIterator
    implements Iterator<List<NeuralSample>> {
        List<NeuralSample> sampleList;
        int i = 0;
        int minibatchSize;

        public MiniBatchIterator(Settings settings, List<NeuralSample> sampleList) {
            this.minibatchSize = settings.minibatchSize;
            this.sampleList = sampleList;
        }

        @Override
        public boolean hasNext() {
            return this.i < this.sampleList.size();
        }

        @Override
        public List<NeuralSample> next() {
            List<NeuralSample> neuralSamples = this.sampleList.subList(this.i, Math.min(this.i + this.minibatchSize, this.sampleList.size()));
            this.i += this.minibatchSize;
            return neuralSamples;
        }
    }

    private class Task {
        SequentialTrainer trainer;
        NeuralSample sample;

        public Task(SequentialTrainer trainer, NeuralSample sample) {
            this.trainer = trainer;
            this.sample = sample;
        }

        public Result runLearning(NeuralModel neuralModel) {
            return MiniBatchTrainer.this.learnFromSample(neuralModel, this.sample, this.trainer.dropout, this.trainer.invalidation, this.trainer.evaluation, this.trainer.backpropagation);
        }

        public Result runEvaluation() {
            MiniBatchTrainer.this.invalidateSample(this.trainer.invalidation, this.sample);
            return MiniBatchTrainer.this.evaluateSample(this.trainer.evaluation, this.sample);
        }
    }

    public class MinibatchStreamTrainer
    implements StreamTrainer {
        @Override
        public Stream<Result> learnEpoch(NeuralModel neuralModel, Stream<NeuralSample> sampleStream) {
            if (sampleStream.isParallel()) {
                LOG.severe("The input sampleStream is parallel, but the training must perform sequential gradient steps!");
            }
            Stream<List> minibatchStream = StreamSupport.stream(new Utilities.BatchSpliterator(sampleStream.spliterator(), MiniBatchTrainer.this.minibatchSize), false);
            Stream<Result> resultStream = minibatchStream.map(batch -> MiniBatchTrainer.this.minibatchParallelLearn(neuralModel, batch)).flatMap(Collection::stream);
            return resultStream;
        }
    }

    public class MinibatchListTrainer
    implements ListTrainer {
        @Override
        public List<Result> learnEpoch(NeuralModel neuralModel, List<NeuralSample> sampleList) {
            ArrayList<Result> resultList = new ArrayList<Result>(sampleList.size());
            MiniBatchIterator miniBatchIterator = new MiniBatchIterator(MiniBatchTrainer.this.settings, sampleList);
            while (miniBatchIterator.hasNext()) {
                Object minibatch = miniBatchIterator.next();
                List results = MiniBatchTrainer.this.minibatchParallelLearn(neuralModel, (List)minibatch);
                resultList.addAll(results);
            }
            return resultList;
        }

        @Override
        public List<Result> evaluate(List<NeuralSample> trainingSet) {
            ArrayList<Result> resultList = new ArrayList<Result>(trainingSet.size());
            MiniBatchIterator miniBatchIterator = new MiniBatchIterator(MiniBatchTrainer.this.settings, trainingSet);
            while (miniBatchIterator.hasNext()) {
                Object minibatch = miniBatchIterator.next();
                List results = MiniBatchTrainer.this.minibatchParallelEvaluate((List)minibatch);
                resultList.addAll(results);
            }
            return resultList;
        }

        @Override
        public void restart(Settings settings) {
            MiniBatchTrainer.this.optimizer.restart(settings);
        }
    }
}

