import torch
import torch.nn as nn
import torch.nn.functional as F
from dgl import function as fn
import numpy as np


class AGCNLayer(nn.Module):
    def __init__(self, g, se, in_dim, dropout):
        super(AGCNLayer, self).__init__()
        self.g = g
        self.se = se
        self.dropout = nn.Dropout(dropout)
        self.gate = nn.Linear(2*in_dim, 1) 
        nn.init.xavier_normal_(self.gate.weight, gain=1.414) 

        self.mlp = nn.Linear(in_dim, in_dim)  
        nn.init.xavier_normal_(self.mlp.weight, gain=1.414) 

        self.adj_w = nn.Linear(se.shape[1], in_dim) 
        

    def edge_applying(self, edges):

        pe1 = F.relu(self.adj_w(self.se[edges.dst['id']]))
        pe2 = F.relu(self.adj_w(self.se[edges.src['id']]))
        
        h2 = torch.cat([self.mlp(edges.dst['h'] + pe1), self.mlp(edges.src['h'] + pe2)], dim=1) 
        g = torch.tanh(self.gate(h2)).squeeze() 

        e = g * edges.dst['d'] * edges.src['d'] 
        e = self.dropout(e)
        return {'e': e, 'm': g}

    def forward(self, h):
        self.g.ndata['h'] = h
        self.g.apply_edges(self.edge_applying)  
        self.g.update_all(fn.u_mul_e('h', 'e', '_'), fn.sum('_', 'z'))

        return self.g.ndata['z']


class AGCN(nn.Module):
    def __init__(self, g, se, in_dim, hidden_dim, dropout, eps, layer_num=2):
        super(AGCN, self).__init__()
        self.g = g
        self.eps = eps
        self.layer_num = layer_num
        self.dropout = dropout

        self.layers = nn.ModuleList()
        for i in range(self.layer_num):
            self.layers.append(AGCNLayer(self.g, se, hidden_dim, dropout))

        self.t1 = nn.Linear(in_dim, hidden_dim)
        nn.init.xavier_normal_(self.t1.weight, gain=1.414)
        
        

    def forward(self, h):
        h = F.dropout(h, p=self.dropout, training=self.training)
        h = torch.relu(self.t1(h))
        h = F.dropout(h, p=self.dropout, training=self.training)
        raw = h

        for i in range(self.layer_num):
            h = self.layers[i](h)
            h = self.eps * raw + h
            

        return h 

        