import torch

# from torch.linalg import svd
from sklearn.decomposition import PCA
import torch.nn.functional as F
# from torch.linalg import svd
from numpy.linalg import svd
# from sklearn.decomposition import TruncatedSVD

from typing import Dict, List, Optional, Tuple
from torch import Tensor

import numpy as np


from torch_geometric.utils import degree, to_dense_adj, to_networkx

import torch_geometric.transforms as T

import os
import os.path as osp
import shutil
from typing import Callable, List, Optional
from random import sample
import tqdm
from sklearn import preprocessing
from .split import split_tu_dataset
import time

from .align import align_graph_feat

from sklearn.neighbors import kneighbors_graph
from sklearn.neighbors import radius_neighbors_graph
from sklearn.utils.extmath import randomized_svd

from typing import Optional, List

from random import sample
import networkx as nx
def dataset_info(dataset):
    print()
    print(f'Dataset: {dataset}:')
    print('====================')
    print(f'Number of graphs: {len(dataset)}')
    print(f'Number of features: {dataset.num_features}')
    print(f"Number of node attributes: {dataset.sizes['num_node_attributes']}")
    print(f'Number of classes: {dataset.num_classes}')
    # print(f"Number of total nodes: {dataset.num_nodes.sum().item()}")
    print(f"dataset.x.shape: {dataset.x.shape}")
    data = dataset[0]
    print()
    print(data)
    print('=============================================================')
    # Gather some statistics about the first graph.
    print(f'Number of nodes: {data.num_nodes}')
    print(f'Number of edges: {data.num_edges}')
    print(f'Average node degree: {data.num_edges / data.num_nodes:.2f}')
    print(f'Has isolated nodes: {data.has_isolated_nodes()}')
    print(f'Has self-loops: {data.has_self_loops()}')
    print(f'Is undirected: {data.is_undirected()}')

    for i in range(dataset.num_classes):
        print(f"Class {i}:, number of observations: {(dataset.y == i).sum().item()}")


def cat(seq: List[Optional[Tensor]]) -> Optional[Tensor]:
    values = [v for v in seq if v is not None]
    values = [v for v in values if v.numel() > 0]
    values = [v.unsqueeze(-1) if v.dim() == 1 else v for v in values]
    return torch.cat(values, dim=-1) if len(values) > 0 else None



def decomp_kernel_mtx(adj: torch.Tensor, d: int, mode: str = 'U', return_uv: bool = False):

    SVD_BOUND = 2000
    if adj.shape[-1] < d:
        adj = F.pad(adj, (0, d - adj.shape[-1], 0, d - adj.shape[-1]), mode='constant', value=0)
    if adj.shape[-1] > SVD_BOUND or adj.shape[0] > SVD_BOUND: 

        U, S, V = randomized_svd(adj.cpu().numpy(), n_components=d, n_oversamples=50, random_state=0)
        U = torch.tensor(U, device=adj.device)
        S = torch.tensor(S, device=adj.device)
        S = torch.diag_embed(S)
        V = torch.tensor(V, device=adj.device)
    else:
        U, S, V = svd(adj, full_matrices=False)
        U = torch.tensor(U[:, :d])
        S = torch.diag_embed(torch.tensor(S[:d]))
        V = torch.tensor(V[:d, :])

    if mode == 'V':
        if return_uv:
            return torch.mm(torch.sqrt(S), V).T, U, S, V
        else:
            return torch.mm(torch.sqrt(S), V).T   

    else:
        if return_uv:
            return torch.mm(U, torch.sqrt(S)), U, S, V
        else:
            return torch.mm(U, torch.sqrt(S))


def get_gaussian_feature(
        node_features: torch.tensor, 
        scales: list, decom_dim: int, dist_mean: float=None, 
        src_feat: torch.tensor=None, 
        return_uv: bool = False
        ):
    MAX_NODE_NUM = 9_0000
    

    device = 'cpu'
    node_features = node_features.to(device)
    node_num = node_features.shape[0]

    dist_mode = 'nystrom' if node_num > MAX_NODE_NUM and src_feat is None else 'normal'
    mode = 'V' if dist_mode == 'nystrom' else 'U'


    chunk = True if node_num > MAX_NODE_NUM and src_feat != None else False
    # src_feat is None means getting train-set feature
    if src_feat is None:
        src_feat = node_features
    distmtx = compute_dist_mtx(node_features, src_feat, mode=dist_mode)
    if dist_mean is None:
        dist_mean = distmtx.mean()

    gau_feats = []
    if return_uv:
        Us, Ss, Vs = [], [], []
    for scale in scales:
        if chunk:
            adj = compute_kernel_mtx_chunked(
                node_features, src_feat, scale, dist_mean, dist_mtx=distmtx, chunk_size=100
            )
        
        else:
            adj = compute_kernel_mtx(
                                node_features, src_feat, 
                                scale=scale, dist_mean=dist_mean, 
                                dist_mtx=distmtx)        
        if return_uv:
            feat, U, S, V = decomp_kernel_mtx(adj, decom_dim, mode, return_uv=return_uv)
            Us.append(U)
            Ss.append(S)
            Vs.append(V)
        else:
            feat = decomp_kernel_mtx(adj, decom_dim, mode, return_uv=return_uv)
        gau_feats.append(feat)
    del adj
    gau_feats = torch.stack(gau_feats, dim=0)
    if return_uv:
        Us = torch.stack(Us, dim=0)
        Ss = torch.stack(Ss, dim=0)
        Vs = torch.stack(Vs, dim=0)
        return gau_feats, Us, Ss, Vs, dist_mean
    return gau_feats



def compute_kernel_mtx(
        node_features: torch.tensor, src_feat: torch.tensor, 
        scale: float, dist_mean: float=None, 
        mode='normal', dist_mtx=None):
    
    if dist_mtx is None:
        dist_mtx = compute_dist_mtx(node_features, src_feat, mode)
    if torch.all(dist_mtx == 0):
        return torch.zeros_like(dist_mtx)
    
    if dist_mean is None:
        dist_mean = dist_mtx.mean()
    dist_mtx = dist_mtx ** 2
    sigmas = (dist_mean ** 2) * scale
    dist_mtx = torch.exp(-dist_mtx / (2*sigmas))
    return dist_mtx


def compute_dist_mtx(tar_feat: torch.tensor, src_feat: torch.tensor, mode: str):
    M = 2000
    if mode == 'nystrom':
        idx = sample(list(range(0, tar_feat.shape[0])), M)
        dist_mtx = torch.cdist(tar_feat[idx, :], src_feat, p=2)   
        
    else:
        dist_mtx = torch.cdist(tar_feat, src_feat, p=2)
    return dist_mtx

def compute_kernel_mtx_chunked(node_attributes, src_feat, scale, dist_mean, dist_mtx=None, chunk_size=1000):

    n = node_attributes.shape[0]
    m = src_feat.shape[0]
    device = node_attributes.device
    
    result = torch.zeros((n, m), device=device)
    
    for i in range(0, n, chunk_size):
        end_i = min(i + chunk_size, n)
        if dist_mtx is None:    
            chunk_node_attr = node_attributes[i:end_i]
            dists = torch.cdist(chunk_node_attr, src_feat, p=2)  # [chunk_size, m]
        else:
            dists = dist_mtx[i:end_i, :]
        sigma = (dist_mean ** 2) * scale

        kernel_values = torch.exp(-(dists**2) / (2 * sigma))
        
        result[i:end_i] = kernel_values
   
        del dists, kernel_values
        torch.cuda.empty_cache()

    return result



def get_adj_feature(edge_index:torch.tensor, node_num: int, dim: int):
    adj = to_dense_adj(edge_index, max_num_nodes=node_num)
    adj = adj.squeeze(0)
    adj += torch.eye(adj.shape[0])
    adj = decompose_adj(adj, dim)
    adj = sign_flip(adj)
    return adj
        

def pad_attr(attr, num):
    if attr.shape[-1] % num != 0:
        attr = F.pad(attr, (0, num - (attr.shape[-1] % num)), mode='constant', value=0)
    return attr

def sign_flip(col):
    feat_pos = torch.sum(col>=0, dim=0)
    feat_neg = torch.sum(col<=0, dim=0)
    diff = feat_pos - feat_neg
    sign = torch.sign(diff)
    return col * sign
     


def permute_edges(data):

    node_num, _ = data.x.size()
    _, edge_num = data.edge_index.size()
    permute_num = int(edge_num / 10)

    edge_index = data.edge_index.transpose(0, 1).numpy()

    idx_add = np.random.choice(node_num, (permute_num, 2))
    # idx_add = [[idx_add[0, n], idx_add[1, n]] for n in range(permute_num) if not (idx_add[0, n], idx_add[1, n]) in edge_index]

    # edge_index = np.concatenate((np.array([edge_index[n] for n in range(edge_num) if not n in np.random.choice(edge_num, permute_num, replace=False)]), idx_add), axis=0)
    # edge_index = np.concatenate((edge_index[np.random.choice(edge_num, edge_num-permute_num, replace=False)], idx_add), axis=0)
    edge_index = edge_index[np.random.choice(edge_num, edge_num-permute_num, replace=False)]
    # edge_index = [edge_index[n] for n in range(edge_num) if not n in np.random.choice(edge_num, permute_num, replace=False)] + idx_add
    data.edge_index = torch.tensor(edge_index).transpose_(0, 1)

    return data

def subgraph(data):

    node_num, _ = data.x.size()
    _, edge_num = data.edge_index.size()
    sub_num = int(node_num * 0.2)

    edge_index = data.edge_index.numpy()

    idx_sub = [np.random.randint(node_num, size=1)[0]]
    idx_neigh = set([n for n in edge_index[1][edge_index[0]==idx_sub[0]]])

    count = 0
    while len(idx_sub) <= sub_num:
        count = count + 1
        if count > node_num:
            break
        if len(idx_neigh) == 0:
            break
        sample_node = np.random.choice(list(idx_neigh))
        if sample_node in idx_sub:
            continue
        idx_sub.append(sample_node)
        idx_neigh.union(set([n for n in edge_index[1][edge_index[0]==idx_sub[-1]]]))

    idx_drop = [n for n in range(node_num) if not n in idx_sub]
    idx_nondrop = idx_sub
    idx_dict = {idx_nondrop[n]:n for n in list(range(len(idx_nondrop)))}

    # data.x = data.x[idx_nondrop]
    edge_index = data.edge_index.numpy()

    adj = torch.zeros((node_num, node_num))
    adj[edge_index[0], edge_index[1]] = 1
    adj[idx_drop, :] = 0
    adj[:, idx_drop] = 0
    edge_index = adj.nonzero().t()

    data.edge_index = edge_index



    # edge_index = [[idx_dict[edge_index[0, n]], idx_dict[edge_index[1, n]]] for n in range(edge_num) if (not edge_index[0, n] in idx_drop) and (not edge_index[1, n] in idx_drop)]
    # edge_index = [[edge_index[0, n], edge_index[1, n]] for n in range(edge_num) if (not edge_index[0, n] in idx_drop) and (not edge_index[1, n] in idx_drop)] + [[n, n] for n in idx_nondrop]
    # data.edge_index = torch.tensor(edge_index).transpose_(0, 1)

    return data


def mask_nodes(data):

    node_num, feat_dim = data.x.size()
    mask_num = int(node_num / 10)

    idx_mask = np.random.choice(node_num, mask_num, replace=False)
    data.x[idx_mask] = torch.tensor(np.random.normal(loc=0.5, scale=0.5, size=(mask_num, feat_dim)), dtype=torch.float32)

    return data

def drop_nodes(data):

    node_num, _ = data.x.size()
    _, edge_num = data.edge_index.size()
    drop_num = int(node_num / 10)

    idx_drop = np.random.choice(node_num, drop_num, replace=False)
    idx_nondrop = [n for n in range(node_num) if not n in idx_drop]
    idx_dict = {idx_nondrop[n]:n for n in list(range(node_num - drop_num))}

    # data.x = data.x[idx_nondrop]
    edge_index = data.edge_index.numpy()

    adj = torch.zeros((node_num, node_num))
    adj[edge_index[0], edge_index[1]] = 1
    adj[idx_drop, :] = 0
    adj[:, idx_drop] = 0
    edge_index = adj.nonzero().t()

    data.edge_index = edge_index

    # edge_index = [[idx_dict[edge_index[0, n]], idx_dict[edge_index[1, n]]] for n in range(edge_num) if (not edge_index[0, n] in idx_drop) and (not edge_index[1, n] in idx_drop)]
    # edge_index = [[edge_index[0, n], edge_index[1, n]] for n in range(edge_num) if (not edge_index[0, n] in idx_drop) and (not edge_index[1, n] in idx_drop)] + [[n, n] for n in idx_nondrop]
    # data.edge_index = torch.tensor(edge_index).transpose_(0, 1)

    return data


def zstandardize(node_features: torch.tensor):
    if node_features.shape[0] == 1:
        return node_features
    mean = node_features.mean(dim=0, keepdim=True)
    std = node_features.std(dim=0, keepdim=True)
    return (node_features - mean) / (std+1e-7)

def get_eps(dataset, node_slice, gamma = 0.5):
    # assert gamma>0 and gamma<=1, "Gamma must be in (0, 1]"
    sum_dist_list = []
    dist_num = []
    for i in range(len(node_slice) - 1):
        start_inx = node_slice[i]
        end_inx = node_slice[i+1]
        n_nodes = end_inx - start_inx
        dist_num.append((n_nodes * (n_nodes - 1)) / 2)

        node_features = dataset[start_inx:end_inx]
        dist_matrix = torch.norm(node_features[:, None] - node_features, dim=-1, p=2)
        # print(dist_matrix.shape)
        sum_dist = dist_matrix.sum(dim=-1).sum(-1) / 2
        sum_dist_list.append(sum_dist.item())
    sum_dist_list = torch.tensor(sum_dist_list)
    dist_num = torch.tensor(dist_num)
    mean_dist = sum_dist_list.float().sum() / dist_num.float().sum()
    return mean_dist.item() * gamma


def decompose_adj(adj: torch.Tensor, d: int, mode: str = None):
    # S_flag = True
    node_num = adj.shape[0]
    if adj.shape[0] < d:

        U, S, V = svd(adj, compute_uv=True, hermitian=True)
        U, V = torch.tensor(U), torch.tensor(V)
        S = torch.diag_embed(torch.tensor(S))
        if mode == 'V':
            feat = F.pad(torch.mm(torch.sqrt(S), V).T, (0, d - adj.shape[0]), mode='constant', value=0)
        else: 
            feat = F.pad(torch.mm(U, torch.sqrt(S)), (0, d - adj.shape[0]), mode='constant', value=0)
        return feat
        # adj = F.pad(adj, (0, d - adj.shape[0], 0, d - adj.shape[0]), mode='constant', value=0)
    
    U, S, V = svd(adj, full_matrices=False, compute_uv=True, hermitian=True)
    U = torch.tensor(U[:, :d])
    S = torch.diag_embed(torch.tensor(S[:d]))
    V = torch.tensor(V[:d, :])
    
    if mode == 'V':
        return torch.mm(torch.sqrt(S), V).T
    else:
        return torch.mm(U, torch.sqrt(S))


def get_knn_feature(node_features:torch.tensor, k: int, decom_dim: int):
    n_nodes = node_features.shape[0]

    if n_nodes == 1:
        adj = torch.tensor([[0.]])
    else:
        n_neighbor = k if n_nodes > k else n_nodes - 1
        adj = kneighbors_graph(node_features.cpu().numpy(), n_neighbors=n_neighbor, mode='distance')
        # non_zero_counts = np.count_nonzero(adj.toarray(), axis=1)
        # print(f'Non-zero elements per row: {non_zero_counts}')

        # from csr_matrix to dense matrix
        adj_dense = adj.toarray()
        adj = torch.tensor(adj_dense)
    
    knn_feature = decomp_kernel_mtx(adj, decom_dim)

    # print(f'KNN tensor shape: {knn_tensor.shape}')
    return knn_feature


def get_eps_feature(node_features: torch.tensor, eps: float, decom_dim: int):
    n_nodes = node_features.shape[0]

    adj = radius_neighbors_graph(node_features.cpu().numpy(), radius=eps, mode='distance')
    
    # from csr_matrix to dense matrix
    adj_tensor = torch.tensor(adj.toarray())
    # print(f"eps graph adj: {adj_tensor}")

    eps_feature = decomp_kernel_mtx(adj_tensor, decom_dim)

    # print(f'eps tensor shape: {eps_tensor.shape}')
    return eps_feature

def compute_landing_probs(adj_matrix, steps):
    adj_matrix.squeeze_(0)
    num_nodes = adj_matrix.shape[0]
    adj_matrix += torch.diag(torch.ones(num_nodes))
    degree_matrix = torch.diag(adj_matrix.sum(dim=1))

    # print(degree_matrix.shape)
    P = torch.mm(torch.linalg.inv(degree_matrix), adj_matrix)
    
    S = torch.zeros((num_nodes, len(steps)))
    initial_distribution = torch.ones(num_nodes)
    P_power = torch.ones(num_nodes)
    for step_idx, t in enumerate(steps):
        P_power = torch.matrix_power(P, t)
        S[:, step_idx] = P_power.sum(dim=1)
    return S    

def get_closeness_centrality(graph, num_nodes):
    # 转换为networkx图
    graph_nx = to_networkx(graph, to_undirected=True)
    
    # 计算closeness centrality
    closeness_dict = nx.closeness_centrality(graph_nx)
    
    closeness_values = []
    for i in range(num_nodes):
        closeness_values.append(closeness_dict.get(i, 0.0))

    closeness_tensor = torch.tensor(closeness_values, dtype=torch.float).unsqueeze(1)
    
    return closeness_tensor

def get_betweenness_centrality(graph, num_nodes):
    # 转换为networkx图
    graph_nx = to_networkx(graph, to_undirected=True)
    
    # 计算betweenness centrality
    betweenness_dict = nx.betweenness_centrality(graph_nx, normalized=True)
    
    betweenness_values = []
    for i in range(num_nodes):
        betweenness_values.append(betweenness_dict.get(i, 0.0))

    betweenness_tensor = torch.tensor(betweenness_values, dtype=torch.float).unsqueeze(1)
    
    return betweenness_tensor

def get_eigenvector_centrality(graph, num_nodes):
    graph_nx = to_networkx(graph, to_undirected=True)
    print(nx.adjacency_matrix(graph_nx))
    eigen_dict = nx.eigenvector_centrality(graph_nx, max_iter=1000)

    eigen_cen_values = []
    for i in range(num_nodes):
        eigen_cen_values.append(eigen_dict.get(i, 0.0))

    eigen_cen_values = torch.tensor(eigen_cen_values, dtype=torch.float).unsqueeze(1)
    
    return eigen_cen_values