import torch
from collections import OrderedDict
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.autograd import Variable
from models.layers import GetSubnet
from utils.utils import rate_act_func
from utils.model import freeze_vars, unfreeze_vars


EPS = 1E-20

def trades_loss(
    model,
    x_natural,
    y,
    device,
    optimizer,
    step_size,
    epsilon,
    perturb_steps,
    beta,
    clip_min,
    clip_max,
    distance="l_inf",
    natural_criterion=nn.CrossEntropyLoss(),
):
    # define KL-loss
    criterion_kl = nn.KLDivLoss(size_average=False)
    model.eval()
    batch_size = len(x_natural)
    # generate adversarial example
    x_adv = (
        x_natural.detach() + 0.001 * torch.randn(x_natural.shape).to(device).detach()
    )
    if distance == "l_inf":
        for _ in range(perturb_steps):
            x_adv.requires_grad_()
            with torch.enable_grad():
                loss_kl = criterion_kl(
                    F.log_softmax(model(x_adv), dim=1),
                    F.softmax(model(x_natural), dim=1),
                )
            grad = torch.autograd.grad(loss_kl, [x_adv])[0]
            x_adv = x_adv.detach() + step_size * torch.sign(grad.detach())
            x_adv = torch.min(
                torch.max(x_adv, x_natural - epsilon), x_natural + epsilon
            )
            x_adv = torch.clamp(x_adv, clip_min, clip_max)
    elif distance == "l_2":
        delta = 0.001 * torch.randn(x_natural.shape).to(device).detach()
        delta = Variable(delta.data, requires_grad=True)

        # Setup optimizers
        optimizer_delta = optim.SGD([delta], lr=epsilon / perturb_steps * 2)

        for _ in range(perturb_steps):
            adv = x_natural + delta

            # optimize
            optimizer_delta.zero_grad()
            with torch.enable_grad():
                loss = (-1) * criterion_kl(
                    F.log_softmax(model(adv), dim=1), F.softmax(model(x_natural), dim=1)
                )
            loss.backward()
            # renorming gradient
            grad_norms = delta.grad.view(batch_size, -1).norm(p=2, dim=1)
            delta.grad.div_(grad_norms.view(-1, 1, 1, 1))
            # avoid nan or inf if gradient is 0
            if (grad_norms == 0).any():
                delta.grad[grad_norms == 0] = torch.randn_like(
                    delta.grad[grad_norms == 0]
                )
            optimizer_delta.step()

            # projection
            delta.data.add_(x_natural)
            delta.data.clamp_(clip_min, clip_max).sub_(x_natural)
            delta.data.renorm_(p=2, dim=0, maxnorm=epsilon)
        x_adv = Variable(x_natural + delta, requires_grad=False)
    else:
        x_adv = torch.clamp(x_adv, clip_min, clip_max)
    model.train()
    x_adv = Variable(torch.clamp(x_adv, clip_min, clip_max), requires_grad=False)
    return x_adv


def diff_all(model, proxy, param_name='popup_scores'):
    """Given a model and its proxy, who has a further backprop done, 
    this method computes the difference in each .weight layer as 
    w*(w-v)/||w-v|| 

    Args:
        model (torch model): _description_
        proxy (torch model): _description_

    Returns:
        diff_dict: _description_
    """
    diff_dict = OrderedDict()
    model_state_dict = model.state_dict()
    proxy_state_dict = proxy.state_dict()
    for (old_k, old_w), (new_k, new_w) in zip(model_state_dict.items(), proxy_state_dict.items()):
        if len(old_w.size()) <= 1:
            continue
        if param_name in old_k:
            diff_w = new_w - old_w
            diff_dict[old_k] = old_w.norm() / (diff_w.norm() + EPS) * diff_w
    return diff_dict


def diff_partial(model, proxy, k, param_name='popup_scores', layer_uniformity=False):
    diff_dict = OrderedDict()
    model_state_dict = model.state_dict()
    proxy_state_dict = proxy.state_dict()
    for (old_k, old_w), (new_k, new_w) in zip(model_state_dict.items(), proxy_state_dict.items()):
        if len(old_w.size()) <= 1:
            continue
        if param_name in old_k:  
            old_ps_k = old_k.replace(param_name, 'popup_scores')
            popup_scores = model_state_dict[old_ps_k]
            
            # if harp 
            if not layer_uniformity:
                old_k_k = old_k.replace(param_name, 'k_score')
                k_score = model_state_dict[old_k_k]
                k_min = k*0.1
                k = rate_act_func(k_score, k_min)

            adj = GetSubnet.apply(popup_scores.abs(), k, 'weight')
            new_w = new_w * adj
            old_w = old_w * adj
            # ensures that weights are zero in pruned slots (0-0=0)
            diff_w = new_w - old_w
            diff_dict[old_k] = old_w.norm() / (diff_w.norm() + EPS) * diff_w
    return diff_dict


def add_all(model, diff, coeff=1.0, param_name='popup_scores'):
    names_in_diff = diff.keys()
    with torch.no_grad():
        for name, param in model.named_parameters():
            if name in names_in_diff:
                if param_name in name:
                    param.add_(coeff * diff[name])
                else:
                    param.add(coeff + diff[name])


def add_partial(model, diff, k, coeff=1.0, param_name='popup_scores', layer_uniformity=False):
    names_in_diff = diff.keys()
    with torch.no_grad():
        for name, param in model.named_parameters():
            if name in names_in_diff:
                if param_name in name:
                    ps_k = name.replace(param_name, 'popup_scores')
                    popup_scores = model.state_dict()[ps_k]
                    
                    # if harp
                    if not layer_uniformity:
                        old_k_k = name.replace(param_name, 'k_score')
                        k_score = model.state_dict()[old_k_k]
                        k_min = k*0.1
                        k = rate_act_func(k_score, k_min)
                    
                    adj = GetSubnet.apply(popup_scores.abs(), k, 'weight')
                    # multiplying by adj gives 0 to pruned values
                    param.add_(coeff * diff[name] * adj)
                else:
                    param.add(coeff + diff[name])
                    



class TradesS2AP(object):
    def __init__(self, model, proxy, proxy_optim, gamma, exp_mode, perturb_weights, freeze_bn, sparse_s2ap, k_s2ap, misalign_pert):
        super(TradesS2AP, self).__init__()
        self.model = model
        self.proxy = proxy
        self.proxy_optim = proxy_optim
        self.gamma = gamma
        self.freeze_bn = freeze_bn
        self.sparse_s2ap = sparse_s2ap
        self.perturb_weights = perturb_weights
        self.exp_mode = exp_mode
        self.k_s2ap = k_s2ap 
        self.misalign_pert = misalign_pert
        self.layer_uniformity = self.exp_mode in ['harp_finetune_lwm', 'score_finetune', 'score_prune'] 
        if self.exp_mode in ['harp_prune', 'score_prune', 'rate_prune'] and not perturb_weights: 
            self.param_name = 'popup_scores' 
        else: 
            self.param_name = 'weight'

        
    def calc_s2ap(self, inputs_adv, inputs_clean, targets, beta, exp_mode, k):
        # load original state_dict into proxy model
        self.proxy.load_state_dict(self.model.state_dict())
        self.proxy.train() 
        
        # change k if you want to misalign perturbation 
        if self.misalign_pert: 
            k = self.k_s2ap
        
        # compute TRADES loss on proxy model
        loss_natural = F.cross_entropy(self.proxy(inputs_clean), targets)
        loss_robust = F.kl_div(F.log_softmax(self.proxy(inputs_adv), dim=1),
                               F.softmax(self.proxy(inputs_clean), dim=1),
                               reduction='batchmean')
        loss = - 1.0 * (loss_natural + beta * loss_robust)
        
        # if perturbing scores proceed normally 
        if self.perturb_weights: 
            # freeze scores and unfreeze weights
            freeze_vars(model=self.proxy, var_name='popup_scores', freeze_bn=self.freeze_bn)
            freeze_vars(model=self.proxy, var_name='k_score', freeze_bn=self.freeze_bn)
            unfreeze_vars(model=self.proxy, var_name='weight')
            unfreeze_vars(model=self.proxy, var_name='bias') 
            
            # backward pass 
            self.proxy_optim.zero_grad()
            loss.backward()
            self.proxy_optim.step()
            
            # freeze weights and unfreeze scores
            freeze_vars(model=self.proxy, var_name='weight', freeze_bn=self.freeze_bn)
            freeze_vars(model=self.proxy, var_name='bias', freeze_bn=self.freeze_bn)
            unfreeze_vars(model=self.proxy, var_name='popup_scores')
            unfreeze_vars(model=self.proxy, var_name='k_score') 
        
        # otherwise unfreeze, do backward, and refreeze weights 
        else:      
            self.proxy_optim.zero_grad()
            loss.backward()
            self.proxy_optim.step()
        
        # perturb every parameter uniformily
        if exp_mode in ['pretrain']:
            diff = diff_all(self.model, self.proxy, param_name=self.param_name)
        # perturb only topk weights
        elif exp_mode in ['harp_finetune', 'harp_finetune_lwm', 'score_finetune']:
            diff = diff_partial(self.model, self.proxy, k=k, param_name=self.param_name, layer_uniformity=self.layer_uniformity)
        elif exp_mode in ['harp_prune', 'score_prune', 'rate_prune']:
            # perturb only topk weights/scores
            if self.sparse_s2ap: 
                diff = diff_partial(self.model, self.proxy, k=k, param_name=self.param_name, layer_uniformity=self.layer_uniformity)
            # perturb every parameter uniformily
            else:  
                diff = diff_all(self.model, self.proxy, param_name=self.param_name)
        return diff
        
    def perturb(self, diff, exp_mode, k): 
        # change k if you want to misalign perturbation 
        if self.misalign_pert: 
            k = self.k_s2ap 
            
        if exp_mode == 'pretrain':
            add_all(self.model, diff, coeff=1.0 * self.gamma, param_name=self.param_name)
        elif exp_mode in ['harp_finetune', 'harp_finetune_lwm', 'score_finetune']:
            add_partial(self.model, diff, coeff=1.0 * self.gamma, k=k, param_name=self.param_name, layer_uniformity=self.layer_uniformity)
        elif exp_mode in ['harp_prune', 'score_prune', 'rate_prune']: 
            if self.sparse_s2ap:
                add_partial(self.model, diff, coeff=1.0 * self.gamma, k=k, param_name=self.param_name, layer_uniformity=self.layer_uniformity)
            else: 
                add_all(self.model, diff, coeff=1.0 * self.gamma, param_name=self.param_name)

    def restore(self, diff, exp_mode, k):
        # change k if you want to misalign perturbation 
        if self.misalign_pert: 
            k = self.k_s2ap 
            
        if exp_mode == 'pretrain':
            add_all(self.model, diff, coeff=-1.0 * self.gamma, param_name=self.param_name)
        elif exp_mode in ['harp_finetune', 'harp_finetune_lwm', 'score_finetune']:
            add_partial(self.model, diff, coeff=-1.0 * self.gamma, k=k, param_name=self.param_name, layer_uniformity=self.layer_uniformity)
        elif exp_mode in ['harp_prune', 'score_prune', 'rate_prune']: 
            if self.sparse_s2ap:
                add_partial(self.model, diff, coeff=-1.0 * self.gamma, k=k, param_name=self.param_name, layer_uniformity=self.layer_uniformity)
            else: 
                add_all(self.model, diff, coeff=-1.0 * self.gamma, param_name=self.param_name)

    def calc_s2ap_fat(self, inputs_adv, inputs_clean, targets, beta, exp_mode, k):
        # TODO: adapt for subnet...
        self.proxy.load_state_dict(self.model.state_dict())
        self.proxy.train()

        loss2 = - F.cross_entropy(self.proxy(inputs_adv), targets)

        self.proxy_optim.zero_grad()
        loss2.backward(retain_graph=True)
        self.proxy_optim.step()
        
        # perturb every parameter uniformily
        if exp_mode in ['pretrain']:
            diff = diff_all(self.model, self.proxy, param_name=self.param_name)
        # perturb only topk weights
        elif exp_mode in ['harp_finetune', 'harp_finetune_lwm', 'score_finetune']:
            diff = diff_partial(self.model, self.proxy, k=k, param_name=self.param_name,
                                layer_uniformity=self.layer_uniformity)
        elif exp_mode in ['harp_prune', 'score_prune', 'rate_prune']:
            # perturb only topk weights/scores
            if self.sparse_awp:
                diff = diff_partial(self.model, self.proxy, k=k, param_name=self.param_name,
                                    layer_uniformity=self.layer_uniformity)
            # perturb every parameter uniformily
            else:
                diff = diff_all(self.model, self.proxy, param_name=self.param_name)
        return diff