import torch as th
from torch import nn
from dgl.nn.pytorch import edge_softmax
from dgl.utils import expand_as_pair
import torch
import dgl
import dgl.function as fn

class GCNLayer(nn.Module):
    def __init__(self, in_feats, out_feats, negative_slope=0.2):
        super(GCNLayer, self).__init__()
        self.linear = nn.Linear(in_feats, out_feats, bias=False)
        self.negative_slope = negative_slope
        self.gcn_msg = fn.copy_u('h', 'm')
        self.gcn_reduce = fn.sum(msg='m', out='h')

    def forward(self, block, feat, twp=False):
        block = block.local_var().to('cuda:{}'.format(feat.get_device()))
        feat_src, feat_dst = expand_as_pair(feat)
        h = self.linear(feat_src)
        block.srcdata['h'] = h
        block.update_all(self.gcn_msg, self.gcn_reduce)
        h = block.dstdata['h']
        if not twp:
            return h, None
        
        block.apply_edges(lambda edges: {'e': th.sum((th.mul(edges.src['h'], th.tanh(edges.dst['h']))), 1)})
        e = nn.LeakyReLU(self.negative_slope)(block.edata.pop('e'))
        e_soft = edge_softmax(block, e)
        return h, e_soft

    def reset_parameters(self):
        """Reinitialize learnable parameters."""
        gain = nn.init.calculate_gain('leaky_relu', 0.2)
        nn.init.xavier_normal_(self.linear.weight, gain=gain)


class untrainedGCNLayer(nn.Module):
    def __init__(self, in_feats, out_feats, gain):
        super(untrainedGCNLayer, self).__init__()
        self.gcn_msg = fn.copy_u('h', 'm')
        self.gcn_reduce = fn.mean(msg='m', out='h')
        self.gain = gain
        self.linear = nn.Linear(in_feats, out_feats, bias=False)
        nn.init.xavier_uniform_(self.linear.weight, gain=self.gain)

    def forward(self, block, feat):
        block = block.local_var().to('cuda:{}'.format(feat.get_device()))
        feat_src, feat_dst = expand_as_pair(feat)
        h = self.linear(feat_src)
        block.srcdata['h'] = h
        block.update_all(self.gcn_msg, self.gcn_reduce)
        h = block.dstdata['h']
        return h
    
    def reset_parameters(self):
        """Reinitialize learnable parameters."""
        nn.init.xavier_uniform_(self.linear.weight, gain=self.gain)

class untrainedGCONVLayer(nn.Module):
    def __init__(self, in_feats, out_feats, gain):
        super(untrainedGCONVLayer, self).__init__()
        self.gcn_msg = fn.copy_u('h', 'm')
        self.gcn_reduce = fn.sum(msg='m', out='h')
        self.gain = gain
        self.linear1 = nn.Linear(in_feats, out_feats, bias=False)
        self.linear2 = nn.Linear(in_feats, out_feats, bias=False)
        nn.init.xavier_uniform_(self.linear1.weight, gain=self.gain)
        nn.init.xavier_uniform_(self.linear2.weight, gain=self.gain)

    def forward(self, block, feat):
        block = block.local_var().to('cuda:{}'.format(feat.get_device()))
        feat_src, feat_dst = expand_as_pair(feat)
        h_n = self.linear2(feat_src)
        block.srcdata['h'] = h_n
        block.update_all(self.gcn_msg, self.gcn_reduce)
        h_n = block.dstdata['h']
        dst_index = torch.where(block.srcdata[dgl.NID] == block.dstdata[dgl.NID].reshape(-1, 1))[1]
        h_s = self.linear1(feat[dst_index])
        return torch.add(h_n, h_s)
    
    def reset_parameters(self):
        """Reinitialize learnable parameters."""
        nn.init.xavier_uniform_(self.linear1.weight, gain=self.gain)
        nn.init.xavier_uniform_(self.linear2.weight, gain=self.gain)
