import os
import torch
import pickle
from torchvision import datasets, transforms

def loader(dataset, path, batch_size, test_batch_size, device):
    if dataset.startswith('cifar'):
        kwargs = {'num_workers': 4, 'pin_memory': True} if device else {}
        CIFAR = datasets.CIFAR10 if dataset == 'cifar10' else datasets.CIFAR100
        if dataset == 'cifar10':
            normalize = transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616))
        else:
            normalize = transforms.Normalize((0.5071, 0.4865, 0.4409), (0.2673, 0.2564, 0.2762))

        train_loader = torch.utils.data.DataLoader(
        CIFAR(path, train=True, download=True, transform =transforms.Compose([
                    transforms.Pad(4),
                    transforms.RandomCrop(32),
                    transforms.RandomHorizontalFlip(),
                    transforms.ToTensor(),
                    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
                    #normalize
                    ])),
        batch_size=batch_size, shuffle=True, **kwargs)

        test_loader = torch.utils.data.DataLoader(
        CIFAR(path, train=False, transform=transforms.Compose([
                    transforms.ToTensor(),
                    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
                    #normalize
                    ])),
        batch_size=test_batch_size, shuffle=False, **kwargs)
    return train_loader, test_loader
