/*
 * Decompiled with CFR 0.152.
 */
package cz.cvut.fel.ida.learning.crossvalidation;

import com.sun.istack.internal.Nullable;
import cz.cvut.fel.ida.learning.LearningSample;
import cz.cvut.fel.ida.learning.crossvalidation.Fold;
import cz.cvut.fel.ida.learning.crossvalidation.MeanStdResults;
import cz.cvut.fel.ida.learning.crossvalidation.TrainTestResults;
import cz.cvut.fel.ida.learning.crossvalidation.splitting.Splitter;
import cz.cvut.fel.ida.learning.results.ClassificationResults;
import cz.cvut.fel.ida.learning.results.DetailedClassificationResults;
import cz.cvut.fel.ida.learning.results.Progress;
import cz.cvut.fel.ida.learning.results.RegressionResults;
import cz.cvut.fel.ida.learning.results.Result;
import cz.cvut.fel.ida.learning.results.Results;
import cz.cvut.fel.ida.setup.Settings;
import cz.cvut.fel.ida.utils.generic.Pair;
import java.util.ArrayList;
import java.util.List;
import java.util.stream.Collectors;
import java.util.stream.Stream;

public class Crossvalidation<T extends LearningSample> {
    Settings settings;
    public int foldCount;
    Splitter<T> splitter;
    public List<Fold<T>> folds;
    List<T> samples;
    @Nullable
    Results results;

    public Crossvalidation(Settings settings) {
        this.settings = settings;
        this.foldCount = settings.foldsCount;
        this.splitter = Splitter.getSplitter(settings);
    }

    public Crossvalidation(Settings settings, int foldCount) {
        this.folds = new ArrayList<Fold<T>>(foldCount);
        for (int i = 0; i < foldCount; ++i) {
            this.folds.add(new Fold(i));
        }
        this.foldCount = foldCount;
        this.splitter = Splitter.getSplitter(settings);
    }

    public void loadFolds(Stream<T> samples) {
        this.folds = this.splitter.splitIntoFolds(samples, this.foldCount);
        this.samples = this.extractSamples(this.folds);
    }

    public void loadFolds(List<Stream<T>> folds) {
        this.folds = this.splitter.splitIntoFolds(folds);
        this.samples = this.extractSamples(this.folds);
    }

    public Pair<MeanStdResults.TrainValTest, TrainTestResults> aggregateResults(List<TrainTestResults> foldRunStatsList) {
        MeanStdResults.TrainValTest trainValTest = this.aggregateResultsMeanStd(foldRunStatsList);
        ArrayList<Result> allTrain = new ArrayList<Result>();
        ArrayList<Result> allValidation = new ArrayList<Result>();
        ArrayList<Result> allTest = new ArrayList<Result>();
        ArrayList<Progress.Restart> allRestarts = new ArrayList<Progress.Restart>();
        int correctClassCount = 0;
        int allCount = 0;
        for (TrainTestResults trainTestResults : foldRunStatsList) {
            Results training = trainTestResults.training.bestResults.training;
            allTrain.addAll(training.evaluations);
            allValidation.addAll(trainTestResults.training.bestResults.validation.evaluations);
            allTest.addAll(trainTestResults.testing.evaluations);
            allRestarts.addAll(trainTestResults.training.restarts);
            if (!this.settings.calculateBestThreshold || !(training instanceof DetailedClassificationResults)) continue;
            Double bestAccuracy = ((DetailedClassificationResults)trainTestResults.testing).computeBestAccuracy(trainTestResults.testing.evaluations, ((DetailedClassificationResults)training).bestThreshold);
            correctClassCount = (int)((long)correctClassCount + Math.round(bestAccuracy * (double)trainTestResults.testing.evaluations.size()));
            allCount += trainTestResults.testing.evaluations.size();
        }
        Results.Factory factory = Results.Factory.getFrom(this.settings);
        Results allTrainResults = factory.createFrom(allTrain);
        Results allValidationResults = factory.createFrom(allValidation);
        Results allTestResults = factory.createFrom(allTest);
        if (this.settings.calculateBestThreshold && allTestResults instanceof ClassificationResults) {
            double mergedBestAccuracy = (double)correctClassCount / (double)allCount;
            ((DetailedClassificationResults)allTestResults).bestAccuracy = mergedBestAccuracy;
        }
        Progress allTrainingMerged = new Progress();
        allTrainingMerged.restarts = allRestarts;
        allTrainingMerged.bestResults = new Progress.TrainVal(allTrainResults, allValidationResults);
        TrainTestResults finalTrainTestResults = new TrainTestResults(allTrainingMerged, allTestResults);
        Pair<MeanStdResults.TrainValTest, TrainTestResults> resultsPair = new Pair<MeanStdResults.TrainValTest, TrainTestResults>(trainValTest, finalTrainTestResults);
        return resultsPair;
    }

    public MeanStdResults.TrainValTest aggregateResultsMeanStd(List<TrainTestResults> foldRunStatsList) {
        List training = foldRunStatsList.stream().map(fold -> fold.training.bestResults.training).collect(Collectors.toList());
        List validation = foldRunStatsList.stream().map(fold -> fold.training.bestResults.validation).collect(Collectors.toList());
        List testing = foldRunStatsList.stream().map(fold -> fold.testing).collect(Collectors.toList());
        if (training.get(0) instanceof ClassificationResults) {
            MeanStdResults train = ClassificationResults.aggregateClassifications(training.stream().map(res -> (ClassificationResults)res).collect(Collectors.toList()));
            MeanStdResults val = ClassificationResults.aggregateClassifications(validation.stream().map(res -> (ClassificationResults)res).collect(Collectors.toList()));
            MeanStdResults test = ClassificationResults.aggregateClassifications(testing.stream().map(res -> (ClassificationResults)res).collect(Collectors.toList()));
            return new MeanStdResults.TrainValTest(train, val, test);
        }
        MeanStdResults train = RegressionResults.aggregateRegressions(training.stream().map(res -> (RegressionResults)res).collect(Collectors.toList()));
        MeanStdResults val = RegressionResults.aggregateRegressions(validation.stream().map(res -> (RegressionResults)res).collect(Collectors.toList()));
        MeanStdResults test = RegressionResults.aggregateRegressions(testing.stream().map(res -> (RegressionResults)res).collect(Collectors.toList()));
        return new MeanStdResults.TrainValTest(train, val, test);
    }

    private List<T> extractSamples(List<Fold<T>> folds) {
        ArrayList samples = new ArrayList();
        folds.forEach(fold -> samples.addAll(fold.test));
        return samples;
    }
}

