import os
import sys
from typing import Tuple, Union, Optional, Callable, List, Dict
import torch
from torch_sparse import SparseTensor
import torch_geometric
import torch.nn.functional as F
from copy import deepcopy
from torch import sigmoid, Tensor
from torch_scatter import scatter
from torch_geometric.utils import to_dense_batch, to_dense_adj
from torch_geometric.nn import global_add_pool, GINConv, GINEConv
from ogb.graphproppred.mol_encoder import BondEncoder
from torch_geometric.nn.aggr import SumAggregation
from torch.nn import Linear
from torch_geometric.nn.conv import MessagePassing
from torch_geometric.nn.inits import reset
from torch_geometric.typing import Adj, OptPairTensor, OptTensor, Size
from torch_geometric.graphgym.register import register_node_encoder
import numpy as np
import math
import torch.nn as nn
from torch_geometric.graphgym.config import cfg

@register_node_encoder('DerivativePreprocessingEncoder')
class DerivativePreprocessingEncoder(nn.Module):
    def __init__(self, x_0_embedding_dim, derivate_embedding_dim):    
        super(DerivativePreprocessingEncoder, self).__init__()


        # first gnn
        num_layers = cfg.derivative_preprocessing.num_layers
        emb_dim = cfg.derivative_preprocessing.emb_dim
        hidden_dim = cfg.derivative_preprocessing.hidden_dim
        derivative_hidden_dim = cfg.derivative_preprocessing.derivative_hidden_dim
        track_running_stats = cfg.derivative_preprocessing.track_running_stats
        activation = cfg.derivative_preprocessing.activation
        # gradient extractor
        max_degree = cfg.derivative_preprocessing.max_degree
        self.num_layers = num_layers
        
        self.device = "cuda" if torch.cuda.is_available() else "cpu"
        activation = self.get_activation(activation)

        if x_0_embedding_dim+derivate_embedding_dim + cfg.derivative_encoder.first_embedding_dim != cfg.gnn.dim_inner:
            x_0_embedding_dim = cfg.gnn.dim_inner - derivate_embedding_dim - cfg.derivative_encoder.first_embedding_dim
            print(f"x_0_embedding_dim+derivate_embedding_dim + cfg.derivative_encoder.first_embedding_dim != cfg.gnn.dim_inner")
            print(f"new x_0_embedding_dim: {x_0_embedding_dim}")
            if x_0_embedding_dim < 0:
                raise ValueError(f"x_0_embedding_dim is less than 0")
        
        self.init_embedding_layers(hidden_dim=hidden_dim, 
                                   x_0_embedding_dim=x_0_embedding_dim, 
                                   derivate_embedding_dim=derivate_embedding_dim, 
                                   activation=activation, 
                                   max_degree=max_degree, 
                                   emb_dim=emb_dim, 
                                   num_layers=num_layers, 
                                   derivative_hidden_dim=derivative_hidden_dim,
                                   track_running_stats=track_running_stats)
        
    
    def forward(self, batch):
        # Extract gradients using first GNN
        # batch = self.compute_derivatives(batch)
        x_intermediate_node_to_node_derivatives , x_0 = self.get_processed_data(batch)
        
        #get x_0 embedding
        x_0 = self.first_to_second_embed(x_0)

        #get derivative embedding
        x_node_to_node_derivatives = self.derivative_feature_embed(x_intermediate_node_to_node_derivatives)
        batch.x = torch.cat((batch.x, x_0, x_node_to_node_derivatives), 1)
        batch.x_node_to_node_derivatives = x_node_to_node_derivatives
        return batch
    
    def get_processed_data(self, batch):
        x_intermediate_node_to_node_derivatives , x_0 = batch.x_intermediate_node_to_node_derivatives, batch.x_0
        batch_size = x_intermediate_node_to_node_derivatives.shape[0]
        x_intermediate_node_to_node_derivatives = x_intermediate_node_to_node_derivatives.reshape(batch_size, -1)
        return x_intermediate_node_to_node_derivatives , x_0
        

    # def compute_derivatives(self, data):
    #     #TODO: add output derivatives
    #     # data = self.gradient_extractor.compute_node_to_node_derivatives(data)
    #     num_layers_first = self.num_layers
    #     centrality_normalization = torch.tensor([math.factorial(i) for i in range(1, num_layers_first+1)],
    #                                              dtype=torch.float32, device=self.device).reshape(1,1,1,1, num_layers_first)
    #     data.x_intermediate_node_to_node_derivatives = data.x_intermediate_node_to_node_derivatives / centrality_normalization
    #     return data
    
   
    def init_embedding_layers(self, hidden_dim,
                               x_0_embedding_dim,
                               derivate_embedding_dim,
                               activation, 
                               max_degree, 
                               emb_dim, 
                               num_layers, 
                               derivative_hidden_dim=4, 
                               track_running_stats:bool = True):
        
        self.first_to_second_embed = torch.nn.Sequential(torch.nn.BatchNorm1d(hidden_dim, track_running_stats=track_running_stats),
                                                         torch.nn.Linear(hidden_dim, hidden_dim),
                                                         torch.nn.BatchNorm1d(hidden_dim, track_running_stats=track_running_stats),
                                                         activation(),
                                                         torch.nn.Linear(hidden_dim, x_0_embedding_dim))
        
        
        
        derivative_dim = emb_dim * max_degree * hidden_dim * num_layers
        self.derivative_feature_embed = torch.nn.Sequential(torch.nn.BatchNorm1d(derivative_dim, track_running_stats=track_running_stats),
                                                            torch.nn.Linear(derivative_dim, derivative_hidden_dim),
                                                            torch.nn.BatchNorm1d(derivative_hidden_dim, track_running_stats=track_running_stats),
                                                            activation(),
                                                            torch.nn.Linear(derivative_hidden_dim, derivate_embedding_dim))
    
    def get_activation(self, activation:str):
        if activation == "relu":
            return torch.nn.ReLU
        elif activation == "silu":
            return torch.nn.SiLU
        else:
            raise ValueError(f"Unsupported activation: {activation}")