import dgl
import math
import torch
import torch.nn as nn
import torch.nn.functional as F

class HetGATLayer(nn.Module):
    def __init__(self, in_dim, out_dim, nodes, edges, edge_features, attention_features, n_heads, dropout = 0.2, use_norm = False):
        super(HetGATLayer, self).__init__()

        self.in_dim        = in_dim
        self.out_dim       = out_dim
        self.n_heads       = n_heads
        self.nodes       = nodes
        self.edges       = edges
        self.edge_features = edge_features
        self.attention_features = attention_features
        
        self.node_linears = nn.ModuleDict(
            {
                node: nn.Linear(in_dim[node], out_dim[node] * self.n_heads) for node in nodes
            }
        )
        self.edge_linears = nn.ModuleDict(
            {
                edge: nn.Linear(edge_features[edge], out_dim[edge] * self.n_heads) for edge in edge_features
            }
        )
        
        self.fc_nodes = nn.ModuleDict(
            {
                edge: nn.Linear(out_dim[src], out_dim[dst] * self.n_heads) for src, edge, dst in edges
            }
        )
        self.fc_edges = nn.ModuleDict(
            {
                edge: nn.Linear(out_dim[edge], out_dim[dst] * self.n_heads) for src, edge, dst in edges if edge in self.edge_features
            }
        )
        
        self.activation    = F.leaky_relu
        
    def edge_attention(self, edges):
        # print("Edge Attention", edges._etype, edges.data.keys())
        etype = edges._etype[1]
        relation_att = self.relation_att[etype]
        relation_pri = self.relation_pri[etype]
        relation_msg = self.relation_msg[etype]
        key   = torch.bmm(edges.src['k'].transpose(1,0), relation_att.squeeze(0)).transpose(1,0)
        att   = (edges.dst['q'] * key).sum(dim=-1) * relation_pri / self.sqrt_dk[etype]
        val   = torch.bmm(edges.src['v'].transpose(1,0), relation_msg.squeeze(0)).transpose(1,0)
        
        relation_edge = self.relation_edge[etype]
        edg   = torch.bmm(edges.data['e'].transpose(1,0), relation_edge.squeeze(0)).transpose(1,0)
        att_edge = (edges.dst['q'] * edg).sum(dim=-1) * relation_pri / self.sqrt_dk[etype]
        
        a_e = F.softmax(att_edge, dim=-1)
        h_edge = torch.sum(a_e.unsqueeze(dim = -1) * edg, dim=1)
        
        edges.data['h_edge'] = h_edge
        # print(etype, key.shape, att.shape, val.shape, h_edge.shape)
        # print("Edge Attention", att, val)
        return {'a': att, 'v': val, 'e': h_edge}
        
        # # print("Edge Attention", att, val)
        # return {'a': att, 'v': val}
    
    def message_func(self, edges):
        # print("Message Passing", edges._etype, edges.data.keys())
        return {'v': edges.data['v'], 'a': edges.data['a'], 'e': edges.data['e']}
        # return {'v': edges.data['v'], 'a': edges.data['a']}
    
    def reduce_func(self, nodes):
        dsttype = nodes._ntype
        # print("Reduce Key", nodes.data.keys())
        att = F.softmax(nodes.mailbox['a'], dim=-1)
        # print(att)
        # TODO: Continue HERE!!!!!!!!
        # print(dsttype, nodes.data['v'].shape, att.shape)
        h   = torch.sum(att.unsqueeze(dim = -1) * nodes.mailbox['v'], dim=1) + torch.sum(nodes.mailbox['e'], dim=1).unsqueeze(1)
        # print(dsttype, h.view(-1, self.out_dim[dsttype]))
        # return {'t': h}
        return {'t': h.view(-1, self.n_heads * self.out_dim[dsttype])}
        
    def forward(self, G, node_dict, edge_dict):
        
        for key, value in edge_dict.items():
            G.edges[key].data['e'] = self.e_linears[key](value).view(-1, self.n_heads, self.out_dim[key])
            
        for srctype, etype, dsttype in self.edges:   
            G.nodes[srctype].data['k'] = self.k_linears[etype](node_dict[srctype]).view(-1, self.n_heads, self.out_dim[dsttype])
            G.nodes[srctype].data['v'] = self.v_linears[etype](node_dict[srctype]).view(-1, self.n_heads, self.out_dim[dsttype])
            G.nodes[dsttype].data['q'] = self.q_linears[etype](node_dict[dsttype]).view(-1, self.n_heads, self.out_dim[dsttype])
            
            G.apply_edges(func=self.edge_attention, etype=etype)
        
        G.multi_update_all({etype : (self.message_func, self.reduce_func) \
                            for _, etype, _ in self.edges}, cross_reducer = 'mean')
        
        # print(G.nodes['paper'].data['t'])
        for srctype, etype, dsttype in self.edges:
            trans_out = self.a_linears[etype](G.nodes[dsttype].data['t'].transpose(1, -1))
            
            # trans_out = G.nodes[dsttype].data['h'] * (1-alpha)
            if self.use_norm:
                G.nodes[dsttype].data['h'] = self.norms[etype](trans_out)
            else:
                G.nodes[dsttype].data['h'] = trans_out
            # if self.use_norm:
            #     G.nodes[dsttype].data['h'] = self.drop(self.norms[etype](trans_out))
            # else:
            #     G.nodes[dsttype].data['h'] = self.drop(trans_out)
        
        node_dicts = {ntype: self.activation(G.nodes[ntype].data['h']) for ntype in self.nodes}
        edge_dict = {etype: self.activation(G.edges[etype].data['h_edge']) for _, etype, _ in self.edges}

        return node_dicts, edge_dict
                
    def __repr__(self):
        return '{}(in_dim={}, out_dim={}, num_types={}, num_types={})'.format(
            self.__class__.__name__, self.in_dim, self.out_dim,
            self.num_types, self.num_relations)

class HetGAT(nn.Module):
    def __init__(self, nodes, edges, in_dim, hid_dim, out_dim, n_layers, n_heads, use_norm = False):
        super(HGTEdge, self).__init__()
        
        self.in_dim = in_dim
        self.hid_dim = hid_dim
        self.out_dim = out_dim
        
        self.layers = nn.ModuleList()

        self.layers.append(HGTEdgeLayer(in_dim, hid_dim, nodes, edges, n_heads, use_norm = use_norm))
        for _ in range(n_layers - 2):
            self.layers.append(HGTEdgeLayer(hid_dim, hid_dim, nodes, edges, n_heads, use_norm = use_norm))
        self.layers.append(HGTEdgeLayer(hid_dim, out_dim, nodes, edges, n_heads, use_norm = use_norm))

    def forward(self, G, node_dict, edge_dict, output_node):
        for layer in self.layers:
            node_dict, edge_dict = layer(G, node_dict, edge_dict)
        return node_dict, edge_dict
    
    def __repr__(self):
        return '{}(n_inp={}, n_hid={}, n_out={}, n_layers={})'.format(
            self.__class__.__name__, self.in_dim, self.hid_dim,
            self.out_dim, self.n_layers)