import numpy as np
import torch
import torch.nn as nn
import torch_geometric.graphgym.register as register
import torch_geometric.nn as pygnn

from torch_geometric.data import Batch
from torch_geometric.nn import Linear as Linear_pyg

from dirgt.layer.dir_gat import DirGAT
from dirgt.layer.dir_gatedconv_layer import DirGatedGCNLayer
from dirgt.layer.dir_gcn import DirGCN
from dirgt.layer.dir_gine import DirGINE
from dirgt.layer.gatedgcn_layer import GatedGCNLayer
from dirgt.layer.gine_conv_layer import GINEConvESLapPE, GINConvLayer


class ResGNNLayer(nn.Module):
    """Local MPNN layer.
    """

    def __init__(self, dim_h,
                 local_gnn_type, num_heads, act='relu',
                 pna_degrees=None, equivstable_pe=False, dropout=0.0,
                 edge_dim=None, alpha=0.5, norm_type=None):
        super().__init__()

        self.dim_h = dim_h
        self.num_heads = num_heads
        self.equivstable_pe = equivstable_pe
        self.activation = register.act_dict[act]

        # Local message-passing model.
        if local_gnn_type == 'None':
            self.local_model = None
        elif local_gnn_type == 'GCN':
            self.local_model = pygnn.GCNConv(dim_h, dim_h)
        elif local_gnn_type == 'DirGCN':
            self.local_model = DirGCN(dim_h, dim_h)
        elif local_gnn_type == 'GENConv':
            self.local_model = pygnn.GENConv(dim_h, dim_h)
        elif local_gnn_type == 'GIN':
            self.local_model = GINConvLayer(dim_h, dim_h, dropout, True, norm_type=norm_type)
        elif local_gnn_type == 'GINE':
            gin_nn = nn.Sequential(Linear_pyg(dim_h, dim_h),
                                   self.activation(),
                                   Linear_pyg(dim_h, dim_h))
            if self.equivstable_pe:  # Use specialised GINE layer for EquivStableLapPE.
                self.local_model = GINEConvESLapPE(gin_nn)
            else:
                self.local_model = pygnn.GINEConv(gin_nn, edge_dim=edge_dim)
        elif local_gnn_type == 'DirGINE':
            self.local_model = DirGINE(dim_h, dim_h, dropout, True, edge_dim=edge_dim, norm_type=norm_type)
        elif local_gnn_type == 'GAT':
            self.local_model = pygnn.GATConv(in_channels=dim_h,
                                             out_channels=dim_h // num_heads,
                                             heads=num_heads,
                                             edge_dim=dim_h,
                                             concat=True)
        elif local_gnn_type == 'DirGAT':
            self.local_model = DirGAT(dim_h, dim_h // num_heads,
                                      dropout=dropout,
                                      edge_dim=edge_dim,
                                      heads=num_heads)
        elif local_gnn_type == 'PNA':
            # Defaults from the paper.
            # aggregators = ['mean', 'min', 'max', 'std']
            # scalers = ['identity', 'amplification', 'attenuation']
            aggregators = ['mean', 'max', 'sum']
            scalers = ['identity']
            deg = torch.from_numpy(np.array(pna_degrees))
            self.local_model = pygnn.PNAConv(dim_h, dim_h,
                                             aggregators=aggregators,
                                             scalers=scalers,
                                             deg=deg,
                                             edge_dim=min(128, dim_h),
                                             towers=1,
                                             pre_layers=1,
                                             post_layers=1,
                                             divide_input=False)
        elif local_gnn_type == 'CustomGatedGCN':
            self.local_model = GatedGCNLayer(dim_h, dim_h,
                                             dropout=dropout,
                                             residual=True,
                                             act=act,
                                             equivstable_pe=equivstable_pe,
                                             norm_type=norm_type)
        elif local_gnn_type == 'DirGatedGCNLayer':
            self.local_model = DirGatedGCNLayer(dim_h, dim_h,
                                             dropout=dropout,
                                             residual=True,
                                             act=act,
                                             equivstable_pe=equivstable_pe,
                                             alpha=alpha,
                                             norm_type=norm_type)

        else:
            raise ValueError(f"Unsupported local GNN model: {local_gnn_type}")
        self.local_gnn_type = local_gnn_type
        self.norm = self.build_norm(norm_type, dim_h)
        self.dropout_local = nn.Dropout(dropout)


    def build_norm(self, norm_type, dim_h=None):
        if norm_type == 'None':
            return None
        elif norm_type == 'batch':
            return nn.BatchNorm1d(dim_h)
        elif norm_type == 'layer':
            return nn.LayerNorm(dim_h)

    def norm_layer(self, x):
        if self.norm is None:
            return x
        else:
            return self.norm(x)


    def forward(self, batch):
        self.local_model: pygnn.conv.MessagePassing  # Typing hint.

        h = batch.x
        h_in = h  # for first skip connection
        e = batch.edge_attr
        e_in = e  # for first skip connection

        if self.local_gnn_type == 'GCN':
            # GCN does not include edge features.
            h_local = self.local_model(h, batch.edge_index)
            h_local = self.norm_layer(h_local)
            h_local = self.dropout_local(h_local)
            h_local = h_in + h_local  # Skip connection.
        elif self.local_gnn_type == 'CustomGatedGCN':
            es_data = None
            if self.equivstable_pe:
                es_data = batch.pe_EquivStableLapPE
            local_out_1 = self.local_model(Batch(batch=batch,
                                               x=h,
                                               edge_index=batch.edge_index,
                                               edge_attr=batch.edge_attr,
                                               pe_EquivStableLapPE=es_data))
            # GatedGCN does skip connection, norm and dropout internally.
            h_local = local_out_1.x
            e = local_out_1.edge_attr
        elif self.local_gnn_type == 'GIN':
            h_local = self.local_model(batch)
        elif self.local_gnn_type == 'DirGINE' or self.local_gnn_type == 'DirGatedGCNLayer':
            b = self.local_model(batch)
            h_local = b.x
            e = b.edge_attr
            # DirGINE and DirGatedGCNLayer do dropout and norm internally in each direction
            # They also do skip connection internally in each direction.
            # The following skip connection is for the whole layer.
            h_local = h_in + h_local  # Skip connection.
        elif self.local_gnn_type == 'DirGAT':
            h_local = self.local_model(batch)  # Applies dropout internally.
            h_local = self.norm_layer(h_local)
            h_local = h_in + h_local  # Skip connection.
        elif self.local_gnn_type == 'DirGCN':
            h_local = self.local_model(batch)
            h_local = self.norm_layer(h_local)
            h_local = self.dropout_local(h_local)
            h_local = h_in + h_local  # Skip connection.
        else:
            if self.equivstable_pe:
                h_local = self.local_model(h, batch.edge_index, batch.edge_attr,
                                           batch.pe_EquivStableLapPE)
            else:
                h_local = self.local_model(h, batch.edge_index, batch.edge_attr)
            h_local = self.norm_layer(h_local)
            h_local = self.dropout_local(h_local)
            h_local = h_in + h_local  # Skip connection.

        return h_local, e






