from datasets import IMBALANCECIFAR100, IMBALANCECIFAR10, \
    SMALLCIFAR10, SMALLCIFAR100
from torchvision import transforms
import torchvision


transform_train = transforms.Compose([
    # transforms.Pad(4, padding_mode="reflect"),
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(), # no normalization?
])

transform_strong = transforms.Compose([
    # transforms.Pad(4, padding_mode="reflect"),
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(), # no normalization?
])

transform_test = transforms.Compose([
    transforms.ToTensor(),
])

def build_datasets(name='CIFAR10', mode='train', num_classes=10, imbalance_ratio=0.0, existing_ratio=1, root='../data', transform=None):
    
    if transform is None:
        if mode == 'train':
            transform = transform_train
            train = True
        elif mode == 'test':
            transform = transform_test
            train = False
        else:
            raise NameError('Invalid mode: {}'.format(mode))

    if imbalance_ratio > 0:
        print('Building imbalance dataset')
        if name == 'cifar10':
            dataset = IMBALANCECIFAR10(imbalance_ratio=imbalance_ratio,
                                        root=root, train=train, download=True,
                                        transform=transform)
        elif name == 'cifar100':
            dataset = IMBALANCECIFAR100(imbalance_ratio=imbalance_ratio,
                                         root=root, train=train, download=True,
                                         transform=transform)
        else:
            raise NameError('Invalid dataset name: {}'.format(name))
        
        samples_per_cls = dataset.img_num_list
    
    elif existing_ratio < 1:
        print('Building smaller dataset')
        if name == 'cifar10':
            dataset = SMALLCIFAR10(existing_ratio=existing_ratio,
                                       root=root, train=train, download=True,
                                       transform=transform)
        else:
            dataset = SMALLCIFAR100(existing_ratio=existing_ratio,
                                   root=root, train=train, download=True,
                                   transform=transform)

        samples_per_cls = dataset.img_num_list
    else:
        if name == 'cifar10':
            dataset = torchvision.datasets.CIFAR10(root=root, train=train, 
                                                   download=True, 
                                                   transform=transform)
        elif name == 'cifar100':
            dataset = torchvision.datasets.CIFAR100(root=root, train=train, 
                                                    download=True, 
                                                    transform=transform)
        elif name == 'ImageNet':
            dataset = torchvision.datasets.ImageNet(root=root, train=train,
                                                     download=True,
                                                     transform=transform)
        else:
            raise NameError('Invalid dataset name: {}'.format(name))
        samples_per_cls = [5000] * num_classes

    return dataset, samples_per_cls