import dgl
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
class HGTELayer(nn.Module):
    def __init__(self, in_dim, out_dim, nodes, edges, n_heads, dropout = 0.2, use_norm = False):
        super(HGTELayer, self).__init__()

        self.in_dim        = in_dim
        self.out_dim       = out_dim
        self.n_heads       = n_heads
        self.d_k           = {key: max((out_dim[key] // n_heads), 1) for key in out_dim}
        self.sqrt_dk       = {key: math.sqrt(value) for key, value in self.d_k.items()}
        
        # print(f"Sqrt DK: {self.sqrt_dk.keys()}")
        self.nodes       = nodes
        self.edges       = edges
        
        # print([d for d in self.in_dim.values()])
        print("\n".join([f"{key} : {value}" for key, value in self.in_dim.items()]))
        self.k_linears   = nn.ModuleDict(
            {
                edge: nn.Linear((in_dim[source] + in_dim[edge]), (out_dim[target] * self.n_heads)) for source, edge, target in self.edges
            }
        )
        self.q_linears   = nn.ModuleDict(
            {
                edge: nn.Linear(in_dim[target], (out_dim[target] * self.n_heads)) for source, edge, target in self.edges
            }
        )
        self.v_linears   = nn.ModuleDict(
            {
                edge: nn.Linear((in_dim[source] + in_dim[edge]), (out_dim[target] * self.n_heads)) for source, edge, target in self.edges
            }
        )
        self.a_linears   = nn.ModuleDict(
            {
                edge: nn.Linear(out_dim[target] * self.n_heads, out_dim[target]) for source, edge, target in self.edges
            }
        )
        if use_norm:
            self.norms       = nn.ModuleDict(
                {
                    edge: nn.LayerNorm(out_dim[target]) for source, edge, target in self.edges
                }
            )
        self.use_norm    = use_norm
            
        # The Relation Primitive is a learnable parameter for each edge based on the edge type for multi-head attention
        self.relation_pri = nn.ParameterDict({
            edge: nn.Parameter(torch.ones(1, self.n_heads)) for _, edge, _ in self.edges
        })
        self.relation_att = nn.ParameterDict({
            edge: nn.Parameter(torch.Tensor(1, n_heads, self.out_dim[target], self.out_dim[target])) for _, edge, target in self.edges
        })
        self.relation_msg = nn.ParameterDict({
            edge: nn.Parameter(torch.Tensor(1, n_heads, self.out_dim[target], self.out_dim[target])) for _, edge, target in self.edges
        })
        
        for param in self.relation_pri.values():
            nn.init.xavier_normal_(param)
        for param in self.relation_att.values():
            nn.init.xavier_normal_(param)
        for param in self.relation_msg.values():
            nn.init.xavier_normal_(param)
        # self.skip           = nn.Parameter(torch.ones(num_types))
        # self.drop           = nn.Dropout(dropout)
        
        self.activation    = F.leaky_relu
        
    def edge_attention(self, edges):
        # print("Edge Attention", edges._etype, edges.data.keys())
        etype = edges._etype[1]
        srctype = edges._etype[0]
        dsttype = edges._etype[2]
        
        print(etype, edges.src['node_feats'].shape, edges.data['edge_feats'].shape, self.k_linears[etype].weight.shape)
        
        k = self.k_linears[etype](torch.cat([edges.src['node_feats'], edges.data['edge_feats']], dim = -1)).view(-1, self.n_heads, self.out_dim[dsttype])
        v = self.v_linears[etype](torch.cat([edges.src['node_feats'], edges.data['edge_feats']], dim = -1)).view(-1, self.n_heads, self.out_dim[dsttype])
        q = self.q_linears[etype](edges.dst['node_feats']).view(-1, self.n_heads, self.out_dim[dsttype])
        
        key = torch.bmm(k.transpose(1,0), self.relation_att[etype].squeeze(0)).transpose(1,0)
        att = (q * key).sum(dim=-1) * self.relation_pri[etype] / self.sqrt_dk[etype]
        val = torch.bmm(v.transpose(1,0), self.relation_msg[etype].squeeze(0)).transpose(1,0)
        
        return {'a': att, 'v': val}
        
        # edges.src['node_feats']
        # edges.tgt['node_feats']
        
        # relation_att = self.relation_att[etype]
        # relation_pri = self.relation_pri[etype]
        # relation_msg = self.relation_msg[etype]
        # key   = torch.bmm(edges.src['k'].transpose(1,0), relation_att.squeeze(0)).transpose(1,0)
        # att   = (edges.dst['q'] * key).sum(dim=-1) * relation_pri / self.sqrt_dk[etype]
        # val   = torch.bmm(edges.src['v'].transpose(1,0), relation_msg.squeeze(0)).transpose(1,0)
        # # print(etype, key.shape, att.shape, val.shape, h_edge.shape)
        # # print("Edge Attention", att, val)
        # return {'a': att, 'v': val}
        
        # # # print("Edge Attention", att, val)
        # # return {'a': att, 'v': val}
    
    def message_func(self, edges):
        # print("Message Passing", edges._etype, edges.data.keys())
        return {'v': edges.data['v'], 'a': edges.data['a']}
        # return {'v': edges.data['v'], 'a': edges.data['a']}
    
    def reduce_func(self, nodes):
        dsttype = nodes._ntype
        # print("Reduce Key", nodes.data.keys())
        att = F.softmax(nodes.mailbox['a'], dim=-1)
        # print(att)
        # print(dsttype, nodes.data['v'].shape, att.shape)
        h   = torch.sum(att.unsqueeze(dim = -1) * nodes.mailbox['v'], dim=1)
        # print(dsttype, h.view(-1, self.out_dim[dsttype]))
        # return {'t': h}
        return {'t': h.view(-1, self.n_heads * self.out_dim[dsttype])}
        
    def forward(self, G, node_dict, edge_dict):
            
        print("Node Dict", node_dict.keys())
        print("Edge Dict", edge_dict.keys())
        for srctype, etype, dsttype in self.edges:
            G.nodes[srctype].data['node_feats'] = node_dict[srctype]
            G.nodes[dsttype].data['node_feats'] = node_dict[dsttype]
            G.edges[etype].data['edge_feats'] = edge_dict[etype]
            # G.nodes[srctype].data['k'] = self.k_linears[etype](node_dict[srctype]).view(-1, self.n_heads, self.out_dim[dsttype])
            # G.nodes[srctype].data['v'] = self.v_linears[etype](node_dict[srctype]).view(-1, self.n_heads, self.out_dim[dsttype])
            # G.nodes[dsttype].data['q'] = self.q_linears[etype](node_dict[dsttype]).view(-1, self.n_heads, self.out_dim[dsttype])
            
            G.apply_edges(func=self.edge_attention, etype=etype)
        
        G.multi_update_all({etype : (self.message_func, self.reduce_func) \
                            for _, etype, _ in self.edges}, cross_reducer = 'mean')
        
        # print(G.nodes['paper'].data['t'])
        for srctype, etype, dsttype in self.edges:
            trans_out = self.a_linears[etype](G.nodes[dsttype].data['t'].transpose(1, -1))
            
            # trans_out = G.nodes[dsttype].data['h'] * (1-alpha)
            if self.use_norm:
                G.nodes[dsttype].data['h'] = self.norms[etype](trans_out)
            else:
                G.nodes[dsttype].data['h'] = trans_out
            # if self.use_norm:
            #     G.nodes[dsttype].data['h'] = self.drop(self.norms[etype](trans_out))
            # else:
            #     G.nodes[dsttype].data['h'] = self.drop(trans_out)
        
        node_dicts = {ntype: self.activation(G.nodes[ntype].data['h']) for ntype in self.nodes}

        return node_dicts, edge_dict
                
    def __repr__(self):
        return '{}(in_dim={}, out_dim={}, num_types={}, num_types={})'.format(
            self.__class__.__name__, self.in_dim, self.out_dim,
            self.num_types, self.num_relations)

class HGT(nn.Module):
    def __init__(self, nodes, edges, in_dim, hid_dim, out_dim, n_layers, n_heads, use_norm = False):
        super(HGT, self).__init__()
        
        self.in_dim = in_dim
        self.hid_dim = hid_dim
        self.out_dim = out_dim
        
        self.layers = nn.ModuleList()

        if n_layers == 1:
            self.layers.append(HGTLayer(in_dim, out_dim, nodes, edges, n_heads, use_norm = use_norm))
        else:
            self.layers.append(HGTLayer(in_dim, hid_dim, nodes, edges, n_heads, use_norm = use_norm))
            for _ in range(n_layers - 2):
                self.layers.append(HGTLayer(hid_dim, hid_dim, nodes, edges, n_heads, use_norm = use_norm))
            self.layers.append(HGTLayer(hid_dim, out_dim, nodes, edges, n_heads, use_norm = use_norm))

    def forward(self, G, node_dict, edge_dict, output_node):
        for layer in self.layers:
            node_dict, edge_dict = layer(G, node_dict, edge_dict)
        return node_dict, edge_dict
    
    def __repr__(self):
        return '{}(n_inp={}, n_hid={}, n_out={}, n_layers={})'.format(
            self.__class__.__name__, self.in_dim, self.hid_dim,
            self.out_dim, self.n_layers)