import sys
from torch.utils.data import DataLoader, DistributedSampler
from pyprojroot import here as project_root
from torchvision.datasets import MNIST, SVHN, CelebA
from torchvision.transforms import Compose, ToTensor, Normalize, CenterCrop, Resize

sys.path.insert(0, str(project_root()))

from src.data.imagenet_dataset import imagenet_dataset
from src.data.cifar10_dataset import cifar10_dataset

def mnist_dataset(split):
    train = split == 'train'
    transforms = Compose([ToTensor(), Normalize((0.1307,), (0.3081,))])
    return MNIST(root='../mnist_dataset', train=train, transform=transforms, download=True)

# def svhn_dataset(split):
#     train = split == 'train'
#     transforms = Compose([ToTensor(), Normalize((0.4377, 0.4438, 0.4728), (0.1980, 0.2010, 0.1970))])
#     return SVHN(root='../svhn_dataset', split='train' if train else 'test', transform=transforms, download=True)

def celebA_dataset(split):
    train = split == 'train'
    transforms = Compose([
        CenterCrop(178),  # Crop to keep face centered
        Resize((64, 64)),  # Downscale to 64x64
        ToTensor(),
        Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
    ])
    return CelebA(root='../celeba_dataset', split='train' if train else 'test', transform=transforms, download=True)

def get_dataloader(args, split, distributed=True):
    if args.dataset == 'imagenet':
        dataset = imagenet_dataset(split)
    elif args.dataset == 'cifar10':
        dataset = cifar10_dataset(split)
    elif args.dataset == 'mnist':
        dataset = mnist_dataset(split)
    # elif args.dataset == 'svhn':
    #     dataset = svhn_dataset(split)
    elif args.dataset == 'celeba':
        dataset = celebA_dataset(split)
    else:
        raise Exception(f'{args.dataset} is not supported as a dataset class.')
    
    if distributed:
        return DataLoader(dataset, batch_size=args.batch_size, sampler=DistributedSampler(dataset, shuffle=True), shuffle=False, num_workers=10)
    return DataLoader(dataset, batch_size=args.batch_size, shuffle=False, num_workers=10) # NOTE: shuffle is turned off for single gpu
