
import torch
import torch_geometric.graphgym.models.head  # noqa, register module
import torch_geometric.graphgym.register as register
from torch_geometric.graphgym.config import cfg
from torch_geometric.graphgym.models.gnn import FeatureEncoder, GNNPreMP
from torch_geometric.graphgym.register import register_network

from graphgps.layer.gatedgcn_layer import GatedGCNLayer
from graphgps.layer.gine_conv_layer import GINEConvLayer, GINEConvLayerColour
from graphgps.layer.gcn_conv_layer import GCNConvLayerColour
from torch_geometric.nn import global_add_pool, global_mean_pool, global_max_pool

from graphgps.layer.nn import SharedLinear, SharedMLP
from graphgps.network.readout import (SumReadout, MeanReadout, MaxReadout, SumJKReadout, MeanJKReadout, MaxJKReadout,
    MeanAveraging, SumAveraging, MaxAveraging, SharedJKReadout, ColourReadout, ColourJKReadout,
    AdaptiveMeanAveraging, AdaptiveSumAveraging, AdaptiveMaxAveraging)

import torch
import torch.nn.functional as F
from graphgps.layer.gps_layer import GPSLayer
from graphgps.layer.gcn_conv_layer import GCNConvLayer
from torch_geometric.nn.conv import MessagePassing


from typing import Tuple, Union, Optional, Callable, List, Dict
import torch
import torch_geometric
from torch import Tensor
from ogb.graphproppred.mol_encoder import BondEncoder
from torch_geometric.nn.inits import reset
from torch_geometric.typing import Adj, OptPairTensor, OptTensor, Size



@register_network('first_gnn')
class FirstGNNnetwork(torch.nn.Module):
    def __init__(self, dim_in, dim_out):
        super().__init__()

        self.emb_dim = cfg.derivative_preprocessing.emb_dim
        # self.feature_encoder = FeatureEncoder(dim_in)
        self.feature_encoder =torch.nn.Embedding(num_embeddings=cfg.dataset.node_encoder_num_types, embedding_dim=self.emb_dim) #TODO: make work for other datasets

        self.hidden_dim = cfg.derivative_preprocessing.hidden_dim
        self.activation_class = self.get_activation_class(cfg.derivative_preprocessing.activation)
        # self.use_features = cfg.gnn.use_features
        self.use_features = True #TODO: add to config
        self.num_layers = cfg.derivative_preprocessing.num_layers 

        # self.feature_encoder =torch.nn.Embedding(num_embeddings=in_dim, embedding_dim=emb_dim)
        self.first_layers = torch.nn.Sequential(
                torch.nn.Linear(in_features=self.emb_dim, out_features= self.hidden_dim),
                self.activation_class(),
                torch.nn.Linear(in_features= self.hidden_dim, out_features=self.hidden_dim),
            )

        self.gnn_layers = torch.nn.ModuleList()
        self.bn_layers = torch.nn.ModuleList()
        self.act_layers = torch.nn.ModuleList()
        for i in range(self.num_layers):
            self.gnn_layers.append(
                CustomGINE(
                    self.hidden_dim,
                    self.hidden_dim,
                    track_running_stats=True, #TODO: maybe add to config
                    activation_class=self.activation_class,
                    batchnorm=cfg.derivative_preprocessing.batchnorm,
                )
            )

            batchnorm_layer = torch.nn.BatchNorm1d(self.hidden_dim, track_running_stats=True) if cfg.derivative_preprocessing.batchnorm else torch.nn.Identity()
            self.bn_layers.append(batchnorm_layer)
            self.act_layers.append(
                self.activation_class()
            )

        self.add_residual = cfg.gnn.residual
        # self.pool = global_add_pool

        GNNHead = register.head_dict[cfg.gnn.head]
        self.post_mp = GNNHead(dim_in=self.hidden_dim, dim_out=dim_out)

        if cfg.derivative_preprocessing.centrality_init:
            self.centrality_init()


    def forward(self, batched_data):
        # batched_data = self.feature_encoder(batched_data)

        # x, edge_index, edge_attr, batch = (
        #     batched_data.x,
        #     batched_data.edge_index,
        #     batched_data.edge_attr,
        #     batched_data.batch,
        # )

        # x, edge_attr = self.get_initial_embeddings(x=x, edge_attr=edge_attr)
        # out = self.get_output(x, edge_index, edge_attr, batch)

        batched_data = self.get_initial_embeddings(batched_data)
        out = self.get_output(batched_data)
        return out
    
    def get_output(self, batched_data):
        batched_data.x = self.get_final_node_embeddings(batched_data.x, batched_data.edge_index, batched_data.edge_attr)
        out = self.post_mp(batched_data)
        return out
    
    def get_final_node_embeddings(self, x, edge_index, edge_attr):
        x = self.first_layers(x)
        for i in range(len(self.gnn_layers)):
            x = self.intermediate_updates(x, edge_index, edge_attr, i, scale=None, shift=None)
        return x
    
    def intermediate_updates(self, x, edge_index, edge_attr, index, scale=None, shift=None):
        gnn, bn, act  =  self.gnn_layers[index], self.bn_layers[index], self.act_layers[index]
        h = act(bn(gnn(x, edge_index, edge_attr)))
        if self.add_residual:
            x = h + x
        else:
            x = h
        scale = 1 if scale is None else scale
        shift = 0 if shift is None else shift
        x = x * scale + shift
        return x
        
    
    def get_initial_embeddings(self, batched_data):
        if not self.use_features:
            batched_data.x = torch.ones_like(batched_data.x)
            batched_data.edge_attr = torch.zeros_like(batched_data.edge_attr)
        batched_data.x = batched_data.x.long()
        batched_data.edge_attr = batched_data.edge_attr.long()
        batched_data.x = self.feature_encoder(batched_data.x.squeeze())
        return batched_data
    
    def get_activation_class(self, activation):
        if activation is None or activation == 'silu':
            return torch.nn.SiLU
        elif activation == 'relu':
            return torch.nn.ReLU
        elif activation == 'silu':
            return torch.nn.SiLU
        else:
            raise activation
    
    def centrality_init(self):
        hidden_dim = self.hidden_dim
        
        with torch.no_grad():
            #init embedding layer to 1
            self.feature_encoder.weight.data = torch.ones(self.feature_encoder.weight.data.shape)
            
            # init first layer to identity concatenated
            in_dim = self.first_layers[0].weight.data.shape[1]
            out_dim = self.first_layers[0].weight.data.shape[0]
            repeat_factor = out_dim // in_dim
            weight = torch.zeros(out_dim, in_dim)
            for i in range(repeat_factor):
                weight[i*in_dim:(i+1)*in_dim, :] = torch.eye(in_dim)
            self.first_layers[0].weight.data = weight

            #init last layer to identity and biases to zero
            self.first_layers[0].bias.data = torch.zeros(out_dim)
            self.first_layers[-1].weight.data = torch.eye(out_dim)
            self.first_layers[-1].bias.data = torch.zeros(out_dim)

            for gnn in self.gnn_layers:
                
                #linear
                gnn.layer.nn[0].weight.data = torch.eye(hidden_dim)  # Identity matrix
                gnn.layer.nn[0].bias.data = torch.zeros(hidden_dim)  # Zero bias
                
                #batchnorm
                if cfg.derivative_preprocessing.batchnorm:
                    gnn.layer.nn[1].weight.data = torch.ones(hidden_dim)
                    gnn.layer.nn[1].bias.data = torch.zeros(hidden_dim)   # Zero bias
                    gnn.layer.nn[1].running_mean = torch.zeros(hidden_dim) # Zero mean
                    gnn.layer.nn[1].running_var = torch.ones(hidden_dim)   # Unit variance
                    gnn.layer.nn[1].eps = 0.
                
                #linear
                gnn.layer.nn[3].weight.data = torch.eye(hidden_dim)  # Identity matrix
                gnn.layer.nn[3].bias.data = torch.zeros(hidden_dim)  # Zero bias
                gnn.layer.eps.data = torch.tensor([-1.])
                 
                for embedding in gnn.edge_embedding.bond_embedding_list:
                    embedding.weight.data = torch.zeros(embedding.weight.data.shape)
            if cfg.derivative_preprocessing.batchnorm:              
                for bn in self.bn_layers: 
                    bn.weight.data = torch.ones(hidden_dim)  # Scale factor of 1
                    bn.bias.data = torch.zeros(hidden_dim)   # Zero bias
                    bn.running_mean = torch.zeros(hidden_dim) # Zero mean
                    bn.running_var = torch.ones(hidden_dim)   # Unit variance
                    bn.eps = 0.


class CustomGINE(torch.nn.Module):
    def __init__(self, in_dim, emb_dim, track_running_stats, num_edge_emb=4, activation_class=None, use_GINE_activation=False, batchnorm:bool=False):
        super().__init__()
        self.activation_class = activation_class if activation_class is not None else torch.nn.SiLU
        
        GINE_activation = activation_class() if use_GINE_activation else torch.nn.Identity()

        batchnorm_layer = torch.nn.BatchNorm1d(emb_dim, track_running_stats=batchnorm) if batchnorm else torch.nn.Identity()

        mlp = torch.nn.Sequential(
            torch.nn.Linear(in_dim, emb_dim),
            batchnorm_layer,
            self.activation_class(),
            torch.nn.Linear(emb_dim, emb_dim),
        )
        self.layer = AnalyticGINEConv(nn=mlp, train_eps=True, activation=GINE_activation)
        self.edge_embedding = BondEncoder(emb_dim=emb_dim)
        
        # self.edge_embedding = torch.nn.Embedding(
        #     num_embeddings=num_edge_emb, embedding_dim=in_dim
        # )

    def forward(self, x, edge_index, edge_attr):
        if len(edge_attr.shape) == 1:
            edge_attr = edge_attr.unsqueeze(1)
        return self.layer(x, edge_index, self.edge_embedding(edge_attr))



class AnalyticGINEConv(MessagePassing):
    """
    Gine with default analytic activation
    """

    def __init__(
        self,
        nn: torch.nn.Module,
        eps: float = 0.0,
        train_eps: bool = False,
        edge_dim: Optional[int] = None,
        activation: Optional[torch.nn.Module] = None,
        **kwargs,
    ):
        kwargs.setdefault("aggr", "add")
        super().__init__(**kwargs)
        self.nn = nn
        self.initial_eps = eps
        if train_eps:
            self.eps = torch.nn.Parameter(torch.empty(1))
        else:
            self.register_buffer("eps", torch.empty(1))

        nn = self.nn[0]
        if hasattr(nn, "in_features"):
            in_channels = nn.in_features
        elif hasattr(nn, "in_channels"):
            in_channels = nn.in_channels
        else:
            raise ValueError("Could not infer input channels from `nn`.")
        # self.lin = (
        #     Linear(2 * in_channels, in_channels)
        #     if edge_dim is None
        #     else Linear(in_channels + edge_dim, in_channels)
        # )
        self.activation = activation if activation is not None else torch.nn.Identity()
        self.reset_parameters()

    def reset_parameters(self):
        reset(self.nn)
        self.eps.data.fill_(self.initial_eps)
        # self.lin.reset_parameters()

    def forward(
        self,
        x: Union[Tensor, OptPairTensor],
        edge_index: Adj,
        edge_attr: OptTensor = None,
        size: Size = None,
    ) -> Tensor:
        if isinstance(x, Tensor):
            x = (x, x)

        # propagate_type: (x: OptPairTensor, edge_attr: OptTensor)
        out = self.propagate(edge_index, x=x, edge_attr=edge_attr, size=size)

        x_r = x[1]
        if x_r is not None:
            out = out + (1 + self.eps) * x_r

        return self.nn(out)

    def message(self, x_j: Tensor, edge_attr: Tensor) -> Tensor:
        # return self.lin(torch.concat([x_j, edge_attr], dim=-1)).relu()
        return self.activation(x_j + edge_attr)

    def __repr__(self) -> str:
        return f"{self.__class__.__name__}(nn={self.nn})"



class EdgeAggregation(MessagePassing):
    """
    gets the sum of edge embedding vectors per node
    """

    def __init__(
        self,
        **kwargs,
    ):
        kwargs.setdefault("aggr", "add")
        super().__init__(**kwargs)

    
    def forward(
        self,
        x: Union[Tensor, OptPairTensor],
        edge_index: Adj,
        edge_attr: OptTensor = None,
        size: Size = None,
    ) -> Tensor:
        if isinstance(x, Tensor):
            x = (x, x)
        if len(edge_attr.shape) > 2 :
            edge_attr = edge_attr.swapaxes(0,-2)
        out = self.propagate(edge_index, x=x, edge_attr=edge_attr, size=size)

        if len(edge_attr.shape) > 2 :
            out = out.swapaxes(0,-2)

        return out

    def message(self, x_j: Tensor, edge_attr: Tensor) -> Tensor:
        return edge_attr

    





    