/*
 * Decompiled with CFR 0.152.
 */
package cz.cvut.fel.ida.neural.networks.computation.training.strategies.trainers;

import cz.cvut.fel.ida.learning.results.Result;
import cz.cvut.fel.ida.neural.networks.computation.training.NeuralModel;
import cz.cvut.fel.ida.neural.networks.computation.training.NeuralSample;
import cz.cvut.fel.ida.neural.networks.computation.training.optimizers.Optimizer;
import cz.cvut.fel.ida.neural.networks.computation.training.strategies.debugging.NeuralDebugging;
import cz.cvut.fel.ida.neural.networks.computation.training.strategies.trainers.ListTrainer;
import cz.cvut.fel.ida.neural.networks.computation.training.strategies.trainers.SequentialTrainer;
import cz.cvut.fel.ida.neural.networks.computation.training.strategies.trainers.StreamTrainer;
import cz.cvut.fel.ida.setup.Settings;
import java.util.List;
import java.util.logging.Logger;
import java.util.stream.Collectors;
import java.util.stream.Stream;

public class AsyncParallelTrainer
extends SequentialTrainer {
    private static final Logger LOG = Logger.getLogger(AsyncParallelTrainer.class.getName());

    public AsyncParallelTrainer(Settings settings, Optimizer optimizer, NeuralModel neuralModel) {
        super(settings, optimizer, neuralModel);
    }

    protected AsyncParallelTrainer() {
    }

    public class AsyncStreamTrainer
    implements StreamTrainer {
        @Override
        public Stream<Result> learnEpoch(NeuralModel neuralModel, Stream<NeuralSample> sampleStream) {
            Stream<Result> resultStream = ((Stream)sampleStream.parallel()).map(sample -> AsyncParallelTrainer.this.learnFromSample(neuralModel, (NeuralSample)sample, AsyncParallelTrainer.this.dropout, AsyncParallelTrainer.this.invalidation, AsyncParallelTrainer.this.evaluation, AsyncParallelTrainer.this.backpropagation));
            return resultStream;
        }

        @Override
        public void setupDebugger(NeuralDebugging trainingDebugger) {
            AsyncParallelTrainer.this.neuralDebugger = trainingDebugger;
        }
    }

    public class AsyncListTrainer
    implements ListTrainer {
        @Override
        public List<Result> learnEpoch(NeuralModel neuralModel, List<NeuralSample> sampleList) {
            List<Result> resultList = sampleList.parallelStream().map(neuralSample -> AsyncParallelTrainer.this.learnFromSample(neuralModel, (NeuralSample)neuralSample, AsyncParallelTrainer.this.dropout, AsyncParallelTrainer.this.invalidation, AsyncParallelTrainer.this.evaluation, AsyncParallelTrainer.this.backpropagation)).collect(Collectors.toList());
            return resultList;
        }

        @Override
        public List<Result> evaluate(List<NeuralSample> sampleList) {
            List<Result> resultList = sampleList.parallelStream().map(neuralSample -> AsyncParallelTrainer.this.evaluateSample(AsyncParallelTrainer.this.evaluation, (NeuralSample)neuralSample)).collect(Collectors.toList());
            return resultList;
        }

        @Override
        public void restart(Settings settings) {
            AsyncParallelTrainer.this.optimizer.restart(settings);
        }

        @Override
        public void setupDebugger(NeuralDebugging trainingDebugger) {
            AsyncParallelTrainer.this.neuralDebugger = trainingDebugger;
        }
    }
}

