'''
Date: 2022-08-06 22:33:13
LastEditors: yuhhong
LastEditTime: 2022-09-27 13:16:40
'''
import math
import torch
import torch.nn as nn
import torch.nn.functional as F

from .utils import FCResDecoder, TRNet



class AtomSharedMultiheadAttention(nn.Module):

    def __init__(self, input_dim, embed_dim, num_heads):
        super().__init__()
        assert embed_dim % num_heads == 0, "Embedding dimension must be 0 modulo number of heads."

        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.head_dim = embed_dim // num_heads

        self.q_proj = nn.Conv2d(input_dim, embed_dim, kernel_size=1, bias=False)
        self.k_proj = nn.Conv2d(input_dim, embed_dim, kernel_size=1, bias=False)
        self.v_proj = nn.Conv2d(input_dim, embed_dim, kernel_size=1, bias=False)

        self.o_proj = nn.Conv2d(embed_dim, embed_dim, kernel_size=1, bias=False)

        self._reset_parameters()

    def _reset_parameters(self):
        nn.init.kaiming_normal_(self.q_proj.weight, a=0.2, mode='fan_in', nonlinearity='leaky_relu')
        nn.init.kaiming_normal_(self.k_proj.weight, a=0.2, mode='fan_in', nonlinearity='leaky_relu')
        nn.init.kaiming_normal_(self.v_proj.weight, a=0.2, mode='fan_in', nonlinearity='leaky_relu')

    def _scaled_dot_product(self, q, k, v, mask=None):
        d_k = q.size()[-1]
        attn_logits = torch.matmul(q, k.transpose(-2, -1))
        attn_logits = attn_logits / math.sqrt(d_k)
        if mask is not None:
            attn_logits = attn_logits.masked_fill(mask == 0, -9e15)
        attention = F.softmax(attn_logits, dim=-1)
        values = torch.matmul(attention, v)
        return values, attention

    def forward(self, q, k, v, mask=None, return_attention=False):
        assert q.size() == k.size() == v.size() 

        # shared mlp
        q = self.q_proj(q)
        k = self.k_proj(k)
        v = self.v_proj(v)
        
        batch_size, _, atom_num, k_num = q.size()
        seq_length = atom_num * k_num

        # unfolding q, k, v
        q = torch.flatten(q, start_dim=2, end_dim=3).permute(0, 2, 1) # [batch_size, seq_length, embed_dim]
        k = torch.flatten(k, start_dim=2, end_dim=3).permute(0, 2, 1)
        v = torch.flatten(v, start_dim=2, end_dim=3).permute(0, 2, 1)
        q = q.reshape(batch_size, seq_length, self.num_heads, self.head_dim).permute(0, 2, 1, 3)
        k = k.reshape(batch_size, seq_length, self.num_heads, self.head_dim).permute(0, 2, 1, 3)
        v = v.reshape(batch_size, seq_length, self.num_heads, self.head_dim).permute(0, 2, 1, 3)

        # determine value outputs
        values, attention = self._scaled_dot_product(q, k, v, mask=mask)
        values = values.permute(0, 2, 1, 3) # [batch_size, seq_length, head, dims]
        values = values.reshape(batch_size, self.embed_dim, atom_num, k_num)
        o = self.o_proj(values)

        if return_attention:
            return o, attention
        else:
            return o



class MolConv(nn.Module):

    def __init__(self, in_dim, out_dim, k, device, skip_connection=False):
        super(MolConv, self).__init__()
        self.k = k
        self.in_dim = in_dim
        self.out_dim = out_dim
        self.device = device
        self.skip_connection = skip_connection

        self.neighbor_ff = nn.Sequential(nn.Conv2d(in_dim*2, out_dim, kernel_size=1, bias=False),
                                   nn.BatchNorm2d(out_dim),
                                   nn.LeakyReLU(negative_slope=0.02))
        self.neighbor_ma = AtomSharedMultiheadAttention(input_dim=out_dim, embed_dim=out_dim, num_heads=4)
        self.neighbor_bn = nn.BatchNorm2d(out_dim)

        self.center_ff = nn.Sequential(nn.Conv1d(in_dim, out_dim, kernel_size=1, bias=False),
                                        nn.BatchNorm1d(out_dim),
                                        nn.LeakyReLU(negative_slope=0.02),
                                        nn.Conv1d(out_dim, out_dim, kernel_size=1, bias=False),
                                        nn.BatchNorm1d(out_dim))

        self.all_ff = nn.Sequential(nn.Conv1d(in_channels=out_dim*2, out_channels=out_dim, kernel_size=1, bias=False),
                                        nn.BatchNorm1d(out_dim),
                                        nn.LeakyReLU(negative_slope=0.02))
        self.activation_f = nn.LeakyReLU(negative_slope=0.02)

        self._reset_parameters()

    def _reset_parameters(self): 
        for m in self.modules(): 
            if isinstance(m, (nn.Conv1d, nn.Conv2d, nn.Linear)):
                nn.init.kaiming_normal_(m.weight, a=0.2, mode='fan_in', nonlinearity='leaky_relu')
            
            elif isinstance(m, (nn.BatchNorm1d, nn.BatchNorm2d, nn.GroupNorm)): 
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)

    def forward(self, x): 
        x_pre = x

        # neighbor features extraction: knn
        feat = self._get_graph_feature(x, device=self.device, k=self.k) # get k-nearest neighbors and center point

        # neighbor feature extraction: ff
        feat = self.neighbor_ff(feat)
        # neighbor feature extraction: atom-shared ma
        attn_feat, attn = self.neighbor_ma(feat, feat, feat, return_attention=True)
        # neighbor feature extraction: bn & add
        feat = self.neighbor_bn(feat + attn_feat)
        # merge k neighbors: sum
        feat = feat.sum(dim=-1, keepdim=False)

        # center feature extraction: ff
        cent = self.center_ff(x)

        # update all features: ff
        x = self.all_ff(torch.cat((feat, cent), dim=1))
        if self.skip_connection:
            x_pre = F.interpolate(x_pre.permute(0, 2, 1), scale_factor=2).permute(0, 2, 1)
            x = self.activation_f(x + x_pre)
        return x

    def _knn(self, x, k): 
        inner = -2*torch.matmul(x.transpose(2, 1), x)
        xx = torch.sum(x**2, dim=1, keepdim=True)
        pairwise_distance = -xx - inner - xx.transpose(2, 1)

        idx = pairwise_distance.topk(k=k+1, dim=-1)[1] # (batch_size, num_points, k)
        idx = idx[:, :, 1:] # except the center point
        return idx

    def _get_graph_feature(self, x, device, k=20): 
        batch_size, num_dims, num_points = x.size()

        idx = self._knn(x, k=k) # (batch_size, num_points, k)

        # original methods: without torch.arange
        device = torch.device("cuda:" + str(device)) if torch.cuda.is_available() else torch.device("cpu")
        idx_base = torch.arange(0, batch_size, device=device).view(-1, 1, 1)*num_points
        idx = idx + idx_base
        idx = idx.view(-1)

        x = x.transpose(2, 1).contiguous() # (batch_size, num_points, num_dims) -> (batch_size*num_points, num_dims) 
        # batch_size * num_points * k + range(0, batch_size*num_points)
        feature = x.view(batch_size*num_points, -1)[idx, :]
        feature = feature.view(batch_size, num_points, k, num_dims)
        
        x = x.view(batch_size, num_points, 1, num_dims).repeat(1, 1, k, 1)
        
        feature = torch.cat((feature, x), dim=3).permute(0, 3, 1, 2).contiguous() # torch.Size([32, 42, 300, 5])
        # feature = (feature-x).permute(0, 3, 1, 2).contiguous() # torch.Size([32, 21, 300, 5])
        return feature
    
    def __repr__(self):
        return self.__class__.__name__ + ' k = ' + str(self.k) + ' (' + str(self.in_dim) + ' -> ' + str(self.out_dim) + ')'



class Encoder(nn.Module):

    def __init__(self, in_dim, emb_dim, k, device):
        super(Encoder, self).__init__()
        
        self.conv1 = MolConv(in_dim=in_dim, out_dim=64, k=k, device=device, skip_connection=False)
        self.conv2 = MolConv(in_dim=64, out_dim=64, k=k, device=device, skip_connection=False)
        self.conv3 = MolConv(in_dim=64, out_dim=128, k=k, device=device, skip_connection=True)
        self.conv4 = MolConv(in_dim=128, out_dim=256, k=k, device=device, skip_connection=True)
        self.conv5 = MolConv(in_dim=256, out_dim=512, k=k, device=device, skip_connection=True)
        
        self.merge = nn.Sequential(nn.Linear(2048, emb_dim), 
                                   nn.BatchNorm1d(emb_dim), 
                                   nn.LeakyReLU(negative_slope=0.2))
        self._reset_parameters()

    def _reset_parameters(self): 
        for m in self.merge: 
            if isinstance(m, (nn.Conv1d, nn.Conv2d, nn.Linear)):
                nn.init.kaiming_normal_(m.weight, a=0.2, mode='fan_in', nonlinearity='leaky_relu')
            
            elif isinstance(m, (nn.BatchNorm1d, nn.BatchNorm2d, nn.GroupNorm)): 
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)

    def forward(self, x, mask): 
        '''
        x:      set of points, torch.Size([32, 21, 300]) 
        mask:   mask of real atom numner (without padding), torch.Size([32, 300])
        ''' 
        x1 = self.conv1(x)
        x1 = torch.mul(x1.permute(1, 0, 2), mask).permute(1, 0, 2) # apply the mask by product them with broadcasting

        x2 = self.conv2(x1)
        x2 = torch.mul(x2.permute(1, 0, 2), mask).permute(1, 0, 2)

        x3 = self.conv3(x2)
        x3 = torch.mul(x3.permute(1, 0, 2), mask).permute(1, 0, 2)

        x4 = self.conv4(x3)
        x4 = torch.mul(x4.permute(1, 0, 2), mask).permute(1, 0, 2)

        x5 = self.conv5(x4)
        x5 = torch.mul(x5.permute(1, 0, 2), mask).permute(1, 0, 2)

        x = torch.cat((x1, x2, x3, x4, x5), dim=1)

        p1 = F.adaptive_max_pool1d(x, 1).squeeze()
        p2 = F.adaptive_avg_pool1d(x, 1).squeeze()
        
        x = torch.cat((p1, p2), 1) 
        x = self.merge(x)
        return x



class MolNet(nn.Module): 

    def __init__(self, args, device): 
        super(MolNet, self).__init__()
        self.num_add = args['num_add']
        self.num_atoms = args['num_atoms']
        self.device = torch.device("cuda:" + str(device)) if torch.cuda.is_available() else torch.device("cpu")

        self.tr_net = TRNet(device)
        self.encoder = Encoder(in_dim=args['in_channels'], 
                                    emb_dim=args['emb_dim'], 
                                    k=args['k'], 
                                    device=device)
        self.decoder = FCResDecoder(in_dim=args['emb_dim']+args['num_add'], 
                                    layers=args['decoder_layers'], 
                                    out_dim=args['out_channels'], 
                                    dropout=args['dropout'])
        self.num_add = args['num_add']
        
    def forward(self, x, mask, env): 
        batch_size = x.size(0)

        # init xyzw by xyz
        w = torch.ones(batch_size, 1, self.num_atoms).to(self.device) 
        xyz = x[:, :3, :]
        xyzw = torch.cat((xyz, w), dim=1)

        # predict transformation matrix by TRNet
        tr_matrix = self.tr_net(xyz)
        
        # translation and rotation transform
        xyzw = torch.bmm(tr_matrix, xyzw)
        
        # convert xyzw to xyz
        w = xyzw[:, 3, :]
        w = torch.stack([w, w, w], dim=1)
        xyz = xyzw[:, :3, :]
        xyz = torch.div(xyz, w)
        
        # concat transformed xyz to input
        x = torch.cat((xyz, x[:, 3:, :]), dim=1)

        # encoder
        x = self.encoder(x, mask) # torch.Size([batch_size, emb_dim])

        if self.num_add == 1:
            x = torch.cat((x, torch.unsqueeze(env, 1)), 1)
        elif self.num_add > 1:
            x = torch.cat((x, env), 1)

        # decoder
        x = self.decoder(x)
        return torch.squeeze(x)

