import torch
import torch.nn as nn
from torch_geometric.nn import GINEConv


class GINE(nn.Module):
    """GINE implemented with PyG framework."""

    def __init__(self, n_layers: int, d_hidden: int, d_edge: int) -> None:
        """Initialize GINE.

        :param n_layers: Number of layers.
        :param d_hidden: Hidden dimension.
        :param d_edge: Edge dimension.
        """
        super().__init__()

        self.n_layers = n_layers
        self.d_hidden = d_hidden
        self.d_edge = d_edge

        self.layers = nn.ModuleList()
        self.batch_norms = nn.ModuleList()

        for _ in range(n_layers):
            nn_module = nn.Sequential(
                nn.Linear(d_hidden, d_hidden),
                nn.GELU(),
                nn.Linear(d_hidden, d_hidden),
            )
            conv = GINEConv(nn_module, edge_dim=d_edge)
            self.layers.append(conv)
            self.batch_norms.append(nn.BatchNorm1d(d_hidden))

        self.relu = nn.ReLU()

    def forward(self, x: torch.Tensor, edge_index: torch.Tensor, edge_attr: torch.Tensor) -> torch.Tensor:
        """Forward pass for GINE.

        :param x: Node feature matrix.
        :param edge_index: Graph connectivity matrix.
        :param edge_attr: Edge feature matrix.
        :return: Output node features.
        """
        for conv, norm in zip(self.layers, self.batch_norms):
            x = conv(x, edge_index, edge_attr)
            x = norm(x)
            x = self.relu(x)
        return x
