import os.path

from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision.datasets import MNIST, FashionMNIST, CIFAR10, CIFAR100, ImageNet


def get_loaders(dataset, batch_size, analysis_only=False, verbose=True):
    if verbose:
        print(f"Constructing {dataset} dataset ...")
    if dataset == "MNIST":
        transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.1307,), (0.3081,))
        ])
        train_set = None if analysis_only else MNIST('~/data', train=True, download=True, transform=transform)
        val_set = MNIST('~/data', train=False, transform=transform)
    elif dataset == "FashionMNIST":
        transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.2860,), (0.3530,))
        ])
        train_set = None if analysis_only else FashionMNIST('~/data', train=True, download=True, transform=transform)
        val_set = FashionMNIST('~/data', train=False, transform=transform)
    elif dataset == "CIFAR10":
        transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616))
        ])
        train_set = None if analysis_only else CIFAR10('~/data', train=True, download=True, transform=transform)
        val_set = CIFAR10('~/data', train=False, transform=transform)
    elif dataset == "CIFAR100":
        transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.5071, 0.4865, 0.4409), (0.2673, 0.2564, 0.2762))
        ])
        train_set = None if analysis_only else CIFAR100('~/data', train=True, download=True, transform=transform)
        val_set = CIFAR100('~/data', train=False, transform=transform)
    elif dataset == "ImageNet":
        root = "/scratch/xc429/ILSVRC12"
        if not os.path.isdir(root):
            root = "/home/xc429/datasets/ILSVRC12_torch"
        normalize = transforms.Normalize(
            mean=[0.485, 0.456, 0.406],
            std=[0.229, 0.224, 0.225])
        train_set = None if analysis_only else ImageNet(
            root=root,
            split="train",
            transform=transforms.Compose([
                transforms.RandomResizedCrop(224),
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                normalize,
            ])
        )
        val_set = ImageNet(
            root=root,
            split="val",
            transform=transforms.Compose([
                transforms.Resize(256),
                transforms.CenterCrop(224),
                transforms.ToTensor(),
                normalize,
            ])
        )
    else:
        raise NotImplementedError

    train_loader = None if analysis_only else DataLoader(train_set, batch_size=batch_size, num_workers=4, pin_memory=True, shuffle=True)
    val_loader = DataLoader(val_set, batch_size=batch_size, num_workers=4, pin_memory=True, shuffle=False)
    return train_loader, val_loader