import torch
import torch.nn.functional as F


class DRASearchLoss(torch.nn.Module):

    def __init__(self):
        super(DRASearchLoss, self).__init__()
        self.ce = torch.nn.CrossEntropyLoss()

    def forward(self, pred, target, search=False, lb_smooth=0.0):
        if search:
            aug_loss = 0
            ori_loss = 0
            if pred[0] is not None:
                # For augmented prediction
                y = torch.zeros(pred[0].shape).to(pred[0].device).scatter_(-1, target[..., None], 1.0)
            else:
                # For original prediction
                y = torch.zeros(pred[1].shape).to(pred[1].device).scatter_(-1, target[..., None], 1.0)
                
            if lb_smooth > 0:
                y = lb_smooth / y.shape[1] * torch.ones(y.shape).to(y.device) + (1 - lb_smooth) * y
                
            if pred[0] is not None:
                aug_loss = -(F.log_softmax(pred[0], dim=-1) * y).sum(dim=-1).mean()
                
            if pred[1] is not None:
                ori_loss = self.ce(pred[1], target)

            loss = aug_loss + ori_loss
        else:
            loss = self.ce(pred[0], target)

        return loss


class DRASearchLossTied(torch.nn.Module):

    def __init__(self, l2_ratio=20.):
        super(DRASearchLossTied, self).__init__()
        self.ce = torch.nn.CrossEntropyLoss()
        self.mse = torch.nn.MSELoss(reduction="mean")
        self.l2_ratio = l2_ratio

    def forward(self, pred1, pred2, target, search=False, lb_smooth=0.0):
        if search:
            extra_loss = 0.
            y = torch.zeros(pred1[0].shape).to(pred1[0].device).scatter_(-1, target[..., None], 1.0)

            if lb_smooth > 0:
                y = lb_smooth / y.shape[1] * torch.ones(y.shape).to(y.device) + (1 - lb_smooth) * y

            aug_loss_tmp = -(F.log_softmax(pred1[0], dim=-1) * y).sum(dim=-1)
            aug_loss = aug_loss_tmp.mean()

            if pred2 is not None:
                aug_loss_tmp = -(F.log_softmax(pred2[0], dim=-1) * y).sum(dim=-1)
                aug_loss += aug_loss_tmp.mean()
                aug_loss /= 2.

                extra_loss = self.l2_ratio * self.mse(pred1[4] / torch.linalg.norm(pred1[4]),
                                                      pred2[4] / torch.linalg.norm(pred2[4]))

            loss = aug_loss + extra_loss
        else:
            loss = self.ce(pred1[0], target)
        return loss
