import torch
import torch.nn as nn
from sentence_transformers import util
import torch.nn.functional as F
import ot

class TripletLoss(nn.Module):
    def __init__(self, margin=0.2):
        super(TripletLoss, self).__init__()
        self.margin = margin

    def forward(self, audio_embeds, text_embeds, labels):
        n = audio_embeds.size(0)  # batch size

        # dist = []
        sim_a2t = util.cos_sim(audio_embeds, text_embeds)  # (batch_size, x batch_size)
        sim_ap = torch.diag(sim_a2t).view(n, 1)
        d1 = sim_ap.expand_as(sim_a2t)
        d2 = sim_ap.t().expand_as(sim_a2t)
        
        # compare every diagonal score to scores in its column
        # caption retrieval
        cost_s = F.relu(self.margin + sim_a2t - d1)
        # compare every diagonal score to scores in its row
        # audio retrieval
        cost_a = F.relu(self.margin + sim_a2t - d2)

        # clear diagonals
        mask = labels.expand(n, n).eq(labels.expand(n, n).t()).to(cost_a.device)
        cost_s = cost_s.masked_fill(mask, 0)
        cost_a = cost_a.masked_fill(mask, 0)

        cost_s = cost_s.max(1)[0]
        cost_a = cost_a.max(0)[0]

        loss = (cost_s.sum() + cost_a.sum()) / n

        return loss
    
class NTXent(nn.Module):
    def __init__(self, temperature=0.07):
        super(NTXent, self).__init__()
        self.loss = nn.LogSoftmax(dim=1)
        self.tau = temperature
    
    def forward(self, audio_embeds, text_embeds, labels):
        n = audio_embeds.shape[0]

        a2t = util.cos_sim(audio_embeds, text_embeds) / self.tau
        t2a = util.cos_sim(text_embeds, audio_embeds) / self.tau
        
        mask = labels.expand(n, n).eq(labels.expand(n, n).t()).to(a2t.device)
        mask_diag = mask.diag()
        mask_diag = torch.diag_embed(mask_diag)
        mask = mask ^ mask_diag # XOR, mask最终表示所有标签对之间的相等关系，但排除了标签与自身的比较。
        
        a2t_loss = - self.loss(a2t.masked_fill(mask, -float('inf'))).diag().mean()
        t2a_loss = - self.loss(t2a.masked_fill(mask, -float('inf'))).diag().mean()

        loss = 0.5 * a2t_loss + 0.5 * t2a_loss

        return loss

class WeightTriplet(nn.Module):
    def __init__(self, margin=0.2):
        super(WeightTriplet, self).__init__()
        self.margin = margin

    def polyloss(self, sim_mat, label):
        epsilon = 1e-5
        size = sim_mat.size(0)

        loss = list()
        for i in range(size):
            pos_pair_ = sim_mat[i][i]
            pos_pair_ = pos_pair_[pos_pair_ < 1 - epsilon] # 过滤掉接近1的值
            neg_pair_ = sim_mat[i][label != label[i]]

            neg_pair = neg_pair_[neg_pair_ + self.margin > min(pos_pair_)] # 选择与正样本相似度接近的困难负样本
            
            pos_pair = pos_pair_
            if len(neg_pair) < 1 or len(pos_pair) < 1:
                continue
            pos_loss = torch.clamp(0.2 * torch.pow(pos_pair, 2) - 0.7 * pos_pair + 0.5, min=0)

            neg_pair = max(neg_pair)
            neg_loss = torch.clamp(0.9 * torch.pow(neg_pair, 2) - 0.4 * neg_pair + 0.03, min=0)
            loss.append(pos_loss + neg_loss)

        if len(loss) == 0:
            return torch.zeros([], requires_grad=True)

        loss = sum(loss) / size
        return loss
    
    def forward(self, audio_embeds, text_embeds, labels):
        scores = util.cos_sim(audio_embeds, text_embeds)
        loss = self.polyloss(scores, labels)
        return loss

class MahalalobisLoss(nn.Module):
    def __init__(self, epsilon=0.05, reg=0.1, m=0.95, pot=False):
        super(MahalalobisLoss, self).__init__()
        self.epsilon = epsilon
        self.reg = reg
        self.m =m
        self.POT = pot

    def forward(self, audio_emb, text_emb, M):
        batch_size = audio_emb.size(0)
        a = torch.ones(batch_size)/batch_size
        b = torch.ones(batch_size)/batch_size
        a = a.to(audio_emb.device)
        b = b.to(audio_emb.device)

        pi_hat = torch.eye(batch_size).to(audio_emb.device)/(batch_size)

        # mask = torch.randn(batch_size).to(audio_emb.device) > 0.4
        # pi_hat = pi_hat * mask.unsqueeze(0).float() 
    
        M = torch.nan_to_num(M)
        u, s, v =torch.svd(torch.diag(M))
        reg = torch.sum(s)

        pairwise_dist = audio_emb[:, None, :] - text_emb[None, :, :]
        M_dist = torch.sum(pairwise_dist**2 * M, dim=-1)  

        # # D_{ij}^2 = (x_i - y_j)^T M (x_i - y_j) = x_i^T M x_i + y_j^T M y_j - x_i^T M y_j - y_j^T M x_i
        # audio_norm = (audio_emb @ M).mul(audio_emb).sum(dim=1, keepdim=True)
        # text_norm = (text_emb @ M).mul(text_emb).sum(dim=1,keepdim=True).T
        # cross = torch.matmul(audio_emb @ M, text_emb.T) + torch.matmul(text_emb @ M, audio_emb.T).T
        # M_dist = audio_norm + text_norm - cross

        M_dist = torch.sqrt(M_dist)
        M_dist = M_dist/ (M_dist.max()+0.1)

        if self.POT:
            pi = ot.partial.entropic_partial_wasserstein(a,b,M_dist, reg=self.epsilon, m=self.m)
        else:
            pi = ot.sinkhorn(a,b,M_dist, method='sinkhorn_log', reg=self.epsilon)
        
        # ot_loss = -pi_hat[mask]*torch.log(pi[mask])
        ot_loss = -pi_hat * torch.log(pi)
        ot_loss = torch.sum(ot_loss)

        loss = ot_loss + self.reg*reg
        loss = ot_loss

        return loss
    

