"""
Multi-EPL

File: src/datasetting/svhn.py
Contains the code for setting SVHN dataset
"""

import scipy.io as sio
import numpy as np
import os

data_dir = '../../data/digits'
train_file_name = 'svhn_train_32x32.mat'
test_file_name = 'svhn_test_32x32.mat'


def load_svhn(root=data_dir, data_num=-1):
    train_data_file_name = os.path.join(root, train_file_name)
    test_data_file_name = os.path.join(root, test_file_name)
    svhn_train_data = sio.loadmat(train_data_file_name)
    svhn_test_data = sio.loadmat(test_data_file_name)

    svhn_train = svhn_train_data['X'].transpose(3, 0, 1, 2).astype(np.uint8)
    train_label = svhn_train_data['y'].reshape(-1) % 10
    svhn_test = svhn_test_data['X'].transpose(3, 0, 1, 2).astype(np.uint8)
    test_label = svhn_test_data['y'].reshape(-1) % 10

    inds = np.random.permutation(svhn_train.shape[0])
    svhn_train = svhn_train[inds]
    train_label = train_label[inds]

    if 0 <= data_num:
        svhn_train = svhn_train[:data_num]
        train_label = train_label[:data_num]

    data_per_label = {}
    for label in range(10):
        inds = np.where(train_label == label)
        svhn_train_label = svhn_train[inds]
        data_per_label[label] = svhn_train_label

    assert sum([len(data_per_label[key]) for key in data_per_label.keys()]) == svhn_train.shape[0]

    print('*** SVHN DATASET ***')
    print('Training data: {}, Training label: {}'.format(svhn_train.shape, train_label.shape))
    print('Test data: {}, Test label: {}'.format(svhn_test.shape, test_label.shape))

    return svhn_train, svhn_test, train_label, test_label, data_per_label
