% Load a specified dataset from a given path.

function [trainX,trainT,trainL,valX,valT,valL,testX,testT,testL,inSize] ...
    = loadDataset(dataset,dataPath)
    % Obtain current path.
    currPath = pwd;
    % Go into dataset directory.
    cd(sprintf('%s/datasets',dataPath));
    fprintf('Loading dataset "%s"...',dataset);
    switch dataset
        case 'mnist'
            [trainX,trainT,trainL,valX,valT,valL,...
                testX,testT,testL,inSize] = aux_loadMNIST();
        case 'f-mnist'
            [trainX,trainT,trainL,valX,valT,valL,...
                testX,testT,testL,inSize] = aux_loadFMNIST();
        case 'svhn'
            [trainX,trainT,trainL,valX,valT,valL,...
                testX,testT,testL,inSize] = aux_loadSVHN();
        case 'cifar10'
            [trainX,trainT,trainL,valX,valT,valL,...
                testX,testT,testL,inSize] = aux_loadCIFAR10();
        case 'tinyImageNet'
            [trainX,trainT,trainL,valX,valT,valL,...
                testX,testT,testL,inSize] = aux_loadTinyImageNet();
        otherwise
            error('Invalid name for dataset: %s',dataset)
    end
    fprintf(' done\n');
    % Go back to current path.
    cd(currPath);
end

%% Auxiliary Functions. ---------------------------------------------------

function [trainX,trainT,trainL,valX,valT,valL,testX,testT,testL,inSize] ...
    = aux_loadMNIST()
    % Go to MNIST path.
    cd('mnist')

    filenameImagesTrain = 'train-images-idx3-ubyte.gz';
    filenameLabelsTrain = 'train-labels-idx1-ubyte.gz';
    filenameImagesTest = 't10k-images-idx3-ubyte.gz';
    filenameLabelsTest = 't10k-labels-idx1-ubyte.gz';

    trainXImg = processImagesMNIST(filenameImagesTrain);
    trainL = processLabelsMNIST(filenameLabelsTrain)';
    
    testXImg = processImagesMNIST(filenameImagesTest);
    testL = processLabelsMNIST(filenameLabelsTest)';
    
    % split test dataset into validation-, and testing-data
    [valIdx,testIdx,~] = dividerand(size(testXImg,4),0.05,0.05,0.9);
    valXImg = testXImg(:,:,:,valIdx);
    valL = testL(valIdx);
    
    % testXImg  = testXImg(:,:,:,testIdx);
    % testL = testL(testIdx);
    
    % categorical targets to 0-1-vectors
    trainTInt = grp2idx(trainL) - 1;
    valTInt = grp2idx(valL) - 1;
    testTInt = grp2idx(testL') - 1;
    
    % convert inputs and targets to vectors
    trainX = single(reshape(trainXImg,[],size(trainXImg,4)));
    trainT = single(onehotencode(trainTInt,2,'ClassNames',(0:9)))';
    
    valX = single(reshape(valXImg,[],size(valXImg,4)));
    valT = single(onehotencode(valTInt,2,'ClassNames',(0:9)))';
    
    testX = single(reshape(testXImg,[],size(testXImg,4)));
    testT = single(onehotencode(testTInt,2,'ClassNames',(0:9)))';

    inSize = size(trainXImg,1:ndims(trainXImg) - 1);
end

function [trainX,trainT,trainL,valX,valT,valL,testX,testT,testL,inSize] ...
    = aux_loadSVHN()

    load('svhn/train_32x32.mat');

    trainXImg = double(X)/255;
    trainL = discretize(y,1:11,'categorical',...
        {'1','2','3','4','5','6','7','8','9','10'})';
    
    % categorical targets to 0-1-vectors
    trainTInt = grp2idx(trainL);
    
    % convert inputs and targets to vectors
    trainX = single(reshape(trainXImg,[],size(trainXImg,4)));
    trainT = single(onehotencode(trainTInt,2,'ClassNames',(1:10)))';
    
    
    load('svhn/test_32x32.mat')
    
    testXImg = double(X)/255;
    testL = discretize(y,1:11,'categorical',...
        {'1','2','3','4','5','6','7','8','9','10'})';
    
    % split test dataset into validation-, and testing-data
    [valIdx,testIdx,~] = dividerand(size(testXImg,4),0.01924,0.01924,0.96152);
    valXImg = testXImg(:,:,:,valIdx);
    valL = testL(valIdx);
    valTInt = grp2idx(valL);
    
    % testXImg  = testXImg(:,:,:,testIdx);
    % testL = testL(testIdx);
    testTInt = grp2idx(testL);
    
    % convert inputs and targets to vectors
    valX = single(reshape(valXImg,[],size(valXImg,4)));
    valT = single(onehotencode(valTInt,2,'ClassNames',(1:10)))';
    
    testX = single(reshape(testXImg,[],size(testXImg,4)));
    testT = single(onehotencode(testTInt,2,'ClassNames',(1:10)))';
    
    clearvars X y

    inSize = size(trainXImg,1:ndims(trainXImg) - 1);
end

function [trainX,trainT,trainL,valX,valT,valL,testX,testT,testL,inSize] ...
    = aux_loadCIFAR10()
    % Go to CIFAR-10 path.
    cd('cifar10')

    % Load the CIFAR10 dataset.
    [trainXImg,trainL,testXImg,testL] = ...
        helperCIFAR10Data.load('.');
    % Normalize images to [0,1].
    trainXImg = double(trainXImg)/255;
    testXImg = double(testXImg)/255;

    % Add 2 pixel padding around each training image.
    trainXImgPad = zeros([4 4 0 0] + size(trainXImg));
    trainXImgPad(3:34,3:34,:,:) = trainXImg; 
    trainXImg = trainXImgPad;
    % Apply crops and random flips.
    % Center-crop the validation images.
    % testXImg = applyCropsAndFlips(valXImg,cropSize,'center',false);
    
    % split test dataset into validation-, and testing-data
    [valIdx,testIdx,~] = dividerand(size(testXImg,4),0.05,0.05,0.9);
    valXImg = testXImg(:,:,:,valIdx);
    valL = testL(valIdx)';
    
    % testXImg  = testXImg(:,:,:,testIdx);
    % testL = testL(testIdx)';
    
    % categorical targets to 0-1-vectors
    trainTInt = grp2idx(trainL) - 1;
    valTInt = grp2idx(valL) - 1;
    testTInt = grp2idx(testL) - 1;
    
    % convert inputs and targets to vectors
    trainX = reshape(trainXImg,[],size(trainXImg,4));
    trainT = onehotencode(trainTInt,2,'ClassNames',(0:9))';
    
    valX = reshape(valXImg,[],size(valXImg,4));
    valT = onehotencode(valTInt,2,'ClassNames',(0:9))';
    
    testX = reshape(testXImg,[],size(testXImg,4));
    testT = onehotencode(testTInt,2,'ClassNames',(0:9))';

    inSize = [32 32 3]; % size(trainXImg,1:ndims(trainXImg) - 1);
end

function [trainX,trainT,trainL,valX,valT,valL,testX,testT,testL,inSize] ...
    = aux_loadTinyImageNet()
    % Go to TinyImageNet path.
    cd('tinyimagenet')

    % Specify the directory containing the dataset.
    datasetDir = '.';
    % Load TinyImageNet dataset.
    [trainXImg,trainL,valXImg,valL,inSize] = processTinyImageNet(datasetDir);
    % Normalize images to [0,1].
    trainXImg = single(trainXImg)/255;
    valXImg = single(valXImg)/255;
    
    % Apply crops and random flips.
    cropSize = [56 56 3];
    % trainXImg = aux_randomCropsAndFlips(trainXImg,cropSize);
    % Center-crop the validation images.
    valXImg = applyCropsAndFlips(valXImg,cropSize,'center',false);
    
    % Convert categorical targets to one-hot encoded-vectors.
    trainTInt = grp2idx(trainL);
    trainT = single(onehotencode(trainTInt,2,'ClassNames',(1:200)))';
    valTInt = grp2idx(valL);
    valT = single(onehotencode(valTInt,2,'ClassNames',(1:200)))';
    % There are no test images; thus, use validation data for testing.
    testXImg = valXImg;
    testT = valT;
    testL = valL;
    % Only use a subset of the validation data.
    [valIdx,~,~] = dividerand(size(valXImg,4),0.05,0.05,0.9);
    valXImg = valXImg(:,:,:,valIdx);
    valL = valL(valIdx)';
    
    % Convert images to vectors.
    trainX = reshape(trainXImg,[],size(trainXImg,4));
    valX = reshape(valXImg,[],size(valXImg,4));
    testX = reshape(testXImg,[],size(testXImg,4));
    
    % Update input size.
    inSize = cropSize;
end