import os
import sys
from typing import Tuple, Union, Optional, Callable, List, Dict
import torch
import torch_geometric
from torch import sigmoid, Tensor
from torch_scatter import scatter
from torch_geometric.utils import to_dense_batch, to_dense_adj
from torch_geometric.nn import global_add_pool
from ogb.graphproppred.mol_encoder import BondEncoder
from torch_geometric.nn.conv import MessagePassing
from torch_geometric.nn.inits import reset
from torch_geometric.typing import Adj, OptPairTensor, OptTensor, Size
from time import time


print_time = False

class DenseGradientExtractor:
    def __init__(self, 
                 model: torch.nn.Module, 
                 max_degree: int=2, 
                 max_variables_per_node: Optional[int]=None,
                 device:str="cpu",):
        self.model = model.to(device)
        # self.model.eval()
        self.max_degree = max_degree
        # TODO: implement use of max_variables_per_node
        self.max_variables_per_node = max_variables_per_node
        self.device = device

    def compute_derivatives(self, graph: torch_geometric.data.Data) -> torch_geometric.data.Data:
        graph = self.compute_node_to_node_derivatives(graph)
        graph = self.aggregation_update(graph)
        return graph

    def compute_node_to_node_derivatives(self, graph: torch_geometric.data.Data) -> torch_geometric.data.Data:
        """
        Computes the gradient of the model with respect to the input graph. 
        Assumes model is of the form embed_layer -> first_layers -> gnn_layers -> batchnorm_layers -> activation_layers -> sum_aggregation -> final_layers.
        """
        graph = graph.to(self.device)

        full_start_time = time()

        start_time = time()
        graph = self.instantiate(graph)
        if print_time:
            print(f"Instantiate time: {time() - start_time}")

        # #### for debugging
        # debug_node_to_node_derivatives_0 = self.get_node_to_node_derivatives(graph)
        # x_0_dense, mask = to_dense_batch(graph.x_0, graph.batch)
        # graph.debug_node_to_node_derivatives_0 = debug_node_to_node_derivatives_0[mask]
        # #### for debugging

        start_time = time()
        graph = self.mlp_update(graph, self.model.first_layers)
        if print_time:
            print(f"first MLP update time: {time() - start_time}")

        # #### for debugging
        # debug_node_to_node_derivatives_1 = self.get_node_to_node_derivatives(graph)
        # graph.debug_node_to_node_derivatives_1 = debug_node_to_node_derivatives_1[mask]
        # #### for debugging

        intermediate_node_to_node_derivatives = []
        
        for gnn, bn, act in zip(self.model.gnn_layers, self.model.bn_layers, self.model.act_layers):
            old_graph = graph.clone()
            
            graph = self.gin_layer_update(graph, gnn)
            graph = self.graph_batchnorm_update(graph, bn)
            graph = self.graph_activation_update(graph, act)

            if self.model.add_residual:
                graph = self.residual_update(graph, old_graph)

            start_time = time()
            intermediate_node_to_node_derivatives.append(self.get_node_to_node_derivatives(graph))
            if print_time:
                print(f"get node to node derivatives time: {time() - start_time}")
        
        #convert intermediate_node_to_node_derivatives into sparse representation
        intermediate_node_to_node_derivatives = torch.stack(intermediate_node_to_node_derivatives, dim=-1)
        start_time = time()
        _, mask = to_dense_batch(graph.x_0, graph.batch)
        if print_time:
                print(f"to dense batch time: {time() - start_time}")
        
        #TODO: check if mask needs to be manipulated to account for all other axes (maybe flatten and then reshape intermediate_node_to_node_derivatives back)
        intermediate_node_to_node_derivatives = intermediate_node_to_node_derivatives[mask] #shape should be (num_nodes_in_batch, variable_dim, polynomial_dim, embedding_dim, num_layers)

        # graph.x_node_to_node_derivatives = intermediate_node_to_node_derivatives[:,:,:,:,-1,]
        graph.x_intermediate_node_to_node_derivatives =  intermediate_node_to_node_derivatives
        if print_time:
            print(f"full time: {time() - full_start_time} \n\n\n\n")
        return graph

    
    def aggregation_update(self, graph: torch_geometric.data.Data) -> torch_geometric.data.Data:
        """
        Reduces edge_dim to number of nodes by summing over all different gradient values.
        used to compute derivatives based on sum aggregation.
        """

        x_0, derivatives = graph.x_0, graph.derivatives
        pooled_derivatives = derivatives.sum(dim=1) #TODO: Check if dim=1 or 2
        pooled_x_0 = global_add_pool(x_0, graph.batch)
        graph.x_0 = pooled_x_0

        graph_out = graph.clone()
        graph_out.x_0 = pooled_x_0
        graph_out.derivatives = pooled_derivatives
        #TODO: check mlp_update works after pooling one of the node dimensions, potentially add 1 dim.
        graph_out = self.final_mlp_update(graph_out, self.model.final_layers)

        graph.x_out = graph_out.x_0
        graph.x_node_to_out_derivatives = graph_out.derivatives
        return graph


    def residual_update(self, graph: torch_geometric.data.Data, old_graph: torch_geometric.data.Data) -> torch_geometric.data.Data:
        """
        Adds the gradients from the old graph to the current graph, implementing the residual connection.
        This combines the gradient information from both paths: the transformed path and the skip connection path.
        """
        x_0_new, derivatives_new = graph.x_0, graph.derivatives
        x_0_old, derivatives_old = old_graph.x_0, old_graph.derivatives
        x_0 = x_0_new + x_0_old
        derivatives = derivatives_new + derivatives_old
        graph.x_0 = x_0
        graph.derivatives = derivatives
        return graph

    def gin_layer_update(self, graph: torch_geometric.data.Data, GNNConv) -> torch_geometric.data.Data:
        """
        single GINConv derivative computation, assumes the GNNConv is a GINEConv with no activation on edge attributes.
        """
        mlp = GNNConv.layer.nn
        eps =  GNNConv.layer.eps
        adj_update_graph = graph.clone()    
        start_time = time()
        adj_update_graph = self.adjacency_update(adj_update_graph)
        

        x_0 = graph.x_0 * (1+eps) + adj_update_graph.x_0
        derivatives=  adj_update_graph.derivatives + (1+eps)*graph.derivatives
        if print_time:
            print(f"adjacency update time: {time() - start_time}")
        
        #TODO: add filtering here
        # derivatives= self.filter_derivatives(derivatives_new=derivatives, derivatives_old=graph.derivatives)
        
        adj_update_graph.x_0 = x_0
        adj_update_graph.derivatives = derivatives

        if len(graph.edge_attr.shape) == 1:
            graph.edge_attr = graph.edge_attr.unsqueeze(1)
        edge_embedding_vec = GNNConv.edge_embedding(graph.edge_attr)
        start_time = time()
        adj_update_graph.x_0 += EdgeAggregation()(graph.x, graph.edge_index, edge_embedding_vec)
        if print_time:
            print(f"edge aggregation time: {time() - start_time}")

        start_time = time()
        adj_update_graph = self.mlp_update(adj_update_graph, mlp)
        if print_time:
            print(f"mlp update time: {time() - start_time}")
        
        return adj_update_graph 
     
    def adjacency_update(self, graph: torch_geometric.data.Data) -> torch_geometric.data.Data:
        """
        Updates the gradient edge attributes based on summation over neighbors in graph. Assumes sum aggregation.
        """
        edge_index = graph.edge_index 
        x_0 = graph.x_0
        x_0 = scatter(x_0[edge_index[0]], edge_index[1], dim=0, dim_size=x_0.size(0), reduce='sum')
        graph.x_0 = x_0
        derivatives, adjacency = graph.derivatives, graph.adjacency
        derivatives = torch.einsum('bni, bimvpe->bnmvpe',  adjacency, derivatives,) #TODO: check if this is correct
        graph.derivatives = derivatives

        return graph
    
    def final_mlp_update(self, graph: torch_geometric.data.Data, mlp: torch.nn.Sequential) -> torch_geometric.data.Data:
        """
        Updates the gradient edge attributes based on a multi-layer perceptron.
          Assumes the mlp is a sequence of linear layer, activation functions and batchnorm layer where the activation is 
          the one specified in the initialization of the GradientExtractor. assumes also the MLP starts and ends with a linear layer.
        """ 
        x_0_out, derivatives  = graph.x_0, graph.derivatives
        x_0_out = x_0_out.unsqueeze(1)
        derivatives = derivatives.unsqueeze(1) #check if unsqueeze(1) or unsqueeze(2)

        for layer in mlp:
            if isinstance(layer, torch.nn.Linear):
                x_0_out, derivatives = self.linear_update(x_0_out, derivatives, layer)
            elif isinstance(layer, torch.nn.BatchNorm1d):
                x_0_out, derivatives = self.batchnorm_update(x_0_out, derivatives, layer)
            elif isinstance(layer, torch.nn.SiLU) or isinstance(layer, torch.nn.ReLU):
                x_0_out, derivatives = self.activation_update(x_0_out, derivatives, layer)
            elif isinstance(layer, torch.nn.Identity):
                pass
            else:
                raise ValueError(f"Unsupported layer type: {type(layer)}")
        graph.x_0 = x_0_out
        graph.derivatives = derivatives
        return graph
    
    
    def mlp_update(self, graph: torch_geometric.data.Data, mlp: torch.nn.Sequential) -> torch_geometric.data.Data:
        """
        Updates the gradient edge attributes based on a multi-layer perceptron.
          Assumes the mlp is a sequence of linear layer, activation functions and batchnorm layer where the activation is 
          the one specified in the initialization of the GradientExtractor. assumes also the MLP starts and ends with a linear layer.
        """ 
        x_0_dense, mask = to_dense_batch(graph.x_0, graph.batch)
        derivatives = graph.derivatives
        for layer in mlp:
            if isinstance(layer, torch.nn.Linear):
                x_0_dense, derivatives = self.linear_update(x_0_dense, derivatives, layer)
            elif isinstance(layer, torch.nn.BatchNorm1d):
                x_0_dense, derivatives = self.batchnorm_update(x_0_dense, derivatives, layer)
            elif isinstance(layer, torch.nn.SiLU) or isinstance(layer, torch.nn.ReLU):
                x_0_dense, derivatives = self.activation_update(x_0_dense, derivatives, layer)
            elif isinstance(layer, torch.nn.Identity):
                pass
            else:
                raise ValueError(f"Unsupported layer type: {type(layer)}")
            
            x_0 = x_0_dense[mask]
            graph.x_0 = x_0
            graph.derivatives = derivatives
        return graph
    
    def graph_batchnorm_update(self, graph: torch_geometric.data.Data, batchnorm_layer: torch.nn.BatchNorm1d) -> torch_geometric.data.Data:
        return graph
        # TODO: add back in
        # x_0, derivatives = graph.x_0, graph.derivatives

        # sigma = batchnorm_layer.running_var
        # gamma = batchnorm_layer.weight
        # eps = batchnorm_layer.eps        
        # x_0 = batchnorm_layer(x_0)
        # derivatives = derivatives/ torch.sqrt(sigma + eps)
        # derivatives = derivatives * gamma 

        # graph.x_0 = x_0
        # graph.derivatives = derivatives
        # return graph

    def batchnorm_update(self, x_0_dense: torch.Tensor, derivatives: torch.Tensor, batchnorm_layer: torch.nn.BatchNorm1d) -> torch_geometric.data.Data:
        #TODO: add back in
        return x_0_dense, derivatives
        # sigma = batchnorm_layer.running_var
        # gamma = batchnorm_layer.weight
        # eps = batchnorm_layer.eps

        # batch_size, num_nodes, embed_dim = x_0_dense.shape
        # x_0_dense = x_0_dense.reshape(-1, embed_dim)
        # x_0_dense = batchnorm_layer(x_0_dense)
        # x_0_dense = x_0_dense.reshape(batch_size, num_nodes, embed_dim)

        # derivatives = derivatives/ torch.sqrt(sigma + eps)
        # derivatives = derivatives * gamma 

        # return x_0_dense, derivatives
    
    def linear_update(self, x_0_dense: torch.Tensor, derivatives: torch.Tensor, linear_layer: torch.nn.Linear) -> torch_geometric.data.Data:
        weight = linear_layer.weight
        x_0_dense = linear_layer(x_0_dense)
        derivatives = torch.matmul(derivatives, weight.T)
        return x_0_dense, derivatives
    
    def graph_activation_update(self, graph: torch_geometric.data.Data, activation_layer: torch.nn.Module) -> torch_geometric.data.Data:
        #TODO make more efficient, you dont have to recompute  x_0_dense, mask
        x_0, derivatives, batch = graph.x_0, graph.derivatives, graph.batch
        x_0_dense, mask = to_dense_batch(x_0, batch)
        x_0_dense, derivatives = self.activation_update(x_0_dense, derivatives, activation_layer)
        graph.x_0 = x_0_dense[mask]
        graph.derivatives = derivatives
        return graph
        

    def activation_update(self, x_0_dense: torch.Tensor, derivatives: torch.Tensor, activation_layer: torch.nn.Module) -> torch_geometric.data.Data:
        if isinstance(activation_layer, torch.nn.SiLU):
            x_0_dense, derivatives = self.silu_derivative_update(x_0_dense, derivatives)
        elif isinstance(activation_layer, torch.nn.ReLU):
            x_0_dense, derivatives = self.relu_derivative_update(x_0_dense, derivatives)
        else:
            print(f" got activation function: {type(activation_layer)}, will not compute derivatives")
        return x_0_dense, derivatives


    def instantiate(self, graph: torch_geometric.data.Data) -> torch_geometric.data.Data:
        # gets a pyg graph batch and for each node feature vector obtained after feature embedding,
        # creates a 1-degree polynomial encoding the derivative information so far.  
        # the tensor graph.x_0 represent the zero coefficietns and is of shape (num_nodes, embed_dim)
        # the tensor graph.derivatives represents the higher order coefficients and is of shape (batch_size, num_nodes,num_nodes, variable_dim, polynomial_Dim, embedding_dim).
        # the tensor graph.adjacency represents the adjacency matrix and is of shape (batch_size, num_nodes,num_nodes).

        # Convert edge_index to dense adjacency (without edge attributes)
        adj_dense = to_dense_adj(graph.edge_index, graph.batch)
        graph.adjacency = adj_dense
        
        # get embedding of node features
        if not hasattr(graph, 'x_0'):
            graph.x_0 = graph.x
        x_0 = self.model.feature_encoder(graph.x_0.long().squeeze())
        graph.x_0 = x_0

        # get derivative tensor dimensions dimensions
        batch_size, num_nodes = adj_dense.shape[0], adj_dense.shape[1]
        emb_dim = x_0.size(1)
        variable_dim = emb_dim
        polynomial_dim = self.max_degree

        # initialize derivatives tensor
        i2 = torch.arange(num_nodes, device=self.device).view(1, num_nodes, 1, 1, 1)
        i3 = torch.arange(num_nodes, device=self.device).view(1, 1, num_nodes, 1,  1)
        i4 = torch.arange(variable_dim, device=self.device).view(1, 1, 1, variable_dim, 1)
        i6 = torch.arange(variable_dim, device=self.device).view(1, 1, 1, 1, variable_dim)

        # Use broadcasting to get a mask where i2 == i3 and i4 == i6 and i5 == 0
        mask = ((i2 == i3) & (i4 == i6)).float()
        
        derivatives = torch.zeros(batch_size, num_nodes, num_nodes, variable_dim, polynomial_dim, emb_dim, device=self.device)
        # derivatives[:,torch.arange(num_nodes, device=self.device),torch.arange(num_nodes, device=self.device),torch.arange(variable_dim, device=self.device), 0, torch.arange(emb_dim, device=self.device)] = 1
        derivatives[:,:,:,:, 0, :] = mask
        graph.derivatives = derivatives
        return graph
    

    def filter_derivatives(self, old_derivatives: torch.Tensor, new_derivatives: torch.Tensor) -> torch.sparse_coo_tensor:
       pass
    

    @staticmethod
    def get_node_to_node_derivatives(graph: torch_geometric.data.Data, sort:bool=True) -> torch.Tensor:
        """
        Extracts the node-to-node derivatives from the gradient edge attributes.    
        Returns:
            Tensor of shape (num_nodes, variable_dim, polynomial_dim, embedding_dim) containing
            the derivative information of each node
        """
        derivatives = graph.derivatives
        num_nodes  = derivatives.shape[1]
        node_to_node_derivatives = derivatives[:,torch.arange(num_nodes), torch.arange(num_nodes)]
        #TODO: potentially move making this sparse here

        return node_to_node_derivatives
    
    
    @staticmethod
    def compute_composition_derivatives(out_der, in_der, polynomial_dim):
        """
        Computes the derivatives of the composition of functions  f(g(x)).
        out_der a list of tensors and is the derivatives of the outer function at point g(x). each tensor is of shape (batch, num_nodes, embd_dim) and the list is of length polynomial_dim 
        in_der is the derivatives of the inner function at point x and is of shape (batch, num_nodes,num_nodes variable_dim, polynomial dim,  embd_dim) 
        """
        out_der = [der.unsqueeze(2).unsqueeze(2) for der in out_der] #TODO: check if unsqueeze(2) or unsqueeze(1)
        composition_derivatives = []
        for i in range(polynomial_dim):
            if i == 0:
                composition_derivatives.append(out_der[0] * in_der[:,:,:,:,0]) 
            elif i == 1:
                composition_derivatives.append(out_der[1] * in_der[:,:,:,:,0]**2 +  out_der[0] * in_der[:,:,:,:,1])
            elif i == 2:
                composition_derivatives.append(out_der[2] * in_der[:,:,:,:,0]**3 + 3 * out_der[1] * in_der[:,:,:,:,0] * in_der[:,:,:,:,1] + out_der[0] * in_der[:,:,:,:,2])
            elif i == 3:
                composition_derivatives.append(out_der[3] * in_der[:,:,:,:,0]**4 +6*out_der[2] * in_der[:,:,:,:,0]**2 * in_der[:,:,:,:,1] + 3 * out_der[1] * in_der[:,:,:,:,1]**2  + 4 * out_der[1] * in_der[:,:,:,:,0] * in_der[:,:,:,:,2] + out_der[0] * in_der[:,:,:,:,3])
            else:
                raise ValueError(f"Currently only supporting up to 4th order derivatives, got {i}")
            
            derivatives = torch.stack(composition_derivatives, dim=4) #TODO: check if this is the correct dimension
    
        return derivatives


    def relu_derivative_update(self, x_0_dense: torch.Tensor, derivatives: torch.Tensor):
        """
        Computes the derivatives of the ReLU function. Assumes X is a tensor of shape (b,) where b is the batch size and k is the number of derivatives to compute.
        """
        x_0_dense = torch.nn.functional.relu(x_0_dense)
        # TODO: check if this works, try replacing with sigmoid with high temperture or custom activation function
        batch_size, num_nodes, embed_dim = x_0_dense.shape
        relu_derivatives = (x_0_dense>0).float()
        # TODO: check if this is correct or (batch_size, 1, num_nodes, 1, 1, embed_dim)
        relu_derivatives = relu_derivatives.reshape(batch_size, num_nodes, 1, 1, 1, embed_dim)

        derivatives = derivatives * relu_derivatives
        return x_0_dense, derivatives
    
    
    def silu_derivative_update(self, x_0_dense: torch.Tensor, derivatives: torch.Tensor):
        """
        Computes the derivatives of the ReLU function. Assumes X is a tensor of shape (b,) where b is the batch size and k is the number of derivatives to compute.
        """
        silu_derivatives = self.compute_silu_derivatives(x_0_dense, self.max_degree)
        x_0_dense = torch.nn.functional.silu(x_0_dense)
        derivatives = self.compute_composition_derivatives(in_der=derivatives, out_der=silu_derivatives, polynomial_dim=self.max_degree)
        return x_0_dense, derivatives
    
    @staticmethod
    def compute_silu_derivatives(x, k):
        """
        Computes the derivatives of the SiLU function. Assumes X is a tensor of shape (b,) where b is the batch size and k is the number of derivatives to compute.
        we first compute the sigmoid derivatives where  sigmoid_derivatives[i,j] = d/dx^j sigmoid(x[i])
        We then use the chain rule to compute the silu derivatives:  = d/dx^j silu(x[i]) = x[i] * d/dx^j sigmoid(x[i]) + j * d/dx^(j-1) sigmoid(x[i]) 
        """
        sigmoid_derivatives = [sigmoid(x)]
        silu_derivatives = [x * sigmoid(x)]

        for i in range(1, k+1):
            sigmoid_derivatives.append(DenseGradientExtractor.compute_sigmoid_derivatives(sigmoid_derivatives, i))
            silu_derivatives.append(x * sigmoid_derivatives[-1] + i * sigmoid_derivatives[-2])
            
        # silu_derivatives = torch.stack(silu_derivatives[1:], dim=1)   
        silu_derivatives = silu_derivatives[1:] 
        return silu_derivatives
    
    
    @staticmethod
    def compute_sigmoid_derivatives(X, k):
        if k==1:
            return X[0] * (1-X[0])
        elif k==2:
            return X[1] * (1 - 2*X[0])
        elif k==3:
            return X[2] * (1 - 2*X[0]) - 2*X[1]**2
        elif k==4:
            return X[3] * (1 - 2*X[0]) - 6*X[1]*X[2]
        else:
            raise ValueError("currently only supporting k up to 4")
    














class FirstGNNnetwork(torch.nn.Module):
    def __init__(
        self,
        num_layers,
        in_dim,
        emb_dim,
        hidden_dim,
        add_residual=False,
        track_running_stats=True,
        num_tasks: int = 1,
        activation=None,
        use_GINE_activation=False,
        num_edge_emb=4,
        use_features=True,
    ):
        super().__init__()
        self.emb_dim = emb_dim
        self.hidden_dim = hidden_dim
        self.activation_class = self.get_activation_class(activation)
        self.use_features = use_features
        self.num_layers = num_layers

        self.feature_encoder =torch.nn.Embedding(num_embeddings=in_dim, embedding_dim=emb_dim)
        self.first_layers = torch.nn.Sequential(
                torch.nn.Linear(in_features=emb_dim, out_features= hidden_dim),
                self.activation_class(),
                torch.nn.Linear(in_features= hidden_dim, out_features=hidden_dim),
            )

        self.gnn_layers = torch.nn.ModuleList()
        self.bn_layers = torch.nn.ModuleList()
        self.act_layers = torch.nn.ModuleList()
        for i in range(num_layers):
            self.gnn_layers.append(
                CustomGINE(
                    hidden_dim,
                    hidden_dim,
                    track_running_stats=track_running_stats,
                    activation_class=self.activation_class,
                    use_GINE_activation=use_GINE_activation,
                    num_edge_emb=num_edge_emb,
                )
            )
            self.bn_layers.append(
                torch.nn.BatchNorm1d(hidden_dim, track_running_stats=track_running_stats)
            )
            self.act_layers.append(
                self.activation_class()
            )

        self.add_residual = add_residual
        self.pool = global_add_pool

        self.final_layers = None
        if num_tasks is not None:
            emb_dim = emb_dim
            self.final_layers = torch.nn.Sequential(
                torch.nn.Linear(in_features=hidden_dim, out_features=2 * hidden_dim),
                self.activation_class(),
                torch.nn.Linear(in_features=2*hidden_dim, out_features=2 * hidden_dim),
                self.activation_class(),
                torch.nn.Linear(in_features=2 * hidden_dim, out_features=num_tasks),
            )

    def forward(self, batched_data):
        x, edge_index, edge_attr, batch = (
            batched_data.x,
            batched_data.edge_index,
            batched_data.edge_attr,
            batched_data.batch,
        )
        x, edge_attr = self.get_initial_embeddings(x=x, edge_attr=edge_attr)
        out = self.get_output(x, edge_index, edge_attr, batch)
        return out
    
    def get_output_no_embedding(self, x, edge_index, edge_attr, batch):
        x = self.get_final_node_embeddings_no_embedding(x, edge_index, edge_attr)
        x = self.pool(x, batch)
        out = self.final_layers(x)
        return out
    
    def get_final_node_embeddings_no_embedding(self, x, edge_index, edge_attr):
        for i in range(len(self.gnn_layers)):
            x = self.intermediate_updates(x, edge_index, edge_attr, i, scale=None, shift=None)
        return x
    
    def get_output(self, x, edge_index, edge_attr, batch):
        x = self.get_final_node_embeddings(x, edge_index, edge_attr)
        x = self.pool(x, batch)
        out = self.final_layers(x)
        return out
    
    def get_final_node_embeddings(self, x, edge_index, edge_attr):
        x = self.first_layers(x)
        for i in range(len(self.gnn_layers)):
            x = self.intermediate_updates(x, edge_index, edge_attr, i, scale=None, shift=None)
        return x
    

    def intermediate_updates(self, x, edge_index, edge_attr, index, scale=None, shift=None):
        gnn, bn, act  =  self.gnn_layers[index], self.bn_layers[index], self.act_layers[index]
        h = act(bn(gnn(x, edge_index, edge_attr)))
        if self.add_residual:
            x = h + x
        else:
            x = h
        scale = 1 if scale is None else scale
        shift = 0 if shift is None else shift
        x = x * scale + shift
        return x
        
        
    
    def get_initial_embeddings(self, x, edge_attr):
        if not self.use_features:
            x = torch.ones_like(x)
            edge_attr = torch.zeros_like(edge_attr)
        x = x.long()
        edge_attr = edge_attr.long()
        x = self.feature_encoder(x.squeeze())

        return x, edge_attr
    
    def get_activation_class(self, activation):
        if activation is None or activation == 'silu':
            return torch.nn.SiLU
        elif activation == 'relu':
            return torch.nn.ReLU
        elif activation == 'silu':
            return torch.nn.SiLU
        else:
            raise activation


class CustomGINE(torch.nn.Module):
    def __init__(self, in_dim, emb_dim, track_running_stats, num_edge_emb=4, activation_class=None, use_GINE_activation=False):
        super().__init__()
        self.activation_class = activation_class if activation_class is not None else torch.nn.SiLU
        
        GINE_activation = activation_class() if use_GINE_activation else torch.nn.Identity()

        mlp = torch.nn.Sequential(
            torch.nn.Linear(in_dim, emb_dim),
            torch.nn.BatchNorm1d(emb_dim, track_running_stats=track_running_stats),
            self.activation_class(),
            torch.nn.Linear(emb_dim, emb_dim),
        )
        self.layer = AnalyticGINEConv(nn=mlp, train_eps=True, activation=GINE_activation)
        self.edge_embedding = BondEncoder(emb_dim=emb_dim)
        
        # self.edge_embedding = torch.nn.Embedding(
        #     num_embeddings=num_edge_emb, embedding_dim=in_dim
        # )

    def forward(self, x, edge_index, edge_attr):
        if len(edge_attr.shape) == 1:
            edge_attr = edge_attr.unsqueeze(1)
        return self.layer(x, edge_index, self.edge_embedding(edge_attr))



class AnalyticGINEConv(MessagePassing):
    """
    Gine with default analytic activation
    """

    def __init__(
        self,
        nn: torch.nn.Module,
        eps: float = 0.0,
        train_eps: bool = False,
        edge_dim: Optional[int] = None,
        activation: Optional[torch.nn.Module] = None,
        **kwargs,
    ):
        kwargs.setdefault("aggr", "add")
        super().__init__(**kwargs)
        self.nn = nn
        self.initial_eps = eps
        if train_eps:
            self.eps = torch.nn.Parameter(torch.empty(1))
        else:
            self.register_buffer("eps", torch.empty(1))

        nn = self.nn[0]
        if hasattr(nn, "in_features"):
            in_channels = nn.in_features
        elif hasattr(nn, "in_channels"):
            in_channels = nn.in_channels
        else:
            raise ValueError("Could not infer input channels from `nn`.")
        # self.lin = (
        #     Linear(2 * in_channels, in_channels)
        #     if edge_dim is None
        #     else Linear(in_channels + edge_dim, in_channels)
        # )
        self.activation = activation if activation is not None else torch.nn.Identity()
        self.reset_parameters()

    def reset_parameters(self):
        reset(self.nn)
        self.eps.data.fill_(self.initial_eps)
        # self.lin.reset_parameters()

    def forward(
        self,
        x: Union[Tensor, OptPairTensor],
        edge_index: Adj,
        edge_attr: OptTensor = None,
        size: Size = None,
    ) -> Tensor:
        if isinstance(x, Tensor):
            x = (x, x)

        # propagate_type: (x: OptPairTensor, edge_attr: OptTensor)
        out = self.propagate(edge_index, x=x, edge_attr=edge_attr, size=size)

        x_r = x[1]
        if x_r is not None:
            out = out + (1 + self.eps) * x_r

        return self.nn(out)

    def message(self, x_j: Tensor, edge_attr: Tensor) -> Tensor:
        # return self.lin(torch.concat([x_j, edge_attr], dim=-1)).relu()
        return self.activation(x_j + edge_attr)

    def __repr__(self) -> str:
        return f"{self.__class__.__name__}(nn={self.nn})"



class EdgeAggregation(MessagePassing):
    """
    gets the sum of edge embedding vectors per node
    """

    def __init__(
        self,
        **kwargs,
    ):
        kwargs.setdefault("aggr", "add")
        super().__init__(**kwargs)

    
    def forward(
        self,
        x: Union[Tensor, OptPairTensor],
        edge_index: Adj,
        edge_attr: OptTensor = None,
        size: Size = None,
    ) -> Tensor:
        if isinstance(x, Tensor):
            x = (x, x)
        if len(edge_attr.shape) > 2 :
            edge_attr = edge_attr.swapaxes(0,-2)
        out = self.propagate(edge_index, x=x, edge_attr=edge_attr, size=size)

        if len(edge_attr.shape) > 2 :
            out = out.swapaxes(0,-2)

        return out

    def message(self, x_j: Tensor, edge_attr: Tensor) -> Tensor:
        return edge_attr

    





    