"""
GPU-accelerated NetLSD computation using PyTorch.

This module provides torch-based implementations of NetLSD heat and wave
kernel signatures that match the ground truth numpy implementations.
"""

import torch
import numpy as np
import networkx as nx
import netlsd  # CPU fallback


def compute_normalized_laplacian(A: torch.Tensor) -> torch.Tensor:

    deg = A.sum(dim=1)
    

    sqrt_deg_inv = torch.zeros_like(deg)
    nonzero_mask = deg > 0
    sqrt_deg_inv[nonzero_mask] = 1.0 / torch.sqrt(deg[nonzero_mask])
    
    D_sqrt_inv = torch.diag(sqrt_deg_inv)
    L_comb = torch.diag(deg) - A
    L_norm = D_sqrt_inv @ L_comb @ D_sqrt_inv
    
    return L_norm


_HEAT_TIMESCALE = None
_WAVE_TIMESCALE = None


def _get_heat_timescale(device: torch.device) -> torch.Tensor:
    """Get heat kernel timescales, creating/moving to device as needed."""
    global _HEAT_TIMESCALE
    if _HEAT_TIMESCALE is None or _HEAT_TIMESCALE.device != device:
        _HEAT_TIMESCALE = torch.logspace(-2, 2, 250, device=device).unsqueeze(1)
    return _HEAT_TIMESCALE


def _get_wave_timescale(device: torch.device) -> torch.Tensor:
    """Get wave kernel timescales, creating/moving to device as needed."""
    global _WAVE_TIMESCALE
    if _WAVE_TIMESCALE is None or _WAVE_TIMESCALE.device != device:
        _WAVE_TIMESCALE = torch.linspace(0, 2.0 * np.pi, 250, device=device).unsqueeze(1)
    return _WAVE_TIMESCALE


def _netlsd_heat_cpu_fallback(G: nx.Graph) -> np.ndarray:
    """CPU fallback using the original netlsd library."""
    return np.asarray(netlsd.heat(G), dtype=np.float32)


def _netlsd_wave_cpu_fallback(G: nx.Graph) -> np.ndarray:
    """CPU fallback using the original netlsd library."""
    return np.asarray(netlsd.wave(G), dtype=np.float32)


def netlsd_heat(A: torch.Tensor, G: nx.Graph = None) -> torch.Tensor:
    device = A.device
    
    try:
        timescales = _get_heat_timescale(device)
        
        L = compute_normalized_laplacian(A)
        eivals = torch.linalg.eigvalsh(L)  # Shape: (N,)

        hkt = torch.exp(-timescales * eivals.unsqueeze(0))  # (250, N)
        hkt = hkt.sum(dim=1)  # (250,)
        
        nv = A.shape[0]
        return hkt / nv
        
    except torch._C._LinAlgError as e:
        if G is not None:
            sig = _netlsd_heat_cpu_fallback(G)
            return torch.from_numpy(sig).to(device)
        else:
            A_np = A.cpu().numpy()
            G_fallback = nx.from_numpy_array((A_np > 0.5).astype(int))
            sig = _netlsd_heat_cpu_fallback(G_fallback)
            return torch.from_numpy(sig).to(device)


def netlsd_wave(A: torch.Tensor, G: nx.Graph = None) -> torch.Tensor:
    device = A.device
    
    try:
        timescales = _get_wave_timescale(device)
        
        L = compute_normalized_laplacian(A)
        eivals = torch.linalg.eigvalsh(L) 
        

        wkt = torch.exp(-1j * timescales * eivals.unsqueeze(0))  # (250, N) complex
        wkt = wkt.sum(dim=1).real  # (250,) real
        
        nv = A.shape[0]
        return wkt / nv
        
    except torch._C._LinAlgError as e:
        if G is not None:
            sig = _netlsd_wave_cpu_fallback(G)
            return torch.from_numpy(sig).to(device)
        else:
            A_np = A.cpu().numpy()
            G_fallback = nx.from_numpy_array((A_np > 0.5).astype(int))
            sig = _netlsd_wave_cpu_fallback(G_fallback)
            return torch.from_numpy(sig).to(device)


def nx_to_adjacency_tensor(G: nx.Graph, device: str = "cpu") -> torch.Tensor:
    A = nx.to_numpy_array(G, dtype=np.float32)
    return torch.from_numpy(A).to(device)

def netlsd_heat_batch(As: list[torch.Tensor], Gs: list[nx.Graph] = None) -> list[torch.Tensor]:
    if Gs is None:
        return [netlsd_heat(A) for A in As]
    else:
        return [netlsd_heat(A, G) for A, G in zip(As, Gs)]


def netlsd_wave_batch(As: list[torch.Tensor], Gs: list[nx.Graph] = None) -> list[torch.Tensor]:
    if Gs is None:
        return [netlsd_wave(A) for A in As]
    else:
        return [netlsd_wave(A, G) for A, G in zip(As, Gs)]
