import sys
import numpy as np
import scipy
import torch
import os
from torch.utils.data.distributed import DistributedSampler

import torchvision.transforms as transforms
import torchvision.datasets as dsets
import torchvision.models as models
import torch.nn as nn
from torch.autograd import Variable
import torch.nn.functional as F
import warnings
import argparse

np.random.seed(0)
torch.manual_seed(0)
torch.cuda.manual_seed_all(0)



def prepare_CIFAR100_data(batch_size):
    normalize = transforms.Normalize(mean=[x/255.0 for x in [125.3, 123.0, 113.9]], std=[x/255.0 for x in [63.0, 62.1, 66.7]])
    transform_test = transforms.Compose([transforms.ToTensor(), normalize])
    transform_train = transforms.Compose([
        	transforms.ToTensor(),
        	transforms.Lambda(lambda x: F.pad(x.unsqueeze(0),
        						(4,4,4,4),mode='reflect').squeeze()),
            transforms.ToPILImage(),
            transforms.RandomCrop(32),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            normalize,
            ])
    train_dataset = dsets.CIFAR100(root='./data/CIFAR100', train=True, transform=transform_train, download=True)
    train_sampler = DistributedSampler(train_dataset,shuffle=True)
    train_loader = torch.utils.data.DataLoader(train_dataset, sampler=train_sampler, batch_size=batch_size,
                                  num_workers=1, pin_memory=True)
       
    
    test_dataset = dsets.CIFAR100(root='./data/CIFAR100', train=False, transform=transform_test, download=True)
    test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=512, shuffle=False)

    return train_loader, test_loader, train_sampler
