import torch
import torch.nn as nn
from torch_geometric.graphgym.register import register_node_encoder
from torch_geometric.utils import to_dense_adj

def dense_random_walk_matrix_batch(edge_index, batch, num_nodes, reverse=False):
    num_nodes = int(num_nodes)
    A = to_dense_adj(edge_index, batch, max_num_nodes=num_nodes)

    if not reverse:
        D_out = torch.diag_embed(A.sum(dim=2))
        D_out_inv = torch.diag_embed(1.0 / (D_out.diagonal(dim1=1, dim2=2) + 1e-6))
        P = torch.bmm(D_out_inv, A)
    else:
        A_T = A.transpose(1, 2)
        D_in = torch.diag_embed(A_T.sum(dim=2))
        D_in_inv = torch.diag_embed(1.0 / (D_in.diagonal(dim1=1, dim2=2) + 1e-6))
        P = torch.bmm(D_in_inv, A_T)

    return P


def k_step_random_walk_batch(edge_index, batch, num_nodes, k=3, ppr_restart_p=None, reverse=False):
    P = dense_random_walk_matrix_batch(edge_index, batch, num_nodes, reverse)
    rw_probs = P.clone()
    output = [rw_probs]

    for _ in range(k - 1):
        rw_probs = torch.bmm(rw_probs, P)
        output.append(rw_probs)

    return torch.stack(output, dim=-1)


def batch_edge_index_to_adj_fixed(edge_index, batch, num_nodes):

    batch_size = batch.max().item() + 1
    adj_batch = to_dense_adj(edge_index, batch, max_num_nodes=num_nodes)
    return adj_batch


class RWSEEdgeEncoder(torch.nn.Module):

    def __init__(self, emb_dim, cfg):
        super().__init__()
        self.step  = 7
        self.encoder = torch.nn.Linear(self.step, emb_dim)
        self.rwse_dim = cfg.posenc_RWSE.dim_pe
        self.bn = nn.LayerNorm(self.step)


    def forward(self, batch):

        num_nodes_per_instance = batch.Eigvecs.shape[0] // batch.y.shape[0]
        pos_enc = k_step_random_walk_batch(batch.edge_index, batch.batch, num_nodes=num_nodes_per_instance, k=self.step)

        src_nodes, dst_nodes = batch.edge_index[0], batch.edge_index[1]
        edge_batch_indices = torch.searchsorted(batch.ptr, src_nodes, right=True) - 1

        relative_src = src_nodes - batch.ptr[edge_batch_indices]
        relative_dst = dst_nodes - batch.ptr[edge_batch_indices]

        init_edge = pos_enc[edge_batch_indices, relative_src, relative_dst, :]

        init_edge = self.bn(init_edge)
        batch.edge_attr = self.encoder(init_edge)
        return batch


edge_encoder_dict = {"RWSEEdge": RWSEEdgeEncoder}