import torch
import torch_geometric.graphgym.models.head  # noqa, register module
import torch_geometric.graphgym.register as register
import torch_geometric.nn as pygnn
from torch_geometric.graphgym.config import cfg
from torch_geometric.graphgym.models.gnn import FeatureEncoder, GNNPreMP
from torch_geometric.graphgym.register import register_network

from dirgt.layer.gatedgcn_layer import GatedGCNLayer
from dirgt.layer.gine_conv_layer import GINEConvLayer
from dirgt.layer.residual_gnn_layer import ResGNNLayer


@register_network('custom_gnn')
class CustomGNN(torch.nn.Module):
    """
    GNN model that customizes the torch_geometric.graphgym.models.gnn.GNN
    to support specific handling of new conv layers.
    """

    def __init__(self, dim_in, dim_out):
        super().__init__()
        self.encoder = FeatureEncoder(dim_in)
        dim_in = self.encoder.dim_in

        if cfg.gnn.layers_pre_mp > 0:
            self.pre_mp = GNNPreMP(
                dim_in, cfg.gnn.dim_inner, cfg.gnn.layers_pre_mp)
            dim_in = cfg.gnn.dim_inner

        assert cfg.gnn.dim_inner == dim_in, \
            "The inner and hidden dims must match."

        layers = []
        for _ in range(cfg.gnn.layers):
            layers.append(
                ResGNNLayer(cfg.gnn.dim_inner,
                            cfg.gnn.layer_type, cfg.gnn.n_heads, act=cfg.gnn.act,
                            pna_degrees=cfg.gnn.pna_degrees, equivstable_pe=cfg.posenc_EquivStableLapPE.enable,
                            dropout=cfg.gnn.dropout, edge_dim=cfg.gnn.dim_inner, alpha=cfg.gnn.alpha,
                            norm_type=cfg.gnn.norm_type)
            )
        self.layers = torch.nn.Sequential(*layers)

        GNNHead = register.head_dict[cfg.gnn.head]
        self.post_mp = GNNHead(dim_in=cfg.gnn.dim_inner, dim_out=dim_out)

        self.jk = None
        if cfg.gnn.jk:
            self.jk = pygnn.JumpingKnowledge(mode=cfg.gnn.jk_mode, channels=cfg.gnn.dim_inner, num_layers=cfg.gnn.layers)



    def forward(self, batch):
        xjump = []
        batch = self.encoder(batch)
        if cfg.gnn.layers_pre_mp > 0:
            batch = self.pre_mp(batch)
        for layer in self.layers:
            x, e = layer(batch)
            batch.x = x
            batch.edge_attr = e
            xjump.append(batch.x)
        if self.jk is not None:
            batch.x = self.jk(xjump)
        batch = self.post_mp(batch)
        return batch
