import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.data import Data
from torch_geometric.loader import DataLoader
from typing import Tuple
from torch import sparse_coo_tensor
from argparse import Namespace

from torch_geometric.utils import degree, to_dense_adj, get_laplacian, sparse
from torch_geometric.transforms import BaseTransform
from utils import add_pos_information
from torch_geometric.utils import add_self_loops
from scipy.linalg import fractional_matrix_power
import numpy as np
import scipy as sp


class LaplacianTransform(BaseTransform):
    def __init__(self, num_eigen_vec: int):
        self.num_eigen_vec = num_eigen_vec

    def __call__(self, data: Data) -> Data:
        data['norm'] = torch.FloatTensor(data.num_nodes, 1).fill_(1. / float(data.num_nodes))
        data['nodes_total'] = data.num_nodes

        degree_vec = degree(data.edge_index[0], data.num_nodes)
        deg_mat = torch.diag(degree_vec)
        adj_mat = torch.squeeze(to_dense_adj(data.edge_index))
        deg_f = torch.tensor(fractional_matrix_power(deg_mat, -1/2))

        id_mat = torch.eye(data.num_nodes)
        lap = id_mat - deg_f @ adj_mat @ deg_f
        EigVal, EigVec = np.linalg.eig(lap)
        idx = EigVal.argsort()  # increasing order

        EigVal, EigVec = EigVal[idx], np.real(EigVec[:, idx])
        lap_features = torch.from_numpy(EigVec[:, 1:self.num_eigen_vec + 1]).float()

        return add_pos_information(data, lap_features)
