import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
import torch.optim as optim



def squared_l2_norm(x):
    flattened = x.view(x.unsqueeze(0).shape[0], -1)
    return (flattened ** 2).sum(1)


def l2_norm(x):
    return squared_l2_norm(x).sqrt()

def accuracy(true, preds):
    """
    Computes multi-class accuracy.
    Arguments:
        true (torch.Tensor): true labels.
        preds (torch.Tensor): predicted labels.
    Returns:
        Multi-class accuracy.
    """
    accuracy = (torch.softmax(preds, dim=1).argmax(dim=1) == true).sum().float()/float(true.size(0))
    return accuracy.item()

def get_snr_loss(layer_outputs_natural, layer_outputs_robust, snr_type, layer_snr_weight_type, base):
    snr_loss = 0.0
    nsr = []
    cnt = 0
    print_signal_batch = []
    print_noise_batch = []
    #print('------------------', snr_type)
    if snr_type == 'ns2':
        #print("use ns2")
        for output_natural, output_robust in zip(layer_outputs_natural, layer_outputs_robust):
            cnt += 1
            if cnt == len(layer_outputs_natural):
                torch.use_deterministic_algorithms(False)
                median_values = torch.median(output_natural, dim=1, keepdim=True)[0]
                median_values_broadcasted = median_values.expand(-1, output_natural.size(1))
                output_natural = output_natural - median_values_broadcasted
            output_natural = output_natural.view(output_natural.size(0), -1)
            output_robust = output_robust.view(output_robust.size(0), -1)
            #print('------------------', output_natural.size(), output_robust.size())
            noise = torch.abs(output_natural - output_robust)
            print_signal_batch.append((torch.mean(output_natural, dim=-1)).pow(2))
            print_noise_batch.append((torch.mean(noise, dim=-1)).pow(2))
            current_snr_loss = torch.mean(noise / (torch.pow(torch.abs(output_natural), 2) + 1))
            if layer_snr_weight_type == "exp":
                snr_loss += current_snr_loss * (base ** cnt)
            elif layer_snr_weight_type == "muln":
                snr_loss += current_snr_loss * cnt 
            elif layer_snr_weight_type == "sum":
                if cnt == len(layer_outputs_natural):
                    snr_loss += current_snr_loss * 10
                else:
                    snr_loss += current_snr_loss
            else:
                raise NameError("layer_snr_weight_type is illegal")
            nsr.append(current_snr_loss)
    
    elif snr_type == 'ns2+n2sv':
        #print("use ns2+n2sv")
        for output_natural, output_robust in zip(layer_outputs_natural, layer_outputs_robust):
            cnt += 1
            if cnt == len(layer_outputs_natural):
                torch.use_deterministic_algorithms(False)
                median_values = torch.median(output_natural, dim=1, keepdim=True)[0]
                median_values_broadcasted = median_values.expand(-1, output_natural.size(1))
                output_natural = output_natural - median_values_broadcasted
            output_natural = output_natural.view(output_natural.size(0), -1)
            output_robust = output_robust.view(output_robust.size(0), -1)
            noise = torch.abs(output_natural - output_robust)
            print_signal_batch.append((torch.mean(output_natural, dim=-1)).pow(2))
            print_noise_batch.append((torch.mean(noise, dim=-1)).pow(2))
            current_snr_loss = torch.mean(noise / (torch.pow(torch.abs(output_natural), 2) + 1))+0.5*torch.mean(torch.var(noise, dim=1)/ torch.var(output_natural, dim=1))
            #print('--------------------', torch.mean(noise / (torch.pow(torch.abs(output_natural), 2) + 1)),torch.mean(torch.var(noise, dim=1)/ torch.var(output_natural, dim=1)))
            #print('------------------', noise)
            if layer_snr_weight_type == "exp":
                snr_loss += current_snr_loss * (base ** cnt)
            elif layer_snr_weight_type == "muln":
                snr_loss += current_snr_loss * cnt 
            elif layer_snr_weight_type == "sum":
                if cnt == len(layer_outputs_natural):
                    snr_loss += current_snr_loss * 10
                else:
                    snr_loss += current_snr_loss
            else:
                raise NameError("layer_snr_weight_type is illegal")
            nsr.append(current_snr_loss)

    elif snr_type == 'n2':
        print("use n2")
        for output_natural, output_robust in zip(layer_outputs_natural, layer_outputs_robust):
            cnt += 1
            if cnt == len(layer_outputs_natural):
                median_values = torch.median(output_natural, dim=1, keepdim=True)[0]
                median_values_broadcasted = median_values.expand(-1, 10)
                output_natural = output_natural - median_values_broadcasted
            output_natural = output_natural.view(output_natural.size(0), -1)
            output_robust = output_robust.view(output_robust.size(0), -1)
            noise = torch.abs(output_natural - output_robust)
            print_signal_batch.append((torch.mean(output_natural, dim=-1)).pow(2))
            print_noise_batch.append((torch.mean(noise, dim=-1)).pow(2))
            current_snr_loss = torch.mean(torch.pow(noise, 2))
            if layer_snr_weight_type == "exp":
                snr_loss += current_snr_loss * (base ** cnt)
            elif layer_snr_weight_type == "muln":
                snr_loss += current_snr_loss * cnt 
            elif layer_snr_weight_type == "sum":
                if cnt == len(layer_outputs_natural):
                    snr_loss += current_snr_loss * 10
                else:
                    snr_loss += current_snr_loss
            else:
                raise NameError("layer_snr_weight_type is illegal")
            nsr.append(current_snr_loss)
    elif snr_type == 'nsscale':
        print("use nsscale")
        for output_natural, output_robust in zip(layer_outputs_natural, layer_outputs_robust):
            cnt += 1
            if cnt == len(layer_outputs_natural):
                median_values = torch.median(output_natural, dim=1, keepdim=True)[0]
                median_values_broadcasted = median_values.expand(-1, 10)
                output_natural = output_natural - median_values_broadcasted
            output_natural = output_natural.view(output_natural.size(0), -1)
            output_robust = output_robust.view(output_robust.size(0), -1)
            noise = torch.abs(output_natural - output_robust)
            print_signal_batch.append((torch.mean(output_natural, dim=-1)).pow(2))
            print_noise_batch.append((torch.mean(noise, dim=-1)).pow(2))
            current_snr_loss = torch.mean(noise/(torch.max(output_natural)-torch.min(output_natural)))
            if layer_snr_weight_type == "exp":
                snr_loss += current_snr_loss * (base ** cnt)
            elif layer_snr_weight_type == "muln":
                snr_loss += current_snr_loss * cnt 
            elif layer_snr_weight_type == "sum":
                if cnt == len(layer_outputs_natural):
                    snr_loss += current_snr_loss * 10
                else:
                    snr_loss += current_snr_loss
            else:
                raise NameError("layer_snr_weight_type is illegal")
            nsr.append(current_snr_loss)
    elif snr_type == 's':
        print("use s")
        for output_natural, output_robust in zip(layer_outputs_natural, layer_outputs_robust):
            cnt += 1
            if cnt == len(layer_outputs_natural):
                median_values = torch.median(output_natural, dim=1, keepdim=True)[0]
                median_values_broadcasted = median_values.expand(-1, 10)
                output_natural = output_natural - median_values_broadcasted
            output_natural = output_natural.view(output_natural.size(0), -1)
            output_robust = output_robust.view(output_robust.size(0), -1)
            noise = torch.abs(output_natural - output_robust)
            print_signal_batch.append((torch.mean(output_natural, dim=-1)).pow(2))
            print_noise_batch.append((torch.mean(noise, dim=-1)).pow(2))
            current_snr_loss = torch.mean(torch.abs(output_natural))
            if layer_snr_weight_type == "exp":
                snr_loss += current_snr_loss * (base ** cnt)
            elif layer_snr_weight_type == "muln":
                snr_loss += current_snr_loss * cnt 
            elif layer_snr_weight_type == "sum":
                if cnt == len(layer_outputs_natural):
                    snr_loss += current_snr_loss * 10
                else:
                    snr_loss += current_snr_loss
            else:
                raise NameError("layer_snr_weight_type is illegal")
            nsr.append(current_snr_loss)
    elif snr_type == 'n2sv':
        #print('--------------',layer_outputs_natural, layer_outputs_robust) 
        for output_natural, output_robust in zip(layer_outputs_natural, layer_outputs_robust):
            cnt += 1
            #print('------------------', output_natural.size(), output_robust.size())
            if cnt == len(layer_outputs_natural):
                torch.use_deterministic_algorithms(False)
                median_values = torch.median(output_natural, dim=1, keepdim=True)[0]
                median_values_broadcasted = median_values.expand(-1, 10)
                output_natural = output_natural - median_values_broadcasted
            print(output_natural.shape)
            output_natural = output_natural.view(output_natural.size(0), -1)
            print(output_natural.shape)
            output_robust = output_robust.view(output_robust.size(0), -1)
            #print('------------------', output_natural.size(), output_robust.size())
            noise = torch.abs(output_natural - output_robust) 
            print((torch.mean(output_natural, dim=-1)).pow(2).shape)
            print_signal_batch.append((torch.mean(output_natural, dim=-1)).pow(2))
            print_noise_batch.append((torch.mean(noise, dim=-1)).pow(2))
            #print('------------------', noise)
            current_snr_loss = torch.mean(torch.var(noise, dim=1)/ torch.var(output_natural, dim=1))
            if layer_snr_weight_type == "exp":
                snr_loss += current_snr_loss * (base ** cnt)
            elif layer_snr_weight_type == "muln":
                snr_loss += current_snr_loss * cnt 
            elif layer_snr_weight_type == "sum":
                if cnt == len(layer_outputs_natural):
                    snr_loss += current_snr_loss * 10
                else:
                    snr_loss += current_snr_loss
            else:
                raise NameError("layer_snr_weight_type is illegal")
            nsr.append(current_snr_loss) 
    else:
        raise NameError("snr_type is illegal")
    return snr_loss, nsr, print_signal_batch, print_noise_batch

def trades_loss(model,
                x_natural,
                y,
                optimizer,
                step_size=0.003,
                epsilon=0.031,
                perturb_steps=10,
                beta=1.0,
                distance='l_inf', args=None):
    # define KL-loss
    #for name, layer in model.named_modules():
        #print(name, layer)
    criterion_kl = nn.KLDivLoss(size_average=False)
    model.eval()
    batch_size = len(x_natural)
    # generate adversarial example
    x_adv = x_natural.detach() + 0.001 * torch.randn(x_natural.shape).cuda().detach()
    if distance == 'l_inf':
        for _ in range(perturb_steps):
            x_adv.requires_grad_()
            with torch.enable_grad():
                loss_kl = criterion_kl(F.log_softmax(model(x_adv), dim=1),
                                       F.softmax(model(x_natural), dim=1))
            grad = torch.autograd.grad(loss_kl, [x_adv])[0]
            x_adv = x_adv.detach() + step_size * torch.sign(grad.detach())
            x_adv = torch.min(torch.max(x_adv, x_natural - epsilon), x_natural + epsilon)
            x_adv = torch.clamp(x_adv, 0.0, 1.0)
    elif distance == 'l_2':
        delta = 0.001 * torch.randn(x_natural.shape).cuda().detach()
        delta = Variable(delta.data, requires_grad=True)

        # Setup optimizers
        optimizer_delta = optim.SGD([delta], lr=epsilon / perturb_steps * 2)  

        for _ in range(perturb_steps):
            adv = x_natural + delta

            # optimize
            optimizer_delta.zero_grad()
            with torch.enable_grad():
                loss = (-1) * criterion_kl(F.log_softmax(model(adv), dim=1),
                                           F.softmax(model(x_natural), dim=1))
            loss.backward()
            # renorming gradient
            grad_norms = delta.grad.view(batch_size, -1).norm(p=2, dim=1)
            delta.grad.div_(grad_norms.view(-1, 1, 1, 1))
            # avoid nan or inf if gradient is 0
            if (grad_norms == 0).any():
                delta.grad[grad_norms == 0] = torch.randn_like(delta.grad[grad_norms == 0])
            optimizer_delta.step()

            # projection
            delta.data.add_(x_natural)
            delta.data.clamp_(0, 1).sub_(x_natural)
            delta.data.renorm_(p=2, dim=0, maxnorm=epsilon)
        x_adv = Variable(x_natural + delta, requires_grad=False)
    else:
        x_adv = torch.clamp(x_adv, 0.0, 1.0)
    model.train()

    x_adv = Variable(torch.clamp(x_adv, 0.0, 1.0), requires_grad=False)
    # zero gradient
    optimizer.zero_grad()
    # calculate robust loss
    if not args.use_snr:
        logits = model(x_natural)
        logits_adv = model(x_adv)        
        loss_natural = F.cross_entropy(logits, y)
        if not args.use_adv:
            loss_robust = (1.0 / batch_size) * criterion_kl(F.log_softmax(logits_adv, dim=1),
                                                    F.softmax(logits, dim=1))
        else:
            loss_robust = (1.0 / batch_size) * (F.cross_entropy(logits_adv, y)+0.01*F.cross_entropy(logits_adv, y))
        loss = loss_natural + beta * loss_robust
        batch_metrics = {'loss': loss.item(), 'clean_acc': accuracy(y, logits.detach()), 'loss_robust':loss_robust.item(), 'loss_natural':loss_natural.item(),
                    'adversarial_acc': accuracy(y, logits_adv.detach())}
        return loss, batch_metrics
    else :
        layer_outputs = []
        def hook_fn(module, input, output):
            layer_outputs.append(output)
        hooks = []
        for name, layer in model.named_modules():
            if name in args.snr_layers:
                hooks.append(layer.register_forward_hook(hook_fn))
        logits_adv = model(x_adv)
        for hook in hooks:
            hook.remove()
        layer_robust_outputs = layer_outputs.copy()

        layer_outputs = []
        def hook_fn(module, input, output):
            layer_outputs.append(output)
        hooks = []
        for name, layer in model.named_modules():
            if name in args.snr_layers:
                hooks.append(layer.register_forward_hook(hook_fn))
        logits = model(x_natural)
        for hook in hooks:
            hook.remove()
        layer_natural_outputs = layer_outputs.copy()
        loss_natural = F.cross_entropy(logits, y)
        #loss_robust = F.cross_entropy(logits_adv, y)
        #loss_robust = (1.0 / batch_size) * criterion_kl(F.log_softmax(logits_adv, dim=1),
        #                                            F.softmax(logits, dim=1))
        EPS = 1e-8
        if not args.use_adv:
            loss_robust = (1.0 / batch_size) * criterion_kl(F.log_softmax(logits_adv, dim=1),
                                                    torch.clamp(F.softmax(logits, dim=1), min=EPS))
        else:
            loss_robust = (1.0 / batch_size) * (F.cross_entropy(logits_adv, y)+0.01*F.cross_entropy(logits_adv, y))
        #loss_robust = torch.tensor(0)
        snr_loss, nsr, print_signal_batch,print_noise_batch = get_snr_loss(layer_natural_outputs, layer_robust_outputs, args.snr_type, args.layer_snr_weight_type, args.base)
        loss = args.trade_weight * (loss_natural + beta * loss_robust) + args.snr_weight * snr_loss
        #print('----------------', loss, loss_natural, loss_robust, snr_loss)
        batch_metrics = {'loss': loss.item(), 'clean_acc': accuracy(y, logits.detach()), 'loss_robust':loss_robust.item(), 'loss_natural':loss_natural.item(), 'snr_loss':snr_loss.item(),
                     'adversarial_acc': accuracy(y, logits_adv.detach()), 'layer1_snr':nsr[0].item(),
                     'layer2_snr':nsr[1].item(),
                     'layer3_snr':nsr[2].item(),
                     'layer4_snr':nsr[3].item(),
                     'layer5_snr':nsr[4].item(),
                     'layer6_snr':nsr[5].item(),
                     'layer7_snr':nsr[6].item(),
                     #'layer8_snr':nsr[7].item(),
                     #'layer9_snr':nsr[8].item(),
                     #'layer10_snr':nsr[9].item(),
                     #'layer11_snr':nsr[10].item(),
                     #'layer12_snr':nsr[11].item(),
                     }
        #print('-------------------', len(print_signal_batch), len(print_noise_batch))
        return loss, batch_metrics, print_signal_batch, print_noise_batch
    
    
