from graphgps.derivative_computation import DenseGradientExtractor, FirstGNNnetwork
import torch
import math
import torch.nn.functional as F
import gc
from torch_geometric.utils import to_dense_adj





def derivative_initialization(batch, cfg, max_nodes):
    
    # Get number of nodes in batch
    device = "cpu"

    # device = "cuda" if torch.cuda.is_available() else "cpu"
    num_nodes = batch.x.shape[0]
    variable_dim = cfg.derivative_encoder.emb_dim
    polynomial_dim = cfg.derivative_encoder.max_degree

    x_derivative_index = torch.arange(batch.x.shape[0], device=device)
    

    M = torch.zeros(num_nodes, max_nodes, device=device)
    M[x_derivative_index, x_derivative_index] = 1.0


    diag_mask = torch.zeros((variable_dim, polynomial_dim, variable_dim), device=device)
    diag_mask[torch.arange(variable_dim, device=device), 0, torch.arange(variable_dim, device=device)] = 1.

    x_derivatives = M.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1) * diag_mask

    batch.x_derivatives =  x_derivatives
    
    batch.x_derivative_mask = x_derivative_index.reshape(-1,1)

    return batch




def compute_derivatives(data, cfg, derivative_extractor):
    device = "cuda" if torch.cuda.is_available() else "cpu"
    with torch.no_grad():
        data = derivative_extractor.compute_node_to_node_derivatives(data)
        num_layers_first = cfg.derivative_preprocessing.num_layers
        centrality_normalization = torch.tensor([math.factorial(i) for i in range(1, num_layers_first+1)],
                                                    dtype=torch.float32, device=device).reshape(1,1,1,1, num_layers_first)
        data.x_intermediate_node_to_node_derivatives = data.x_intermediate_node_to_node_derivatives / centrality_normalization
    return data


def derivative_preprocessing(data, cfg, derivative_extractor):

    if hasattr(data, 'num_nodes'):
        num_nodes = data.num_nodes  # Explicitly given number of nodes, e.g. ogbg-ppa
    else:
        num_nodes = data.x.shape[0]  # Number of nodes, including disconnected nodes.
    
    data = compute_derivatives(data, cfg, derivative_extractor)
    
    #node to node centrality
    # centralities = data.x_intermediate_node_to_node_derivatives.sum(dim=(1,2,3,4))

    #node to out centrality
    centralities = data.derivatives.sum(dim=(0,2,3,4,5))


    order = torch.argsort(centralities)
    data.community = order
    
       
     #TODO: understand this
    num_samples = 2
    lowest = min(num_nodes, num_samples - 1) # 
    indices = torch.zeros([num_samples - 1], dtype=torch.int64)
    if lowest > 0:
        indices[:lowest] = data.community[:lowest] # 
    coloring = F.one_hot(indices.to(torch.int64), num_classes=int(num_nodes)).T #.sum(dim=1).unsqueeze(dim=1)
    # append original graph too
    x = torch.zeros(num_nodes).unsqueeze(dim=1)
    coloring = torch.cat((x, coloring), 1)
    data.c_samples = coloring
    
    n,_, _, _,_ = data.x_intermediate_node_to_node_derivatives.shape
    data.x_intermediate_node_to_node_derivatives = data.x_intermediate_node_to_node_derivatives.reshape(n,-1)
    
    # Clean up intermediate attributes that are no longer needed
    delattr(data, 'derivatives')
    delattr(data, 'adjacency')
    
    data = to_cpu(data)
    return data

def to_cpu(data):
    for key in data.keys():
        if torch.is_tensor(data[key]):
            data[key] = data[key].to("cpu")
    return data



# def derivative_initialization(batch, cfg):
    
#     # Get number of nodes in batch
#     device = "cpu"

#     # device = "cuda" if torch.cuda.is_available() else "cpu"
#     num_nodes = batch.x.shape[0]
#     variable_dim = cfg.derivative_encoder.emb_dim
#     polynomial_dim = cfg.derivative_encoder.max_degree
    
#     # Calculate max nodes per graph
#     _, mask = torch.unique(batch.batch, return_counts=True)
#     max_nodes = mask.max().item()


#     M = torch.zeros(num_nodes, max_nodes, device=device)

#     _, counts = torch.unique(batch.batch, return_counts=True)
#     cumsum = torch.cat([torch.tensor([0]), counts.cumsum(0)[:-1]])
#     node_indices = torch.arange(batch.x.shape[0]) - cumsum[batch.batch]

#     i1 = torch.arange(batch.x.shape[0], device=device)
#     i2 = node_indices
#     M[i1, i2] = 1.0


#     diag_mask = torch.zeros((variable_dim, polynomial_dim, variable_dim), device=device)
#     diag_mask[torch.arange(variable_dim, device=device), 0, torch.arange(variable_dim, device=device)] = 1.

#     x_derivatives = M.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1) * diag_mask

#     batch.x_derivatives =  x_derivatives
#     batch.x_derivative_index = node_indices
#     return batch
