import torch
import numpy as np
# import contextlib
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
# import numpy as np


class LabelSmoothingLoss(nn.Module):

    def __init__(self, args, dim=-1):
        super(LabelSmoothingLoss, self).__init__()
        self.n_cls = args.n_cls
        smoothing = args.alpha_sm
        self.confidence = 1.0 - smoothing
        self.smoothing = smoothing
        self.dim = dim

    def forward(self, pred, target):
        l_pred = pred.log_softmax(dim=self.dim)
        with torch.no_grad():
            # true_dist = pred.data.clone()
            true_dist = torch.zeros_like(l_pred)
            true_dist.fill_(self.smoothing / (self.n_cls - 1))
            true_dist.scatter_(1, target.data.unsqueeze(1), self.confidence)
        return torch.mean(torch.sum(-true_dist * l_pred, dim=self.dim))


class XentEC(nn.Module):

    def __init__(self, args):
        super(XentEC, self).__init__()
        self.n_cls = args.n_cls
        self.pos_fn = [torch.abs, torch.square, self.abs_smooth][args.i_pos_fn]
        self.neg = [self.neg1, self.neg2][args.i_neg_fn]

    def forward(self, logit, cossim, target):
        pos_term = torch.gather(logit, 1, target[:, None])
        neg_term = self.neg(logit, cossim)
        loss = - (pos_term - neg_term)
        return loss.mean()

    def neg1(self, logit, cossim):
        pos_cossim = self.pos_fn(cossim)
        neg_term = torch.logsumexp(pos_cossim, dim=1, keepdim=True)
        return neg_term

    def neg2(self, logit, cossim):
        pos_cossim= self.pos_fn(cossim)
        with torch.no_grad():
            p = F.softmax(logit, dim=1).detach()
        neg_term = p * pos_cossim
        return neg_term

    def abs_smooth(self, logit, beta = 1.0):
        mask = (torch.abs(logit) < beta)
        sq_logit = mask * (0.5*torch.square(logit) / beta)
        abs_logit = (~mask) * (torch.abs(logit) - 0.5*beta)
        sm_logit = sq_logit + abs_logit
        return sm_logit


def mixup_criterion(criterion, pred, y_a, y_b, lam, cond=True):
    if cond:
        loss = lam * criterion(pred, y_a) + (1 - lam) * criterion(pred, y_b)
    else:
        loss = criterion(pred, y_a, lam.view(-1, 1)) + criterion(pred, y_b, (1-lam).view(-1, 1))
    return loss.mean()


def get_adv_x(model, x_natural, step_size=0.003, epsilon=0.031, perturb_steps=10,
              adversarial=True, distance='l_inf',):
    with torch.autograd.set_detect_anomaly(True):
        # define KL-loss
        criterion_kl = nn.KLDivLoss(reduction='sum')
        model.eval()  # moving to eval mode to freeze batchnorm stats
        # generate adversarial example
        x_adv = x_natural.detach() + 0.  # the + 0. is for copying the tensor
        if adversarial:
            if distance == 'l_inf':
                # .cuda() -> .to('cuda')
                x_adv += 0.001 * torch.randn(x_natural.shape).to('cuda').detach()
                for i_step in range(perturb_steps):
                    x_adv.requires_grad_()
                    with torch.enable_grad():
                        loss_kl = criterion_kl(F.log_softmax(model(x_adv), dim=1),
                                            F.softmax(model(x_natural), dim=1))
                    grad = torch.autograd.grad(loss_kl, [x_adv])[0]
                    x_adv = x_adv.detach() + step_size * torch.sign(grad.detach())
                    x_adv = torch.min(torch.max(x_adv, x_natural - epsilon),
                                    x_natural + epsilon)
                    x_adv = torch.clamp(x_adv, 0.0, 1.0)
            else:
                raise ValueError('No support for distance %s in adversarial '
                                'training' % distance)
        else:
            if distance == 'l_2':
                x_adv = x_adv + epsilon * torch.randn_like(x_adv)
            else:
                raise ValueError('No support for distance %s in stability '
                                'training' % distance)
        model.train()  # moving to train mode to update batchnorm stats
        x_adv = Variable(torch.clamp(x_adv, 0.0, 1.0), requires_grad=False)
        return x_adv


def xent_float_target(logit, target, n_cls, coef=1.0):
    if target.dtype == torch.long :
        target = F.one_hot(target, num_classes=n_cls).float()
    assert logit.size() == target.size(), (logit.shape, target.shape)
    estimate = torch.nn.functional.log_softmax(logit, dim=1)
    # return  -(coef * target * estimate).sum() / estimate.shape[0]
    return - (target * estimate).mean()


@torch.no_grad()
def onehot(targets, num_classes):
    """Origin: https://github.com/moskomule/mixup.pytorch
    convert index tensor into onehot tensor
    :param targets: index tensor
    :param num_classes: number of classes
    """
    # .cuda() -> .to('cuda')
    oh = torch.zeros(targets.size()[0], num_classes).to('cuda').scatter_(1, targets.view(-1, 1), 1)
    return oh


def entropy_loss(unlabeled_logits):
    unlabeled_probs = F.softmax(unlabeled_logits, dim=1)
    return -(unlabeled_probs * F.log_softmax(unlabeled_logits, dim=1)).sum(
        dim=1).mean(dim=0)


def noise_loss(model,
               x_natural,
               y,
               epsilon=0.25,
               clamp_x=True):
    """Augmenting the input with random noise as in Cohen et al."""
    # logits_natural = model(x_natural)
    x_noise = x_natural + epsilon * torch.randn_like(x_natural)
    if clamp_x:
        x_noise = x_noise.clamp(0.0, 1.0)
    logits_noise = model(x_noise)
    loss = F.cross_entropy(logits_noise, y, ignore_index=-1)
    return loss
