"""
Multi-EPL

File: src/datasetting/dataloader.py
Contains the code for setting datasets and dataloaders
"""

import numpy as np
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader

from datasetting.mnist import load_mnist
from datasetting.mnist_m import load_mnist_m
from datasetting.svhn import load_svhn
from datasetting.synthdigits import load_synthdigits
from datasetting.usps import load_usps

digits_data_dir = './../../data/digits/'

digits_transform = transforms.Compose([transforms.ToPILImage(),
                                       transforms.Resize((32, 32)),
                                       transforms.ToTensor(),
                                       transforms.Normalize((0.1307,), (0.3081,))
                                       ])


class GeneralDataset(Dataset):
    def __init__(self, images, labels, target, transform=digits_transform, batch_size=-1):
        # For target training dataset, set all labels to -1 which mean "no label given"
        self.images = images
        if target:
            self.labels = np.array([-1] * len(labels), dtype=np.int8)
        elif type(labels) == np.uint8 or type(labels) == int:
            self.labels = np.array([labels] * len(self.images), dtype=np.int8)
        else:
            self.labels = labels
        self.num_data = len(self.labels)
        self.transform = transform
        self.batch_size = batch_size

    def __len__(self):
        '''
        if self.batch_size > 0:
            return max(self.num_data, self.batch_size)
        else:'''
        return self.num_data

    def __getitem__(self, idx):
        # idx = idx % self.num_data
        image = self.images[idx]
        label = int(self.labels[idx])
        if self.transform:
            image = self.transform(image)
        sample = {'image': image, 'label': label, 'index': idx}
        return sample


def get_digits_dataloader(name='MNIST', target=True, transform=digits_transform, batch_size=64, data_num=-1, data_dir=digits_data_dir):
    if name == 'MNIST':
        print('Load MNIST data')
        train_data, test_data, train_label, test_label, data_per_label = load_mnist(data_dir, data_num)
    elif name == 'MNIST-M':
        print('Load MNIST-M data')
        train_data, test_data, train_label, test_label, data_per_label = load_mnist_m(data_dir, data_num)
    elif name == 'SVHN':
        print('Load SVHN data')
        train_data, test_data, train_label, test_label, data_per_label = load_svhn(data_dir, data_num)
    elif name == 'SYN':
        print('Load SYN data')
        train_data, test_data, train_label, test_label, data_per_label = load_synthdigits(data_dir, data_num)
    elif name == 'USPS':
        print('Load USPS data')
        train_data, test_data, train_label, test_label, data_per_label = load_usps(data_dir, data_num)
    else:
        raise ValueError('Name should be one of MNIST, MNIST-M, SVHN, SYN, and USPS')

    train_dataset = GeneralDataset(train_data, train_label, target, transform, batch_size)
    test_dataset = GeneralDataset(test_data, test_label, False, transform, batch_size)

    dataloader_per_label = dict()
    for key in data_per_label.keys():
        dataset = GeneralDataset(data_per_label[key], key, False, transform, batch_size)
        dataloader_per_label[key] = DataLoader(dataset, batch_size=batch_size, shuffle=True, drop_last=True, num_workers=0)
    train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, drop_last=True, num_workers=0)
    test_dataloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, drop_last=True, num_workers=0)

    return train_dataset, train_dataloader, test_dataloader, dataloader_per_label

