# ------------------------ : new rwpse ----------------
from typing import Union, Any, Optional
import numpy as np
import torch
import torch.nn.functional as F
import torch_geometric as pyg
from torch_geometric.data import Data, HeteroData
from torch_geometric.transforms import BaseTransform
from torch_scatter import scatter, scatter_add, scatter_max

from torch_geometric.graphgym.config import cfg

from torch_geometric.utils import (
    get_laplacian,
    get_self_loop_attr,
    to_scipy_sparse_matrix,
)
import torch_sparse
from torch_sparse import SparseTensor


def add_node_attr(data: Data, value: Any,
                  attr_name: Optional[str] = None) -> Data:
    if attr_name is None:
        if 'x' in data:
            x = data.x.view(-1, 1) if data.x.dim() == 1 else data.x
            data.x = torch.cat([x, value.to(x.device, x.dtype)], dim=-1)
        else:
            data.x = value
    else:
        data[attr_name] = value

    return data



@torch.no_grad()
def add_full_rrwp(data,
                  walk_length=8,
                  attr_name_abs="rrwp", # name: 'rrwp'
                  attr_name_rel="rrwp", # name: ('rrwp_idx', 'rrwp_val')
                  add_identity=True,
                  spd=False,
                  **kwargs
                  ):
    device=data.edge_index.device
    ind_vec = torch.eye(walk_length, dtype=torch.float, device=device)
    num_nodes = data.num_nodes
    edge_index, edge_weight = data.edge_index, data.edge_weight

    adj = SparseTensor.from_edge_index(edge_index, edge_weight,
                                       sparse_sizes=(num_nodes, num_nodes),
                                       )

    # Compute D^{-1} A:
    deg = adj.sum(dim=1)
    deg_inv = 1.0 / adj.sum(dim=1)
    deg_inv[deg_inv == float('inf')] = 0
    adj = adj * deg_inv.view(-1, 1)
    adj = adj.to_dense()

    pe_list = []
    pe_list_for_abs = []

    if attr_name_abs == "rrwp":
        i = 0
        if add_identity:
            pe_list.append(torch.eye(num_nodes, dtype=torch.float))
            i = i + 1
        
        out = adj
        pe_list.append(adj)

        if walk_length > 2:
            for _ in range(i + 1, walk_length):
                out = out @ adj
                pe_list.append(out)
    
    elif attr_name_abs == "gkse":
        sum_out = torch.eye(num_nodes, dtype=torch.float)
        i = 0
        if add_identity:
            pe_list.append(torch.eye(num_nodes, dtype=torch.float))
            i = i + 1
        
        out = adj
        sum_out = sum_out + out
        pe_list.append(sum_out)

        if walk_length > 2:
            for _ in range(i + 1, walk_length):
                out = out @ adj
                sum_out = sum_out + out
                pe_list.append(sum_out)

    elif attr_name_abs == "mkse":
        sum_out = torch.eye(num_nodes, dtype=torch.float)
        i = 0
        if add_identity:
            pe_list.append(torch.eye(num_nodes, dtype=torch.float))
            i = i + 1
        
        out = adj
        sum_out = sum_out + out
        diag = torch.diag(sum_out)
        diag_inv = torch.where(diag == 0, 0, 1 / diag)
        normalized_sum_out = sum_out * diag_inv.view(1, -1)
        pe_list.append(normalized_sum_out)

        if walk_length > 2:
            for _ in range(i + 1, walk_length):
                out = out @ adj
                sum_out = sum_out + out
                diag = torch.diag(sum_out)
                diag_inv = torch.where(diag == 0, 0, 1 / diag)
                normalized_sum_out = sum_out * diag_inv.view(1, -1)
                pe_list.append(normalized_sum_out)

        pe_list_for_abs.append(torch.zeros(num_nodes, num_nodes))
        for m in pe_list[:-1]:
            pe_list_for_abs.append(adj @ m)

    else:
        raise ValueError(f"Invalid command: {attr_name_abs}. Valid commands are 'rrwp', 'gkse', and 'mkse'")

    if len(pe_list_for_abs) == 0:
        pe_list_for_abs = pe_list

    pe = torch.stack(pe_list, dim=-1) # n x n x k
    pe_for_abs = torch.stack(pe_list_for_abs, dim=-1)

    abs_pe = pe_for_abs.diagonal().transpose(0, 1) # n x k

    rel_pe = SparseTensor.from_dense(pe, has_value=True)
    rel_pe_row, rel_pe_col, rel_pe_val = rel_pe.coo()
    rel_pe_idx = torch.stack([rel_pe_row, rel_pe_col], dim=0)

    if spd:
        spd_idx = walk_length - torch.arange(walk_length)
        val = (rel_pe_val > 0).type(torch.float) * spd_idx.unsqueeze(0)
        val = torch.argmax(val, dim=-1)
        rel_pe_val = F.one_hot(val, walk_length).type(torch.float)
        abs_pe = torch.zeros_like(abs_pe)

    data = add_node_attr(data, abs_pe, attr_name=attr_name_abs)
    data = add_node_attr(data, rel_pe_idx, attr_name=f"{attr_name_rel}_index")
    data = add_node_attr(data, rel_pe_val, attr_name=f"{attr_name_rel}_val")
    data.log_deg = torch.log(deg + 1)
    data.deg = deg.type(torch.long)

    return data

