import dgl
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
        
import os
import sys
import inspect

currentdir = os.path.dirname(os.path.abspath(inspect.getfile(inspect.currentframe())))
parentdir = os.path.dirname(currentdir)
sys.path.insert(0, parentdir) 

from graph.hgt_edge import HGTEdgeLayer

class HGTEdgeRes(nn.Module):
    def __init__(self, nodes, edges, in_dim, hid_dim, out_dim, n_layers, n_heads, use_norm = False):
        super(HGTEdgeRes, self).__init__()
        
        self.in_dim = in_dim
        self.hid_dim = hid_dim
        self.out_dim = out_dim
        
        self.res_dim = {}
        for key in hid_dim:
            self.res_dim[key] = hid_dim[key] + in_dim[key]
            
        self.layers = nn.ModuleList()
        
        if n_layers == 1:
            self.layers.append(HGTEdgeLayer(in_dim, out_dim, nodes, edges, n_heads, use_norm = use_norm))
        else:
            self.layers.append(HGTEdgeLayer(in_dim, hid_dim, nodes, edges, n_heads, use_norm = use_norm))
            for _ in range(n_layers - 2):
                self.layers.append(HGTEdgeLayer(self.res_dim, hid_dim, nodes, edges, n_heads, use_norm = use_norm))
            self.layers.append(HGTEdgeLayer(self.res_dim, out_dim, nodes, edges, n_heads, use_norm = use_norm))

    def forward(self, G, node_dict, edge_dict, output_node):
        nf = node_dict
        ef = edge_dict
        for i, layer in enumerate(self.layers):
            nf, ef = layer(G, nf, ef)
            
            if i < len(self.layers) - 1:
                for key in nf:
                    nf[key] = torch.cat((nf[key], node_dict[key]), dim=1)
                for key in ef:
                    ef[key] = torch.cat((ef[key], edge_dict[key]), dim=1)
        return nf, ef
    
    def __repr__(self):
        return '{}(n_inp={}, n_hid={}, n_out={}, n_layers={})'.format(
            self.__class__.__name__, self.in_dim, self.hid_dim,
            self.out_dim, self.n_layers)