from torchvision.datasets import CIFAR100, CIFAR10
import torch.distributed as dist
from torch.utils.data import DataLoader, DistributedSampler
from torchvision import transforms


def cifar100(batch_size, workers, distributed=False):
    train_set = CIFAR100('.', train=True, download=True,
                         transform=transforms.Compose([
                            transforms.RandomCrop(32, padding=4),
                            transforms.RandomHorizontalFlip(),
                            transforms.ToTensor(),
                            transforms.Normalize((0.4914, 0.4822, 0.4465), (0.247, 0.243, 0.261)),
                         ]))
    test_set = CIFAR100('.', train=False, download=True,
                        transform=transforms.Compose([
                           transforms.ToTensor(),
                           transforms.Normalize((0.4914, 0.4822, 0.4465), (0.247, 0.243, 0.261)),
                        ]))

    sampler = None
    if distributed:
        sampler = DistributedSampler(
            dataset=train_set,
            num_replicas=dist.get_world_size(),
            rank=dist.get_rank(),
            shuffle=True,
            drop_last=True
        )
    train_loader = DataLoader(train_set, batch_size=batch_size, sampler=sampler, shuffle=(sampler is None),
                              persistent_workers=True, num_workers=workers, pin_memory=True)

    if distributed:
        sampler = DistributedSampler(
            dataset=test_set,
            num_replicas=dist.get_world_size(),
            rank=dist.get_rank(),
            shuffle=False,
            drop_last=False
        )
    val_loader = DataLoader(test_set, batch_size=batch_size, sampler=sampler,
                            persistent_workers=True, num_workers=workers, pin_memory=True)
    return train_loader, val_loader


def cifar10(batch_size, workers, distributed=False):
    train_set = CIFAR10('.', train=True, download=True,
                        transform=transforms.Compose([
                            transforms.RandomCrop(32, padding=4),
                            transforms.RandomHorizontalFlip(),
                            transforms.ToTensor(),
                            transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
                        ]))
    test_set = CIFAR10('.', train=False, download=True,
                       transform=transforms.Compose([
                            transforms.ToTensor(),
                            transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
                       ]))

    sampler = None
    if distributed:
        sampler = DistributedSampler(
            dataset=train_set,
            num_replicas=dist.get_world_size(),
            rank=dist.get_rank(),
            shuffle=True,
            drop_last=True
        )
    train_loader = DataLoader(train_set, batch_size=batch_size, sampler=sampler, shuffle=(sampler is None),
                              persistent_workers=True, num_workers=workers, pin_memory=True)

    if distributed:
        sampler = DistributedSampler(
            dataset=test_set,
            num_replicas=dist.get_world_size(),
            rank=dist.get_rank(),
            shuffle=False,
            drop_last=False
        )
    val_loader = DataLoader(test_set, batch_size=batch_size, sampler=sampler,
                            persistent_workers=True, num_workers=workers, pin_memory=True)
    return train_loader, val_loader
