import torch
import torch.nn as nn
import torch.nn.functional as F


class LabelSmoothingCrossEntropy(nn.Module):
    def __init__(self, smoothing=0.1):
        super(LabelSmoothingCrossEntropy, self).__init__()
        assert smoothing < 1.0
        self.smoothing = smoothing
        self.confidence = 1. - smoothing

    def forward(self, x: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
        logprobs = F.log_softmax(x, dim=-1)
        nll_loss = -logprobs.gather(dim=-1, index=target.unsqueeze(1))
        nll_loss = nll_loss.squeeze(1)
        smooth_loss = -logprobs.mean(dim=-1)
        loss = self.confidence * nll_loss + self.smoothing * smooth_loss
        return loss.mean()


class SoftTargetCrossEntropy(nn.Module):
    def __init__(self):
        super(SoftTargetCrossEntropy, self).__init__()

    def forward(self, x: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
        loss = torch.sum(-target * F.log_softmax(x, dim=-1), dim=-1)
        return loss.mean()


class ConsistencyLoss(nn.Module):
    def __init__(self, margin=0., reduction='mean'):
        super(ConsistencyLoss, self).__init__()
        self.margin = margin
        self.target = torch.Tensor([1]).cuda()
        self.reduction = reduction

    def forward(self, input1: torch.Tensor, input2: torch.Tensor) -> torch.Tensor:
        return F.cosine_embedding_loss(
            input1, input2, self.target,
            margin=self.margin, reduction=self.reduction)


class SPRegularization(nn.Module):
    def __init__(self, source_model: nn.Module, target_model: nn.Module):
        super(SPRegularization, self).__init__()
        self.target_model = target_model
        self.source_model = source_model

    def forward(self):
        output = torch.tensor(0.0, device = "cuda")
        for i, j in zip(self.target_model.named_parameters(), self.source_model.named_parameters()):
            name, param = i
            name_, param_ = j
            if name==name_:
                output += 0.5 * torch.norm(param - param_) ** 2
        return output
