import torch.nn as nn
import torch_geometric.nn as pyg_nn
from torch_geometric.graphgym import cfg
import torch_geometric.graphgym.register as register


class GATConvLayer(nn.Module):
    """Graph Isomorphism Network with Edge features (GINE) layer.
    """
    def __init__(self, dim_in, dim_out, dropout, residual):
        super().__init__()
        self.dim_in = dim_in
        self.dim_out = dim_out
        self.dropout = dropout
        self.residual = residual

        self.act = nn.Sequential(
            register.act_dict[cfg.gnn.act](),
            nn.Dropout(self.dropout),
        )
        self.model1 = pyg_nn.GATConv(dim_in, dim_out, bias=True, concat=False, heads=1, add_self_loops=False)
        #self.model1 = pyg_nn.GATv2Conv(dim_in, dim_out, bias=True, concat=False, heads=1, add_self_loops=False,
        #                                    share_weights=False, bottleneck=cfg.gnn.bottleneck)
        #self.model2 = pyg_nn.GATv2Conv(dim_in, dim_out, bias=True, concat=False, heads=1, add_self_loops=False,
        #                             share_weights=True, bottleneck=cfg.gnn.bottleneck)

    def forward(self, batch):
        x_in = batch.x

        batch.x = self.model1(batch.x, batch.edge_index) #+ self.model2(batch.x, batch.edge_index)
        batch.x = self.act(batch.x)

        if self.residual:
            batch.x = x_in + batch.x  # residual connection

        return batch
