"""
Multi-EPL

File: src/datasetting/usps.py
Contains the code for setting USPS dataset
"""

import scipy.io as sio
import numpy as np
import os

data_dir = '../../data/digits'
file_name = 'usps_28x28.mat'


def load_usps(root=data_dir, data_num=-1):
    data_file_name = os.path.join(root, file_name)
    usps_data = sio.loadmat(data_file_name)

    usps_train = usps_data['dataset'][0][0].transpose(0, 2, 3, 1)
    usps_train = np.concatenate([usps_train, usps_train, usps_train], axis=3)
    usps_train = (usps_train * 255).astype(np.uint8)
    usps_labels_train = usps_data['dataset'][0][1].reshape(-1)

    usps_test = usps_data['dataset'][1][0].transpose(0, 2, 3, 1)
    usps_test = np.concatenate([usps_test, usps_test, usps_test], axis=3)
    usps_test = (usps_test * 255).astype(np.uint8)
    test_label = usps_data['dataset'][1][1].reshape(-1)

    inds = np.random.permutation(usps_train.shape[0])
    usps_train = usps_train[inds]
    train_label = usps_labels_train[inds]

    if 0 <= data_num:
        usps_train = usps_train[:data_num]
        train_label = train_label[:data_num]

    data_per_label = {}
    for label in range(10):
        inds = np.where(train_label == label)
        usps_train_label = usps_train[inds]
        data_per_label[label] = usps_train_label

    assert sum([len(data_per_label[key]) for key in data_per_label.keys()]) == usps_train.shape[0]

    print('*** USPS DATASET ***')
    print('Training data: {}, Training label: {}'.format(usps_train.shape, train_label.shape))
    print('Test data: {}, Test label: {}'.format(usps_test.shape, test_label.shape))

    return usps_train, usps_test, train_label, test_label, data_per_label
