/*
 * Decompiled with CFR 0.152.
 */
package learning.crossvalidation;

import com.sun.istack.internal.Nullable;
import java.util.ArrayList;
import java.util.List;
import java.util.stream.Stream;
import learning.LearningSample;
import learning.crossvalidation.Fold;
import learning.crossvalidation.TrainTestResults;
import learning.crossvalidation.splitting.Splitter;
import networks.computation.evaluation.results.ClassificationResults;
import networks.computation.evaluation.results.Progress;
import networks.computation.evaluation.results.Result;
import networks.computation.evaluation.results.Results;
import settings.Settings;

public class Crossvalidation<T extends LearningSample> {
    Settings settings;
    public int foldCount;
    Splitter<T> splitter;
    public List<Fold<T>> folds;
    List<T> samples;
    @Nullable
    Results results;

    public Crossvalidation(Settings settings) {
        this.settings = settings;
        this.foldCount = settings.foldsCount;
        this.splitter = Splitter.getSplitter(settings);
    }

    public Crossvalidation(Settings settings, int foldCount) {
        this.folds = new ArrayList<Fold<T>>(foldCount);
        for (int i = 0; i < foldCount; ++i) {
            this.folds.add(new Fold(i));
        }
        this.foldCount = foldCount;
        this.splitter = Splitter.getSplitter(settings);
    }

    public void loadFolds(Stream<T> samples) {
        this.folds = this.splitter.splitIntoFolds(samples, this.foldCount);
        this.samples = this.extractSamples(this.folds);
    }

    public void loadFolds(List<Stream<T>> folds) {
        this.folds = this.splitter.splitIntoFolds(folds);
        this.samples = this.extractSamples(this.folds);
    }

    public TrainTestResults aggregateResults(List<TrainTestResults> foldRunStatsList) {
        ArrayList<Result> allTrain = new ArrayList<Result>();
        ArrayList<Result> allValidation = new ArrayList<Result>();
        ArrayList<Result> allTest = new ArrayList<Result>();
        ArrayList<Progress.Restart> allRestarts = new ArrayList<Progress.Restart>();
        int correctClassCount = 0;
        int allCount = 0;
        for (TrainTestResults trainTestResults : foldRunStatsList) {
            Results training = trainTestResults.training.bestResults.training;
            allTrain.addAll(training.evaluations);
            allValidation.addAll(trainTestResults.training.bestResults.validation.evaluations);
            allTest.addAll(trainTestResults.testing.evaluations);
            allRestarts.addAll(trainTestResults.training.restarts);
            if (!this.settings.calculateBestThreshold || !(training instanceof ClassificationResults)) continue;
            Double bestAccuracy = ((ClassificationResults)trainTestResults.testing).getBestAccuracy(trainTestResults.testing.evaluations, ((ClassificationResults)training).bestThreshold);
            correctClassCount = (int)((long)correctClassCount + Math.round(bestAccuracy * (double)trainTestResults.testing.evaluations.size()));
            allCount += trainTestResults.testing.evaluations.size();
        }
        Results.Factory factory = Results.Factory.getFrom(this.settings);
        Results allTrainResults = factory.createFrom(allTrain);
        Results allValidationResults = factory.createFrom(allValidation);
        Results allTestResults = factory.createFrom(allTest);
        if (this.settings.calculateBestThreshold && allTestResults instanceof ClassificationResults) {
            double mergedBestAccuracy = (double)correctClassCount / (double)allCount;
            ((ClassificationResults)allTestResults).bestAccuracy = mergedBestAccuracy;
        }
        Progress allTrainingMerged = new Progress();
        allTrainingMerged.restarts = allRestarts;
        allTrainingMerged.bestResults = new Progress.TrainVal(allTrainResults, allValidationResults);
        TrainTestResults finalTrainTestResults = new TrainTestResults(allTrainingMerged, allTestResults);
        return finalTrainTestResults;
    }

    private List<T> extractSamples(List<Fold<T>> folds) {
        ArrayList samples = new ArrayList();
        folds.forEach(fold -> samples.addAll(fold.test));
        return samples;
    }
}

