import torch
import torch.nn as nn

class NCELoss(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.nce_num_samples = config.nce_num_samples
        self.nce_T = config.nce_T
        self.cross_entropy = nn.CrossEntropyLoss()

    def forward(self, lig_iface_feats, rec_iface_feats):
        
        num_samples = min(self.nce_num_samples, rec_iface_feats.size(0))
        choices = torch.randperm(lig_iface_feats.size(0))[:num_samples]

        query = rec_iface_feats[choices]
        keys = lig_iface_feats[choices]

        logits = -torch.cdist(query, keys, p=2.0) / self.nce_T
        labels = torch.arange(query.size(0)).to(lig_iface_feats.device)
        loss = self.cross_entropy(logits, labels)
        
        return loss


class FocalLoss(nn.Module):
    def __init__(self, alpha=.25, gamma=2):
        super(FocalLoss, self).__init__()
        self.alpha = torch.tensor([alpha, 1-alpha])
        self.gamma = gamma
        self.bce_loss = nn.BCELoss(reduction='none')

    def forward(self, bsp, iface_label):
        pred_prob = torch.sigmoid(bsp)
        bce_loss = self.bce_loss(pred_prob, iface_label.float())
        iface_label = iface_label.long()
        at = self.alpha.to(bsp.device).gather(0, iface_label.data.view(-1))
        pt = torch.exp(-bce_loss)
        focal_loss = at * (1-pt)**self.gamma * bce_loss

        return focal_loss.mean()


class DockingLoss(nn.Module):
    def __init__(self, config):
        super().__init__()

        # BSP loss
        if config.bsp_loss == 'bce':
            self.binary_loss = nn.BCEWithLogitsLoss()
        elif config.bsp_loss == 'focal':
            self.binary_loss = FocalLoss(alpha=config.focal_alpha, gamma=config.focal_gamma)
        else:
            raise NotImplementedError
        
        # NCE loss
        if config.nce_loss:
            self.nce_loss = NCELoss(config)
        else:
            self.nce_loss = None

    def forward(self, out_dict):
        # features
        lig_dict = out_dict['lig_dict']
        rec_dict = out_dict['rec_dict']
        lig_bid = lig_dict['num_verts']
        rec_bid = rec_dict['num_verts']
        lig_iface_p2p = lig_dict['iface_p2p']
        rec_iface_p2p = rec_dict['iface_p2p']
        lig_h_split = torch.split(lig_dict['h'], lig_bid)
        rec_h_split = torch.split(rec_dict['h'], rec_bid)
        lig_bsp_split = torch.split(lig_dict['bsp'], lig_bid)
        rec_bsp_split = torch.split(rec_dict['bsp'], rec_bid)
        lig_attn = lig_dict['attn']
        rec_attn = rec_dict['attn']
        bsize = len(lig_bid)
        # init loss
        lig_bsp_loss = torch.tensor([0.]).to(lig_dict['bsp'].device)
        rec_bsp_loss = torch.tensor([0.]).to(lig_dict['bsp'].device)
        lig_attn_loss = torch.tensor([0.]).to(lig_dict['bsp'].device)
        rec_attn_loss = torch.tensor([0.]).to(lig_dict['bsp'].device)
        nce_loss = torch.tensor([0.]).to(lig_dict['bsp'].device)
        for bid in range(bsize):
            # lig bsp
            lig_bsp = lig_bsp_split[bid].squeeze()
            lig_label = torch.zeros(lig_bsp.size(0), dtype=torch.float32).to(lig_bsp.device)
            lig_label[lig_iface_p2p[bid][:, 0]] = 1
            lig_bsp_loss += self.binary_loss(lig_bsp, lig_label)
            # rec bsp
            rec_bsp = rec_bsp_split[bid].squeeze()
            rec_label = torch.zeros(rec_bsp.size(0), dtype=torch.float32).to(rec_bsp.device)
            rec_label[rec_iface_p2p[bid][:, 0]] = 1
            rec_bsp_loss += self.binary_loss(rec_bsp, rec_label)
            # Attention
            if lig_attn is not None:
                lig_iface_ids = lig_iface_p2p[bid][:, 0]
                rec_iface_ids = rec_iface_p2p[bid][:, 0]
                lig_num_verts = lig_h_split[bid].size(0)
                rec_num_verts = rec_h_split[bid].size(0)
                lig_iface_ratio = lig_iface_ids.size(0) / lig_num_verts
                rec_iface_ratio = rec_iface_ids.size(0) / rec_num_verts
                lig_iface_attn = lig_attn[bid][lig_iface_ids]
                lig_iface_attn_weight = torch.sum(lig_iface_attn[:, rec_iface_ids]) / lig_iface_attn.size(0)
                lig_attn_loss += lig_iface_attn_weight / rec_iface_ratio    
                rec_iface_attn = rec_attn[bid][rec_iface_ids]
                rec_iface_attn_weight = torch.sum(rec_iface_attn[:, lig_iface_ids]) / rec_iface_attn.size(0)
                rec_attn_loss += rec_iface_attn_weight / lig_iface_ratio
            # NCE loss
            if self.nce_loss is not None:
                map_rec2lig = rec_iface_p2p[bid]
                rec_iface_feats = rec_h_split[bid][map_rec2lig[:, 0]]
                lig_iface_feats = lig_h_split[bid][map_rec2lig[:, 1]]
                nce_loss += self.nce_loss.forward(lig_iface_feats, rec_iface_feats)

        lig_bsp_loss /= bsize
        rec_bsp_loss /= bsize
        lig_attn_loss /= bsize 
        rec_attn_loss /= bsize 
        nce_loss /= bsize  

        return lig_bsp_loss, rec_bsp_loss, lig_attn_loss, rec_attn_loss, nce_loss


