import numpy as np
from scipy.linalg import eigvals
import scipy.sparse as sp
from scipy.sparse.linalg import eigsh
import torch

def calc_gso(dir_adj, gso_type):
    if sp.issparse(dir_adj):
        id = sp.identity(dir_adj.shape[0], format='csc')
        # Symmetrizing an adjacency matrix
        adj = dir_adj + dir_adj.T.multiply(dir_adj.T > dir_adj) - dir_adj.multiply(dir_adj.T > dir_adj)
        #adj = 0.5 * (dir_adj + dir_adj.transpose())
    
        if gso_type == 'sym_renorm_adj' or gso_type == 'rw_renorm_adj' \
            or gso_type == 'sym_renorm_lap' or gso_type == 'rw_renorm_lap':
            adj = adj + id
    
        if gso_type == 'sym_norm_adj' or gso_type == 'sym_renorm_adj' \
            or gso_type == 'sym_norm_lap' or gso_type == 'sym_renorm_lap':
            row_sum = adj.sum(axis=1).A1
            row_sum_inv_sqrt = np.power(row_sum, -0.5)
            row_sum_inv_sqrt[np.isinf(row_sum_inv_sqrt)] = 0.
            deg_inv_sqrt = sp.diags(row_sum_inv_sqrt, format='csc')
            # A_{sym} = D^{-0.5} * A * D^{-0.5}
            sym_norm_adj = deg_inv_sqrt.dot(adj).dot(deg_inv_sqrt)

            if gso_type == 'sym_norm_lap' or gso_type == 'sym_renorm_lap':
                sym_norm_lap = id - sym_norm_adj
                gso = sym_norm_lap
            else:
                gso = sym_norm_adj
        
        elif gso_type == 'rw_norm_adj' or gso_type == 'rw_renorm_adj' \
            or gso_type == 'rw_norm_lap' or gso_type == 'rw_renorm_lap':
            row_sum = adj.sum(axis=1).A1
            row_sum_inv = np.power(row_sum, -1)
            row_sum_inv[np.isinf(row_sum_inv)] = 0.
            deg_inv = sp.diags(row_sum_inv, format='csc')
            # A_{rw} = D^{-1} * A
            rw_norm_adj = deg_inv.dot(adj)

            if gso_type == 'rw_norm_lap' or gso_type == 'rw_renorm_lap':
                rw_norm_lap = id - rw_norm_adj
                gso = rw_norm_lap
            else:
                gso = rw_norm_adj

        else:
            raise ValueError(f'{gso_type} is not defined.')
    
    else:
        id = np.identity(dir_adj.shape[0])
        # Symmetrizing an adjacency matrix
        adj = np.maximum(dir_adj, dir_adj.T)
        #adj = 0.5 * (dir_adj + dir_adj.T)

        if gso_type == 'sym_renorm_adj' or gso_type == 'rw_renorm_adj' \
            or gso_type == 'sym_renorm_lap' or gso_type == 'rw_renorm_lap':
            adj = adj + id

        if gso_type == 'sym_norm_adj' or gso_type == 'sym_renorm_adj' \
            or gso_type == 'sym_norm_lap' or gso_type == 'sym_renorm_lap':
            row_sum = np.sum(adj, axis=1)
            row_sum_inv_sqrt = np.power(row_sum, -0.5)
            row_sum_inv_sqrt[np.isinf(row_sum_inv_sqrt)] = 0.
            deg_inv_sqrt = np.diag(row_sum_inv_sqrt)
            # A_{sym} = D^{-0.5} * A * D^{-0.5}
            sym_norm_adj = deg_inv_sqrt.dot(adj).dot(deg_inv_sqrt)

            if gso_type == 'sym_norm_lap' or gso_type == 'sym_renorm_lap':
                sym_norm_lap = id - sym_norm_adj
                gso = sym_norm_lap
            else:
                gso = sym_norm_adj

        elif gso_type == 'rw_norm_adj' or gso_type == 'rw_renorm_adj' \
            or gso_type == 'rw_norm_lap' or gso_type == 'rw_renorm_lap':
            row_sum = np.sum(adj, axis=1).A1
            row_sum_inv = np.power(row_sum, -1)
            row_sum_inv[np.isinf(row_sum_inv)] = 0.
            deg_inv = np.diag(row_sum_inv)
            # A_{rw} = D^{-1} * A
            rw_norm_adj = deg_inv.dot(adj)

            if gso_type == 'rw_norm_lap' or gso_type == 'rw_renorm_lap':
                rw_norm_lap = id - rw_norm_adj
                gso = rw_norm_lap
            else:
                gso = rw_norm_adj

        else:
            raise ValueError(f'{gso_type} is not defined.')

    return gso

def calc_chebynet_gso(gso):
    # 如果gso是scipy的稀疏矩阵（例如csc_matrix或csr_matrix），首先转换为torch.Tensor
    if isinstance(gso, (sp.csc_matrix, sp.csr_matrix)):
        gso = torch.tensor(gso.toarray(), dtype=torch.float32)

    # 计算最大特征值
    eigval_max = torch.norm(gso, p=2)  # 计算L2范数，即最大特征值

    # 确保最大特征值小于或等于2
    id = torch.eye(gso.shape[0], dtype=torch.float32)  # 创建单位矩阵
    gso = 2 * gso / eigval_max - id

    return gso

def cnv_sparse_mat_to_coo_tensor(sp_mat, device):
    # 如果输入是scipy的稀疏矩阵（如csc_matrix或csr_matrix），转换为torch的稀疏矩阵
    if isinstance(sp_mat, (sp.csc_matrix, sp.csr_matrix)):
        sp_coo_mat = sp_mat.tocoo()
        i = torch.from_numpy(np.vstack((sp_coo_mat.row, sp_coo_mat.col))).to(device)
        v = torch.from_numpy(sp_coo_mat.data).to(device)
        s = torch.Size(sp_coo_mat.shape)
        
        if sp_mat.dtype == np.complex64 or sp_mat.dtype == np.complex128:
            return torch.sparse_coo_tensor(indices=i, values=v, size=s, dtype=torch.complex64, device=device, requires_grad=False)
        elif sp_mat.dtype == np.float32 or sp_mat.dtype == np.float64:
            return torch.sparse_coo_tensor(indices=i, values=v, size=s, dtype=torch.float32, device=device, requires_grad=False)
        else:
            raise TypeError(f'ERROR: The dtype of {sp_mat} is {sp_mat.dtype}, not been applied in implemented models.')

    # 如果输入是torch.Tensor类型
    elif isinstance(sp_mat, torch.Tensor):
        # 检查是否是稀疏张量
        if sp_mat.is_sparse:
            return sp_mat.to_sparse().to(device)  # 直接返回稀疏张量
        else:
            # 如果是密集张量，将其转换为稀疏COO张量
            indices = sp_mat.nonzero(as_tuple=False)  # 获取非零元素的索引
            values = sp_mat[indices[:, 0], indices[:, 1]]  # 获取对应的非零值
            
            # 检查非零元素的数量是否匹配
            if indices.size(0) != values.size(0):
                raise ValueError(f"Mismatch between number of indices ({indices.size(0)}) and values ({values.size(0)})")
            
            size = sp_mat.shape
            return torch.sparse_coo_tensor(indices.t(), values, size, dtype=torch.float32, device=device, requires_grad=False)

    else:
        raise TypeError(f"Unsupported input type: {type(sp_mat)}")

def calc_accuracy(output, labels):
    preds = output.max(1)[1].type_as(labels)
    correct = preds.eq(labels).double().sum()
    accuracy = correct / len(labels)

    return accuracy
