import torch
import torch.nn as nn
import torch_geometric.graphgym.register as register
import torch_geometric.nn as pygnn
import torch_scatter
import torch_geometric.graphgym.register as register
from torch_geometric.graphgym.config import cfg
from torch_geometric.graphgym.models.gnn import GNNPreMP
from torch_geometric.graphgym.models.layer import (new_layer_config,
                                                   BatchNorm1dNode)
from torch_geometric.graphgym.register import register_network
from torch_geometric.nn import Linear as Linear_pyg

from dirgt.layer.sat_layer import SATLayer
from dirgt.network.feature_encoder import FeatureEncoder



@register_network('SATModel')
class SATModel(torch.nn.Module):
    """Structure-Aware Transformer (SAT) model.
    """

    def __init__(self, dim_in, dim_out):
        super().__init__()
        self.encoder = FeatureEncoder(dim_in)
        dim_in = self.encoder.dim_in

        if cfg.gnn.layers_pre_mp > 0:
            self.pre_mp = GNNPreMP(
                dim_in, cfg.gnn.dim_inner, cfg.gnn.layers_pre_mp)
            dim_in = cfg.gnn.dim_inner

        assert cfg.gt.dim_hidden == cfg.gnn.dim_inner == dim_in, \
            "The inner and hidden dims must match."

        try:
            local_gnn_type, global_model_type = cfg.gt.layer_type.split('+')
        except:
            raise ValueError(f"Unexpected layer type: {cfg.gt.layer_type}")
        layers = []
        for _ in range(cfg.gt.layers):
            layers.append(SATLayer(
                dim_h=cfg.gt.dim_hidden,
                local_gnn_type=local_gnn_type,
                global_model_type=global_model_type,
                num_heads=cfg.gt.n_heads,
                act=cfg.gnn.act,
                pna_degrees=cfg.gt.pna_degrees,
                equivstable_pe=cfg.posenc_EquivStableLapPE.enable,
                dropout=cfg.gt.dropout,
                attn_dropout=cfg.gt.attn_dropout,
                layer_norm=cfg.gt.layer_norm,
                batch_norm=cfg.gt.batch_norm,
                log_attn_weights=cfg.train.mode == 'log-attn-weights',
                alpha=cfg.gnn.alpha,
                edge_dim=cfg.gnn.dim_edge,
            ))
        self.layers = torch.nn.Sequential(*layers)

        GNNHead = register.head_dict[cfg.gnn.head]
        self.post_mp = GNNHead(dim_in=cfg.gnn.dim_inner, dim_out=dim_out)

        self.jk = None
        if cfg.gt.jk:
            self.jk = pygnn.JumpingKnowledge(mode=cfg.gt.jk_mode, channels=cfg.gt.dim_hidden, num_layers=cfg.gt.layers)


    def forward(self, batch):
        '''
        '''
        xjump = []
        batch = self.encoder(batch)
        if cfg.gnn.layers_pre_mp > 0:
            batch = self.pre_mp(batch)
        for layer in self.layers:
            batch = layer(batch)
            xjump.append(batch.x)
        if self.jk is not None:
            batch.x = self.jk(xjump)
        batch = self.post_mp(batch)
        return batch

