import torch
import torch.nn as nn
import torch.nn.functional as F
import torch_geometric.graphgym.register as register
import torch_geometric.nn as pyg_nn
from torch_geometric.graphgym.models.layer import LayerConfig
from torch_geometric.graphgym.register import register_layer
from torch_scatter import scatter

from dirgt.layer.gatedgcn_layer import GatedGCNLayer


class DirGatedGCNLayer(pyg_nn.conv.MessagePassing):
    """
        GatedGCN layer
        Residual Gated Graph ConvNets
        https://arxiv.org/pdf/1711.07553.pdf
    """
    def __init__(self, in_dim, out_dim, dropout, residual, alpha, act='relu',
                 equivstable_pe=False, norm_type=None, **kwargs):
        super().__init__(**kwargs)
        self.conv_src_to_dst = GatedGCNLayer(in_dim, out_dim, dropout, residual, act=act, equivstable_pe=False, norm_type=norm_type, **kwargs)
        self.conv_dst_to_src = GatedGCNLayer(in_dim, out_dim, dropout, residual, act=act, equivstable_pe=False, norm_type=norm_type,
                                             **kwargs)
        self.alpha = alpha

    def forward(self, batch):
        x, e, edge_index = batch.x, batch.edge_attr, batch.edge_index
        edge_index_t = torch.stack([edge_index[1], edge_index[0]], dim=0)

        batch.edge_index = edge_index_t
        batch_dst_to_src = self.conv_dst_to_src(batch)
        batch.edge_index = edge_index
        batch_src_to_dst = self.conv_src_to_dst(batch)

        batch.x = (1. - self.alpha) * batch_src_to_dst.x + self.alpha * batch_dst_to_src.x
        batch.edge_attr = (1. - self.alpha) * batch_src_to_dst.edge_attr + self.alpha * batch_dst_to_src.edge_attr

        return batch
