import os
import sys
from typing import Tuple, Union, Optional, Callable, List, Dict
import torch
from torch_sparse import SparseTensor
import torch_geometric
import torch.nn.functional as F
from copy import deepcopy
from torch import sigmoid, Tensor
from torch_scatter import scatter
from torch_geometric.utils import to_dense_batch, to_dense_adj
from torch_geometric.nn import global_add_pool, GINConv, GINEConv
from ogb.graphproppred.mol_encoder import BondEncoder
from torch_geometric.nn.aggr import SumAggregation
from torch.nn import Linear
from torch_geometric.nn.conv import MessagePassing
from torch_geometric.nn.inits import reset
from torch_geometric.typing import Adj, OptPairTensor, OptTensor, Size
from torch_geometric.graphgym.register import register_node_encoder
import numpy as np
import math
import torch.nn as nn
from torch_geometric.graphgym.config import cfg

@register_node_encoder('DerivativeEncoder')
class DerivativeEncoder(nn.Module):
    def __init__(self, x_0_embedding_dim, derivate_embedding_dim):    
        super(DerivativeEncoder, self).__init__()


        # first gnn
        num_layers = cfg.derivative_encoder.num_layers
        in_dim = cfg.derivative_encoder.in_dim
        emb_dim = cfg.derivative_encoder.emb_dim
        hidden_dim = cfg.derivative_encoder.hidden_dim
        derivative_hidden_dim = cfg.derivative_encoder.derivative_hidden_dim
        add_residual = cfg.derivative_encoder.add_residual
        track_running_stats = cfg.derivative_encoder.track_running_stats
        num_tasks = cfg.derivative_encoder.num_tasks
        activation = cfg.derivative_encoder.activation
        use_GINE_activation = cfg.derivative_encoder.use_GINE_activation
        num_edge_emb = cfg.derivative_encoder.num_edge_emb
        use_features = cfg.derivative_encoder.use_features
        centrality_init = cfg.derivative_encoder.centrality_init
        # gradient extractor
        max_degree = cfg.derivative_encoder.max_degree
        max_variables_per_node = cfg.derivative_encoder.max_variables_per_node
        sparse = cfg.derivative_encoder.sparse
        dropout = cfg.derivative_encoder.dropout
        derivative_batchnorm = cfg.derivative_encoder.derivative_batchnorm

        self.device = "cuda" if torch.cuda.is_available() else "cpu"
        activation = self.get_activation(activation)

        if x_0_embedding_dim+derivate_embedding_dim + cfg.derivative_encoder.first_embedding_dim != cfg.gnn.dim_inner:
            x_0_embedding_dim = cfg.gnn.dim_inner - derivate_embedding_dim - cfg.derivative_encoder.first_embedding_dim
            print(f"x_0_embedding_dim+derivate_embedding_dim + cfg.derivative_encoder.first_embedding_dim != cfg.gnn.dim_inner")
            print(f"new x_0_embedding_dim: {x_0_embedding_dim}")
            if x_0_embedding_dim < 0:
                raise ValueError(f"x_0_embedding_dim is less than 0")
        
        self.init_embedding_layers(hidden_dim=hidden_dim, 
                                   x_0_embedding_dim=x_0_embedding_dim, 
                                   derivate_embedding_dim=derivate_embedding_dim, 
                                   activation=activation, 
                                   max_degree=max_degree, 
                                   emb_dim=emb_dim, 
                                   num_layers=num_layers, 
                                   derivative_hidden_dim=derivative_hidden_dim,
                                   track_running_stats=track_running_stats,
                                   dropout=dropout)
        
        self.init_first_gnn(num_layers=num_layers,
                  in_dim=in_dim,
                  emb_dim=emb_dim,
                  hidden_dim=hidden_dim,
                  add_residual=add_residual,
                  track_running_stats=track_running_stats,
                  num_tasks=num_tasks,
                  activation=activation,
                  use_GINE_activation=use_GINE_activation,
                  num_edge_emb=num_edge_emb,
                  use_features=use_features,
                  centrality_init=centrality_init)
        
        self.init_gradient_extractor(max_degree=max_degree, 
                                     max_variables_per_node=max_variables_per_node, 
                                     sparse=sparse,
                                     derivative_batchnorm=derivative_batchnorm)
        
       
    
    def forward(self, batch):
        # Extract gradients using first GNN
        batch = self.compute_derivatives(batch)
        x_intermediate_node_to_node_derivatives , x_0 = self.get_processed_data(batch)
        
        #get x_0 embedding
        x_0 = self.first_to_second_embed(x_0)

        #get derivative embedding
        x_node_to_node_derivatives = self.derivative_feature_embed(x_intermediate_node_to_node_derivatives)
        batch.x = torch.cat((batch.x, x_0, x_node_to_node_derivatives), 1)
        batch.x_node_to_node_derivatives = x_node_to_node_derivatives
        return batch
    
    def get_processed_data(self, batch):
        x_intermediate_node_to_node_derivatives , x_0 = batch.x_intermediate_node_to_node_derivatives, batch.x_0
        batch_size = x_intermediate_node_to_node_derivatives.shape[0]
        x_intermediate_node_to_node_derivatives = x_intermediate_node_to_node_derivatives.reshape(batch_size, -1)
        return x_intermediate_node_to_node_derivatives , x_0
        

    def compute_derivatives(self, data):
        #TODO: add output derivatives
        data = self.gradient_extractor.compute_node_to_node_derivatives(data)
        return data
    #TODO: add back
        # num_layers_first = self.first_gnn.num_layers
        # centrality_normalization = torch.tensor([math.factorial(i) for i in range(1, num_layers_first+1)],
        #                                          dtype=torch.float32, device=self.device).reshape(1,1,1,1, num_layers_first)
        # data.x_intermediate_node_to_node_derivatives = data.x_intermediate_node_to_node_derivatives / centrality_normalization
        # return data
    
   
    def init_embedding_layers(self, hidden_dim,
                               x_0_embedding_dim,
                               derivate_embedding_dim,
                               activation, 
                               max_degree, 
                               emb_dim, 
                               num_layers, 
                               derivative_hidden_dim=4, 
                               track_running_stats:bool = True,
                               dropout:float = 0.0):
        
        self.first_to_second_embed = torch.nn.Sequential(torch.nn.BatchNorm1d(hidden_dim, track_running_stats=track_running_stats),
                                                         torch.nn.Linear(hidden_dim, hidden_dim),
                                                         torch.nn.BatchNorm1d(hidden_dim, track_running_stats=track_running_stats),
                                                         activation(),
                                                         torch.nn.Dropout(dropout),
                                                         torch.nn.Linear(hidden_dim, x_0_embedding_dim))
        
        
        
        
        derivative_dim = emb_dim * max_degree * hidden_dim * num_layers
        self.derivative_feature_embed = torch.nn.Sequential(torch.nn.BatchNorm1d(derivative_dim, track_running_stats=track_running_stats),
                                                            torch.nn.Linear(derivative_dim, derivative_hidden_dim),
                                                            torch.nn.BatchNorm1d(derivative_hidden_dim, track_running_stats=track_running_stats),
                                                            activation(),
                                                            torch.nn.Dropout(dropout),
                                                            torch.nn.Linear(derivative_hidden_dim, derivate_embedding_dim))
        
        
        
    def init_first_gnn(self, 
                       num_layers, 
                       in_dim, 
                       emb_dim, 
                       hidden_dim, 
                       num_tasks, 
                       add_residual,
                       track_running_stats,
                       activation, 
                       use_GINE_activation, 
                       num_edge_emb, 
                       use_features, 
                       centrality_init):
        
        self.first_gnn = FirstGNNnetwork(
            in_dim=in_dim,
            emb_dim=emb_dim, 
            hidden_dim=hidden_dim,
            num_layers=num_layers,
            num_tasks=num_tasks,
            use_features=use_features,
            add_residual=add_residual,
            track_running_stats=track_running_stats,
            activation=activation,
            use_GINE_activation=use_GINE_activation,
            num_edge_emb=num_edge_emb,
        ).to(self.device)
        
        # Remove final layer as we only use the node vectors
        with torch.no_grad():
            self.first_gnn.final_layers = torch.nn.Sequential(torch.nn.Identity())
            
            # if final_layer_dim_first is not None:
            #     self.first_gnn.gnn_layers[-1].layer.nn[2]= torch.nn.Linear(2*hidden_dim_first, final_layer_dim_first)

        # initialize the first mpnn so that its first derivatives are centrallity encoding
        if centrality_init:
            self.init_first_gnn_centrality()
        
    def init_gradient_extractor(self, max_degree, max_variables_per_node, sparse, derivative_batchnorm):
        if sparse:
            #TODO: write batched sparse extractor
            raise NotImplementedError("Sparse gradient extractor not implemented")
        else:
            self.gradient_extractor = DenseGradientExtractor(
                model=self.first_gnn,
                device=self.device,
                max_degree=max_degree,
                max_variables_per_node=max_variables_per_node,
                derivative_batchnorm=derivative_batchnorm)
        
    
    def init_first_gnn_centrality(self):

        hidden_dim = self.first_gnn.hidden_dim
        
        with torch.no_grad():
            #init embedding layer to 1
            self.first_gnn.feature_encoder.weight.data = torch.ones(self.first_gnn.feature_encoder.weight.data.shape)
            
            # init first layer to identity concatenated
            in_dim = self.first_gnn.first_layers[0].weight.data.shape[1]
            out_dim = self.first_gnn.first_layers[0].weight.data.shape[0]
            repeat_factor = out_dim // in_dim
            weight = torch.zeros(out_dim, in_dim)
            for i in range(repeat_factor):
                weight[i*in_dim:(i+1)*in_dim, :] = torch.eye(in_dim)
            self.first_gnn.first_layers[0].weight.data = weight

            #init last layer to identity and biases to zero
            self.first_gnn.first_layers[0].bias.data = torch.zeros(out_dim)
            self.first_gnn.first_layers[-1].weight.data = torch.eye(out_dim)
            self.first_gnn.first_layers[-1].bias.data = torch.zeros(out_dim)

            for gnn in self.first_gnn.gnn_layers:
                
                #linear
                gnn.layer.nn[0].weight.data = torch.eye(hidden_dim)  # Identity matrix
                gnn.layer.nn[0].bias.data = torch.zeros(hidden_dim)  # Zero bias
                
                #batchnorm
                gnn.layer.nn[1].weight.data = torch.ones(hidden_dim)
                gnn.layer.nn[1].bias.data = torch.zeros(hidden_dim)   # Zero bias
                gnn.layer.nn[1].running_mean = torch.zeros(hidden_dim) # Zero mean
                gnn.layer.nn[1].running_var = torch.ones(hidden_dim)   # Unit variance
                # gnn.layer.nn[1].eps = 0. #TODO: check if this matters
                
                #linear
                gnn.layer.nn[3].weight.data = torch.eye(hidden_dim)  # Identity matrix
                gnn.layer.nn[3].bias.data = torch.zeros(hidden_dim)  # Zero bias
                gnn.layer.eps.data = torch.tensor([-1.])
                 
                for embedding in gnn.edge_embedding.bond_embedding_list:
                    embedding.weight.data = torch.zeros(embedding.weight.data.shape)
                        
            for bn in self.first_gnn.bn_layers: 
                bn.weight.data = torch.ones(hidden_dim)  # Scale factor of 1
                bn.bias.data = torch.zeros(hidden_dim)   # Zero bias
                bn.running_mean = torch.zeros(hidden_dim) # Zero mean
                bn.running_var = torch.ones(hidden_dim)   # Unit variance
                # bn.eps = 0. #TODO: check if this matters
    

    def get_activation(self, activation:str):
        if activation == "relu":
            return torch.nn.ReLU
        elif activation == "silu":
            return torch.nn.SiLU
        else:
            raise ValueError(f"Unsupported activation: {activation}")




class DenseGradientExtractor:
    def __init__(self, 
                 model: torch.nn.Module, 
                 max_degree: int=2, 
                 max_variables_per_node: Optional[int]=None,
                 device:str="cpu",
                 derivative_batchnorm:bool = False):
        self.model = model.to(device)
        # self.model.eval()
        self.max_degree = max_degree
        # TODO: implement use of max_variables_per_node
        self.max_variables_per_node = max_variables_per_node
        self.device = device
        self.derivative_batchnorm = derivative_batchnorm

    def compute_derivatives(self, graph: torch_geometric.data.Data) -> torch_geometric.data.Data:
        graph = self.compute_node_to_node_derivatives(graph)
        graph = self.aggregation_update(graph)
        return graph

    def compute_node_to_node_derivatives(self, graph: torch_geometric.data.Data) -> torch_geometric.data.Data:
        """
        Computes the gradient of the model with respect to the input graph. 
        Assumes model is of the form embed_layer -> first_layers -> gnn_layers -> batchnorm_layers -> activation_layers -> sum_aggregation -> final_layers.
        """
        graph = graph.to(self.device)
        graph = self.instantiate(graph)

        # #### for debugging
        # debug_node_to_node_derivatives_0 = self.get_node_to_node_derivatives(graph)
        # x_0_dense, mask = to_dense_batch(graph.x_0, graph.batch)
        # graph.debug_node_to_node_derivatives_0 = debug_node_to_node_derivatives_0[mask]
        # #### for debugging


        graph = self.mlp_update(graph, self.model.first_layers)

        # #### for debugging
        # debug_node_to_node_derivatives_1 = self.get_node_to_node_derivatives(graph)
        # graph.debug_node_to_node_derivatives_1 = debug_node_to_node_derivatives_1[mask]
        # #### for debugging

        intermediate_node_to_node_derivatives = []
        
        for gnn, bn, act in zip(self.model.gnn_layers, self.model.bn_layers, self.model.act_layers):
            old_graph = graph.clone()
            
            graph = self.gin_layer_update(graph, gnn)
            graph = self.graph_batchnorm_update(graph, bn)
            graph = self.graph_activation_update(graph, act)

            if self.model.add_residual:
                graph = self.residual_update(graph, old_graph)

            intermediate_node_to_node_derivatives.append(self.get_node_to_node_derivatives(graph))
        
        #convert intermediate_node_to_node_derivatives into sparse representation
        intermediate_node_to_node_derivatives = torch.stack(intermediate_node_to_node_derivatives, dim=-1)
        _, mask = to_dense_batch(graph.x_0, graph.batch)
        
        #TODO: check if mask needs to be manipulated to account for all other axes (maybe flatten and then reshape intermediate_node_to_node_derivatives back)
        intermediate_node_to_node_derivatives = intermediate_node_to_node_derivatives[mask] #shape should be (num_nodes_in_batch, variable_dim, polynomial_dim, embedding_dim, num_layers)

        # graph.x_node_to_node_derivatives = intermediate_node_to_node_derivatives[:,:,:,:,-1,]
        graph.x_intermediate_node_to_node_derivatives =  intermediate_node_to_node_derivatives
        return graph

    
    def aggregation_update(self, graph: torch_geometric.data.Data) -> torch_geometric.data.Data:
        """
        Reduces edge_dim to number of nodes by summing over all different gradient values.
        used to compute derivatives based on sum aggregation.
        """

        x_0, derivatives = graph.x_0, graph.derivatives
        pooled_derivatives = derivatives.sum(dim=1) #TODO: Check if dim=1 or 2
        pooled_x_0 = global_add_pool(x_0, graph.batch)
        graph.x_0 = pooled_x_0

        graph_out = graph.clone()
        graph_out.x_0 = pooled_x_0
        graph_out.derivatives = pooled_derivatives
        #TODO: check mlp_update works after pooling one of the node dimensions, potentially add 1 dim.
        graph_out = self.final_mlp_update(graph_out, self.model.final_layers)

        graph.x_out = graph_out.x_0
        graph.x_node_to_out_derivatives = graph_out.derivatives
        return graph


    def residual_update(self, graph: torch_geometric.data.Data, old_graph: torch_geometric.data.Data) -> torch_geometric.data.Data:
        """
        Adds the gradients from the old graph to the current graph, implementing the residual connection.
        This combines the gradient information from both paths: the transformed path and the skip connection path.
        """
        x_0_new, derivatives_new = graph.x_0, graph.derivatives
        x_0_old, derivatives_old = old_graph.x_0, old_graph.derivatives
        x_0 = x_0_new + x_0_old
        derivatives = derivatives_new + derivatives_old
        graph.x_0 = x_0
        graph.derivatives = derivatives
        return graph

    def gin_layer_update(self, graph: torch_geometric.data.Data, GNNConv) -> torch_geometric.data.Data:
        """
        single GINConv derivative computation, assumes the GNNConv is a GINEConv with no activation on edge attributes.
        """
        mlp = GNNConv.layer.nn
        eps =  GNNConv.layer.eps
        adj_update_graph = graph.clone()
        adj_update_graph = self.adjacency_update(adj_update_graph)

        x_0 = graph.x_0 * (1+eps) + adj_update_graph.x_0
        derivatives=  adj_update_graph.derivatives + (1+eps)*graph.derivatives
        
        #TODO: add filtering here
        # derivatives= self.filter_derivatives(derivatives_new=derivatives, derivatives_old=graph.derivatives)
        
        adj_update_graph.x_0 = x_0
        adj_update_graph.derivatives = derivatives

        if len(graph.edge_attr.shape) == 1:
            graph.edge_attr = graph.edge_attr.unsqueeze(1)
        edge_embedding_vec = GNNConv.edge_embedding(graph.edge_attr)
        adj_update_graph.x_0 += EdgeAggregation()(graph.x, graph.edge_index, edge_embedding_vec)

        adj_update_graph = self.mlp_update(adj_update_graph, mlp)
        
        return adj_update_graph 
     
    def adjacency_update(self, graph: torch_geometric.data.Data) -> torch_geometric.data.Data:
        """
        Updates the gradient edge attributes based on summation over neighbors in graph. Assumes sum aggregation.
        """
        edge_index = graph.edge_index 
        x_0 = graph.x_0
        x_0 = scatter(x_0[edge_index[0]], edge_index[1], dim=0, dim_size=x_0.size(0), reduce='sum')
        graph.x_0 = x_0
        derivatives, adjacency = graph.derivatives, graph.adjacency
        # derivatives = torch.einsum('bnivpe,bim->bnmvpe', derivatives, adjacency) #TODO: check if this is correct
        derivatives = torch.einsum('bni, bimvpe->bnmvpe',  adjacency, derivatives,) #TODO: check if this is correct
        graph.derivatives = derivatives

        return graph
    
    def final_mlp_update(self, graph: torch_geometric.data.Data, mlp: torch.nn.Sequential) -> torch_geometric.data.Data:
        """
        Updates the gradient edge attributes based on a multi-layer perceptron.
          Assumes the mlp is a sequence of linear layer, activation functions and batchnorm layer where the activation is 
          the one specified in the initialization of the GradientExtractor. assumes also the MLP starts and ends with a linear layer.
        """ 
        x_0_out, derivatives  = graph.x_0, graph.derivatives
        x_0_out = x_0_out.unsqueeze(1)
        derivatives = derivatives.unsqueeze(1) #check if unsqueeze(1) or unsqueeze(2)

        for layer in mlp:
            if isinstance(layer, torch.nn.Linear):
                x_0_out, derivatives = self.linear_update(x_0_out, derivatives, layer)
            elif isinstance(layer, torch.nn.BatchNorm1d):
                x_0_out, derivatives = self.batchnorm_update(x_0_out, derivatives, layer)
            elif isinstance(layer, torch.nn.SiLU) or isinstance(layer, torch.nn.ReLU):
                x_0_out, derivatives = self.activation_update(x_0_out, derivatives, layer)
            elif isinstance(layer, torch.nn.Identity):
                pass
            else:
                raise ValueError(f"Unsupported layer type: {type(layer)}")
        graph.x_0 = x_0_out
        graph.derivatives = derivatives
        return graph
    
    
    def mlp_update(self, graph: torch_geometric.data.Data, mlp: torch.nn.Sequential) -> torch_geometric.data.Data:
        """
        Updates the gradient edge attributes based on a multi-layer perceptron.
          Assumes the mlp is a sequence of linear layer, activation functions and batchnorm layer where the activation is 
          the one specified in the initialization of the GradientExtractor. assumes also the MLP starts and ends with a linear layer.
        """ 
        x_0_dense, mask = to_dense_batch(graph.x_0, graph.batch)
        derivatives = graph.derivatives
        for layer in mlp:
            if isinstance(layer, torch.nn.Linear):
                x_0_dense, derivatives = self.linear_update(x_0_dense, derivatives, layer)
            elif isinstance(layer, torch.nn.BatchNorm1d):
                x_0_dense, derivatives = self.batchnorm_update(x_0_dense, derivatives, layer)
            elif isinstance(layer, torch.nn.SiLU) or isinstance(layer, torch.nn.ReLU):
                x_0_dense, derivatives = self.activation_update(x_0_dense, derivatives, layer)
            elif isinstance(layer, torch.nn.Identity):
                pass
            else:
                raise ValueError(f"Unsupported layer type: {type(layer)}")
            
            x_0 = x_0_dense[mask]
            graph.x_0 = x_0
            graph.derivatives = derivatives
        return graph
    
    def graph_batchnorm_update(self, graph: torch_geometric.data.Data, batchnorm_layer: torch.nn.BatchNorm1d) -> torch_geometric.data.Data:
        if self.derivative_batchnorm:
            x_0, derivatives = graph.x_0, graph.derivatives

            sigma = batchnorm_layer.running_var
            gamma = batchnorm_layer.weight
            eps = batchnorm_layer.eps        
            x_0 = batchnorm_layer(x_0)
            derivatives = derivatives/ torch.sqrt(sigma + eps)
            derivatives = derivatives * gamma 

            graph.x_0 = x_0
            graph.derivatives = derivatives

        return graph
        

    def batchnorm_update(self, x_0_dense: torch.Tensor, derivatives: torch.Tensor, batchnorm_layer: torch.nn.BatchNorm1d) -> torch_geometric.data.Data:
        if self.derivative_batchnorm:
            sigma = batchnorm_layer.running_var
            gamma = batchnorm_layer.weight
            eps = batchnorm_layer.eps

            batch_size, num_nodes, embed_dim = x_0_dense.shape
            x_0_dense = x_0_dense.reshape(-1, embed_dim)
            x_0_dense = batchnorm_layer(x_0_dense)
            x_0_dense = x_0_dense.reshape(batch_size, num_nodes, embed_dim)

            derivatives = derivatives/ torch.sqrt(sigma + eps)
            derivatives = derivatives * gamma 

        return x_0_dense, derivatives
        

    
    def linear_update(self, x_0_dense: torch.Tensor, derivatives: torch.Tensor, linear_layer: torch.nn.Linear) -> torch_geometric.data.Data:
        weight = linear_layer.weight
        x_0_dense = linear_layer(x_0_dense)
        derivatives = torch.matmul(derivatives, weight.T)
        return x_0_dense, derivatives
    
    def graph_activation_update(self, graph: torch_geometric.data.Data, activation_layer: torch.nn.Module) -> torch_geometric.data.Data:
        #TODO make more efficient, you dont have to recompute  x_0_dense, mask
        x_0, derivatives, batch = graph.x_0, graph.derivatives, graph.batch
        x_0_dense, mask = to_dense_batch(x_0, batch)
        x_0_dense, derivatives = self.activation_update(x_0_dense, derivatives, activation_layer)
        graph.x_0 = x_0_dense[mask]
        graph.derivatives = derivatives
        return graph
        

    def activation_update(self, x_0_dense: torch.Tensor, derivatives: torch.Tensor, activation_layer: torch.nn.Module) -> torch_geometric.data.Data:
        if isinstance(activation_layer, torch.nn.SiLU):
            x_0_dense, derivatives = self.silu_derivative_update(x_0_dense, derivatives)
        elif isinstance(activation_layer, torch.nn.ReLU):
            x_0_dense, derivatives = self.relu_derivative_update(x_0_dense, derivatives)
        else:
            print(f" got activation function: {type(activation_layer)}, will not compute derivatives")
        return x_0_dense, derivatives


    def instantiate(self, graph: torch_geometric.data.Data) -> torch_geometric.data.Data:
        # gets a pyg graph batch and for each node feature vector obtained after feature embedding,
        # creates a 1-degree polynomial encoding the derivative information so far.  
        # the tensor graph.x_0 represent the zero coefficietns and is of shape (num_nodes, embed_dim)
        # the tensor graph.derivatives represents the higher order coefficients and is of shape (batch_size, num_nodes,num_nodes, variable_dim, polynomial_Dim, embedding_dim).
        # the tensor graph.adjacency represents the adjacency matrix and is of shape (batch_size, num_nodes,num_nodes).

        # Convert edge_index to dense adjacency (without edge attributes)
        adj_dense = to_dense_adj(graph.edge_index, graph.batch)
        graph.adjacency = adj_dense
        
        # ####### debug #######
        # num_nodes = 3
        # torch.arange(num_nodes, device=self.device).view(1, num_nodes, 1, 1, 1)

        # graph.x

        # batch_size, num_nodes = adj_dense.shape[0], adj_dense.shape[1]
        # variable_dim = 10

        # # initialize derivatives tensor
        # i2 = torch.arange(num_nodes, device=self.device).view(1, num_nodes, 1, 1, 1)
        # i3 = torch.arange(num_nodes, device=self.device).view(1, 1, num_nodes, 1,  1)
        # i4 = torch.arange(variable_dim, device=self.device).view(1, 1, 1, variable_dim, 1)
        # i6 = torch.arange(variable_dim, device=self.device).view(1, 1, 1, 1, variable_dim)
        # ####### debug #######

        # get embedding of node features
        x_0 = self.model.feature_encoder(graph.x_0.long().squeeze())
        graph.x_0 = x_0

        # get derivative tensor dimensions dimensions
        batch_size, num_nodes = adj_dense.shape[0], adj_dense.shape[1]
        emb_dim = x_0.size(1)
        variable_dim = emb_dim
        polynomial_dim = self.max_degree

        # initialize derivatives tensor
        i2 = torch.arange(num_nodes, device=self.device).view(1, num_nodes, 1, 1, 1)
        i3 = torch.arange(num_nodes, device=self.device).view(1, 1, num_nodes, 1,  1)
        i4 = torch.arange(variable_dim, device=self.device).view(1, 1, 1, variable_dim, 1)
        i6 = torch.arange(variable_dim, device=self.device).view(1, 1, 1, 1, variable_dim)

        # Use broadcasting to get a mask where i2 == i3 and i4 == i6 and i5 == 0
        mask = ((i2 == i3) & (i4 == i6)).float()
        
        
        derivatives = torch.zeros(batch_size, num_nodes, num_nodes, variable_dim, polynomial_dim, emb_dim, device=self.device)
        # derivatives[:,torch.arange(num_nodes, device=self.device),torch.arange(num_nodes, device=self.device),torch.arange(variable_dim, device=self.device), 0, torch.arange(emb_dim, device=self.device)] = 1
        derivatives[:,:,:,:, 0, :] = mask
        graph.derivatives = derivatives
        return graph
    

    def filter_derivatives(self, old_derivatives: torch.Tensor, new_derivatives: torch.Tensor) -> torch.sparse_coo_tensor:
       pass
    

    @staticmethod
    def get_node_to_node_derivatives(graph: torch_geometric.data.Data, sort:bool=True) -> torch.Tensor:
        """
        Extracts the node-to-node derivatives from the gradient edge attributes.    
        Returns:
            Tensor of shape (num_nodes, variable_dim, polynomial_dim, embedding_dim) containing
            the derivative information of each node
        """
        derivatives = graph.derivatives
        num_nodes  = derivatives.shape[1]
        node_to_node_derivatives = derivatives[:,torch.arange(num_nodes), torch.arange(num_nodes)]
        #TODO: potentially move making this sparse here

        return node_to_node_derivatives
    
    
    @staticmethod
    def compute_composition_derivatives(out_der, in_der, polynomial_dim):
        """
        Computes the derivatives of the composition of functions  f(g(x)).
        out_der a list of tensors and is the derivatives of the outer function at point g(x). each tensor is of shape (batch, num_nodes, embd_dim) and the list is of length polynomial_dim 
        in_der is the derivatives of the inner function at point x and is of shape (batch, num_nodes,num_nodes variable_dim, polynomial dim,  embd_dim) 
        """
        out_der = [der.unsqueeze(2).unsqueeze(2) for der in out_der] #TODO: check if unsqueeze(2) or unsqueeze(1)
        composition_derivatives = []
        for i in range(polynomial_dim):
            if i == 0:
                composition_derivatives.append(out_der[0] * in_der[:,:,:,:,0]) 
            elif i == 1:
                composition_derivatives.append(out_der[1] * in_der[:,:,:,:,0]**2 +  out_der[0] * in_der[:,:,:,:,1])
            elif i == 2:
                composition_derivatives.append(out_der[2] * in_der[:,:,:,:,0]**3 + 3 * out_der[1] * in_der[:,:,:,:,0] * in_der[:,:,:,:,1] + out_der[0] * in_der[:,:,:,:,2])
            elif i == 3:
                composition_derivatives.append(out_der[3] * in_der[:,:,:,:,0]**4 +6*out_der[2] * in_der[:,:,:,:,0]**2 * in_der[:,:,:,:,1] + 3 * out_der[1] * in_der[:,:,:,:,1]**2  + 4 * out_der[1] * in_der[:,:,:,:,0] * in_der[:,:,:,:,2] + out_der[0] * in_der[:,:,:,:,3])
            else:
                raise ValueError(f"Currently only supporting up to 4th order derivatives, got {i}")
            
            derivatives = torch.stack(composition_derivatives, dim=4) #TODO: check if this is the correct dimension
    
        return derivatives


    def relu_derivative_update(self, x_0_dense: torch.Tensor, derivatives: torch.Tensor):
        """
        Computes the derivatives of the ReLU function. Assumes X is a tensor of shape (b,) where b is the batch size and k is the number of derivatives to compute.
        """
        x_0_dense = torch.nn.functional.relu(x_0_dense)
        # TODO: check if this works, try replacing with sigmoid with high temperture or custom activation function
        batch_size, num_nodes, embed_dim = x_0_dense.shape
        relu_derivatives = (x_0_dense>0).float()
        # TODO: check if this is correct or (batch_size, 1, num_nodes, 1, 1, embed_dim)
        relu_derivatives = relu_derivatives.reshape(batch_size, num_nodes, 1, 1, 1, embed_dim)

        derivatives = derivatives * relu_derivatives
        return x_0_dense, derivatives
    
    
    def silu_derivative_update(self, x_0_dense: torch.Tensor, derivatives: torch.Tensor):
        """
        Computes the derivatives of the ReLU function. Assumes X is a tensor of shape (b,) where b is the batch size and k is the number of derivatives to compute.
        """
        silu_derivatives = self.compute_silu_derivatives(x_0_dense, self.max_degree)
        x_0_dense = torch.nn.functional.silu(x_0_dense)
        derivatives = self.compute_composition_derivatives(in_der=derivatives, out_der=silu_derivatives, polynomial_dim=self.max_degree)
        return x_0_dense, derivatives
    
    @staticmethod
    def compute_silu_derivatives(x, k):
        """
        Computes the derivatives of the SiLU function. Assumes X is a tensor of shape (b,) where b is the batch size and k is the number of derivatives to compute.
        we first compute the sigmoid derivatives where  sigmoid_derivatives[i,j] = d/dx^j sigmoid(x[i])
        We then use the chain rule to compute the silu derivatives:  = d/dx^j silu(x[i]) = x[i] * d/dx^j sigmoid(x[i]) + j * d/dx^(j-1) sigmoid(x[i]) 
        """
        sigmoid_derivatives = [sigmoid(x)]
        silu_derivatives = [x * sigmoid(x)]

        for i in range(1, k+1):
            sigmoid_derivatives.append(DenseGradientExtractor.compute_sigmoid_derivatives(sigmoid_derivatives, i))
            silu_derivatives.append(x * sigmoid_derivatives[-1] + i * sigmoid_derivatives[-2])
            
        # silu_derivatives = torch.stack(silu_derivatives[1:], dim=1)   
        silu_derivatives = silu_derivatives[1:] 
        return silu_derivatives
    
    
    @staticmethod
    def compute_sigmoid_derivatives(X, k):
        if k==1:
            return X[0] * (1-X[0])
        elif k==2:
            return X[1] * (1 - 2*X[0])
        elif k==3:
            return X[2] * (1 - 2*X[0]) - 2*X[1]**2
        elif k==4:
            return X[3] * (1 - 2*X[0]) - 6*X[1]*X[2]
        else:
            raise ValueError("currently only supporting k up to 4")
    














class FirstGNNnetwork(torch.nn.Module):
    def __init__(
        self,
        num_layers,
        in_dim,
        emb_dim,
        hidden_dim,
        add_residual=False,
        track_running_stats=True,
        num_tasks: int = 1,
        activation=None,
        use_GINE_activation=False,
        num_edge_emb=4,
        use_features=True,
    ):
        super().__init__()
        self.emb_dim = emb_dim
        self.hidden_dim = hidden_dim
        self.activation_class = activation if activation is not None else torch.nn.SiLU
        self.use_features = use_features
        self.num_layers = num_layers

        self.feature_encoder =torch.nn.Embedding(num_embeddings=in_dim, embedding_dim=emb_dim)
        self.first_layers = torch.nn.Sequential(
                torch.nn.Linear(in_features=emb_dim, out_features= hidden_dim),
                self.activation_class(),
                torch.nn.Linear(in_features= hidden_dim, out_features=hidden_dim),
            )

        self.gnn_layers = torch.nn.ModuleList()
        self.bn_layers = torch.nn.ModuleList()
        self.act_layers = torch.nn.ModuleList()
        for i in range(num_layers):
            self.gnn_layers.append(
                CustomGINE(
                    hidden_dim,
                    hidden_dim,
                    track_running_stats=track_running_stats,
                    activation_class=self.activation_class,
                    use_GINE_activation=use_GINE_activation,
                    num_edge_emb=num_edge_emb,
                )
            )
            self.bn_layers.append(
                torch.nn.BatchNorm1d(hidden_dim, track_running_stats=track_running_stats)
            )
            self.act_layers.append(
                self.activation_class()
            )

        self.add_residual = add_residual
        self.pool = global_add_pool

        self.final_layers = None
        if num_tasks is not None:
            emb_dim = emb_dim
            self.final_layers = torch.nn.Sequential(
                torch.nn.Linear(in_features=hidden_dim, out_features=2 * hidden_dim),
                self.activation_class(),
                torch.nn.Linear(in_features=2*hidden_dim, out_features=2 * hidden_dim),
                self.activation_class(),
                torch.nn.Linear(in_features=2 * hidden_dim, out_features=num_tasks),
            )

    def forward(self, batched_data):
        x, edge_index, edge_attr, batch = (
            batched_data.x,
            batched_data.edge_index,
            batched_data.edge_attr,
            batched_data.batch,
        )
        x, edge_attr = self.get_initial_embeddings(x=x, edge_attr=edge_attr)
        out = self.get_output(x, edge_index, edge_attr, batch)
        return out
    
    def get_output_no_embedding(self, x, edge_index, edge_attr, batch):
        x = self.get_final_node_embeddings_no_embedding(x, edge_index, edge_attr)
        x = self.pool(x, batch)
        out = self.final_layers(x)
        return out
    
    def get_final_node_embeddings_no_embedding(self, x, edge_index, edge_attr):
        for i in range(len(self.gnn_layers)):
            x = self.intermediate_updates(x, edge_index, edge_attr, i, scale=None, shift=None)
        return x
    
    def get_output(self, x, edge_index, edge_attr, batch):
        x = self.get_final_node_embeddings(x, edge_index, edge_attr)
        x = self.pool(x, batch)
        out = self.final_layers(x)
        return out
    
    def get_final_node_embeddings(self, x, edge_index, edge_attr):
        x = self.first_layers(x)
        for i in range(len(self.gnn_layers)):
            x = self.intermediate_updates(x, edge_index, edge_attr, i, scale=None, shift=None)
        return x
    

    def intermediate_updates(self, x, edge_index, edge_attr, index, scale=None, shift=None):
        gnn, bn, act  =  self.gnn_layers[index], self.bn_layers[index], self.act_layers[index]
        h = act(bn(gnn(x, edge_index, edge_attr)))
        if self.add_residual:
            x = h + x
        else:
            x = h
        scale = 1 if scale is None else scale
        shift = 0 if shift is None else shift
        x = x * scale + shift
        return x
        
        
    
    def get_initial_embeddings(self, x, edge_attr):
        if not self.use_features:
            x = torch.ones_like(x)
            edge_attr = torch.zeros_like(edge_attr)
        x = x.long()
        edge_attr = edge_attr.long()
        x = self.feature_encoder(x.squeeze())

        return x, edge_attr


class CustomGINE(torch.nn.Module):
    def __init__(self, in_dim, emb_dim, track_running_stats, num_edge_emb=4, activation_class=None, use_GINE_activation=False):
        super().__init__()
        self.activation_class = activation_class if activation_class is not None else torch.nn.SiLU
        
        GINE_activation = activation_class() if use_GINE_activation else torch.nn.Identity()

        mlp = torch.nn.Sequential(
            torch.nn.Linear(in_dim, emb_dim),
            torch.nn.BatchNorm1d(emb_dim, track_running_stats=track_running_stats),
            self.activation_class(),
            torch.nn.Linear(emb_dim, emb_dim),
        )
        self.layer = AnalyticGINEConv(nn=mlp, train_eps=True, activation=GINE_activation)
        self.edge_embedding = BondEncoder(emb_dim=emb_dim)
        
        # self.edge_embedding = torch.nn.Embedding(
        #     num_embeddings=num_edge_emb, embedding_dim=in_dim
        # )

    def forward(self, x, edge_index, edge_attr):
        if len(edge_attr.shape) == 1:
            edge_attr = edge_attr.unsqueeze(1)
        return self.layer(x, edge_index, self.edge_embedding(edge_attr))



class AnalyticGINEConv(MessagePassing):
    """
    Gine with default analytic activation
    """

    def __init__(
        self,
        nn: torch.nn.Module,
        eps: float = 0.0,
        train_eps: bool = False,
        edge_dim: Optional[int] = None,
        activation: Optional[torch.nn.Module] = None,
        **kwargs,
    ):
        kwargs.setdefault("aggr", "add")
        super().__init__(**kwargs)
        self.nn = nn
        self.initial_eps = eps
        if train_eps:
            self.eps = torch.nn.Parameter(torch.empty(1))
        else:
            self.register_buffer("eps", torch.empty(1))

        nn = self.nn[0]
        if hasattr(nn, "in_features"):
            in_channels = nn.in_features
        elif hasattr(nn, "in_channels"):
            in_channels = nn.in_channels
        else:
            raise ValueError("Could not infer input channels from `nn`.")
        # self.lin = (
        #     Linear(2 * in_channels, in_channels)
        #     if edge_dim is None
        #     else Linear(in_channels + edge_dim, in_channels)
        # )
        self.activation = activation if activation is not None else torch.nn.Identity()
        self.reset_parameters()

    def reset_parameters(self):
        reset(self.nn)
        self.eps.data.fill_(self.initial_eps)
        # self.lin.reset_parameters()

    def forward(
        self,
        x: Union[Tensor, OptPairTensor],
        edge_index: Adj,
        edge_attr: OptTensor = None,
        size: Size = None,
    ) -> Tensor:
        if isinstance(x, Tensor):
            x = (x, x)

        # propagate_type: (x: OptPairTensor, edge_attr: OptTensor)
        out = self.propagate(edge_index, x=x, edge_attr=edge_attr, size=size)

        x_r = x[1]
        if x_r is not None:
            out = out + (1 + self.eps) * x_r

        return self.nn(out)

    def message(self, x_j: Tensor, edge_attr: Tensor) -> Tensor:
        # return self.lin(torch.concat([x_j, edge_attr], dim=-1)).relu()
        return self.activation(x_j + edge_attr)

    def __repr__(self) -> str:
        return f"{self.__class__.__name__}(nn={self.nn})"



class EdgeAggregation(MessagePassing):
    """
    gets the sum of edge embedding vectors per node
    """

    def __init__(
        self,
        **kwargs,
    ):
        kwargs.setdefault("aggr", "add")
        super().__init__(**kwargs)

    
    def forward(
        self,
        x: Union[Tensor, OptPairTensor],
        edge_index: Adj,
        edge_attr: OptTensor = None,
        size: Size = None,
    ) -> Tensor:
        if isinstance(x, Tensor):
            x = (x, x)
        if len(edge_attr.shape) > 2 :
            edge_attr = edge_attr.swapaxes(0,-2)
        out = self.propagate(edge_index, x=x, edge_attr=edge_attr, size=size)

        if len(edge_attr.shape) > 2 :
            out = out.swapaxes(0,-2)

        return out

    def message(self, x_j: Tensor, edge_attr: Tensor) -> Tensor:
        return edge_attr






    