import torch 

from utils import get_diffusion, torch_adj_to_nx

from GraphRicciCurvature.OllivierRicci import OllivierRicci
from count_cycles import count_cycles

from tqdm import tqdm

# from torch_sparse import spspmm

def get_node_degree_target(data):
    # output: N x 1 tensor representing the degree per node

    A = data.adj
    device = A.device
    dtype = A.values().dtype

    deg = torch.sparse.sum(A, dim=1).to_dense().to(dtype)
    return deg.unsqueeze(-1)


def get_clustering_coeff_target(data, symmetrize=True):
    """
    Compute the local clustering coefficient for each node of an (undirected) graph.

    Args:
        data: a PyG Data-like object with data.adj = NxN sparse adjacency (0/1, no self-loops preferred).
        symmetrize (bool): if True, force A to be symmetric 0/1 and drop self-loops.

    Returns:
        N x 1 dense float tensor with C_i = (A^3)_{ii} / (k_i * (k_i - 1)), with 0 when k_i < 2.
    """
    A = data.adj
    N = A.size(0)
    device = A.device
    dtype = A.values().dtype

    # Binarize and (optionally) symmetrize, remove self-loops
    idx = A.indices()
    val = torch.ones(idx.size(1), device=device, dtype=dtype)

    # degree 
    deg = torch.sparse.sum(A, dim=1).to_dense().to(dtype)

    # triangles per node: t_i = (A^3)_{ii} / 2
    A3 = (A @ A @ A).to_dense()

    tri_diag = torch.diagonal(A3, dim1=-2, dim2=-1)

    triangles = tri_diag / 2  # each triangle contributes 2 to (A^3)_{ii} in undirected graphs

    # clustering coefficient
    denom = deg * (deg - 1)
    C = torch.zeros_like(deg)
    nz = denom > 0
    # Using C_i = (A^3)_{ii} / (k_i (k_i - 1)) which equals (2*triangles)/denom
    C[nz] = (2 * triangles[nz]) / denom[nz]


    return C.unsqueeze(-1)

def get_rwse_target(data, config):
    # Computes RWSE, first introduced in Dwivedi et al. 
    # Self-walk probabilities for each node, given k steps (for a list of such k to include)
    adj = data.adj 
    P = get_diffusion(adj)

    max_ksteps = config.rwse_target_settings["max_steps"]
    rws = []

    Pk = P.clone().detach().matrix_power(2)
    for k in range(2, max_ksteps + 1):
    
        rws.append(torch.diagonal(Pk, dim1=-2, dim2=-1))
        Pk = Pk @ P

    out = torch.stack(rws).transpose(0, 1)

    
    return out 

def get_curvature_target(data, config):
    # Computes the Olivier-Ricci curvature per node 
    A = data.adj
    N = A.size(0)
    device = A.device
    dtype = A.values().dtype
    alpha = config.curvature_target_settings["alpha"]
    G = torch_adj_to_nx(A)

    orc = OllivierRicci(G, alpha=alpha, verbose="INFO")
    orc.compute_ricci_curvature()

    
    node_agg_curvatures = torch.zeros(N, device=device, dtype=dtype)
    for i in range(N):
        curvature_sum = 0
        num_neighbors = 0
        for neighbor in orc.G[i].keys():
            curvature_sum += orc.G[i][neighbor]['ricciCurvature']
            num_neighbors += 1

        if num_neighbors > 0:
            node_agg_curvatures[i] = curvature_sum / num_neighbors
    

    return node_agg_curvatures

def get_cycle_target(data, config):
    
    max_length = config.cycle_target_settings["max_length"]

    klist = list(range(2, max_length))
    cycles = count_cycles(klist, data)
    return cycles
  


def get_lap_eval_target(data, config):
    A = data.adj
    N = A.size(0)
    device = A.device
    dtype = A.values().dtype
    num_evals = config.lap_eval_target_settings["num_evals"]

    evals = data.eigvals[:num_evals]

    return evals.unsqueeze(0)


def normalize_alternative_targets(dataset, config):
    
    data_counts = {}
    data_sums = {}
    data_sums2 = {}
    data_sample = dataset[0]

    COMBINED_TARGETS = {} 
    print("computing means and variances...")
    for key in tqdm(config.lambda_alt_targets.keys()):
        emb_length = getattr(data_sample, key).shape[-1]
        data_counts[key] = torch.zeros(emb_length)
        data_sums[key] = torch.zeros(emb_length)
        data_sums2[key] = torch.zeros(emb_length)
        COMBINED_TARGETS[key] = torch.zeros(0, emb_length)

    # computing sums and squared sums per-dim per-target 
    for data in dataset:
        for key in config.lambda_alt_targets.keys():
            data_counts[key] += getattr(data, key).shape[0]
            data_sums[key] += torch.sum(getattr(data, key), dim=0)
            data_sums2[key] += torch.sum(torch.pow(getattr(data, key), 2), dim=0)
            
    means = {}
    variances = {}
    means2 = {}


    print("Normalizing...")
    for key in tqdm(config.lambda_alt_targets.keys()):

        means[key] = data_sums[key] / data_counts[key]
        means2[key] = data_sums2[key] / data_counts[key]
        variances[key] = means2[key] - torch.pow(means[key], 2)
        print(key, "means before", means[key])
        print(key, "variances before", means[key])



    new_data_list = []


    for data in dataset:
        for key in config.lambda_alt_targets.keys():
            

            orig_val = getattr(data, key)
            new_val = (orig_val - means[key]) / torch.sqrt(variances[key])
            setattr(data, key, new_val) 

            COMBINED_TARGETS[key] = torch.cat((COMBINED_TARGETS[key], getattr(data, key)), dim=0)
            
        new_data_list.append(data)

            
    
    for key in config.lambda_alt_targets.keys():

        print(key, "mean after", COMBINED_TARGETS[key].mean(dim=0))
        print(key, "std after", COMBINED_TARGETS[key].std(dim=0))

    return new_data_list

def compute_alternative_targets(data, config):

    # data.alternative_targets = torch.zeros(data.num_nodes, 0).to(data.x.device)
    
    data.node_degree_target = get_node_degree_target(data)

    data.clustering_coeff_target = get_clustering_coeff_target(data)

    data.rwse_target = get_rwse_target(data, config)

    # data.curvature_target = torch.zeros(data.x.shape[0])
    # data.curvature_target = get_curvature_target(data, config)

    data.cycle_target = get_cycle_target(data, config)

    data.lap_eval_target = get_lap_eval_target(data, config)


    return data
    
