import dgl
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
class HGTEdgeLayer(nn.Module):
    def __init__(self, in_dim, out_dim, nodes, edges, n_heads, dropout = 0.2, use_norm = False):
        super(HGTEdgeLayer, self).__init__()

        self.in_dim        = in_dim
        self.out_dim       = out_dim
        self.n_heads       = n_heads
        self.d_k           = {key: max((out_dim[key] // n_heads), 1) for key in out_dim}
        self.sqrt_dk       = {key: math.sqrt(value) for key, value in self.d_k.items()}
        
        # print(f"Sqrt DK: {self.sqrt_dk.keys()}")
        self.nodes       = nodes
        self.edges       = edges
        
        self.k_linears   = nn.ModuleDict(
            {
                edge: nn.Linear(in_dim[source], out_dim[target] * self.n_heads) for source, edge, target in self.edges
            }
        )
        for i, key in enumerate(self.k_linears):
            nn.init.xavier_normal_(self.k_linears[key].weight)
            nn.init.zeros_(self.k_linears[key].bias)
            
            # print(key)
            # if key == 'agent_to_agent_select':
            #     print(f"Weight: {key} - {self.k_linears[key].weight}")
        
        self.q_linears   = nn.ModuleDict(
            {
                edge: nn.Linear(in_dim[target], out_dim[target] * self.n_heads) for source, edge, target in self.edges
            }
        )
        for key in self.q_linears:
            nn.init.xavier_normal_(self.q_linears[key].weight)
            nn.init.zeros_(self.q_linears[key].bias)
            
        self.v_linears   = nn.ModuleDict(
            {
                edge: nn.Linear(in_dim[source], out_dim[target] * self.n_heads) for source, edge, target in self.edges
            }
        )
        for key in self.v_linears:
            nn.init.xavier_normal_(self.v_linears[key].weight)
            nn.init.zeros_(self.v_linears[key].bias)
        
        # equivalent of k
        self.e_linears   = nn.ModuleDict(
            {
                edge: nn.Linear(in_dim[edge], out_dim[target] * self.n_heads) for source, edge, target in self.edges
            }
        )
        for key in self.e_linears:
            nn.init.xavier_normal_(self.e_linears[key].weight)
            nn.init.zeros_(self.e_linears[key].bias)
            
        # equivalent of v
        self.ev_linears   = nn.ModuleDict(
            {
                edge: nn.Linear(in_dim[edge], out_dim[target] * self.n_heads) for source, edge, target in self.edges
            }
        )
        for key in self.ev_linears:
            nn.init.xavier_normal_(self.ev_linears[key].weight)
            nn.init.zeros_(self.ev_linears[key].bias)
        
        self.a_linears   = nn.ModuleDict(
            {
                edge: nn.Linear(out_dim[target] * self.n_heads, out_dim[target]) for source, edge, target in self.edges
            }
        )
        for key in self.a_linears:
            nn.init.xavier_normal_(self.a_linears[key].weight)
            nn.init.zeros_(self.a_linears[key].bias)
            
        self.ea_linears  = nn.ModuleDict(
            {
                edge: nn.Linear(out_dim[target] * self.n_heads, out_dim[target]) for source, edge, target in self.edges
            }
        )
        for key in self.ea_linears:
            nn.init.xavier_normal_(self.ea_linears[key].weight)
            nn.init.zeros_(self.ea_linears[key].bias)
        
        if use_norm:
            self.norms       = nn.ModuleDict(
                {
                    edge: nn.LayerNorm(out_dim[target]) for source, edge, target in self.edges
                }
            )
        self.use_norm    = use_norm
            
        # The Relation Primitive is a learnable parameter for each edge based on the edge type for multi-head attention
        self.relation_pri = nn.ParameterDict({
            edge: nn.Parameter(torch.ones(1, self.n_heads)) for _, edge, _ in self.edges
        })
        self.relation_att = nn.ParameterDict({
            edge: nn.Parameter(torch.Tensor(1, n_heads, self.out_dim[target], self.out_dim[target])) for _, edge, target in self.edges
        })
        self.relation_msg = nn.ParameterDict({
            edge: nn.Parameter(torch.Tensor(1, n_heads, self.out_dim[target], self.out_dim[target])) for _, edge, target in self.edges
        })
        
        self.relation_edge = nn.ParameterDict({
            edge: nn.Parameter(torch.Tensor(1, n_heads, self.out_dim[edge], self.out_dim[target])) for _, edge, target in self.edges
        })
        self.relation_edge_msg = nn.ParameterDict({
            edge: nn.Parameter(torch.Tensor(1, n_heads, self.out_dim[edge], self.out_dim[target])) for _, edge, target in self.edges
        })
        
        for key in self.relation_pri:
            nn.init.xavier_normal_(self.relation_pri[key])
        for key in self.relation_att:
            nn.init.xavier_normal_(self.relation_att[key])
        for key in self.relation_msg:
            nn.init.xavier_normal_(self.relation_msg[key])
        
        for key in self.relation_edge:
            nn.init.xavier_normal_(self.relation_edge[key])
        for key in self.relation_edge_msg:
            nn.init.xavier_normal_(self.relation_edge_msg[key])
        
        # self.skip           = nn.Parameter(torch.ones(num_types))
        # self.drop           = nn.Dropout(dropout)
        
        self.activation    = F.leaky_relu
        # self.activation    = F.mish
        
    def edge_attention(self, edges):
        # print("Edge Attention", edges._etype, edges.data.keys())
        etype = edges._etype[1]
        
        relation_att = self.relation_att[etype]
        relation_pri = self.relation_pri[etype]
        relation_msg = self.relation_msg[etype]
        key   = torch.bmm(edges.src['k'].transpose(1,0), relation_att.squeeze(0)).transpose(1,0)
        att   = (edges.dst['q'] * key).sum(dim=-1) * relation_pri / self.sqrt_dk[etype]
        val   = torch.bmm(edges.src['v'].transpose(1,0), relation_msg.squeeze(0)).transpose(1,0)
        
        relation_edge = self.relation_edge[etype]
        relation_edge_msg = self.relation_edge_msg[etype]
        
        # key
        edg   = torch.bmm(edges.data['ek'].transpose(1,0), relation_edge.squeeze(0)).transpose(1,0)
        # att
        att_edge = (edges.dst['q'] * edg).sum(dim=-1) * relation_pri / self.sqrt_dk[etype]
        # val
        edge_val = torch.bmm(edges.data['ev'].transpose(1,0), relation_edge_msg.squeeze(0)).transpose(1,0)
        
        a_e = F.softmax(att_edge, dim=-1)
        h_edge = torch.sum(a_e.unsqueeze(dim = -1) * edge_val, dim=1)
        
        edges.data['h_edge'] = h_edge
        # print(etype, key.shape, att.shape, val.shape, h_edge.shape)
        # print("Edge Attention", att, val)
        return {'a': att, 'v': val, 'ea': att_edge, 'ev': edge_val}
        
        # # print("Edge Attention", att, val)
        # return {'a': att, 'v': val}
    
    def message_func(self, edges):
        # print("Message Passing", edges._etype, edges.data.keys())
        return {'v': edges.data['v'], 'a': edges.data['a'], 'ea': edges.data['ea'], 'ev': edges.data['ev']}
        # return {'v': edges.data['v'], 'a': edges.data['a']}
    
    def reduce_func(self, nodes):
        dsttype = nodes._ntype
        # print("Reduce Key", nodes.data.keys())
        att = F.softmax(nodes.mailbox['a'], dim=-1)
        att_edge = F.softmax(nodes.mailbox['ea'], dim=-1)
        # print(dsttype, nodes.data['v'].shape, att.shape)
        h   = torch.sum(att.unsqueeze(dim = -1) * nodes.mailbox['v'], dim=1) + torch.sum(att_edge.unsqueeze(dim=-1) * nodes.mailbox['ev'], dim=1) # .unsqueeze(1)
        # print(dsttype, h.view(-1, self.out_dim[dsttype]))
        # return {'t': h}
        
        # # Combined SOFTMAX TEST # Comment out the below code to return to using the above code
        # # Combine the node and edge attention before softmax
        # attention = F.softmax(att + att_edge, dim=-1)
        # h = torch.sum(attention.unsqueeze(dim = -1) * nodes.mailbox['v'], dim=1) + torch.sum(attention.unsqueeze(dim=-1) * nodes.mailbox['ev'], dim=1)
        
        # # GUMBEL SOFTMAX TEST # Comment out the below code to return to using the above code
        # attention = F.gumbel_softmax(att + att_edge, tau=1, hard=True)
        # h = torch.sum(attention.unsqueeze(dim = -1) * nodes.mailbox['v'], dim=1) + torch.sum(attention.unsqueeze(dim=-1) * nodes.mailbox['ev'], dim=1)
        
        # print(att)
        # TODO: Continue HERE!!!!!!!!
        
        return {'t': h.view(-1, self.n_heads * self.out_dim[dsttype])}
        
    def forward(self, G, node_dict, edge_dict):
        
        for srctype, etype, dsttype in self.edges:   
            if dsttype not in node_dict:
                continue
            G.nodes[dsttype].data['q'] = self.q_linears[etype](node_dict[dsttype]).view(-1, self.n_heads, self.out_dim[dsttype])
            G.nodes[srctype].data['k'] = self.k_linears[etype](node_dict[srctype]).view(-1, self.n_heads, self.out_dim[dsttype])
            G.nodes[srctype].data['v'] = self.v_linears[etype](node_dict[srctype]).view(-1, self.n_heads, self.out_dim[dsttype])
            # print(etype, edge_dict[etype].shape)
            if edge_dict[etype].shape[0] == 0:
                G.edges[etype].data['ek'] = edge_dict[etype].view(-1, self.n_heads, self.out_dim[dsttype])
                G.edges[etype].data['ev'] = edge_dict[etype].view(-1, self.n_heads, self.out_dim[dsttype])
            else:
                G.edges[etype].data['ek'] = self.e_linears[etype](edge_dict[etype]).view(-1, self.n_heads, self.out_dim[dsttype])
                G.edges[etype].data['ev'] = self.ev_linears[etype](edge_dict[etype]).view(-1, self.n_heads, self.out_dim[dsttype])
            
            G.apply_edges(func=self.edge_attention, etype=etype)
        G.multi_update_all({etype : (self.message_func, self.reduce_func) \
                            for _, etype, _ in self.edges if etype in edge_dict}, cross_reducer = 'mean')
        
        # print(G.nodes['paper'].data['t'])
        for srctype, etype, dsttype in self.edges:
            trans_out = self.a_linears[etype](G.nodes[dsttype].data['t'].transpose(1, -1))
            
            # trans_out = G.nodes[dsttype].data['h'] * (1-alpha)
            if self.use_norm:
                G.nodes[dsttype].data['h'] = self.norms[etype](trans_out)
            else:
                G.nodes[dsttype].data['h'] = trans_out
            # if self.use_norm:
            #     G.nodes[dsttype].data['h'] = self.drop(self.norms[etype](trans_out))
            # else:
            #     G.nodes[dsttype].data['h'] = self.drop(trans_out)
        
        node_dicts = {ntype: self.activation(G.nodes[ntype].data['h']) for ntype in self.nodes}
        edge_dict = {etype: self.activation(G.edges[etype].data['h_edge']) for _, etype, _ in self.edges}

        return node_dicts, edge_dict
                
    def __repr__(self):
        return '{}(in_dim={}, out_dim={}, num_types={}, num_types={})'.format(
            self.__class__.__name__, self.in_dim, self.out_dim,
            self.num_types, self.num_relations)

class HGTEdge(nn.Module):
    def __init__(self, nodes, edges, in_dim, hid_dim, out_dim, n_layers, n_heads, use_norm = False):
        super(HGTEdge, self).__init__()
        
        self.in_dim = in_dim
        self.hid_dim = hid_dim
        self.out_dim = out_dim
        
        self.layers = nn.ModuleList()
        
        if n_layers == 1:
            self.layers.append(HGTEdgeLayer(in_dim, out_dim, nodes, edges, n_heads, use_norm = use_norm))
        else:
            self.layers.append(HGTEdgeLayer(in_dim, hid_dim, nodes, edges, n_heads, use_norm = use_norm))
            for _ in range(n_layers - 2):
                self.layers.append(HGTEdgeLayer(hid_dim, hid_dim, nodes, edges, n_heads, use_norm = use_norm))
            self.layers.append(HGTEdgeLayer(hid_dim, out_dim, nodes, edges, n_heads, use_norm = use_norm))

    def forward(self, G, node_dict, edge_dict, output_node):
        for layer in self.layers:
            node_dict, edge_dict = layer(G, node_dict, edge_dict)
        return node_dict, edge_dict
    
    def __repr__(self):
        return '{}(n_inp={}, n_hid={}, n_out={}, n_layers={})'.format(
            self.__class__.__name__, self.in_dim, self.hid_dim,
            self.out_dim, self.n_layers)