import torch
import torch.nn as nn
import torch.nn.functional as F

EPISILON=1e-10

class NCELoss(torch.nn.Module):
    def __init__(self, dataset='mosi', classification=True):
        super(NCELoss, self).__init__()
        self.softmax = nn.Softmax(dim=1)
        self.dataset = dataset
        self.classification = classification

    def _discretize(self, x):
        # return x.round(decimals=1)
        return torch.round(x*2)/2
       
    def forward(self, f1, f2, targets, predictions, target_mask, temperature=0.3):
        ### cuda implementation
        B = f1.shape[0]
        if self.dataset in ['mosi', 'mosei']:
            if self.classification:
                targets = targets.squeeze(-1)
                predictions = torch.argmax(predictions, dim=1).float()
            else:
                ## discretize the predictions and targets
                targets = self._discretize(targets).squeeze(-1)
                predictions = self._discretize(predictions).squeeze(-1)
        elif self.dataset in ['ucf101']:
            predictions = torch.argmax(predictions, dim=1).unsqueeze(-1)
            targets = targets.unsqueeze(-1)
        elif self.dataset in ['mmimdb']:
            predictions = predictions.squeeze(-1)
            predictions = torch.sigmoid(predictions) > 0.5
            predictions = predictions.float()
            targets = targets.squeeze(-1)
        elif self.dataset in ['food101', 'hatememes']:
            predictions = torch.argmax(predictions, dim=1).float()
        
        cos = torch.mm(f1, f2.t().contiguous()) # [B, B]
        sim_matrix_joint_mod = torch.exp(cos / temperature)
        log_sm = -torch.log(sim_matrix_joint_mod / sim_matrix_joint_mod.sum(dim=-1, keepdim=True) + EPISILON)
        
        padded_targets = torch.zeros_like(targets)
        padded_targets[target_mask] = targets[target_mask]
        padded_targets[~target_mask] = predictions[~target_mask]

        # find hard positive pairs
        mask = (padded_targets.unsqueeze(1) - padded_targets).detach()
        if self.dataset in ['mmimdb']:
            mask = mask.sum(-1)
        self_mask1 = (mask == 0).squeeze() & (target_mask.unsqueeze(1) & target_mask) ### where the positive samples are labeled as True
        padd = torch.full((B,), True, dtype=torch.bool, device=self_mask1.device)
        self_mask1 = torch.diagonal_scatter(self_mask1, padd, 0, 0, 1)
        # find soft positive pairs
        self_mask2 = (mask == 0).squeeze() & ((target_mask.unsqueeze(1) & ~target_mask) | (~target_mask.unsqueeze(1) & target_mask)) ### where the u-positive samples are labeled as True
        self_mask2 = torch.diagonal_scatter(self_mask2, ~padd, 0, 0, 1)
        
        # find soft pu weights
        true_positive_values = sim_matrix_joint_mod * self_mask1.float()
        mean_pos, var_pos = torch.mean(true_positive_values[true_positive_values > 0]), torch.var(true_positive_values[true_positive_values > 0])
        pu_weight = torch.exp(-torch.pow(sim_matrix_joint_mod - mean_pos, 2) / (2 * var_pos + EPISILON))
        
        # actual loss
        log_softmax_pl = log_sm * self_mask1.float()
        log_softmax_pu = log_sm * self_mask2.float() * pu_weight
        log_softmax = log_softmax_pl + log_softmax_pu
        if torch.isnan(log_softmax).any():
            import IPython; IPython.embed(); exit(1)

        return torch.mean(log_softmax)

class IndividualLoss(torch.nn.Module):
    def __init__(self, dataset='mosi'):
        super(IndividualLoss, self).__init__()
        self.softmax = nn.Softmax(dim=1)
        self.dataset = dataset

    def forward_orthorgonal(self, f1, f2, fs=None):
        '''
        Ensure the orthorgonality of the set of vectors
        '''
        # cos2 = torch.mm(f1, f2.t().contiguous())**2
        # return torch.mean(cos2.diag()/temperature)
        if fs is None:
            cos = torch.mm(f1, f2.t().contiguous())
            return torch.mean(-torch.log((1 - cos.abs()).diag() + EPISILON))
        else:
            loss = 0
            for i, f1 in enumerate(fs):
                for j, f2 in enumerate(fs):
                    if i == j:
                        continue
                    cos = torch.mm(f1, f2.t().contiguous())
                    loss += torch.mean(-torch.log((1 - cos.abs()).diag() + EPISILON))

            return loss / (len(fs) * (len(fs) - 1))
    
    def forward(self, f1, f2):
        f1 = F.normalize(f1, dim=1)
        f2 = F.normalize(f2, dim=1)
        
        # cos = torch.mm(f1, f2.t().contiguous())
        # loss = -torch.log((cos).abs().diag())
        # return loss.mean()
        cos2 = torch.mm(f1, f2.t().contiguous())**2
        return torch.mean(1 - cos2.diag())
        # f1 = F.normalize(f1, dim=1)
        # f2 = F.normalize(f2, dim=1)
        # loss = (f1 - f2)**2  # [N, L]
        # return loss.mean()

if __name__ == '__main__':
    loss = NCELoss()
    f1 = torch.randn(30, 100)
    f2 = torch.randn(30, 100)
    targets = torch.rand((30, 1)) * 6 - 3
    predictions = torch.rand((30, 1)) * 6 - 3
    target_mask = torch.rand(30) > 0.5

    print(loss(f1, f2, targets, predictions, target_mask))