import torch

class CosFace(torch.nn.Module):
    def __init__(self, s=64.0, m=0.40):
        super(CosFace, self).__init__()
        self.s = s
        self.m = m

    def forward(self, logits, labels):
        one_hot = torch.zeros_like(logits).scatter_(1, labels.view(-1, 1), 1.0).cuda()
        phi = logits - self.m
        output = torch.where(one_hot==1, phi, logits)
        output *= self.s

        return output



import torch
from torch import nn

class IDMMD(nn.Module):
    def __init__(self, margin=0.3):
        super(IDMMD, self).__init__()
        self.margin = margin
    
    def forward(self, modal1_inputs, modal2_inputs, targets):
        centersR = []
        centersT = []

        for i, l in enumerate(targets):
            feat1 = modal1_inputs[targets==l]
            feat2 = modal2_inputs[targets==l]
            
            centersR.append(feat1.mean(dim=0).unsqueeze(0))
            centersT.append(feat2.mean(dim=0).unsqueeze(0))

        centersR = torch.cat(centersR, 0).cuda()
        centersT = torch.cat(centersT, 0).cuda()

        n = targets.size(0)
        mask = targets.expand(n, n).eq(targets.expand(n, n).t())
        label = mask.float()

        dist = self.compute_dist(centersR, centersT)
        dd = torch.sqrt(dist + 1e-10)
        loss = self.compute_loss(dist, dd, label)
        return loss        
        

    def compute_loss(self, d2, d1, label):
        pos = label * torch.pow(d2, 2)
        neg = (1.0-label) * torch.pow(torch.clamp(self.margin-d1, min=0.0), 2)
        loss_contrastive = torch.mean(pos + neg)
        return loss_contrastive

    def compute_dist(self, inputs1, inputs2):
        n = inputs1.size(0)
        dist1 = torch.pow(inputs1, 2).sum(dim=1, keepdim=True).expand(n, n)
        dist2 = torch.pow(inputs2, 2).sum(dim=1, keepdim=True).expand(n, n)
        
        dist = dist1 + dist2.t()
        dist.addmm_(mat1=inputs1, mat2=inputs2.t(), beta=1, alpha=-2)
        dist = dist.clamp(min=1e-12).sqrt()  # for numerical stability
        return dist

