import os
import sys
from typing import Tuple, Union, Optional, Callable, List, Dict
import torch
from torch_sparse import SparseTensor
import torch_geometric
import torch.nn.functional as F
from copy import deepcopy
from torch import sigmoid, Tensor
from torch_scatter import scatter
from torch_geometric.utils import to_dense_batch, to_dense_adj
from torch_geometric.nn import global_add_pool, GINConv
from ogb.graphproppred.mol_encoder import BondEncoder
from torch_geometric.nn.aggr import SumAggregation
from torch.nn import Linear
from torch_geometric.nn.conv import MessagePassing
from torch_geometric.nn.inits import reset
from torch_geometric.typing import Adj, OptPairTensor, OptTensor, Size
from torch_geometric.graphgym.register import register_node_encoder
import numpy as np
import math
import torch.nn as nn
from torch_geometric.graphgym.config import cfg
from torch_geometric.utils import spmm
from torch_geometric.graphgym.models.encoder import AtomEncoder



@register_node_encoder('EfficientDerivativeEncoder')
class EfficientDerivativeEncoder(nn.Module):
    def __init__(self, x_0_embedding_dim, derivate_embedding_dim):    
        super(EfficientDerivativeEncoder, self).__init__()

        num_layers = cfg.efficient_derivative_encoder.num_layers
        in_dim = cfg.efficient_derivative_encoder.in_dim
        emb_dim = cfg.efficient_derivative_encoder.emb_dim
        hidden_dim = cfg.efficient_derivative_encoder.hidden_dim
        derivative_hidden_dim = cfg.efficient_derivative_encoder.derivative_hidden_dim
        add_residual = cfg.efficient_derivative_encoder.add_residual
        track_running_stats = cfg.efficient_derivative_encoder.track_running_stats
        activation = cfg.efficient_derivative_encoder.activation
        num_edge_emb = cfg.efficient_derivative_encoder.num_edge_emb
        centrality_init = cfg.efficient_derivative_encoder.centrality_init
        max_degree = cfg.efficient_derivative_encoder.max_degree
        max_variables_per_node = cfg.efficient_derivative_encoder.max_variables_per_node
        sparse = cfg.efficient_derivative_encoder.sparse
        dropout = cfg.efficient_derivative_encoder.encoder_dropout
        derivative_batchnorm = cfg.efficient_derivative_encoder.derivative_batchnorm
        
        self.encoder_type = cfg.dataset.node_encoder_name.split('+')[0]
        self.num_layers = num_layers
        self.device = "cuda" if torch.cuda.is_available() else "cpu"
        activation = self.get_activation(activation)

        if x_0_embedding_dim+derivate_embedding_dim + cfg.efficient_derivative_encoder.first_embedding_dim != cfg.gnn.dim_inner:
            x_0_embedding_dim = cfg.gnn.dim_inner - derivate_embedding_dim - cfg.efficient_derivative_encoder.first_embedding_dim
            print(f"x_0_embedding_dim+derivate_embedding_dim + cfg.derivative_encoder.first_embedding_dim != cfg.gnn.dim_inner")
            print(f"new x_0_embedding_dim: {x_0_embedding_dim}")
            if x_0_embedding_dim < 0:
                raise ValueError(f"x_0_embedding_dim is less than 0")
        
        self.init_embedding_layers(hidden_dim=hidden_dim, 
                                   x_0_embedding_dim=x_0_embedding_dim, 
                                   derivate_embedding_dim=derivate_embedding_dim, 
                                   activation=activation, 
                                   max_degree=max_degree, 
                                   emb_dim=emb_dim, 
                                   num_layers=num_layers, 
                                   derivative_hidden_dim=derivative_hidden_dim,
                                   track_running_stats=track_running_stats,
                                   dropout=dropout)
        
        self.init_derivative_gnn(num_layers=num_layers,
                  in_dim=in_dim,
                  emb_dim=emb_dim,
                  hidden_dim=hidden_dim,
                  add_residual=add_residual,
                  track_running_stats=track_running_stats,
                  activation=activation,
                  num_edge_emb=num_edge_emb,
                  centrality_init=centrality_init,
                  max_degree=max_degree,
                  max_variables_per_node=max_variables_per_node,
                  sparse=sparse,
                  derivative_batchnorm=derivative_batchnorm,
                  encoder_type=self.encoder_type)
        
       
    
    def forward(self, batch):
        # Extract gradients using first GNN
        batch = self.compute_derivatives(batch)
        x_intermediate_node_to_node_derivatives , x_0 = self.get_processed_data(batch)
        
        #get x_0 embedding
        x_0 = self.first_to_second_embed(x_0)

        #get derivative embedding
        x_node_to_node_derivatives = self.derivative_feature_embed(x_intermediate_node_to_node_derivatives)
        batch.x = torch.cat((batch.x, x_0, x_node_to_node_derivatives), 1)
        batch.edge_attr = batch.edge_attr.squeeze() #TODO: make cleaner
        # batch.x_node_to_node_derivatives = x_node_to_node_derivatives
        return batch
    
    def get_processed_data(self, batch):
        #TODO: check if this is correct
        x_intermediate_node_to_node_derivatives , x_0 = batch.x_intermediate_node_to_node_derivatives, batch.x_0
        batch_size = x_intermediate_node_to_node_derivatives.shape[0]
        x_intermediate_node_to_node_derivatives = x_intermediate_node_to_node_derivatives.reshape(batch_size, -1)
        return x_intermediate_node_to_node_derivatives , x_0
        

    def compute_derivatives(self, data):
        #TODO: add output derivatives
        data = self.derivative_gnn(data)
        num_layers_first = self.num_layers
        centrality_normalization = torch.tensor([math.factorial(i) for i in range(1, num_layers_first+1)],
                                                 dtype=torch.float32, device=self.device).reshape(1,1,1,1, num_layers_first)
        data.x_intermediate_node_to_node_derivatives = data.x_intermediate_node_to_node_derivatives / centrality_normalization
        return data
    
   
    def init_embedding_layers(self, hidden_dim,
                               x_0_embedding_dim,
                               derivate_embedding_dim,
                               activation, 
                               max_degree, 
                               emb_dim, 
                               num_layers, 
                               derivative_hidden_dim=4, 
                               track_running_stats:bool = True,
                               dropout:float = 0.0):
        
        self.first_to_second_embed = torch.nn.Sequential(torch.nn.BatchNorm1d(hidden_dim, track_running_stats=track_running_stats),
                                                         torch.nn.Linear(hidden_dim, hidden_dim),
                                                         torch.nn.BatchNorm1d(hidden_dim, track_running_stats=track_running_stats),
                                                         activation(),
                                                         torch.nn.Dropout(dropout),
                                                         torch.nn.Linear(hidden_dim, x_0_embedding_dim))
        
        
        
        
        derivative_dim = emb_dim * max_degree * hidden_dim * num_layers
        self.derivative_feature_embed = torch.nn.Sequential(torch.nn.BatchNorm1d(derivative_dim, track_running_stats=track_running_stats),
                                                            torch.nn.Linear(derivative_dim, derivative_hidden_dim),
                                                            torch.nn.BatchNorm1d(derivative_hidden_dim, track_running_stats=track_running_stats),
                                                            activation(),
                                                            torch.nn.Dropout(dropout),
                                                            torch.nn.Linear(derivative_hidden_dim, derivate_embedding_dim))
        
        
        
    def init_derivative_gnn(self, 
                       num_layers, 
                       in_dim, 
                       emb_dim, 
                       hidden_dim, 
                       add_residual,
                       track_running_stats,
                       activation, 
                       num_edge_emb, 
                       centrality_init,
                       max_degree,
                       max_variables_per_node,
                       sparse,
                       derivative_batchnorm,
                       encoder_type):
        
        #TODO: make right
        
        self.derivative_gnn = DerivativeGNNnetwork(
            in_dim=in_dim,
            emb_dim=emb_dim, 
            hidden_dim=hidden_dim,
            num_layers=num_layers,
            add_residual=add_residual,
            track_running_stats=track_running_stats,
            activation=activation,
            num_edge_emb=num_edge_emb,
            centrality_init=centrality_init,
            max_degree=max_degree,
            max_variables_per_node=max_variables_per_node,
            sparse=sparse,
            derivative_batchnorm=derivative_batchnorm,
            encoder_type=encoder_type)
    
    def get_activation(self, activation:str):
        if activation == "relu":
            return torch.nn.ReLU
        elif activation == "silu":
            return torch.nn.SiLU
        else:
            raise ValueError(f"Unsupported activation: {activation}")

        
    
        
    
        
class DerivativeGNNnetwork(torch.nn.Module):
    def __init__(self, 
                 in_dim, 
                 emb_dim, 
                 hidden_dim, 
                 num_layers, 
                 add_residual, 
                 track_running_stats, 
                 activation, 
                 num_edge_emb, 
                 centrality_init,
                 max_degree,
                 max_variables_per_node = None,
                 sparse = False,
                 derivative_batchnorm = False,
                 encoder_type = "TypeDictNode"):
        super(DerivativeGNNnetwork, self).__init__()
        self.device = "cuda" if torch.cuda.is_available() else "cpu"
        self.in_dim = in_dim
        self.emb_dim = emb_dim
        self.hidden_dim = hidden_dim
        self.num_layers = num_layers
        self.add_residual = add_residual
        self.track_running_stats = track_running_stats
        self.activation = activation
        self.num_edge_emb = num_edge_emb
        self.centrality_init = centrality_init
        self.max_degree = max_degree
        self.max_variables_per_node = max_variables_per_node
        self.sparse = sparse
        self.derivative_batchnorm = derivative_batchnorm    
        self.encoder_type = encoder_type

        self.init_derivative_gnn()

    
    def forward(self, data):
        data = self.get_initial_embeddings(data)
        intermediate_node_to_node_derivatives = []
        for layer in self.layers:
            data = layer(data)
            # intermediate_node_to_node_derivatives.append(self.get_node_to_node_derivatives(data))
            intermediate_node_to_node_derivatives.append(data.x_derivatives[torch.arange(data.x_derivatives.shape[0], device=self.device), data.x_derivative_mask.squeeze()])

        intermediate_node_to_node_derivatives = torch.stack(intermediate_node_to_node_derivatives, dim=-1)
        data.x_intermediate_node_to_node_derivatives = intermediate_node_to_node_derivatives
        return data
    
    
    # def get_node_to_node_derivatives(self, data):
    #     _, mask = to_dense_batch(data.x_0, data.batch)
    #     return data.x_derivatives[mask]
    
    def get_initial_embeddings(self, data):
        data.x_0 = data.x_0.long().squeeze()  
        if self.encoder_type == "TypeDictNode":
            data.x_0 = self.encoder(data.x_0)
            data.edge_attr = data.edge_attr.long().reshape(-1,1)

        elif self.encoder_type == "Atom":
            encoded_features = 0
            for i in range(data.x_0.shape[1]):
                encoded_features += self.encoder.atom_embedding_list[i](data.x_0[:, i])
            data.x_0 = encoded_features / data.x_0.shape[1]
            data.edge_attr = data.edge_attr.long()
        else:
            raise ValueError(f"Unsupported encoder type: {self.encoder_type}")
        data.x_0, data.x_derivatives = apply_mlp(data.x_0, data.x_derivatives, self.first_mlp)
        return data
    

    def init_derivative_gnn(self):
        if self.encoder_type == "TypeDictNode":
            self.encoder = torch.nn.Embedding(num_embeddings=self.in_dim, embedding_dim=self.emb_dim)
        elif self.encoder_type == "Atom":
            self.encoder = AtomEncoder(emb_dim=self.emb_dim)
        else:
            raise ValueError(f"Unsupported encoder type: {self.encoder_type}")
        
        self.first_mlp = torch.nn.Sequential(
                torch.nn.Linear(in_features=self.emb_dim, out_features= self.hidden_dim),
                self.activation(),
                torch.nn.Linear(in_features= self.hidden_dim, out_features=self.hidden_dim)
                )

        layers = []
        
        #all intermediate layers end in activation
        for _ in range(self.num_layers-1):
            layers.append(DerivativeGNNlayer(in_dim=self.hidden_dim, 
                                             out_dim=self.hidden_dim, 
                                             add_residual=self.add_residual, 
                                             track_running_stats=self.track_running_stats, 
                                             activation=self.activation, 
                                             num_edge_emb=self.num_edge_emb, 
                                             end_linear=False))
        
        #last layer end in linear
        layers.append(DerivativeGNNlayer(in_dim=self.hidden_dim, 
                                             out_dim=self.hidden_dim, 
                                             add_residual=self.add_residual, 
                                             track_running_stats=self.track_running_stats, 
                                             activation=self.activation, 
                                             num_edge_emb=self.num_edge_emb, 
                                             end_linear=True))
        
        self.layers = torch.nn.ModuleList(layers)
        self.init_with_centrality()

    
    def init_with_centrality(self):
        with torch.no_grad():

            #init embedding layer to 1
            if self.encoder_type == "TypeDictNode":
                self.encoder.weight.data = torch.ones(self.encoder.weight.data.shape)
            elif self.encoder_type == "Atom":
                for enc in self.encoder.atom_embedding_list:
                    enc.weight.data = torch.ones(enc.weight.data.shape)
            
            # init first layer to identity concatenated
            in_dim = self.first_mlp[0].weight.data.shape[1]
            out_dim = self.first_mlp[0].weight.data.shape[0]
            repeat_factor = out_dim // in_dim
            weight = torch.zeros(out_dim, in_dim)
            for i in range(repeat_factor):
                weight[i*in_dim:(i+1)*in_dim, :] = torch.eye(in_dim)
            self.first_mlp[0].weight.data = weight

            #init last layer to identity and biases to zero
            self.first_mlp[0].bias.data = torch.zeros(out_dim)
            self.first_mlp[-1].weight.data = torch.eye(out_dim)
            self.first_mlp[-1].bias.data = torch.zeros(out_dim)

        
            hidden_dim = self.hidden_dim
            for i, layer in enumerate(self.layers):
                #epsilon -1 so the conv is gcn
                # layer.conv.eps.data = torch.tensor([-1.])
                layer.x_0_conv.eps.data = torch.tensor([-1.])
                # layer.x_derivatives_conv.eps.data = torch.tensor([-1.])

                #edge embedding to zero
                for embedding in layer.edge_embedding.bond_embedding_list:
                    embedding.weight.data = torch.zeros(embedding.weight.data.shape)

                #first linear
                layer.mlp[0].weight.data = torch.eye(hidden_dim)  # Identity matrix
                layer.mlp[0].bias.data = torch.zeros(hidden_dim)  # Zero bias
                
                # first batchnorm
                layer.mlp[1].weight.data = torch.ones(hidden_dim)
                layer.mlp[1].bias.data = torch.zeros(hidden_dim)   # Zero bias
                layer.mlp[1].running_mean = torch.zeros(hidden_dim) # Zero mean
                layer.mlp[1].running_var = torch.ones(hidden_dim)   # Unit variance
                # gnn.layer.nn[1].eps = 0. #TODO: check if this matters
                
                # second linear
                layer.mlp[3].weight.data = torch.eye(hidden_dim)  # Identity matrix
                layer.mlp[3].bias.data = torch.zeros(hidden_dim)  # Zero bias

                #second batchnorm not relevant for last layer
                if i < len(self.layers) - 1:
                    layer.mlp[4].weight.data = torch.ones(hidden_dim)
                    layer.mlp[4].bias.data = torch.zeros(hidden_dim)   # Zero bias
                    layer.mlp[4].running_mean = torch.zeros(hidden_dim) # Zero mean
                    layer.mlp[4].running_var = torch.ones(hidden_dim)   # Unit variance

    
        




class DerivativeGNNlayer(torch.nn.Module):
    def __init__(self, 
                 in_dim,
                 out_dim, 
                 add_residual, 
                 track_running_stats, 
                 activation, 
                 num_edge_emb, 
                 end_linear:bool = False):
        
        super(DerivativeGNNlayer, self).__init__()
        self.device = "cuda" if torch.cuda.is_available() else "cpu"
        self.in_dim = in_dim
        self.out_dim = out_dim
        self.add_residual = add_residual
        self.track_running_stats = track_running_stats
        self.activation = activation    
        self.end_linear = end_linear

        #networks 
        # self.conv  = IdentityGINEConv(train_eps=True)

        self.x_0_conv  = IdentityGINEConv(train_eps=True)
        self.x_derivatives_conv  = IdentityGINConv(eps=self.x_0_conv.eps)

        self.get_mlp()
        self.edge_embedding = BondEncoder(emb_dim=self.in_dim)
    
    def forward(self,data):
        edge_attr = self.edge_embedding(data.edge_attr)
        # x_0 = self.conv(data.x_0, data.edge_index, edge_attr)
        # x_derivatives = self.conv(data.x_derivatives, data.edge_index, torch.zeros_like(data.x_derivatives)) #TODO: improve this so you wont need zeroes like

        
        x_0 = self.x_0_conv(data.x_0, data.edge_index, edge_attr)
        x_derivatives = self.x_derivatives_conv(data.x_derivatives, data.edge_index) #TODO: improve this so you wont need zeroes like
        
        x_0, x_derivatives = apply_mlp(x_0, x_derivatives, self.mlp)
        data.x_0 = x_0
        data.x_derivatives = x_derivatives
        return data


    def get_mlp(self):
        layers = [torch.nn.Linear(self.in_dim, self.in_dim),
            torch.nn.BatchNorm1d(self.in_dim, track_running_stats=self.track_running_stats),
            self.activation(),
            torch.nn.Linear(self.in_dim, self.out_dim)]
        
        if not self.end_linear:
            layers += [torch.nn.BatchNorm1d(self.out_dim, track_running_stats=self.track_running_stats),
                       self.activation()]
            
        self.mlp = torch.nn.Sequential(*layers)







def apply_mlp (x_0, x_derivatives, mlp):
        for layer in mlp:
            if isinstance(layer, torch.nn.Linear):
                x_0,  x_derivatives = linear_update(x_0, x_derivatives, layer)
            elif isinstance(layer, torch.nn.BatchNorm1d):
                x_0, x_derivatives = batchnorm_update(x_0, x_derivatives, layer)
            elif isinstance(layer, torch.nn.SiLU) or isinstance(layer, torch.nn.ReLU):
                x_0, x_derivatives = activation_update(x_0, x_derivatives, layer)
            elif isinstance(layer, torch.nn.Identity):
                pass
            else:
                raise ValueError(f"Unsupported layer type: {type(layer)}")
            
        return x_0, x_derivatives
    
def linear_update(x_0, x_derivatives, layer):
    weight = layer.weight
    x_0 = layer(x_0)
    x_derivatives = torch.matmul(x_derivatives, weight.T)
    return x_0, x_derivatives

def batchnorm_update(x_0, x_derivatives, layer):
    #TODO: add here
    return x_0, x_derivatives

def activation_update(x_0, x_derivatives, layer):
    if isinstance(layer, torch.nn.SiLU):
        x_0, x_derivatives = silu_derivative_update(x_0, x_derivatives)
    elif isinstance(layer, torch.nn.ReLU):
        x_0, x_derivatives = relu_derivative_update(x_0, x_derivatives)
    else:
        print(f" got activation function: {type(layer)}, will not compute derivatives")
    return x_0, x_derivatives

def silu_derivative_update(x_0, x_derivatives):
    #TODO: add here
    return x_0, x_derivatives

def relu_derivative_update(x_0, x_derivatives):
    x_0 = torch.nn.functional.relu(x_0)
    relu_derivatives = (x_0>0).float()
    num_nodes_in_batch, embed_dim = x_0.shape
    relu_derivatives = relu_derivatives.reshape(num_nodes_in_batch, 1, 1, 1, embed_dim)
    # num_nodes_in_batch, num_nodes_in_graph, variable_dim, polynomial_dim, embedding_dim = x_derivatives.shape
    x_derivatives = x_derivatives * relu_derivatives
    return x_0, x_derivatives




class IdentityGINEConv(MessagePassing):
    """
    Gine with default analytic activation
    """

    def __init__(
        self,
        eps: float = 0.0,
        train_eps: bool = False,
        **kwargs,
    ):
        kwargs.setdefault("aggr", "add")
        super().__init__(**kwargs)
        self.initial_eps = eps
        if train_eps:
            self.eps = torch.nn.Parameter(torch.empty(1))
        else:
            self.register_buffer("eps", torch.empty(1))


        self.activation = torch.nn.Identity()
        self.reset_parameters()

    def reset_parameters(self):
        self.eps.data.fill_(self.initial_eps)

    def forward(
        self,
        x: Union[Tensor, OptPairTensor],
        edge_index: Adj,
        edge_attr: OptTensor = None,
        size: Size = None,
    ) -> Tensor:
        if isinstance(x, Tensor):
            x = (x, x)

        # propagate_type: (x: OptPairTensor, edge_attr: OptTensor)
        out = self.propagate(edge_index, x=x, edge_attr=edge_attr, size=size)

        x_r = x[1]
        if x_r is not None:
            out = out + (1 + self.eps) * x_r

        return out

    def message(self, x_j: Tensor, edge_attr: Tensor) -> Tensor:
        # return self.lin(torch.concat([x_j, edge_attr], dim=-1)).relu()
        return self.activation(x_j + edge_attr)

    def __repr__(self) -> str:
        return f"{self.__class__.__name__}(nn={self.nn})"












class IdentityGINConv(MessagePassing):
    
    def __init__(self, eps, **kwargs):
        kwargs.setdefault('aggr', 'add')
        super().__init__(**kwargs)

        self.eps = eps


    def forward(
        self,
        x: Union[Tensor, OptPairTensor],
        edge_index: Adj,
        size: Size = None,
    ) -> Tensor:

        if isinstance(x, Tensor):
            #TODO: check if reshape is necessary
            num_nodes, max_nodes, variable_dim, polynomial_dim, embedding_dim = x.shape
            x = x.reshape(num_nodes ,-1)
            x = (x, x)

        # propagate_type: (x: OptPairTensor)
        out = self.propagate(edge_index, x=x, size=size)

        x_r = x[1]
        if x_r is not None:
            out = out + (1 + self.eps) * x_r

        #TODO: check if reshape is necessary
        out = out.reshape(num_nodes, max_nodes, variable_dim, polynomial_dim, embedding_dim)
        return out


    def message(self, x_j: Tensor) -> Tensor:
        return x_j

    def message_and_aggregate(self, adj_t: Adj, x: OptPairTensor) -> Tensor:
        if isinstance(adj_t, SparseTensor):
            adj_t = adj_t.set_value(None, layout=None)
        return spmm(adj_t, x[0], reduce=self.aggr)







# class IdentityGINConv(MessagePassing):
#     """
#     Gine with default analytic activation
#     """

#     def __init__(
#         self,
#         eps: torch.Tensor,
#         **kwargs,
#     ):
        
#         kwargs.setdefault("aggr", "add")
#         super().__init__(**kwargs)
#         self.initial_eps = eps
#         self.eps = eps
 
#     def forward(
#         self,
#         x: Union[Tensor, OptPairTensor],
#         edge_index: Adj,
#         size: Size = None,
#     ) -> Tensor:
        

#         if isinstance(x, Tensor):
#             #TODO: check if reshape is necessary
#             num_nodes, max_nodes, variable_dim, polynomial_dim, embedding_dim = x.shape
#             x = x.reshape(num_nodes ,-1)
#             x = (x, x)
        
        
#         # propagate_type: (x: OptPairTensor, edge_attr: OptTensor)
#         out = self.propagate(edge_index, x=x, size=size)

#         x_r = x[1]
#         if x_r is not None:
#             out = out + (1 + self.eps) * x_r
            
#         #TODO: check if reshape is necessary
#         out = out.reshape(num_nodes, max_nodes, variable_dim, polynomial_dim, embedding_dim)
#         return out

#     def message(self, x_j: Tensor) -> Tensor:
#         # return self.lin(torch.concat([x_j, edge_attr], dim=-1)).relu()
#         return self.activation(x_j)

   








