from copy import deepcopy
import torch
import torch.nn as nn
import torch.nn.init as init
import torch.nn.functional as F
from torch.autograd import Variable
from utils.buffer import Buffer
from info_nce import InfoNCE


def drop_path(x, drop_prob: float = 0., training: bool = False, scale_by_keep: bool = True):
    if drop_prob == 0. or not training:
        return x
    keep_prob = 1 - drop_prob
    shape = (x.shape[0],) + (1,) * (x.ndim - 1) 
    random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
    if keep_prob > 0.0 and scale_by_keep:
        random_tensor.div_(keep_prob)
    return x * random_tensor

class DropPath(nn.Module):
    """Drop paths (Stochastic Depth) per sample  (when applied in main path of residual blocks).
    """
    def __init__(self, drop_prob=None, scale_by_keep=True):
        super(DropPath, self).__init__()
        self.drop_prob = drop_prob
        self.scale_by_keep = scale_by_keep

    def forward(self, x):
        return drop_path(x, self.drop_prob, self.training, self.scale_by_keep)
    
class Mlp(nn.Module):
    def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
        super().__init__()
        out_features = out_features or in_features
        hidden_features = hidden_features or in_features

        self.fc1 = nn.Linear(in_features, hidden_features)
        self.act = act_layer()
        self.drop1 = nn.Dropout(drop)
        self.fc2 = nn.Linear(hidden_features, out_features)
        self.drop2 = nn.Dropout(drop)

    def forward(self, x):
        x = self.fc1(x)
        x = self.act(x)
        x = self.drop1(x)
        x = self.fc2(x)
        x = self.drop2(x)
        return x
    
class Attention(nn.Module):
    def __init__(self, dim, num_heads=4, qkv_bias=True, attn_drop=0., proj_drop=0.):
        super().__init__()
        assert dim % num_heads == 0, 'dim should be divisible by num_heads'
        self.num_heads = num_heads
        head_dim = dim // num_heads
        self.scale = head_dim ** -0.5

        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)
        self.norm = nn.LayerNorm(dim)

    def forward(self, x, mask=None):
        residual = x
        B, N, C = x.shape
        qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
        q, k, v = qkv.unbind(0)   # make torchscript happy (cannot use tensor as tuple)

        attn = (q @ k.transpose(-2, -1)) * self.scale

        if mask is not None:
            mask = mask.unsqueeze(1)
            attn = attn.masked_fill(mask==0, -1e9)
        attn = attn.softmax(dim=-1)

        x = (self.attn_drop(attn) @ v).transpose(1, 2).reshape(B, N, C)
        x = self.proj(x)
        x = self.proj_drop(x)
        return x, attn

class ContextLearner(nn.Module):
    def __init__(self, dim, num_heads, qkv_bias=True, drop=0., attn_drop=0., drop_path=0.):
        super(ContextLearner, self).__init__()
        self.dim = dim
        self.norm1 = nn.LayerNorm(dim)
        self.norm2 = nn.LayerNorm(dim)
        self.attn = Attention(dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop)
        self.fc = nn.Linear(dim, dim)
        self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
        
    def forward(self, entity_embeds, rel_embeds, mask):    # [bs, 100, 20], [bs, 100, 20]
        neighbor_embeds = torch.cat((rel_embeds, entity_embeds), dim=2)
        mask = mask.reshape(-1, 100, 100)
        
        # shortcut = neighbor_embeds
        neighbor_embeds, attn = self.attn(self.norm1(neighbor_embeds), mask)
        neighbor_embeds = self.drop_path(neighbor_embeds)  # + shortcut

        weighted_context = torch.bmm(attn.mean(dim=2), neighbor_embeds.squeeze(1))
        # shortcut = weighted_context
        weighted_context = self.drop_path(self.fc(self.norm2(weighted_context)))  # + shortcut
        return weighted_context.squeeze(1)    
    

class TransE(nn.Module):
    def __init__(self, num_ents, num_rels, hidden_size, margin, neg_ratio, batch_size, topK, device, epad_id, rpad_id,
                 coeff_info=1., coeff_l2=1.):
        super(TransE, self).__init__()
        self.num_ents = num_ents
        self.num_rels = num_rels
        self.epad_id = epad_id
        self.rpad_id = rpad_id
        self.neg_ratio = neg_ratio
        self.batch_size = batch_size
        self.topK = topK
        self.device = device
        self.ent_embeddings = nn.Embedding(self.num_ents+1, hidden_size, padding_idx=epad_id).to(device)
        self.rel_embeddings = nn.Embedding(self.num_rels+1, hidden_size, padding_idx=rpad_id).to(device)
        self.attn = Attention(hidden_size*2, num_heads=1, qkv_bias=True, attn_drop=0.2, proj_drop=0.)
        self.context_learner = ContextLearner(dim=hidden_size*2, num_heads=1, drop=0.2, attn_drop=0.2, drop_path=0.2)
        self.meta_learner = nn.Linear(hidden_size*2, hidden_size)
        self.criterion = nn.MarginRankingLoss(margin, reduction="sum").to(device)
        self.learned_rel_embedding = nn.Embedding(self.num_rels+1, hidden_size, padding_idx=rpad_id).to(device)
        self.learned_rel_embedding.weight.requires_grad=False
        self.init_weights()
        self.coeff_info = coeff_info
        self.coeff_l2 = coeff_l2

    def init_weights(self):
        nn.init.xavier_uniform_(self.ent_embeddings.weight.data)
        nn.init.xavier_uniform_(self.rel_embeddings.weight.data)

    def _calc(self, h, r, t):
        h = nn.functional.normalize(h, 2, -1)
        r = nn.functional.normalize(r, 2, -1)
        t = nn.functional.normalize(t, 2, -1)
        return torch.norm(h + r - t, 1, -1)

    def loss(self, p_score, n_score):
        y = Variable(torch.Tensor([-1])).to(self.device)
        return self.criterion(p_score, n_score, y)
    
    def build_context(self, meta): 
        # meta: ([bs, 50, 2], [bs], [bs, 50, 2], [bs])
        left_connections, left_degrees, right_connections, right_degrees = meta
        
        batch_size, max_neighbor = left_connections.shape[0], left_connections.shape[1]
        left_digits = torch.zeros(batch_size, max_neighbor).to(self.device)
        right_digits = torch.zeros(batch_size, max_neighbor).to(self.device)
        for i in range(batch_size):
            left_digits[i, :left_degrees[i]] = 1
            right_digits[i, :right_degrees[i]] = 1
        
        connections = torch.cat((left_connections, right_connections), dim=1)     # [bs, 100, 2]    
        mask = torch.cat((left_digits, right_digits), dim=1)                      # [bs, 100]
        mask_matrix = mask.unsqueeze(2)
        mask = torch.bmm(mask_matrix, mask_matrix.transpose(1,2))
        
        return connections, mask

    def forward(self, urels, batch_h, batch_r, batch_t, meta, negative_meta):
        # batch_h.shape: [1664], h.shape: [1664, 20]
        h = self.ent_embeddings(batch_h)
        r = self.rel_embeddings(batch_r)
        t = self.ent_embeddings(batch_t)
        # score.shape: [1664]
        score = self._calc(h, r, t)
        p_score = self.get_positive_score(score)
        n_score = self.get_negative_score(score)
        """
        translational loss at triplet level
        """
        loss = loss_margin = self.loss(p_score, n_score)
        
        # continual learning representative sample selection:
        # pick out the true triplets and their embeddings
        true_batch_h = batch_h[0: len(batch_h): self.neg_ratio+1]
        true_batch_r = batch_r[0: len(batch_r): self.neg_ratio+1]
        true_batch_t = batch_t[0: len(batch_t): self.neg_ratio+1]
        
        true_h = h[0: len(batch_h): self.neg_ratio+1, :]
        true_r = r[0: len(batch_r): self.neg_ratio+1, :]
        true_t = t[0: len(batch_t): self.neg_ratio+1, :]
        
        # generate relation mask for each relation
        mask = torch.ones((urels.shape[0], true_batch_r.shape[0], true_batch_r.shape[0])).to(self.device)
        rel_square_mask = torch.square(urels.view(-1, 1, 1).expand_as(mask))
        batch_square = true_batch_r.unsqueeze(0).t().float() @ true_batch_r.unsqueeze(0).float()
        attn_mask = (batch_square.repeat(rel_square_mask.shape[0], 1, 1) == rel_square_mask)   
        
        # masked attention
        true_triplets = torch.cat((true_h, true_t), dim=1)
        _, attn = self.attn(true_triplets.unsqueeze(0))
        attn = attn.squeeze().expand_as(attn_mask)
        rel_attns = (attn * attn_mask).sum(dim=1)
        sel_triplets = set()
        try:
            max_attn_idx = torch.topk(rel_attns, self.topK, dim=1)[1]
            # select triplets (remove meaningless idx)
            rel_tri_cnt = torch.bincount(true_batch_r, minlength=self.num_rels+1)
            for i in range(urels.shape[0]):      # for each relation      
                valid_topK = torch.min(torch.tensor(self.topK).to(self.device), rel_tri_cnt[urels[i]])
                sel_triplet_h = true_batch_h[max_attn_idx[i][:valid_topK]]
                sel_triplet_r = true_batch_r[max_attn_idx[i][:valid_topK]]
                sel_triplet_t = true_batch_t[max_attn_idx[i][:valid_topK]]
                sel_triplets.add(torch.cat((sel_triplet_h.unsqueeze(0).t(),
                                            sel_triplet_r.unsqueeze(0).t(),
                                            sel_triplet_t.unsqueeze(0).t()), dim=1).detach())
        except:
            print('Warning: No urels.')

        """
        contrastive learning
        """
        # build meta/negative meta]
        positive_connections, positive_mask = self.build_context(meta)
        # get entity and rel emb for meta
        relations, entities = positive_connections[:, :, 0], positive_connections[:, :, 1]
        rel_emb = self.rel_embeddings(relations) # [bs, 100, 20]
        entity_emb = self.ent_embeddings(entities)
        positive_context = self.context_learner(entity_emb, rel_emb, positive_mask) #[bs, 20]
        
        negative_context = []
        for nn_connections, nn_masks in [self.build_context(negative_m) for negative_m in negative_meta]:
            relations, entities = nn_connections[:, :, 0], nn_connections[:, :, 1]
            rel_emb = self.rel_embeddings(relations)
            entity_emb = self.ent_embeddings(entities)
            negative_context.append(self.context_learner(entity_emb, rel_emb, nn_masks))
        
        triplet_emb = torch.cat((true_h, true_t), dim=1)
        loss_infonce = InfoNCE(negative_mode='unpaired')
        loss_info = loss_infonce(triplet_emb, positive_context, torch.stack(negative_context, dim=1).squeeze(1))
        loss += self.coeff_info * loss_info
        
        """
        store the rel embeddings if the rel is the unique rel in current task
        otherwise: calculate L2 loss
        """
        batch_rels = true_batch_r.unique()
        for rel in urels:
            self.learned_rel_embedding.weight[rel] = self.rel_embeddings(rel).detach()
            
        # for rels in batch_rel but not the unique rel in this task: should be fixed
        compareview = urels.expand(batch_rels.shape[0], urels.shape[0]).T
        previous_rels = batch_rels[(compareview != batch_rels).T.prod(1)==1]
        # calculate L2 for previous_rels
        loss_l2 = None
        for rel in previous_rels:
            l2loss = nn.MSELoss()
            loss_l2 = l2loss(self.learned_rel_embedding(rel), self.rel_embeddings(rel))
            loss += self.coeff_l2 * loss_l2
        
        return (loss, loss_margin, loss_info, loss_l2), sel_triplets

    def predict(self, batch_h, batch_r, batch_t):
        h = self.ent_embeddings(batch_h)
        r = self.rel_embeddings(batch_r)
        t = self.ent_embeddings(batch_t)
        score = self._calc(h, r, t)
        return score.cpu().data.numpy()

    def get_positive_score(self, score):
        return score[0:len(score):self.neg_ratio+1]

    def get_negative_score(self, score):
        negs = torch.tensor([], dtype=torch.float32).to(self.device)
        for idx in range(0, len(score), self.neg_ratio + 1):
            batch_negs = score[idx + 1:idx + self.neg_ratio + 1]
            negs = torch.cat((negs, torch.mean(batch_negs,0,keepdim=True)))
        return negs