import math
import torch
import torch.nn.functional as F
from torch import autograd, nn

class AMMLoss(nn.Module):
    '''
    NOTE:
    Affinity Metric Margin Loss, composed by two section
    Positive and negitive
    input:  part='both'/'pos'/'neg'
    '''
    def __init__(self, part='both', margin=0.9, **kwargs):
        super(AMMLoss, self).__init__()
        self.name = 'amm_loss'
        assert part in ['both', 'pos', 'neg'], "The value of 'part' must be the one of 'both'/'pos'/'neg'!"
        self.part = part
        self.margin = margin
        self.ranking_loss = nn.MarginRankingLoss(margin=self.margin, reduction='none')

    def forward(self, affinity, labels):
        # affinity: N, N; labels: N
        affi = affinity
        soft_affi = F.softmax(affi, dim=-1)
        N = affi.shape[0]
        # get all masks of pos and neg points
        Y = labels.unsqueeze(1)
        _m = 2 * (Y == Y.t()).type(torch.float) - 1
        _m -= torch.eye(N, N).byte().cuda()
        m_pos = (_m > 0).type(torch.bool)
        m_neg = (_m < 0).type(torch.bool)
        pos_affi = affi[m_pos].view(N, -1)
        neg_affi = affi[m_neg].view(N, -1)
        soft_pos_affi = soft_affi[m_pos].view(N, -1)

        # trans pos affinity to fit loss: trans_pos = e^(-pos)-e^(-1), pos -> [0,1]
        trans_pos_affi = torch.exp(-pos_affi)-math.exp(-1)
        # pos_loss = torch.mean(torch.sum(trans_pos_affi), dim=-1))
        pos_loss = torch.mean(torch.sum(-torch.log(soft_pos_affi), dim=-1))
        # pos_loss = torch.mean(torch.sum(-torch.log(pos_affi), dim=-1))

        # get the margin rank loss ==> loss(x1,x2,y) = max(0, -y*(x1-x2)+margin)
        furtherest_pos_affi, _ = torch.min(pos_affi, dim=-1, keepdim=True)
        rank_tragets = torch.ones_like(furtherest_pos_affi)
        _neg_loss = self.ranking_loss(furtherest_pos_affi, neg_affi, rank_tragets)
        neg_loss = torch.mean(torch.mean(_neg_loss, dim=-1))
        if self.part == 'pos':
            loss = pos_loss
        elif self.part == 'neg':
            loss = neg_loss
        else:
            loss = pos_loss + neg_loss
        return loss