import torch
import torch_geometric.graphgym.models.head  # noqa, register module
import torch_geometric.graphgym.register as register
from torch_geometric.graphgym.config import cfg
from torch_geometric.graphgym.models.gnn import FeatureEncoder, GNNPreMP
from torch_geometric.graphgym.register import register_network
from torch_geometric.nn.norm import LayerNorm, BatchNorm
from graphgps.layer.fsw_conv import FSW_conv
import torch.nn as nn
import torch.nn.functional as F



@register_network('fsw_gnn')
class FSWGNN(torch.nn.Module):
    def __init__(self, dim_in, dim_out):
        super().__init__()
        dim_inner = cfg.gnn.dim_inner
        fsw_embed_dim = cfg.gnn.fsw_embed_dim
        learnable_embedding = cfg.gnn.learnable_embedding
        concat_self = cfg.gnn.concat_self
        bias = cfg.gnn.bias
        conv_mlp_layers = cfg.gnn.conv_mlp_layers
        conv_mlp_hidden_dim = cfg.gnn.conv_mlp_hidden_dim
        conv_mlp_activation_final = register.act_dict[cfg.gnn.conv_mlp_activation_final]()
        conv_mlp_activation_hidden = register.act_dict[cfg.gnn.conv_mlp_activation_hidden]()
        ###
        edgefeat_dim = cfg.gnn.edgefeat_dim
        encode_vertex_degrees = cfg.gnn.encode_vertex_degrees
        vertex_degree_encoding_function = cfg.gnn.vertex_degree_encoding_function
        homog_degree_encoding = cfg.gnn.homog_degree_encoding
        ###
        mlp_init = cfg.gnn.mlp_init
        batchnorm_final = cfg.gnn.batchnorm_final
        batchnorm_hidden = cfg.gnn.batchnorm_hidden
        dropout_final = cfg.gnn.dropout_final
        dropout_hidden = cfg.gnn.dropout_hidden
        edge_weighting = cfg.gnn.edge_weighting
        self_loop_weight = cfg.gnn.self_loop_weight
        self.skip_connections = cfg.gnn.residual
        

        self.node_encoder = FeatureEncoder(dim_in)
        dim_in = self.node_encoder.dim_in

        if cfg.gnn.layers_pre_mp > 0:
            self.pre_mp = GNNPreMP(
                dim_in, cfg.gnn.dim_inner, cfg.gnn.layers_pre_mp)
            dim_in = cfg.gnn.dim_inner

        assert cfg.gnn.dim_inner == dim_in, \
            "The inner and hidden dims must match."
        
        conv1 = FSW_conv(in_channels=dim_in, out_channels=dim_inner, edgefeat_dim=edgefeat_dim, embed_dim=fsw_embed_dim, 
                        learnable_embedding=learnable_embedding, encode_vertex_degrees=encode_vertex_degrees, 
                        vertex_degree_encoding_function=vertex_degree_encoding_function, homog_degree_encoding=homog_degree_encoding, 
                        concat_self=concat_self, bias=bias, mlp_layers=conv_mlp_layers, mlp_hidden_dim=conv_mlp_hidden_dim, 
                        mlp_activation_final=conv_mlp_activation_final, mlp_activation_hidden=conv_mlp_activation_hidden,
                        mlp_init=mlp_init, batchNorm_final=batchnorm_final, batchNorm_hidden=batchnorm_hidden,
                        dropout_final=dropout_final, dropout_hidden=dropout_hidden, self_loop_weight=self_loop_weight,
                        edge_weighting=edge_weighting)
        layers = [conv1]
        for _ in range(cfg.gnn.layers_mp-1):
            layers.append(FSW_conv(in_channels=dim_inner, out_channels=dim_inner, edgefeat_dim=edgefeat_dim, embed_dim=fsw_embed_dim, 
                        learnable_embedding=learnable_embedding, encode_vertex_degrees=encode_vertex_degrees, 
                        vertex_degree_encoding_function=vertex_degree_encoding_function, homog_degree_encoding=homog_degree_encoding, 
                        concat_self=concat_self, bias=bias, mlp_layers=conv_mlp_layers, mlp_hidden_dim=conv_mlp_hidden_dim, 
                        mlp_activation_final=conv_mlp_activation_final, mlp_activation_hidden=conv_mlp_activation_hidden,
                        mlp_init=mlp_init, batchNorm_final=batchnorm_final, batchNorm_hidden=batchnorm_hidden,
                        dropout_final=dropout_final, dropout_hidden=dropout_hidden, self_loop_weight=self_loop_weight,
                        edge_weighting=edge_weighting))
            
        self.convs = torch.nn.Sequential(*layers)

        self.edge_encoders = torch.nn.ModuleList()
        if edgefeat_dim > 0:
            for _ in range(cfg.gnn.layers_mp):
                self.edge_encoders.append(register.edge_encoder_dict[cfg.dataset.edge_encoder_name](edgefeat_dim))

        
        GNNHead = register.head_dict[cfg.gnn.head]
        self.post_mp = GNNHead(dim_in=cfg.gnn.dim_inner, dim_out=dim_out)



    def forward(self, batch):
        self.node_encoder(batch)
        x ,edge_index = batch.x, batch.edge_index
        orig_edge_attr = batch.edge_attr
        
        for i, conv in enumerate(self.convs):
            if self.skip_connections:
               res_x = x
            if len(self.edge_encoders) > 0:
                batch.edge_attr = orig_edge_attr
                self.edge_encoders[i](batch)
                edge_features = batch.edge_attr
            else:
                edge_features = None
            x = conv(x, edge_index, edge_features)
            if self.skip_connections and i > 0:
              # i>0 to make sure x and res_x are of the same dimensions
              x = x + res_x
        
        batch.x = x
        out = self.post_mp(batch)
        return out