"""
This module provides various positional encoding (PE) functions.

Functions:
    laplacian_pe_func(dim: int) -> Callable[[SparseGraph], torch.Tensor]:
        Generates Laplacian positional encodings for a given graph.
    
    laplacian_abs_pe_func(dim: int) -> Callable[[SparseGraph], torch.Tensor]:
        Generates absolute Laplacian positional encodings for a given graph.
    
    rw_pe_func(dim: int) -> Callable[[SparseGraph], torch.Tensor]:
        Generates random walk positional encodings for a given graph.
    
    gape_gatedgcn_18_pe_func(dim: int) -> Callable[[SparseGraph], torch.Tensor]:
        Generates GatedGCN positional encodings using a pre-trained model for a given graph.
    
    gape_gatedgcn_30_pe_func(dim: int) -> Callable[[SparseGraph], torch.Tensor]:
        Generates GatedGCN positional encodings using a different pre-trained model for a given graph.
    
    gape_gatedgcn_30_aqsol_pe_func(dim: int) -> Callable[[SparseGraph], torch.Tensor]:
        Generates GatedGCN positional encodings using another pre-trained model for a given graph.
    
    nope_pe_func(dim: int) -> Callable[[SparseGraph], torch.Tensor]:
        Generates zero positional encodings for a given graph.
"""

from ngab import SparseGraph
from ngab.models import LaplacianEmbeddings
from ngab.models import GatedGCN
from torch_geometric.transforms import AddRandomWalkPE
from torch_geometric.data import Data as PygData
from safetensors.torch import load_model
import torch

from ngab import BatchedSignals


def laplacian_pe_func(dim: int):
    "Laplacian positional encoding function"
    laplacian_encoding_model = LaplacianEmbeddings(dim)

    def pe_func(graph: SparseGraph) -> torch.Tensor:
        signal = BatchedSignals(
            torch.ones((graph.order(), 1)), torch.zeros((graph.order(),))
        )
        pe = (
            laplacian_encoding_model.forward(signal, graph.to_batch())
            .x()
        )
        return pe

    return pe_func

def laplacian_abs_pe_func(dim: int):
    "Absolute Laplacian positional encoding function"
    laplacian_encoding_model = LaplacianEmbeddings(dim)

    def pe_func(graph: SparseGraph) -> torch.Tensor:
        signal = BatchedSignals(
            torch.ones((graph.order(), 1)), torch.zeros((graph.order(),))
        )
        pe = torch.abs(
            laplacian_encoding_model.forward(signal, graph.to_batch())
            .x()
        )
        return pe

    return pe_func

def rw_pe_func(dim: int):
    "Random walk positional encoding function"
    rwpe_model = AddRandomWalkPE(walk_length=dim)

    def pe_func(graph: SparseGraph) -> torch.Tensor:
        data = PygData(
            edge_index=graph.edge_index()
        )
        data = rwpe_model(data)
        assert data.random_walk_pe.shape[1] == dim
        return data.random_walk_pe

    return pe_func

def gape_gatedgcn_18_pe_func(dim: int):
    "GatedGCN positional encoding function"
    assert dim == 32
    PE_MODEL = "path/to/gape_gatedgcn_18.safetensors"
    gape_encoding_pcqm4mv2 = GatedGCN(4, 48, 32)
    load_model(
        gape_encoding_pcqm4mv2,
        PE_MODEL,
    )
    gape_encoding_pcqm4mv2 = gape_encoding_pcqm4mv2.eval()
    gape_encoding_pcqm4mv2.requires_grad_(False)

    def pe_func(graph: SparseGraph) -> torch.Tensor:
        signal = BatchedSignals(
            torch.ones((graph.order(), 1)), torch.zeros((graph.order(),))
        )
        pe = (
            gape_encoding_pcqm4mv2.forward(signal, graph.to_batch())
            .x()
        )
        return pe
    return pe_func

def nope_pe_func(dim: int):
    "Zero positional encoding function"
    def pe_func(graph: SparseGraph) -> torch.Tensor:
        return torch.zeros((graph.order(), dim))
    return pe_func
