import numpy as np
import time

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.optim.lr_scheduler import StepLR, MultiStepLR, CosineAnnealingLR, ConstantLR

from torchvision import datasets, transforms
from torch.autograd import Variable


from typing import List
from ffcv.fields import IntField, RGBImageField
from ffcv.fields.decoders import IntDecoder, SimpleRGBImageDecoder
from ffcv.loader import Loader, OrderOption
from ffcv.pipeline.operation import Operation
from ffcv.transforms import RandomHorizontalFlip, Cutout, \
    RandomTranslate, Convert, ToDevice, ToTensor, ToTorchImage
from ffcv.transforms.common import Squeeze
from ffcv.writer import DatasetWriter


from models.simple_net import simple_net, simpler_net, simple_cnn, simple_net_stl, simple_net_mnist
from models.vgg import vgg11, vgg11_bn, vgg11_cifar100
from models.wideresnet import WideResNet
from models.wideresnet_wo_bn import WideResNet_wo_bn
from models.pyramid import PyramidNet
from models.res9 import res9
from models.res9_wo_bn import res9_wo_bn

from homura.vision.models.cifar_resnet import wrn16_8, wrn28_2, wrn28_10, resnet20, resnet56, resnext29_32x4d
from homura.vision.models import cifar_resnet

import tin

def vectorize(top_eigenvectors):
    eigvecs = []
    for i, eigenvector in enumerate(top_eigenvectors):
        eigvecs.append(torch.cat([param.reshape(-1,1) for param in eigenvector]))
    return torch.cat(eigvecs, axis=1)

def alt_eig_from_logit_derivs(model, output, y, num_classes):
#     print(y.shape) # (b,)
    deltas = []
    wdeltas = []
    L_G1 = []
    
#     optimizer = optim.SGD(model.parameters(), lr=1.)
    
    exist_class = [torch.sum(y==i) != 0 for i in range(num_classes)]
#     sort, ind = exist_class.sort()
#     last_exist_class = ind[sort==1].max()
#     print('y',y)
#     print('exist class',exist_class)
    for i, exist in enumerate(exist_class):
        if not exist:
            continue
            
#         not_last = not(i==last_exist_class)
        
        delta = []
        p = F.softmax(output, dim=1).detach() ### (b, C)

        temp2 = (output*p).sum(dim=1).unsqueeze(-1) ### (b, 1)
        temp1 = torch.sqrt(p)*(output - temp2) ### (b, C)
        mask = torch.zeros_like(temp1) ### (b, C)
        mask[y==i] = 1 ### 0th axis
        mask[:,i] = 0
        loss = torch.sum(temp1*mask)/torch.sum(mask)
        
#         optimizer.zero_grad()
        model.zero_grad()
        loss.backward(retain_graph = True)
        
        for _, param in model.named_parameters():
            if param.grad is None:
                continue
            delta.append(torch.clone(param.grad))
        
        wdeltas.append(torch.sqrt((num_classes-1)*torch.sum(y==i)/len(y))*vectorize([delta]))
        V_norm = torch.norm(vectorize([delta]))
        if V_norm > 1e-6:
            for delta_j in delta:
                delta_j /= V_norm
            deltas.append(delta)
            L_G1.append(((num_classes-1)*torch.sum(y==i)/len(y)*V_norm**2).item())
        else:
            print('small norm', i, V_norm)

    model.zero_grad()
    return L_G1, deltas, wdeltas

# class data_properties:
#     def __init__(self, dataset='cifar10'):
#         self.dataset = dataset
#         self.num_classes = self._num_classes()
#     def _num_classes(self):
#         num_classes = {'cifar10':10, 'Simple100':10, 'cifar100':100}
#         return num_classes[self.dataset]
        
              

    
def smooth_crossentropy(pred, gold, smoothing=0.1):
    n_class = pred.size(1)
    one_hot = torch.full_like(pred, fill_value=smoothing / (n_class - 1))
    one_hot.scatter_(dim=1, index=gold.unsqueeze(1), value=1.0 - smoothing)
    log_prob = F.log_softmax(pred, dim=1)
    return F.kl_div(input=log_prob, target=one_hot, reduction='none').sum(-1).mean()

def _one_hot(tensor: torch.Tensor, num_classes: int, default=0):
    M = F.one_hot(tensor, num_classes)
    M[M == 0] = default
    return M.float()    

class SquaredLoss(nn.Module):
    def forward(self, input: torch.Tensor, target: torch.Tensor):
        return 0.5 * ((input - _one_hot(target,10)) ** 2).sum()/len(input)

class smooth_CrossEntropyLoss(nn.Module):
    def __init__(self, smoothing):
        super(smooth_CrossEntropyLoss, self).__init__()
        self.smoothing = smoothing

    def forward(self, pred: torch.Tensor, gold: torch.Tensor):
        return smooth_crossentropy(pred, gold, self.smoothing)

# def show_running_stats(model):
#     def _show(module):
#         if isinstance(module, nn.BatchNorm2d):
#             print(module.backup_momentum)
#             print(module.momentum)
#             print(module.__dict__.keys())
#     print('show')
#     model.apply(_show)
    
def disable_running_stats(model):
    def _disable(module):
        if isinstance(module, nn.BatchNorm2d):
            module.backup_momentum = module.momentum
            module.momentum = 0
    model.apply(_disable)

def enable_running_stats(model):
    def _enable(module):
        if isinstance(module, nn.BatchNorm2d) and hasattr(module, "backup_momentum"):
            module.momentum = module.backup_momentum
    model.apply(_enable)
    
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

def get_model(model_name, dataset='cifar10',num_classes=10):
    if dataset=='mnist' or dataset=='submnist':
        in_channel = 1
    else:
        in_channel = 3
    if model_name == 'Simple100': ### 6CNN
        if dataset in ['stl','substl']:
            model = simple_net_stl(in_channel=in_channel, widen_factor=1, n_fc=100, num_classes=num_classes)
        elif dataset in ['mnist','submnist']:
            model = simple_net_mnist(in_channel=in_channel, widen_factor=1, n_fc=100, num_classes=num_classes)
        else:
            model = simple_net(in_channel=in_channel, widen_factor=1, n_fc=100, num_classes=num_classes)
    elif model_name == 'Simple128':
        model = simple_cnn(in_channel=in_channel, widen_factor=1, n_fc=128, num_classes=num_classes) ###
    elif model_name == 'Simple512':
        model = simple_net(in_channel=in_channel, widen_factor=1, n_fc=512, num_classes=num_classes)       
    elif model_name == '3FCN':
        if dataset == 'mnist' or dataset == 'submnist':
            model = nn.Sequential(
                        nn.Flatten(),
                        nn.Linear(784, 200, bias=True),
                        nn.ReLU(),
                        nn.Linear(200, 200, bias=True),
                        nn.ReLU(),
                        nn.Linear(200, 10, bias=True),
                        )            
        else: #cifar10, cifar100
            model = nn.Sequential(
                        nn.Flatten(),
                        nn.Linear(3072, 200, bias=True),
                        nn.ReLU(),
                        nn.Linear(200, 200, bias=True),
                        nn.ReLU(),
                        nn.Linear(200, num_classes, bias=True),
                        )
            
    elif model_name == '3CNN':
        if dataset == 'mnist' or dataset == 'submnist':
            model = nn.Sequential(
                        nn.Conv2d(1,6,3,1),
                        nn.ReLU(),
                        nn.MaxPool2d(2, 2),
                        nn.Conv2d(6,16,3,1),
                        nn.ReLU(),
                        nn.Flatten(),
                        nn.Linear(1936, 120),
                        nn.ReLU(),
                        nn.Linear(120, 84),
                        nn.ReLU(),
                        nn.Linear(84, 10),
                        )     
    elif model_name == 'vgg11':
        if dataset == 'cifar100' or dataset=='subcifar100':
            model = vgg11_cifar100() 
        else:
            model = vgg11()    
    elif model_name == 'vgg11_bn':
        if dataset == 'cifar100' or dataset=='subcifar100':
            model = vgg11_bn_cifar100() 
        else:
            model = vgg11_bn()    
    elif model_name == 'res9': ### ffcv
        model = res9(num_classes=num_classes)
    elif model_name == 'res9_wo_bn': ### ekfac
        model = res9_wo_bn(num_classes=num_classes)        
    elif model_name == 'resnet20': ### homura
        model = resnet20(num_classes=num_classes)
    elif model_name == 'resnet56': ### homura
        model = resnet56(num_classes=num_classes)
    elif model_name == 'resnext29_32x4d': ### homura
        model = resnext29_32x4d(num_classes=num_classes)
    elif model_name == 'wrn10_1': ### homura
        model =  cifar_resnet.wide_resnet(num_classes,10,1)
    elif model_name == 'wrn16_4': ### homura
        model =  cifar_resnet.wide_resnet(num_classes,16,4)
    elif model_name == 'wrn16_8': ### homura
        model = wrn16_8(num_classes=num_classes)
    elif model_name == 'wrn28_2': ### homura
        model = wrn28_2(num_classes=num_classes)
    elif model_name == 'wrn28_10': ### homura
        model = wrn28_10(num_classes=num_classes)
    elif model_name == 'Pyramid272':
        model = PyramidNet(dataset, depth=272, alpha=200, num_classes=num_classes)
        
    elif model_name == 'WRN101':
        model = WideResNet(10, num_classes, 1)
    elif model_name == 'WRN168':
        model = WideResNet(16, num_classes, 8)
    elif model_name == 'WRN282':
        model = WideResNet(28, num_classes, 2)
    elif model_name == 'WRN2810':
        model = WideResNet(28, num_classes, 10)
    elif model_name == 'WRN282_wo_bn':
        model = WideResNet_wo_bn(28, num_classes, 2)
    
    
    else:
        raise ValueError("Unknown model")
        
    print(model)
    print('# params:',count_parameters(model))
    return model    
    
class my_Cutout:
    def __init__(self, size=16, p=0.5):
        self.size = size
        self.half_size = size // 2
        self.p = p

    def __call__(self, image):
        if torch.rand([1]).item() > self.p:
            return image
        left = torch.randint(-self.half_size, image.size(1) - self.half_size, [1]).item()
        top = torch.randint(-self.half_size, image.size(2) - self.half_size, [1]).item()
        right = min(image.size(1), left + self.size)
        bottom = min(image.size(2), top + self.size)

        image[:, max(0, left): right, max(0, top): bottom] = 0
        return image
    

def get_data_stats(dataset):
    if dataset =='cifar10' or dataset =='subcifar10':
        train_set = datasets.CIFAR10(root='../data', train=True, download=True, transform=transforms.ToTensor())
    elif dataset =='stl' or dataset =='substl':
        train_set = datasets.STL10(root='../data', split='train', download=True, transform=transforms.ToTensor())  
    elif dataset =='svhn' or dataset =='subsvhn':
        print('******svhn')
        train_set = datasets.SVHN(root='../data', split='train', download=True, transform=transforms.ToTensor())  
        
    elif dataset =='timagenet' or dataset =='subtimagenet':
        print('******tiny imagenet')
        train_set = tin.TinyImageNetDataset(root_dir='../data/tiny-imagenet-200', mode='train', download=True, transform=transforms.ToTensor())
        
    elif dataset =='mnist' or dataset =='submnist':
        train_set = datasets.MNIST(root='../data', train=True, download=True, transform=transforms.ToTensor())        
    elif dataset =='cifar100' or dataset=='subcifar100':
        train_set = datasets.CIFAR100(root='../data', train=True, download=True, transform=transforms.ToTensor())
        
    if not (dataset =='svhn' or dataset =='subsvhn'):
        # if dataset in ['timagenet','subtimagenet']:
        #     for d in torch.utils.data.DataLoader(train_set):
        #         print(d)
        data = torch.cat([d[0] for d in torch.utils.data.DataLoader(train_set)])
        if dataset in ['cifar10','cifar100','subcifar10','subcifar100','stl','substl','svhn','subsvhn','imagenet','subimagenet','timagenet','subtimagenet']:
            mean, std = data.mean(dim=[0, 2, 3]), data.std(dim=[0, 2, 3])
        elif dataset =='mnist' or dataset =='submnist':
            mean, std = data.mean(dim=[0]), data.std(dim=[0])
    else:
        print(train_set)
        mean = (0.4376821, 0.4437697, 0.47280442)
        std = (0.19803012, 0.20101562, 0.19703614)
    return mean, std

def get_imagenet_data(dataset='timagenet', train_bs=128, test_bs=1000, data_augmentation = True, normalization= False, shuffle = True, cutout=False, drop_last=True, n_data=8192):
    """
    Get the dataloader
    """
    
    (mean, std) = get_data_stats(dataset)
    
    if dataset == 'timagenet':
        transform_train_list = []
        transform_test_list = []
        if data_augmentation:        
            augmentation_list = [transforms.RandomCrop(64, padding = 4),transforms.RandomHorizontalFlip()]
            transform_train_list += augmentation_list
            
        transform_train_list += [transforms.ToTensor()]
        transform_test_list += [transforms.ToTensor()]
        
        if normalization:
            normalizaqtion_list = [transforms.Normalize(mean,std)]
            transform_train_list += normalizaqtion_list
            transform_test_list += normalizaqtion_list
            
        if cutout:
            cutout_list = [my_Cutout()]
            transform_train_list += cutout_list
            
        transform_train = transforms.Compose(transform_train_list)
        transform_test = transforms.Compose(transform_test_list)
        
        trainset = tin.TinyImageNetDataset(root_dir='../data/tiny-imagenet-200',
                                    mode='train',
                                    download=True,
                                    transform=transform_train)
        testset = tin.TinyImageNetDataset(root_dir='../data/tiny-imagenet-200',
                                   mode='val',
                                   download=True,
                                   transform=transform_test)
        
        train_loader = torch.utils.data.DataLoader(trainset,
                                                   batch_size=train_bs,
                                                   shuffle=shuffle, 
                                                   num_workers=4,
                                                   pin_memory=True,
                                                  drop_last=drop_last ################
                                                  )
        test_loader = torch.utils.data.DataLoader(testset,
                                                  batch_size=test_bs,
                                                  shuffle=False, 
                                                  num_workers=4,
                                                  pin_memory=True
                                                 )
        
    elif dataset == 'subtimagenet':
        transform_train_list = []
        transform_test_list = []
        if data_augmentation:        
            augmentation_list = [transforms.RandomCrop(64, padding = 4),transforms.RandomHorizontalFlip()]
            transform_train_list += augmentation_list
            
        transform_train_list += [transforms.ToTensor()]
        transform_test_list += [transforms.ToTensor()]
        
        if normalization:
            normalizaqtion_list = [transforms.Normalize(mean,std)]
            transform_train_list += normalizaqtion_list
            transform_test_list += normalizaqtion_list
            
        if cutout:
            cutout_list = [my_Cutout()]
            transform_train_list += cutout_list
            
        transform_train = transforms.Compose(transform_train_list)
        transform_test = transforms.Compose(transform_test_list)
        
        trainset_ = tin.TinyImageNetDataset(root_dir='../data/tiny-imagenet-200',
                                    mode='train',
                                    download=True,
                                    transform=transform_train)
        testset = tin.TinyImageNetDataset(root_dir='../data/tiny-imagenet-200',
                                   mode='val',
                                   download=True,
                                   transform=transform_test)
        
        subset = range(0, n_data)
        trainset = torch.utils.data.Subset(trainset_, subset)

        
        train_loader = torch.utils.data.DataLoader(trainset,
                                                   batch_size=train_bs,
                                                   shuffle=shuffle, 
                                                   num_workers=4,
                                                   pin_memory=True,
                                                  drop_last=drop_last ################
                                                  )
        test_loader = torch.utils.data.DataLoader(testset,
                                                  batch_size=test_bs,
                                                  shuffle=False, 
                                                  num_workers=4,
                                                  pin_memory=True
                                                 )
        
    
        
    return train_loader, test_loader
    
    
def get_data(dataset='cifar10', train_bs=128, test_bs=1000, data_augmentation = True, normalization= False, shuffle = True, cutout=False, drop_last=True, n_data=8192):
    """
    Get the dataloader
    """
    
    (mean, std) = get_data_stats(dataset)
    
    if dataset == 'cifar10':
        transform_train_list = []
        transform_test_list = []
        if data_augmentation:        
            augmentation_list = [transforms.RandomCrop(32, padding = 4),transforms.RandomHorizontalFlip()]
            transform_train_list += augmentation_list
            
        transform_train_list += [transforms.ToTensor()]
        transform_test_list += [transforms.ToTensor()]
        
        if normalization:
            normalizaqtion_list = [transforms.Normalize(mean,std)]
            transform_train_list += normalizaqtion_list
            transform_test_list += normalizaqtion_list
            
        if cutout:
            cutout_list = [my_Cutout()]
            transform_train_list += cutout_list
            
        transform_train = transforms.Compose(transform_train_list)
        transform_test = transforms.Compose(transform_test_list)
        
        trainset = datasets.CIFAR10(root='../data',
                                    train=True,
                                    download=True,
                                    transform=transform_train)
        testset = datasets.CIFAR10(root='../data',
                                   train=False,
                                   download=True,
                                   transform=transform_test)
        
        train_loader = torch.utils.data.DataLoader(trainset,
                                                   batch_size=train_bs,
                                                   shuffle=shuffle, 
                                                   num_workers=4,
                                                   pin_memory=True,
                                                  drop_last=drop_last ################
                                                  )
        test_loader = torch.utils.data.DataLoader(testset,
                                                  batch_size=test_bs,
                                                  shuffle=False, 
                                                  num_workers=4,
                                                  pin_memory=True
                                                 )
        
    elif dataset == 'subcifar10':
        transform_train_list = []
        transform_test_list = []
        if data_augmentation:        
            augmentation_list = [transforms.RandomCrop(32, padding = 4),transforms.RandomHorizontalFlip()]
            transform_train_list += augmentation_list
            
        transform_train_list += [transforms.ToTensor()]
        transform_test_list += [transforms.ToTensor()]
        
        if normalization:
            normalizaqtion_list = [transforms.Normalize(mean,std)]
            transform_train_list += normalizaqtion_list
            transform_test_list += normalizaqtion_list
            
        if cutout:
            cutout_list = [my_Cutout()]
            transform_train_list += cutout_list
            
        transform_train = transforms.Compose(transform_train_list)
        transform_test = transforms.Compose(transform_test_list)
        
        trainset_ = datasets.CIFAR10(root='../data',
                                    train=True,
                                    download=True,
                                    transform=transform_train)
        testset = datasets.CIFAR10(root='../data',
                                   train=False,
                                   download=True,
                                   transform=transform_test)
        
        subset = range(0, n_data)
        trainset = torch.utils.data.Subset(trainset_, subset)

        
        train_loader = torch.utils.data.DataLoader(trainset,
                                                   batch_size=train_bs,
                                                   shuffle=shuffle, 
                                                   num_workers=4,
                                                   pin_memory=True,
                                                  drop_last=drop_last ################
                                                  )
        test_loader = torch.utils.data.DataLoader(testset,
                                                  batch_size=test_bs,
                                                  shuffle=False, 
                                                  num_workers=4,
                                                  pin_memory=True
                                                 )
        
    
    elif dataset == 'stl':
        transform_train_list = []
        transform_test_list = []
        if data_augmentation:        
            augmentation_list = [transforms.RandomCrop(96, padding = 4),transforms.RandomHorizontalFlip()]
            transform_train_list += augmentation_list
            
        transform_train_list += [transforms.ToTensor()]
        transform_test_list += [transforms.ToTensor()]
        
        if normalization:
            normalizaqtion_list = [transforms.Normalize(mean,std)]
            transform_train_list += normalizaqtion_list
            transform_test_list += normalizaqtion_list
            
        if cutout:
            cutout_list = [my_Cutout()]
            transform_train_list += cutout_list
            
        transform_train = transforms.Compose(transform_train_list)
        transform_test = transforms.Compose(transform_test_list)
        
        trainset = datasets.STL10(root='../data',
                                    split='train',
                                    download=True,
                                    transform=transform_train)
        testset = datasets.STL10(root='../data',
                                   split='test',
                                   download=True,
                                   transform=transform_test)
        
        train_loader = torch.utils.data.DataLoader(trainset,
                                                   batch_size=train_bs,
                                                   shuffle=shuffle, 
                                                   num_workers=4,
                                                   pin_memory=True,
                                                  drop_last=drop_last ################
                                                  )
        test_loader = torch.utils.data.DataLoader(testset,
                                                  batch_size=test_bs,
                                                  shuffle=False, 
                                                  num_workers=4,
                                                  pin_memory=True
                                                 )
        
    elif dataset == 'substl':
        transform_train_list = []
        transform_test_list = []
        if data_augmentation:        
            augmentation_list = [transforms.RandomCrop(96, padding = 4),transforms.RandomHorizontalFlip()]
            transform_train_list += augmentation_list
            
        transform_train_list += [transforms.ToTensor()]
        transform_test_list += [transforms.ToTensor()]
        
        if normalization:
            normalizaqtion_list = [transforms.Normalize(mean,std)]
            transform_train_list += normalizaqtion_list
            transform_test_list += normalizaqtion_list
            
        if cutout:
            cutout_list = [my_Cutout()]
            transform_train_list += cutout_list
            
        transform_train = transforms.Compose(transform_train_list)
        transform_test = transforms.Compose(transform_test_list)
        
        trainset_ = datasets.STL10(root='../data',
                                    split='train',
                                    download=True,
                                    transform=transform_train)
        testset = datasets.STL10(root='../data',
                                   split='test',
                                   download=True,
                                   transform=transform_test)
        
        subset = range(0, n_data)
        trainset = torch.utils.data.Subset(trainset_, subset)

        
        train_loader = torch.utils.data.DataLoader(trainset,
                                                   batch_size=train_bs,
                                                   shuffle=shuffle, 
                                                   pin_memory=True,
                                                  drop_last=drop_last ################
                                                  )
        test_loader = torch.utils.data.DataLoader(testset,
                                                  batch_size=test_bs,
                                                  shuffle=False, 
                                                  pin_memory=True
                                                 )
        
    elif dataset == 'svhn':
        transform_train_list = []
        transform_test_list = []
        if data_augmentation:        
            augmentation_list = [transforms.RandomCrop(32, padding = 4)]
            transform_train_list += augmentation_list
            
        transform_train_list += [transforms.ToTensor()]
        transform_test_list += [transforms.ToTensor()]
        
        if normalization:
            normalizaqtion_list = [transforms.Normalize(mean,std)]
            transform_train_list += normalizaqtion_list
            transform_test_list += normalizaqtion_list
            
        if cutout:
            cutout_list = [my_Cutout()]
            transform_train_list += cutout_list
            
        transform_train = transforms.Compose(transform_train_list)
        transform_test = transforms.Compose(transform_test_list)
        
        trainset = datasets.SVHN(root='../data',
                                    split='train',
                                    download=True,
                                    transform=transform_train)
        testset = datasets.SVHN(root='../data',
                                   split='test',
                                   download=True,
                                   transform=transform_test)
        
        train_loader = torch.utils.data.DataLoader(trainset,
                                                   batch_size=train_bs,
                                                   shuffle=shuffle, 
                                                   pin_memory=True,
                                                  drop_last=drop_last ################
                                                  )
        test_loader = torch.utils.data.DataLoader(testset,
                                                  batch_size=test_bs,
                                                  shuffle=False, 
                                                  pin_memory=True
                                                 )
        
    elif dataset == 'subsvhn':
        transform_train_list = []
        transform_test_list = []
        if data_augmentation:        
            augmentation_list = [transforms.RandomCrop(32, padding = 4)]
            transform_train_list += augmentation_list
            
        transform_train_list += [transforms.ToTensor()]
        transform_test_list += [transforms.ToTensor()]
        
        if normalization:
            normalizaqtion_list = [transforms.Normalize(mean,std)]
            transform_train_list += normalizaqtion_list
            transform_test_list += normalizaqtion_list
            
        if cutout:
            cutout_list = [my_Cutout()]
            transform_train_list += cutout_list
            
        transform_train = transforms.Compose(transform_train_list)
        transform_test = transforms.Compose(transform_test_list)
        
        trainset_ = datasets.SVHN(root='../data',
                                    split='train',
                                    download=True,
                                    transform=transform_train)
        testset = datasets.SVHN(root='../data',
                                   split='test',
                                   download=True,
                                   transform=transform_test)
        
        subset = range(0, n_data)
        trainset = torch.utils.data.Subset(trainset_, subset)

        
        train_loader = torch.utils.data.DataLoader(trainset,
                                                   batch_size=train_bs,
                                                   shuffle=shuffle, 
                                                   num_workers=4,
                                                   pin_memory=True,
                                                  drop_last=drop_last ################
                                                  )
        test_loader = torch.utils.data.DataLoader(testset,
                                                  batch_size=test_bs,
                                                  shuffle=False, 
                                                  num_workers=4,
                                                  pin_memory=True
                                                 )
        
        
    elif dataset == 'imagenet':
        transform_train_list = []
        transform_test_list = []
        if data_augmentation:        
            augmentation_list = [transforms.RandomCrop(32, padding = 4)]
            transform_train_list += augmentation_list
            
        transform_train_list += [transforms.ToTensor()]
        transform_test_list += [transforms.ToTensor()]
        
        if normalization:
            normalizaqtion_list = [transforms.Normalize(mean,std)]
            transform_train_list += normalizaqtion_list
            transform_test_list += normalizaqtion_list
            
        if cutout:
            cutout_list = [my_Cutout()]
            transform_train_list += cutout_list
            
        transform_train = transforms.Compose(transform_train_list)
        transform_test = transforms.Compose(transform_test_list)
        
        trainset = datasets.ImageNet(root='../data',
                                    split='train',
                                    download=True,
                                    transform=transform_train)
        testset = datasets.ImageNet(root='../data',
                                   split='test',
                                   download=True,
                                   transform=transform_test)
        
        train_loader = torch.utils.data.DataLoader(trainset,
                                                   batch_size=train_bs,
                                                   shuffle=shuffle, 
                                                   pin_memory=True,
                                                  drop_last=drop_last ################
                                                  )
        test_loader = torch.utils.data.DataLoader(testset,
                                                  batch_size=test_bs,
                                                  shuffle=False, 
                                                  pin_memory=True
                                                 )
        
    elif dataset == 'subimagenet':
        transform_train_list = []
        transform_test_list = []
        if data_augmentation:        
            augmentation_list = [transforms.RandomCrop(32, padding = 4)]
            transform_train_list += augmentation_list
            
        transform_train_list += [transforms.ToTensor()]
        transform_test_list += [transforms.ToTensor()]
        
        if normalization:
            normalizaqtion_list = [transforms.Normalize(mean,std)]
            transform_train_list += normalizaqtion_list
            transform_test_list += normalizaqtion_list
            
        if cutout:
            cutout_list = [my_Cutout()]
            transform_train_list += cutout_list
            
        transform_train = transforms.Compose(transform_train_list)
        transform_test = transforms.Compose(transform_test_list)
        
        trainset_ = datasets.ImageNet(root='../data',
                                    split='train',
                                    download=True,
                                    transform=transform_train)
        testset = datasets.ImageNet(root='../data',
                                   split='test',
                                   download=True,
                                   transform=transform_test)
        
        subset = range(0, n_data)
        trainset = torch.utils.data.Subset(trainset_, subset)

        
        train_loader = torch.utils.data.DataLoader(trainset,
                                                   batch_size=train_bs,
                                                   shuffle=shuffle, 
                                                   num_workers=4,
                                                   pin_memory=True,
                                                  drop_last=drop_last ################
                                                  )
        test_loader = torch.utils.data.DataLoader(testset,
                                                  batch_size=test_bs,
                                                  shuffle=False, 
                                                  num_workers=4,
                                                  pin_memory=True
                                                 )
        
    elif dataset == 'mnist':
        transform_train_list = []
        transform_test_list = []
        transform_train_list += [transforms.ToTensor()]
        transform_test_list += [transforms.ToTensor()]
        
        if normalization:
            normalizaqtion_list = [transforms.Normalize(mean,std)]
            transform_train_list += normalizaqtion_list
            transform_test_list += normalizaqtion_list
        
        transform_train = transforms.Compose(transform_train_list)
        transform_test = transforms.Compose(transform_test_list)
        
        trainset = datasets.MNIST(root='../data',
                                     train=True,
                                     download=True,
                                     transform=transform_train)
        testset = datasets.MNIST(root='../data',
                                    train=False,
                                    download=True,
                                    transform=transform_test)
        train_loader = torch.utils.data.DataLoader(trainset,
                                                   batch_size=train_bs,
                                                   shuffle=shuffle,
                                                   num_workers=4,
                                                   pin_memory=True
                                                  )
        test_loader = torch.utils.data.DataLoader(testset,
                                                  batch_size=test_bs,
                                                  shuffle=False,
                                                  num_workers=4,
                                                  pin_memory=True
                                                 )
        
    elif dataset == 'submnist':
        transform_train_list = []
        transform_test_list = []
        transform_train_list += [transforms.ToTensor()]
        transform_test_list += [transforms.ToTensor()]
        
        if normalization:
            normalizaqtion_list = [transforms.Normalize(mean,std)]
            transform_train_list += normalizaqtion_list
            transform_test_list += normalizaqtion_list
        
        transform_train = transforms.Compose(transform_train_list)
        transform_test = transforms.Compose(transform_test_list)
        
        
        trainset_ = datasets.MNIST(root='../data',
                                     train=True,
                                     download=True,
                                     transform=transforms.ToTensor())
        testset = datasets.MNIST(root='../data',
                                    train=False,
                                    download=True,
                                    transform=transforms.ToTensor())
        
        if normalization:
            normalizaqtion_list = [transforms.Normalize(mean,std)]
            transform_train_list += normalizaqtion_list
            transform_test_list += normalizaqtion_list
            
        subset = range(0, n_data)
        trainset = torch.utils.data.Subset(trainset_, subset)
        
        train_loader = torch.utils.data.DataLoader(trainset,
                                                   batch_size=train_bs,
                                                   shuffle=shuffle,
                                                   num_workers=4,
                                                   pin_memory=True
                                                  )
        test_loader = torch.utils.data.DataLoader(testset,
                                                  batch_size=test_bs,
                                                  shuffle=False,
                                                  num_workers=4,
                                                  pin_memory=True
                                                 )
    
    elif dataset == 'cifar100': ####### TODO: cutout 
        transform_train_list = []
        transform_test_list = []
        if data_augmentation:        
            augmentation_list = [transforms.RandomCrop(32, padding=4),
                                 transforms.RandomHorizontalFlip(),
#                                  transforms.RandomRotation(15)
                                ]
            transform_train_list += augmentation_list
            
        transform_train_list += [transforms.ToTensor()]
        transform_test_list += [transforms.ToTensor()]
        
        if normalization:
            normalizaqtion_list = [transforms.Normalize(mean,std)]
            transform_train_list += normalizaqtion_list
            transform_test_list += normalizaqtion_list
        transform_train = transforms.Compose(transform_train_list)
        transform_test = transforms.Compose(transform_test_list)
        
        trainset = datasets.CIFAR100(root='../data',
                                    train=True,
                                    download=True,
                                    transform=transform_train
                                    )
        testset = datasets.CIFAR100(root='../data',
                                   train=False,
                                   download=True,
                                   transform=transform_test
                                   )
        
        train_loader = torch.utils.data.DataLoader(trainset,
                                                   batch_size=train_bs,
                                                   shuffle=shuffle,  
                                                   num_workers=4,
                                                   pin_memory=True
                                                  )

        test_loader = torch.utils.data.DataLoader(testset,
                                                  batch_size=test_bs,
                                                  shuffle=False, 
                                                  num_workers=4,
                                                  pin_memory=True
                                                 )
    
    elif dataset == 'subcifar100': ####### TODO: cutout 
        transform_train_list = []
        transform_test_list = []
        if data_augmentation:        
            augmentation_list = [transforms.RandomCrop(32, padding=4),
                                 transforms.RandomHorizontalFlip(),
#                                  transforms.RandomRotation(15)
                                ]
            transform_train_list += augmentation_list
            
        transform_train_list += [transforms.ToTensor()]
        transform_test_list += [transforms.ToTensor()]
        
        if normalization:
            normalizaqtion_list = [transforms.Normalize(mean,std)]
            transform_train_list += normalizaqtion_list
            transform_test_list += normalizaqtion_list
        transform_train = transforms.Compose(transform_train_list)
        transform_test = transforms.Compose(transform_test_list)
        
        trainset_ = datasets.CIFAR100(root='../data',
                                    train=True,
                                    download=True,
                                    transform=transform_train
                                    )
        testset = datasets.CIFAR100(root='../data',
                                   train=False,
                                   download=True,
                                   transform=transform_test
                                   )
        
        subset = range(0, n_data)
        trainset = torch.utils.data.Subset(trainset_, subset)
        
        train_loader = torch.utils.data.DataLoader(trainset,
                                                   batch_size=train_bs,
                                                   shuffle=shuffle,  
                                                   num_workers=4,
                                                   pin_memory=True
                                                  )

        test_loader = torch.utils.data.DataLoader(testset,
                                                  batch_size=test_bs,
                                                  shuffle=False, 
                                                  num_workers=4,
                                                  pin_memory=True
                                                 )

    else:
        raise ValueError("Unknown dataset")
        
    return train_loader, test_loader


def get_data_ffcv(dataset='cifar10', train_bs=128, test_bs=1000, data_augmentation = True, normalization= False, shuffle = True, cutout=False):
    """
    Get the dataloader with ffcv
    https://github.com/libffcv/ffcv
    https://github.com/libffcv/ffcv/blob/main/examples/cifar/train_cifar.py
    """
    train_dataset = '../data/cifar_train.beton'
    val_dataset = '../data/cifar_test.beton'
    paths = {
    'train': train_dataset,
    'test': val_dataset
    }

    start_time = time.time()
    (mean, std) = get_data_stats(dataset)
    
    CIFAR_MEAN = mean
    CIFAR_STD = std
    
    loaders = {}
    batch_sizes = {'train': train_bs, 'test': test_bs}
    num_workers = 4

    for name in ['train', 'test']:
        label_pipeline: List[Operation] = [IntDecoder(), ToTensor(), ToDevice('cuda:0'), Squeeze()]
        image_pipeline: List[Operation] = [SimpleRGBImageDecoder()]
        if name == 'train' and data_augmentation:
            image_pipeline.extend([
                RandomHorizontalFlip(),
                RandomTranslate(padding=2, fill=tuple(map(int, CIFAR_MEAN))),
                Cutout(4, tuple(map(int, CIFAR_MEAN))),
            ])
        image_pipeline.extend([
            ToTensor(),
            ToDevice('cuda:0', non_blocking=True),
            ToTorchImage(),
            Convert(torch.float16), ########################
            transforms.Normalize(CIFAR_MEAN, CIFAR_STD),
        ])
        if normalization:
            image_pipeline.extend([
                transforms.Normalize(CIFAR_MEAN, CIFAR_STD),
            ])
        if cutout:
            image_pipeline.extend([
                my_Cutout(),
            ])
        
        ordering = OrderOption.RANDOM if name == 'train' and shuffle else OrderOption.SEQUENTIAL

        loaders[name] = Loader(paths[name], batch_size=batch_sizes[name], num_workers=num_workers,
                               order=ordering, drop_last=(name == 'train'),
                               pipelines={'image': image_pipeline, 'label': label_pipeline})
        
        
    return loaders['train'], loaders['test'] #train_loader, test_loader

                
        
def get_criterion(criterion, smoothing=0):    
    if criterion =='cross-entropy':
        cr = nn.CrossEntropyLoss()
    elif criterion =='mse':
        cr = SquaredLoss()
    elif criterion =='label_smoothing':
        cr = smooth_CrossEntropyLoss(smoothing=smoothing)
    return cr

 
def get_lr_scheduler(lr_scheduler, optimizer, milestones, gamma, epochs): 
    if lr_scheduler =='multistep':
        lrs = MultiStepLR(optimizer, milestones, gamma=gamma) 
    elif lr_scheduler =='cosine':
        lrs = CosineAnnealingLR(optimizer, epochs)
    elif lr_scheduler =='constant':
        lrs = ConstantLR(optimizer, factor=1)
    return lrs

def test(model, test_loader, cuda=True, print_opt=True, cr='ce'):
    """
    Get the test performance
    """
    model.eval()
    correct = 0
    total_num = 0
    test_loss = 0
    if cr =='ce':
        criterion = nn.CrossEntropyLoss()  
    elif cr =='mse':
        criterion = SquaredLoss()
    for data, target in test_loader:
        if cuda:
            data, target = data.cuda(), target.cuda()
        output = model((data))
        
        loss = criterion(output, target)
        test_loss += loss.item() * target.size()[0]
        
        pred = output.data.max(
            1, keepdim=True)[1]  # get the index of the max log-probability
        correct += pred.eq(target.data.view_as(pred)).cpu().sum().item()
        total_num += len(data)
#     print(pred)
#     print(target)
    if print_opt:
        print('testing_correct: ', correct / total_num, '\n')
    return correct / total_num, test_loss / total_num

