
import numpy as np
import copy
import scipy.sparse as sp
from scipy.sparse import linalg


data_root_path = 'none'
if data_root_path == 'none':
    raise ValueError('Please replace `root_path` into the absolute path of `raw_data` folder!')


pems_graph_file_path_dict ={
            'PEMS04':f'{data_root_path}/graph/PEMS04.csv',
            'PEMS08':f'{data_root_path}/graph/PEMS08.csv',
            'PEMS07':f'{data_root_path}/graph/PEMS07.csv',
            'PEMS03':f'{data_root_path}/graph/PEMS03.csv',
        }
num_nodes_dict ={
        'PEMS04':307,
        'PEMS08':170,
        'PEMS07':883,
        'PEMS03':358,
    }
graph_file_path_dict = {
    'nrel_al':f'{data_root_path}/graph/nrel_al_adj_mx_0.9.npy',
    'SeaLoop':f'{data_root_path}/graph/SeaLoop_adj_mx_01.npy',
    }
node_dist_file_path_dict = {
    'nrel_al':f'{data_root_path}/graph/nrel_al_dist.npy',
    'SeaLoop':f'{data_root_path}/graph/SeaLoop_adj_mx_01.npy',
}


def sym_adj(adj):
    """Symmetrically normalize adjacency matrix."""
    adj = sp.coo_matrix(adj)
    rowsum = np.array(adj.sum(1))
    d_inv_sqrt = np.power(rowsum, -0.5).flatten()
    d_inv_sqrt[np.isinf(d_inv_sqrt)] = 0.
    d_mat_inv_sqrt = sp.diags(d_inv_sqrt)
    return adj.dot(d_mat_inv_sqrt).transpose().dot(d_mat_inv_sqrt).astype(np.float32).todense()


def asym_adj(adj):
    adj = sp.coo_matrix(adj)
    rowsum = np.array(adj.sum(1)).flatten()
    d_inv = np.power(rowsum, -1).flatten()
    d_inv[np.isinf(d_inv)] = 0.
    d_mat = sp.diags(d_inv)
    return d_mat.dot(adj).astype(np.float32).todense()


def calculate_scaled_laplacian(adj_mx, lambda_max=2, undirected=True):
    if undirected:
        adj_mx = np.maximum.reduce([adj_mx, adj_mx.T])
    lap = calculate_normalized_laplacian(adj_mx)
    if lambda_max is None:
        lambda_max, _ = linalg.eigsh(lap, 1, which='LM')
        lambda_max = lambda_max[0]
    lap = sp.csr_matrix(lap)
    m, _ = lap.shape
    identity = sp.identity(m, format='csr', dtype=lap.dtype)
    lap = (2 / lambda_max * lap) - identity
    return lap.astype(np.float32).todense()


def calculate_normalized_laplacian(adj):
    adj = sp.coo_matrix(adj)
    d = np.array(adj.sum(1))
    isolated_point_num = np.sum(np.where(d, 0, 1))
    # print(f"Number of isolated points: {isolated_point_num}")
    d_inv_sqrt = np.power(d, -0.5).flatten()
    d_inv_sqrt[np.isinf(d_inv_sqrt)] = 0.
    d_mat_inv_sqrt = sp.diags(d_inv_sqrt)
    normalized_laplacian = sp.eye(adj.shape[0]) - adj.dot(d_mat_inv_sqrt).transpose().dot(d_mat_inv_sqrt).tocoo()
    return normalized_laplacian, isolated_point_num

def cal_lape(adj_mx, lape_dim=8):
    L, isolated_point_num = calculate_normalized_laplacian(adj_mx)
    EigVal, EigVec = np.linalg.eig(L.toarray())
    idx = EigVal.argsort()
    EigVal, EigVec = EigVal[idx], np.real(EigVec[:, idx])

    laplacian_pe = EigVec[:, isolated_point_num + 1: lape_dim + isolated_point_num + 1]
    return laplacian_pe

def compute_clustering_coefficient(A):
    N = A.shape[0]
    C = np.zeros((N,1))
    for i in range(N):
        neighbors = np.where(A[i] > 0)[0]
        d = len(neighbors)
        if d < 2:
            C[i] = 0.0
            continue
        sub = A[np.ix_(neighbors, neighbors)]
        actual = sub.sum()  
        possible = d * (d - 1)
        C[i] = actual / possible
    return C

def thresholded_gaussian_kernel(x, theta=None, threshold=None, threshold_on_input=False):
    if theta is None:
        theta = np.std(x)
    weights = np.exp(-np.square(x / theta))
    if threshold is not None:
        mask = x > threshold if threshold_on_input else weights < threshold
        weights[mask] = 0.
    return weights



def get_adjacency_matrix(distance_df_filename, num_of_vertices, id_filename=None):
    if 'npy' in distance_df_filename:
        adj_mx = np.load(distance_df_filename)
        return adj_mx, None
    else:
        import csv
        A = np.zeros((int(num_of_vertices), int(num_of_vertices)),
                     dtype=np.float32)
        distanceA = np.zeros((int(num_of_vertices), int(num_of_vertices)),
                            dtype=np.float32)
        if id_filename:
            with open(id_filename, 'r') as f:
                id_dict = {int(i): idx for idx, i in enumerate(f.read().strip().split('\n'))}  # 把节点id（idx）映射成从0开始的索引
            with open(distance_df_filename, 'r') as f:
                f.readline()
                reader = csv.reader(f)
                for row in reader:
                    if len(row) != 3:
                        continue
                    i, j, distance = int(row[0]), int(row[1]), float(row[2])
                    A[id_dict[i], id_dict[j]] = 1
                    distanceA[id_dict[i], id_dict[j]] = distance
            return A, distanceA
        else:
            with open(distance_df_filename, 'r') as f:
                f.readline()
                reader = csv.reader(f)
                for row in reader:
                    if len(row) != 3:
                        continue
                    i, j, distance = int(row[0]), int(row[1]), float(row[2])
                    A[i, j] = 1
                    distanceA[i, j] = distance
            return A, distanceA

def cal_adj(adj, adjtype='scalap'):
    if adjtype == 'auto':
        if (adj==adj.T).all():
            adjtype='transition'
        else:
            adjtype='doubletransition'

    # d = np.sum(A, axis=1)
    # sinvD = np.sqrt(np.mat(np.diag(d)).I)
    # A_mx = np.mat(np.identity(A.shape[0]) + sinvD * A * sinvD)
    # adj_mx = copy.deepcopy(A_mx)

    if adjtype == "scalap":
        adj_mx = [calculate_scaled_laplacian(adj)]
    elif adjtype == "normlap":
        adj_mx,_ = [calculate_normalized_laplacian(adj).astype(np.float32).todense()]
    elif adjtype == "symnadj":
        adj_mx = [sym_adj(adj)]
    elif adjtype == "transition":
        adj_mx = [asym_adj(adj)]
    elif adjtype == "doubletransition":
        adj_mx = [asym_adj(adj), asym_adj(np.transpose(adj))]
    elif adjtype == "identity":
        adj_mx = [np.diag(np.ones(adj.shape[0])).astype(np.float32)]
    else:
        assert 0, "adj type not defined"
    return adj_mx

def load_adj_matrix(dataset_name):
    if dataset_name in pems_graph_file_path_dict.keys():
        graph_file_path = pems_graph_file_path_dict[dataset_name]
        num_nodes = num_nodes_dict[dataset_name]
        A, full_distance = get_adjacency_matrix(
            distance_df_filename= graph_file_path,
            num_of_vertices=num_nodes)
        A_link = A + np.eye(A.shape[0])
        A_Distance = thresholded_gaussian_kernel(full_distance,threshold=None) #全部保留
        A_Distance = A_Distance + np.eye(A.shape[0])
        
    elif dataset_name in ['SeaLoop','nrel_al']:
        A_Distance = np.load(graph_file_path_dict[dataset_name]).astype(float)
        full_distance = np.load(node_dist_file_path_dict[dataset_name])
        A_link = np.zeros_like(A_Distance)
        A_link[A_Distance>0]=1.0
        if A_link[0,0]<1:
            A_link = A_link + np.eye(A_link.shape[0])
            A_Distance = A_Distance + np.eye(A_link.shape[0])

    elif dataset_name in ['EPeMS']:
        full_distance = np.load(f'{data_root_path}/graph/gla_dis.npy')
        finite_dist = full_distance.reshape(-1)
        finite_dist = finite_dist[~np.isinf(finite_dist)]
        sigma = finite_dist.std()
        A_Distance = np.exp(-np.square(full_distance / sigma))
        A_Distance[A_Distance < 0.1] = 0.
        A_link = np.zeros_like(A_Distance)
        A_link[A_Distance>0]=1.0
        if A_link[0,0]<1:
            A_link = A_link + np.eye(A_link.shape[0])  
            A_Distance = A_Distance + np.eye(A_link.shape[0])
    else:
        # A = np.load(graph_file_path+f'{dataset_name}_rn_adj.npy')
        raise NotImplementedError
    if A_link is not None:
        print('adj:',A_link.shape)
    else:
        print('adj: None')
    return A_link, A_Distance, full_distance
