import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.data import Data
from torch_geometric.loader import DataLoader
from typing import Tuple
from argparse import Namespace
from torch_geometric.utils import degree, to_dense_adj
from torch_geometric.transforms import BaseTransform
from utils import add_pos_information
from torch_geometric.utils import add_self_loops
import scipy.linalg


class RandomWalkTransform(BaseTransform):
    def __init__(self, max_walk_len: int):
        self.max_walk_len = max_walk_len

    def __call__(self, data: Data) -> Data:
        data['norm'] = torch.FloatTensor(data.num_nodes, 1).fill_(1. / float(data.num_nodes))
        data['nodes_total'] = data.num_nodes

        degree_vec = degree(data.edge_index[0], data.num_nodes)
        deg_mat = torch.diag(degree_vec)
        adj_mat = torch.squeeze(to_dense_adj(data.edge_index))
        ran_walk = adj_mat @ torch.linalg.inv(deg_mat)
        cur_matrix = ran_walk.clone()

        features = []
        for k in range(self.max_walk_len):
            features.append(torch.diagonal(cur_matrix).unsqueeze(1))
            cur_matrix = cur_matrix@ran_walk

        rw_features = torch.cat(features, dim=1)
        return add_pos_information(data, rw_features)
