import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.optim.lr_scheduler import StepLR, MultiStepLR, CosineAnnealingLR
from torchvision import datasets, transforms
from torch.autograd import Variable
from torch.utils.data import Dataset
from models.vgg import vgg11_bn, vgg11_bn_cifar100
from models.wideresnet import WideResNet
from models.resnet import resnet20, resnet56
from models.lenet import LeNet5


# class data_properties:
#     def __init__(self, dataset='cifar10'):
#         self.dataset = dataset
#         self.num_classes = self._num_classes()
#     def _num_classes(self):
#         num_classes = {'cifar10':10, 'mnist':10, 'cifar100':100}
#         return num_classes[self.dataset]
        
import math
from torch.optim.lr_scheduler import _LRScheduler
# this is from https://gaussian37.github.io/dl-pytorch-lr_scheduler/
class CosineAnnealingWarmUpRestarts(_LRScheduler):
    def __init__(self, optimizer, T_0, T_mult=1, eta_max=0.1, T_up=0, gamma=1., last_epoch=-1):
        if T_0 <= 0 or not isinstance(T_0, int):
            raise ValueError("Expected positive integer T_0, but got {}".format(T_0))
        if T_mult < 1 or not isinstance(T_mult, int):
            raise ValueError("Expected integer T_mult >= 1, but got {}".format(T_mult))
        if T_up < 0 or not isinstance(T_up, int):
            raise ValueError("Expected positive integer T_up, but got {}".format(T_up))
        self.T_0 = T_0
        self.T_mult = T_mult
        self.base_eta_max = eta_max
        self.eta_max = eta_max
        self.T_up = T_up
        self.T_i = T_0
        self.gamma = gamma
        self.cycle = 0
        self.T_cur = last_epoch
        super(CosineAnnealingWarmUpRestarts, self).__init__(optimizer, last_epoch)
    
    def get_lr(self):
        if self.T_cur == -1:
            return self.base_lrs
        elif self.T_cur < self.T_up:
            return [(self.eta_max - base_lr)*self.T_cur / self.T_up + base_lr for base_lr in self.base_lrs]
        else:
            return [base_lr + (self.eta_max - base_lr) * (1 + math.cos(math.pi * (self.T_cur-self.T_up) / (self.T_i - self.T_up))) / 2
                    for base_lr in self.base_lrs]

    def step(self, epoch=None):
        if epoch is None:
            epoch = self.last_epoch + 1
            self.T_cur = self.T_cur + 1
            if self.T_cur >= self.T_i:
                self.cycle += 1
                self.T_cur = self.T_cur - self.T_i
                self.T_i = (self.T_i - self.T_up) * self.T_mult + self.T_up
        else:
            if epoch >= self.T_0:
                if self.T_mult == 1:
                    self.T_cur = epoch % self.T_0
                    self.cycle = epoch // self.T_0
                else:
                    n = int(math.log((epoch / self.T_0 * (self.T_mult - 1) + 1), self.T_mult))
                    self.cycle = n
                    self.T_cur = epoch - self.T_0 * (self.T_mult ** n - 1) / (self.T_mult - 1)
                    self.T_i = self.T_0 * self.T_mult ** (n)
            else:
                self.T_i = self.T_0
                self.T_cur = epoch
                
        self.eta_max = self.base_eta_max * (self.gamma**self.cycle)
        self.last_epoch = math.floor(epoch)
        for param_group, lr in zip(self.optimizer.param_groups, self.get_lr()):
            param_group['lr'] = lr
"""    
def smooth_crossentropy(pred, gold, smoothing=0.1):
    n_class = pred.size(1)
    one_hot = torch.full_like(pred, fill_value=smoothing / (n_class - 1))
    one_hot.scatter_(dim=1, index=gold.unsqueeze(1), value=1.0 - smoothing)
    log_prob = F.log_softmax(pred, dim=1)
    return F.kl_div(input=log_prob, target=one_hot, reduction='none').sum(-1).mean()
"""
def smooth_crossentropy(pred, gold, smoothing=0.1, reduction='mean'):
    n_class = pred.size(1)
    one_hot = torch.full_like(pred, fill_value=smoothing / (n_class - 1))
    one_hot.scatter_(dim=1, index=gold.unsqueeze(1), value=1.0 - smoothing)
    log_prob = F.log_softmax(pred, dim=1)
    if reduction == 'none':
        return F.kl_div(input=log_prob, target=one_hot, reduction='none').sum(-1)
    return F.kl_div(input=log_prob, target=one_hot, reduction='none').sum(-1).mean()

def _one_hot(tensor: torch.Tensor, num_classes: int, default=0):
    M = F.one_hot(tensor, num_classes)
    M[M == 0] = default
    return M.float()    

class SquaredLoss(nn.Module):
    def forward(self, input: torch.Tensor, target: torch.Tensor):
        return 0.5 * ((input - _one_hot(target,10)) ** 2).sum()/len(input)
"""
class smooth_CrossEntropyLoss(nn.Module):
    def __init__(self, smoothing):
        super(smooth_CrossEntropyLoss, self).__init__()
        self.smoothing = smoothing

    def forward(self, pred: torch.Tensor, gold: torch.Tensor):
        return smooth_crossentropy(pred, gold, self.smoothing)
"""

class smooth_CrossEntropyLoss(nn.Module):
    def __init__(self, smoothing, reduction = 'mean'):
        super(smooth_CrossEntropyLoss, self).__init__()
        self.smoothing = smoothing
        self.reduction = reduction

    def forward(self, pred: torch.Tensor, gold: torch.Tensor):
        return smooth_crossentropy(pred, gold, self.smoothing, self.reduction)
    
def disable_running_stats(model):
    def _disable(module):
        if isinstance(module, nn.BatchNorm2d):
            module.backup_momentum = module.momentum
            module.momentum = 0

    model.apply(_disable)

def enable_running_stats(model):
    def _enable(module):
        if isinstance(module, nn.BatchNorm2d) and hasattr(module, "backup_momentum"):
            module.momentum = module.backup_momentum

    model.apply(_enable)
    
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

def get_model(model_name, dataset='cifar10',num_classes=10):
    if model_name == '3FCN':
        if dataset == 'mnist':
            model = nn.Sequential(
                        nn.Flatten(),
                        nn.Linear(784, 200, bias=True),
                        nn.ReLU(),
                        nn.Linear(200, 200, bias=True),
                        nn.ReLU(),
                        nn.Linear(200, 10, bias=True),
                        )            
        else: #cifar10, cifar100
            model = nn.Sequential(
                        nn.Flatten(),
                        nn.Linear(3072, 200, bias=True),
                        nn.ReLU(),
                        nn.Linear(200, 200, bias=True),
                        nn.ReLU(),
                        nn.Linear(200, num_classes, bias=True),
                        )
    elif model_name == 'lenet':
        if dataset == 'cifar100':
            model = LeNet5(100)
        elif dataset == 'mnist':
            model = LeNet5(10, input_channel = 1)
        else: # cifar10
            model = LeNet5(10)
    elif model_name == 'vgg11_bn':
        if dataset == 'cifar100':
            model = vgg11_bn_cifar100() 
        else:
            model = vgg11_bn()   
    elif model_name == 'resnet20':
        model = resnet20(num_classes=num_classes)
    elif model_name == 'resnet56':
        model = resnet56(num_classes=num_classes)
    elif model_name == 'WRN164':
        model = WideResNet(16, num_classes, 4)
    elif model_name == 'WRN168':
        model = WideResNet(16, num_classes, 8)
    
    
    else:
        raise ValueError("Unknown model")
        
    print(model)
    print('# params:',count_parameters(model))
    return model    
    
class Cutout:
    def __init__(self, size=16, p=0.5):
        self.size = size
        self.half_size = size // 2
        self.p = p

    def __call__(self, image):
        if torch.rand([1]).item() > self.p:
            return image
        left = torch.randint(-self.half_size, image.size(1) - self.half_size, [1]).item()
        top = torch.randint(-self.half_size, image.size(2) - self.half_size, [1]).item()
        right = min(image.size(1), left + self.size)
        bottom = min(image.size(2), top + self.size)

        image[:, max(0, left): right, max(0, top): bottom] = 0
        return image
    

def get_data_stats(dataset):
    if dataset =='cifar10':
        train_set = datasets.CIFAR10(root='./data', train=True, download=True, transform=transforms.ToTensor())
    elif dataset =='mnist':
        train_set = datasets.MNIST(root='./data', train=True, download=True, transform=transforms.ToTensor())        
    elif dataset =='cifar100':
        train_set = datasets.CIFAR100(root='./data', train=True, download=True, transform=transforms.ToTensor())

    data = torch.cat([d[0] for d in torch.utils.data.DataLoader(train_set)])
    if dataset =='cifar10' or dataset =='cifar100':
        mean, std = data.mean(dim=[0, 2, 3]), data.std(dim=[0, 2, 3])
    elif dataset =='mnist':
        mean, std = data.mean(dim=[0]), data.std(dim=[0])
    return mean, std

def get_data(dataset='cifar10', train_bs=128, test_bs=1000, data_augmentation = True, normalization= False, shuffle = True, cutout=False, model = None):
    """
    Get the dataloader
    subset: indices for subset
    """
    (mean, std) = get_data_stats(dataset)
    
    if dataset == 'cifar10':
        transform_train_list = []
        transform_test_list = []
        if data_augmentation:        
            augmentation_list = [transforms.RandomCrop(32, padding = 4),transforms.RandomHorizontalFlip()]
            transform_train_list += augmentation_list
            
        transform_train_list += [transforms.ToTensor()]
        transform_test_list += [transforms.ToTensor()]
        
        if normalization:
            normalizaqtion_list = [transforms.Normalize(mean,std)]
            transform_train_list += normalizaqtion_list
            transform_test_list += normalizaqtion_list
            
        if cutout:
            cutout_list = [Cutout()]
            transform_train_list += cutout_list
            
        transform_train = transforms.Compose(transform_train_list)
        transform_test = transforms.Compose(transform_test_list)
        
        trainset = datasets.CIFAR10(root='./data',
                                    train=True,
                                    download=True,
                                    transform=transform_train)
        testset = datasets.CIFAR10(root='./data',
                                   train=False,
                                   download=True,
                                   transform=transform_test)
        
        train_loader = torch.utils.data.DataLoader(trainset,
                                                   batch_size=train_bs,
                                                   shuffle=shuffle, 
                                                   num_workers=4,
                                                   pin_memory=True
                                                  )
        test_loader = torch.utils.data.DataLoader(testset,
                                                  batch_size=test_bs,
                                                  shuffle=False, 
                                                  num_workers=4,
                                                  pin_memory=True
                                                 )
        
    elif dataset == 'mnist':
        if model == 'lenet':
            transform = transforms.Compose([transforms.Resize((32, 32)),
                                 transforms.ToTensor()])
        else:
            transform = transforms.ToTensor()
        trainset = datasets.MNIST(root='./data',
                                     train=True,
                                     download=True,
                                     transform=transform)
        testset = datasets.MNIST(root='./data',
                                    train=False,
                                    download=True,
                                    transform=transform)
        train_loader = torch.utils.data.DataLoader(trainset,
                                                   batch_size=train_bs,
                                                   shuffle=shuffle,
                                                   num_workers=4,
                                                   pin_memory=True
                                                  )
        test_loader = torch.utils.data.DataLoader(testset,
                                                  batch_size=test_bs,
                                                  shuffle=False,
                                                  num_workers=4,
                                                  pin_memory=True
                                                 )
    elif dataset == 'cifar100':
        transform_train_list = []
        transform_test_list = []
        if data_augmentation:        
            augmentation_list = [transforms.RandomCrop(32, padding=4),
                                 transforms.RandomHorizontalFlip(),
#                                  transforms.RandomRotation(15)
                                ]
            transform_train_list += augmentation_list
            
        transform_train_list += [transforms.ToTensor()]
        transform_test_list += [transforms.ToTensor()]
        
        if normalization:
            normalizaqtion_list = [transforms.Normalize(mean,std)]
            transform_train_list += normalizaqtion_list
            transform_test_list += normalizaqtion_list
        transform_train = transforms.Compose(transform_train_list)
        transform_test = transforms.Compose(transform_test_list)
        
        trainset = datasets.CIFAR100(root='./data',
                                    train=True,
                                    download=True,
                                    transform=transform_train
                                    )
        testset = datasets.CIFAR100(root='./data',
                                   train=False,
                                   download=True,
                                   transform=transform_test
                                   )
        
        train_loader = torch.utils.data.DataLoader(trainset,
                                                   batch_size=train_bs,
                                                   shuffle=shuffle,  
                                                   num_workers=4,
                                                   pin_memory=True
                                                  )

        test_loader = torch.utils.data.DataLoader(testset,
                                                  batch_size=test_bs,
                                                  shuffle=False, 
                                                  num_workers=4,
                                                  pin_memory=True
                                                 )

    else:
        raise ValueError("Unknown dataset")
        
    return train_loader, test_loader

        
        
def get_criterion(criterion, smoothing=0):    
    if criterion =='cross-entropy':
        cr = nn.CrossEntropyLoss()
    elif criterion =='mse':
        cr = SquaredLoss()
    elif criterion =='label_smoothing':
        cr = smooth_CrossEntropyLoss(smoothing=smoothing)
    return cr

 
def get_lr_scheduler(lr_scheduler, optimizer, milestones, gamma, epochs, 
                     T_0=None, T_mult=1, T_up=0, eta_max=0.1): 
    if lr_scheduler =='multistep':
        lrs = MultiStepLR(optimizer, milestones, gamma=gamma) 
    elif lr_scheduler =='cosine':
        lrs = CosineAnnealingLR(optimizer, epochs)
    elif lr_scheduler == 'cosine_warmup':
        if T_0 is None:
            T_0 = epochs
        lrs = CosineAnnealingWarmUpRestarts(optimizer, T_0, T_mult=T_mult, eta_max=eta_max, 
                                            T_up=T_up, gamma=gamma)
    return lrs

def test(model, test_loader, cuda=True, print_opt=False):
    """
    Get the test performance
    """
    model.eval()
    correct = 0
    total_num = 0
    test_loss = 0
    
    criterion = nn.CrossEntropyLoss()
    for data, target in test_loader:
        if cuda:
            data, target = data.cuda(), target.cuda()
        output = model((data))
        
        loss = criterion(output, target)
        test_loss += loss.item() * target.size()[0]
        
        pred = output.data.max(
            1, keepdim=True)[1]  # get the index of the max log-probability
        correct += pred.eq(target.data.view_as(pred)).cpu().sum().item()
        total_num += len(data)
    if print_opt:
        print('testing_correct: ', correct / total_num, '\n')
    return correct / total_num, test_loss / total_num
