import torch
from collections import OrderedDict
import torch.nn as nn
import torch.nn.functional as F
from collections import namedtuple
import torchvision
import argparse
import logging
import sys
import time
import math

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable

import os
EPS = 1E-20

cifar10_mean = (0, 0, 0)
cifar10_std = (1, 1, 1)


upper_limit, lower_limit = 1,0
mu = torch.tensor(cifar10_mean).view(3, 1, 1).cuda()
std = torch.tensor(cifar10_std).view(3, 1, 1).cuda()

def normalize(X):
    return (X - mu)/std

def diff_in_weights(model, proxy):
    diff_dict = OrderedDict()
    model_state_dict = model.state_dict()
    proxy_state_dict = proxy.state_dict()
    for (old_k, old_w), (new_k, new_w) in zip(model_state_dict.items(), proxy_state_dict.items()):
        if len(old_w.size()) <= 1:
            continue
        if 'weight' in old_k:
            diff_w = new_w - old_w
            diff_dict[old_k] = old_w.norm() / (diff_w.norm() + EPS) * diff_w
    return diff_dict


def add_into_weights(model, diff, coeff=1.0):
    names_in_diff = diff.keys()
    with torch.no_grad():
        for name, param in model.named_parameters():
            if name in names_in_diff:
                param.add_(coeff * diff[name])


class AdvWeightPerturb(object):
    def __init__(self, model, proxy, proxy_optim, gamma):
        super(AdvWeightPerturb, self).__init__()
        self.model = model
        self.proxy = proxy
        self.proxy_optim = proxy_optim
        self.gamma = gamma

    def calc_awp(self, inputs_adv, targets):
        self.proxy.load_state_dict(self.model.state_dict())
        self.proxy.train()
        
        loss = - F.cross_entropy(self.proxy(inputs_adv), targets)

        self.proxy_optim.zero_grad()
        loss.backward()
        self.proxy_optim.step()

        # the adversary weight perturb
        diff = diff_in_weights(self.model, self.proxy)
        return diff

    def perturb(self, diff):
        add_into_weights(self.model, diff, coeff=1.0 * self.gamma)

    def restore(self, diff):
        add_into_weights(self.model, diff, coeff=-1.0 * self.gamma)









def normalise(x, mean=cifar10_mean, std=cifar10_std):
    x, mean, std = [np.array(a, np.float32) for a in (x, mean, std)]
    x -= mean*255
    x *= 1.0/(255*std)
    return x

def pad(x, border=4):
    return np.pad(x, [(0, 0), (border, border), (border, border), (0, 0)], mode='reflect')

def transpose(x, source='NHWC', target='NCHW'):
    return x.transpose([source.index(d) for d in target]) 

#####################
## data augmentation
#####################

class Crop(namedtuple('Crop', ('h', 'w'))):
    def __call__(self, x, x0, y0):
        return x[:,y0:y0+self.h,x0:x0+self.w]

    def options(self, x_shape):
        C, H, W = x_shape
        return {'x0': range(W+1-self.w), 'y0': range(H+1-self.h)}
    
    def output_shape(self, x_shape):
        C, H, W = x_shape
        return (C, self.h, self.w)
    
class FlipLR(namedtuple('FlipLR', ())):
    def __call__(self, x, choice):
        return x[:, :, ::-1].copy() if choice else x 
        
    def options(self, x_shape):
        return {'choice': [True, False]}

class Cutout(namedtuple('Cutout', ('h', 'w'))):
    def __call__(self, x, x0, y0):
        x = x.copy()
        x[:,y0:y0+self.h,x0:x0+self.w].fill(0.0)
        return x

    def options(self, x_shape):
        C, H, W = x_shape
        return {'x0': range(W+1-self.w), 'y0': range(H+1-self.h)} 
    
    
class Transform():
    def __init__(self, dataset, transforms):
        self.dataset, self.transforms = dataset, transforms
        self.choices = None
        
    def __len__(self):
        return len(self.dataset)
           
    def __getitem__(self, index):
        data, labels = self.dataset[index]
        for choices, f in zip(self.choices, self.transforms):
            args = {k: v[index] for (k,v) in choices.items()}
            data = f(data, **args)
        return data, labels
    
    def set_random_choices(self):
        self.choices = []
        x_shape = self.dataset[0][0].shape
        N = len(self)
        for t in self.transforms:
            options = t.options(x_shape)
            x_shape = t.output_shape(x_shape) if hasattr(t, 'output_shape') else x_shape
            self.choices.append({k:np.random.choice(v, size=N) for (k,v) in options.items()})

#####################
## dataset
#####################

def cifar10(root):
    train_set = torchvision.datasets.CIFAR10(root=root, train=True, download=True)
    test_set = torchvision.datasets.CIFAR10(root=root, train=False, download=True)
    return {
        'train': {'data': train_set.data, 'labels': train_set.targets},
        'test': {'data': test_set.data, 'labels': test_set.targets}
    }

def cifar100(root):
    train_set = torchvision.datasets.CIFAR100(root=root, train=True, download=True)
    test_set = torchvision.datasets.CIFAR100(root=root, train=False, download=True)
    return {
        'train': {'data': train_set.data, 'labels': train_set.targets},
        'test': {'data': test_set.data, 'labels': test_set.targets}
    }
#####################
## data loading
#####################

class Batches():
    def __init__(self, dataset, batch_size, shuffle, set_random_choices=False, num_workers=0, drop_last=False):
        self.dataset = dataset
        self.batch_size = batch_size
        self.set_random_choices = set_random_choices
        self.dataloader = torch.utils.data.DataLoader(
            dataset, batch_size=batch_size, num_workers=num_workers, pin_memory=True, shuffle=shuffle, drop_last=drop_last
        )

    def __iter__(self):
        if self.set_random_choices:
            self.dataset.set_random_choices()
        return ({'input': x.cuda().float(), 'target': y.cuda().long()} for (x,y) in self.dataloader)

    def __len__(self):
        return len(self.dataloader)
    
    
def attack_pgd_from_AWP(model, X, y, epsilon, alpha, attack_iters, restarts,
               norm, early_stop=False,
               mixup=False, y_a=None, y_b=None, lam=None):
    max_loss = torch.zeros(y.shape[0]).cuda()
    max_delta = torch.zeros_like(X).cuda()
    for _ in range(restarts):
        delta = torch.zeros_like(X).cuda()
        if norm == "l_inf":
            delta.uniform_(-epsilon, epsilon)
        elif norm == "l_2":
            delta.normal_()
            d_flat = delta.view(delta.size(0),-1)
            n = d_flat.norm(p=2,dim=1).view(delta.size(0),1,1,1)
            r = torch.zeros_like(n).uniform_(0, 1)
            delta *= r/n*epsilon
        else:
            raise ValueError
        delta = clamp(delta, lower_limit-X, upper_limit-X)
        delta.requires_grad = True
        for _ in range(attack_iters):
            output = model(normalize(X + delta))
            if early_stop:
                index = torch.where(output.max(1)[1] == y)[0]
            else:
                index = slice(None,None,None)
            if not isinstance(index, slice) and len(index) == 0:
                break
            if mixup:
                criterion = nn.CrossEntropyLoss()
                loss = mixup_criterion(criterion, model(normalize(X+delta)), y_a, y_b, lam)
            else:
                loss = F.cross_entropy(output, y)
            loss.backward()
            grad = delta.grad.detach()
            d = delta[index, :, :, :]
            g = grad[index, :, :, :]
            x = X[index, :, :, :]
            if norm == "l_inf":
                d = torch.clamp(d + alpha * torch.sign(g), min=-epsilon, max=epsilon)
            elif norm == "l_2":
                g_norm = torch.norm(g.view(g.shape[0],-1),dim=1).view(-1,1,1,1)
                scaled_g = g/(g_norm + 1e-10)
                d = (d + scaled_g*alpha).view(d.size(0),-1).renorm(p=2,dim=0,maxnorm=epsilon).view_as(d)
            d = clamp(d, lower_limit - x, upper_limit - x)
            delta.data[index, :, :, :] = d
            delta.grad.zero_()
        if mixup:
            criterion = nn.CrossEntropyLoss(reduction='none')
            all_loss = mixup_criterion(criterion, model(normalize(X+delta)), y_a, y_b, lam)
        else:
            all_loss = F.cross_entropy(model(normalize(X+delta)), y, reduction='none')
        max_delta[all_loss >= max_loss] = delta.detach()[all_loss >= max_loss]
        max_loss = torch.max(max_loss, all_loss)
    return max_delta


    
def lr_schedule(t, epochs = 200, lr_max = 0.1 ):
    if t / epochs < 0.5:
        return lr_max
    elif t / epochs < 0.75:
        return lr_max / 10.
    else:
        return lr_max / 100.
    
    
def clamp(X, lower_limit, upper_limit):
    return torch.max(torch.min(X, upper_limit), lower_limit)






def BSL(labels, logits, sample_per_class):
	"""Compute the Balanced Softmax Loss between `logits` and the ground truth `labels`.
	Args:
	  labels: A int tensor of size [batch].
	  logits: A float tensor of size [batch, no_of_classes].
	  sample_per_class: A int tensor of size [no of classes].
	  reduction: string. One of "none", "mean", "sum"
	Returns:
	  loss: A float tensor. Balanced Softmax Loss.
	"""
	for i in range(len(sample_per_class)):
		if sample_per_class[i] <= 0:
			sample_per_class[i] = 1
	spc = torch.tensor(sample_per_class).type_as(logits)
	spc = spc.unsqueeze(0).expand(logits.shape[0], -1)
	logits = logits + spc.log()
	loss = F.cross_entropy(input=logits, target=labels)
	return loss

def RBL(labels, logits, sample_pre_class, at_pre_class):
	beta = np.zeros(len(sample_pre_class)).astype(np.float32)
	E = np.zeros(len(sample_pre_class)).astype(np.float32)
	for i in range(len(sample_pre_class)):
		beta[i] = (sample_pre_class[i] - 1.) / sample_pre_class[i]
		E[i] = (1. - beta[i]**at_pre_class[i]) / (1. - beta[i])
	weights = 1. / (E + 1e-5)
	weights = weights / np.sum(weights) * len(sample_pre_class)
	loss = F.cross_entropy(logits, labels, weight=torch.from_numpy(weights.astype(np.float32)).cuda())
	return loss

def REAT(model, x, y, optimizer, sample_per_class, at_per_class, args):
	kl = nn.KLDivLoss(size_average='none').cuda()
	spc = torch.tensor(sample_per_class).type_as(x)
	weights = torch.sqrt(1. / (spc / spc.sum()))
	tail_class = [i for i in range(len(sample_per_class)//3 * 2 + 1, len(sample_per_class))]
	model.eval()
	epsilon = 8./255. #args.eps
	num_steps = 10 # args.ns
	step_size = 2./255. # args.ss
	x_adv = x.detach() + torch.FloatTensor(*x.shape).uniform_(-epsilon, epsilon).cuda()
	for _ in range(num_steps):
		x_adv.requires_grad_()
		with torch.enable_grad():
			logits_adv, f_adv  = model(x_adv, returnt='all')
			loss = RBL(y, logits_adv, sample_per_class, at_per_class)
		grad = torch.autograd.grad(loss, [x_adv])[0]
		x_adv = x_adv.detach() + step_size * torch.sign(grad.detach())
		x_adv = torch.min(torch.max(x_adv, x - epsilon), x + epsilon)
		x_adv = torch.clamp(x_adv, 0.0, 1.0)
	model.train()
	x_adv = Variable(torch.clamp(x_adv, 0.0, 1.0), requires_grad=False)
	# zero gradient
	optimizer.zero_grad()
	# calculate robust loss
	logits, f_adv = model(x_adv, returnt='all')
	TAIL = 0.0
	counter = 0.0
	for bi in range(y.size(0)):
		if y[bi].item() in tail_class:
			idt = torch.tensor([-1. if y[bi].item()==y[bj].item() else 1. for bj in range(y.size(0))]).cuda()
			W = torch.tensor([weights[y[bi].item()] + weights[y[bj].item()] for bj in range(y.size(0))]).cuda()
			TAIL += kl(F.log_softmax(f_adv, 1), F.softmax(f_adv[bi, :].clone().detach().view(1, -1).tile(y.size(0), ).view(y.size(0), -1), 1)) * idt * W
			counter += 1
	TAIL = TAIL.mean() / counter if counter>0. else 0.0
	loss = BSL(y, logits, sample_per_class) + TAIL
	return logits, loss