# eigvec_util.py 

import torch
import numpy as np
import scipy.sparse as sparse
from torch_geometric.utils import to_dense_adj, to_dense_batch

from typing import List

import torch.nn.functional as F


from torch_geometric.data import InMemoryDataset, Data

import matplotlib.pyplot as plt

def get_masked_laplacian(edge_index, batch_idx, keep):
    
    undir_edge_index = edge_index


    N = batch_idx.shape[0]

    
    
    # BUILDING KEEP MASK DONE

    idx = torch.arange(N, dtype=batch_idx.dtype, device=batch_idx.device)[keep]

    mask = keep[edge_index[0]] & keep[edge_index[1]]

    ei_keep = edge_index[:, mask]

    # remap node ids to 0..len(idx)-1
    new_id = -torch.ones(N, dtype=torch.long, device=edge_index.device)
    new_id[idx] = torch.arange(idx.numel(), device=edge_index.device)
    ei_new = new_id[ei_keep]

    assert(ei_new.min().item() > -0.5)

    edge_index_new = ei_new  # shape [2, E']
    edge_weight_new = None   # if you had weights, filter them by 'mask'

    # Now get Laplacian of the induced subgraph

    adj = edge_index_to_sparse_adj(edge_index_new, idx.numel())
    L = get_lap(adj)
    
    return L

def eigen_mask(min_nodes, batch_idx, max_nodes=None):
    # Input: 
    # - batch_idx tensor(N): graph index for every node 
    # Output: 
    # - keep, a binary mask 
    # filters small and large graphs 
        
    graph_sizes = torch.bincount(batch_idx) 

    keep = torch.ones(batch_idx.shape[0], dtype=torch.bool, device=batch_idx.device)
    
    # want to apply large-graph filtering to EFFECTIVE graph sizes (incl. virtual node), as this is just input dim to concat MLP
    if max_nodes != None:
        too_big_graphs = graph_sizes > max_nodes
        keep = keep & ~too_big_graphs[batch_idx]
        
    too_small_graphs = graph_sizes < min_nodes
    keep = keep & ~too_small_graphs[batch_idx]

    return keep


def eigvec_forward(linear, h, batch, max_nodes, device):
    # h:    [total_nodes,  F]
    # batch:[total_nodes]  with values in 0…B−1
    
    
    # 1) pack into a dense (padded) batch + mask
    h_dense, mask = to_dense_batch(h, batch, max_num_nodes= max_nodes)  

    #   h_dense: [B, max_nodes,  F]
    #     mask: [B, max_nodes]  (True where a real node lives)

    B, N, F_in = h_dense.size()

    # 2) flatten and run through the linear
    flat = h_dense.view(B, N * F_in)                   # [B, N*F_in]
    out_flat = linear(flat)                            # [B, out_dim]
    
    # if your linear outputs per‐node features instead of per‐graph,
    # you can do something like:
    # out_flat = out_flat.view(B, N, out_feat)

    # 3) unravel and re‑mask
    # out_flat.view(B, N, out_feat)  → [B, N, C]
    # mask                            → [B, N]
    # final_out                       → [total_nodes, C]
    final_out = out_flat.view(B, N, -1)[mask]

    return final_out

def wavelet_positional_emb(data: Data, filters, num_scales=10, lazy_parameter=0.5, num_nodes=3):
    adj = edge_index_to_sparse_adj(data.edge_index, data.num_nodes)
    N = adj.size(0)

    # 1) build your wavelet filters
    filters = filters[:num_scales]

    # 2) pick num_nodes distinct node‐indices
    #    torch.randperm(N) gives a random permutation of [0..N-1]
    indices = torch.randperm(N)[:num_nodes]  

    # 3) build the (N x num_nodes) Dirac signal matrix
    #    each column i is 1 at indices[i], else 0
    signal = torch.zeros(N, num_nodes, device=adj.device)
    signal[indices, torch.arange(num_nodes, device=adj.device)] = 1.

    # 4) apply each filter to all diracs at once
    #    results in an (N x num_nodes) matrix per scale
    embs = []
    for W in filters:
        embs.append(W @ signal)   # shape: (N, num_nodes)
    # now embs is a list of length num_scales, each an (N x num_nodes) tensor

    # 5) stack or reshape as you like; for example:
    #    a) return a tensor of shape (num_scales, N, num_nodes):
    return torch.cat(embs, dim=1)

def generate_wavelet_bank(data: Data, num_scales=10, lazy_parameter=0.5, abs_val = False):
    adj = edge_index_to_sparse_adj(data.edge_index, data.num_nodes) # TODO: make this cleaner, check if adj is actually sparse or if we should just directly convert using to_dense_adj
    degree = torch.diag(torch.sum(adj.to_dense(), dim = 0))
    diff_op = adj @ torch.inverse(degree)

    N = adj.size(0)

    lazy_diff_op = lazy_parameter * torch.eye(N) + (1-lazy_parameter) * diff_op

    diff_op_1 = lazy_diff_op
    diff_op_2 = None

    filters = ()

    for i in range(num_scales):
        diff_op_2 = diff_op_1 @ diff_op_1 # iterative squaring
        wavelet_filter = diff_op_1  - diff_op_2 

        if abs_val:
            wavelet_filter = torch.abs(wavelet_filter)

        filters = filters + (wavelet_filter,)

        diff_op_1 = diff_op_2 


    return filters

def random_walk_positional_encoding(adj: torch.Tensor,
                                    deg: torch.Tensor,
                                    num_hops: int) -> torch.Tensor:
    if num_hops == 0:
        return torch.zeros(adj.shape[0], 0).to(adj.device)

    """
    Compute k-step Random-Walk Positional Encodings (RWPE).

    Args:
        adj (torch.Tensor):     Adjacency matrix of shape (N, N), dense float tensor.
        deg (torch.Tensor):     Degree vector of shape (N,), where deg[i] = sum_j adj[i,j].
        num_hops (int):         Number of RW hops (k).

    Returns:
        torch.Tensor:  Positional encodings of shape (N, num_hops),
                       where entry [i,h] = (P^h)_{ii}, and
                       P = D^{-1} A is the row-stochastic transition matrix.
    """
    # Invert degrees, avoiding division by zero
    deg_inv = deg.clone()
    deg_inv[deg_inv > 0] = 1.0 / deg_inv[deg_inv > 0]
    deg_inv[deg_inv == 0] = 0.0

    # Build one-step transition matrix P = D^{-1} A
    #   P[i,j] = A[i,j] / deg[i]
    P = adj * deg_inv.unsqueeze(1)    # (N, N)

    # Initialize M = P^1
    M = P.clone()

    # Collect return-probabilities diag(M) for hops 1…k
    rwpe_list = []
    for _ in range(num_hops):
        # diag(M) is shape (N,)
        rwpe_list.append(torch.diag(M))
        # advance to next power: M ← M · P
        M = M.mm(P)
    # Stack into shape (N, num_hops)
    return torch.stack(rwpe_list, dim=1)

class PositionalWaveletTransform: 
    def __init__(self, num_scales=5, num_nodes=3, lazy_parameter=0.5):
        self.num_scales = num_scales
        self.num_nodes = num_nodes
        self.lazy_parameter = lazy_parameter

    def __call__(self, data: Data) -> Data:
        filters = generate_wavelet_bank(data, num_scales=self.num_scales, lazy_parameter=self.lazy_parameter)
        # post positional: given only to the MLP prediction head 
        data.post_positional = wavelet_positional_emb(data, filters, num_scales=self.num_scales, lazy_parameter=self.lazy_parameter, num_nodes = self.num_nodes)
        data.post_positional = torch.zeros_like(data.post_positional) # TEMP: sanity check

        return data 


class PostTransformSanityCheck: 
    def __init__(self, dim=0):
        self.dim = dim

    def __call__(self, data: Data) -> Data:

        data.post_positional = torch.zeros(data.x.shape[0], self.dim).to(data.x.device) # TEMP: sanity check

        return data 

class RandomWalkPETransform:
    def __init__(self, walk_length=20, on_substruct = False):
        self.walk_length = walk_length
        self.on_substruct = on_substruct
    
    def __call__(self, data: Data) -> Data:
        # print("DOING RANDOM WALK TRANSFORM")
        if self.on_substruct:
            edge_index = data.edge_index_substruct
            x = data.x_substruct
        else:
            edge_index = data.edge_index
            x = data.x
        

        if edge_index.numel() == 0:
            data.pre_positional = torch.ones(x.shape[0], self.walk_length).to(x.device)
            # print("Edgeless graph")
            return data

        adj = to_dense_adj(edge_index, max_num_nodes = x.shape[0])[0]
        deg = adj.sum(dim=1)
        rwpe = random_walk_positional_encoding(adj, deg, self.walk_length)
        data.pre_positional = rwpe  # pre_positional: given to the GNN at the very first layer 
        # data.pre_positional = torch.zeros_like(rwpe) # TODO SANITY CHECK, REMOVE LATER 
        return data

# DATALOADING UTILS: 
class EigvecPretransform2:
    def __init__(self, num_eigvecs=5):
        self.num_eigvecs = num_eigvecs
    
    def __call__(self, data: Data) -> Data:
        adj = edge_index_to_sparse_adj(data.edge_index, data.num_nodes)
        # data.adj = adj
        # symmetric = (adj.to_dense().transpose(0, 1) == adj.to_dense()).all()
        # print("Symmetric: ", symmetric)
        lap = get_lap(adj)

        # Full spectrum
        evals, evecs = torch.linalg.eigh(lap)      # shapes (n,), (n, n)
        N = adj.shape[0]
        if N < self.num_eigvecs:
            evals = F.pad(evals, (0, self.num_eigvecs - N))
            evecs = F.pad(evecs, (0, self.num_eigvecs - N))
        
        data.eigvecs = evecs[:, :self.num_eigvecs]
        data.eigvals = evals[:self.num_eigvecs].unsqueeze(0)
        return data

class EigvecPretransform:

    def __init__(self, evec_len):
        self.evec_len = evec_len
    
    def __call__(self, data: Data) -> Data:
        adj = edge_index_to_sparse_adj(data.edge_index, data.num_nodes)
        # data.adj = adj
        # symmetric = (adj.to_dense().transpose(0, 1) == adj.to_dense()).all()
        # print("Symmetric: ", symmetric)
        if adj.shape[0] > self.evec_len:
            raise ValueError("num_nodes larger than evec_len")
        evals, evecs = get_padded_eigvecs(adj, self.evec_len)
        data.eigvecs = evecs 
        data.eigvals = evals
        return data

class EigvecPrefilter:

    def __init__(self, min_size, max_size):
        self.min_size  = min_size
        self.max_size = max_size
    
    def __call__(self, data: Data) -> Data:
        return data.num_nodes <= self.max_size and data.num_nodes >= self.min_size


def get_padded_eigvecs(adj: torch.Tensor, max_graph_size: int):
    """
    Compute eigenvalues/eigenvectors of Laplacian(adj), then
    pad both to size `max_graph_size` with zeros (or trim if larger).
    
    Returns:
      evals:  Tensor of shape (max_graph_size,)
      evecs: Tensor of shape (max_graph_size, max_graph_size)
    """
    # Build Laplacian
    lap = get_lap(adj)

    # Full spectrum
    evals, evecs = torch.linalg.eigh(lap)      # shapes (n,), (n, n)
    n = evals.size(0)

    # If graph bigger, truncate
    if n > max_graph_size:
        evals = evals[:max_graph_size]
        evecs = evecs[:max_graph_size, :max_graph_size]
        return evals, evecs

    # Otherwise, pad up to max_graph_size
    pad_len = max_graph_size - n

    # 1) pad evals: concatenate zeros
    pad_evals = torch.zeros(pad_len, device=evals.device, dtype=evals.dtype)
    evals_padded = torch.cat([evals, pad_evals], dim=0)  # shape = (max_graph_size,)

    # 2) pad evecs: add zero‐rows and zero‐columns
    #    F.pad takes (pad_left, pad_right, pad_top, pad_bottom)
    evecs_padded = F.pad(evecs,
                        # columns: (left, right) = (0, pad_len)
                        # rows   : (top, bottom) = (0, pad_len)
                        pad=(0, pad_len, 0, pad_len),
                        mode='constant', value=0.0)
    # now shape = (n+pad_len, n+pad_len) = (max_graph_size, max_graph_size)

    return evals_padded, evecs_padded


def edge_index_to_sparse_adj(edge_index: torch.LongTensor, num_nodes: int) -> torch.Tensor:
    # edge_index: [2, E], num_nodes: N
    row, col = edge_index
    # if our graph is undirected, you may want to add the reverse edges here

    # row = torch.cat([row, col]) # make sure the graph is undirected 
    # col = torch.cat([col, row])
    device = edge_index.device

    values = torch.ones(row.size(0), dtype=torch.float32).to(device)
    adj = torch.sparse_coo_tensor(
        torch.stack([row, col], dim=0),
        values,
        (num_nodes, num_nodes),
    ).coalesce()

    return adj

def get_lap(adj):
    degree = torch.diag(torch.sum(adj.to_dense(), dim = 0))
    lap = degree - adj
    return lap


def normalize_by_batch(x, batch):
    num_groups = int(batch.max()) + 1
    norm_sq = x.new_zeros(num_groups, x.shape[-1]).index_add(0, batch, x.pow(2))
    eps   = 1e-6
    norms  = (norm_sq + eps).sqrt()
    x_norm = x / norms[batch]

    return x_norm

def orthogonalize_by_batch_slow(x, batch):
    """
    Orthonormalize the set of vectors in each batch group.
    
    x:     (m, k) tensor, where m = num_groups * 30
           each row is one k-dimensional vector
    batch: (m,) long tensor with values in {0,…,num_groups-1},
           exactly 30 rows per group
    returns: x_orth of same shape, where for each g,
             the 30 rows x[batch==g] are replaced by an
             orthonormal set in R^k.
    """
    num_groups = int(batch.max().item()) + 1
    x_orth = torch.empty_like(x)

    for g in range(num_groups):
        mask = (batch == g)
        Xg = x[mask]              # shape (30, k)
        # QR on the transpose → Q has shape (k, 30) with orthonormal columns
        Q, R = torch.linalg.qr(Xg)
        Xg_orth = Q            
        x_orth[mask] = Xg_orth
    
    return x_orth

def orthogonalize_by_batch(x, batch, max_nodes=50):


    X_dense, mask = to_dense_batch(x, batch) # [B, m, k]

    Q_t, R = torch.linalg.qr(X_dense)  
    Q = Q_t       

    # 3) Un‐pack back to (m, k)
    x_orth = Q[mask]  

    return x_orth

def ortho_loss_by_batch(eigvecs, batch):
    """
    eigvecs:  a [N × k] (or [..., N, k]) dense tensor on CUDA
    returns:  sum of all pairwise dot-products between distinct columns
    """
    
    batched_eigvecs, mask = to_dense_batch(eigvecs, batch)
    B = batched_eigvecs.shape[0]
    # 1) Gram matrix G [k × k]
    G = batched_eigvecs.transpose(-2, -1) @ batched_eigvecs    # shape (..., k, k)
    # 2) Identity of size k on the same device & dtype
    k = G.size(-1)

    I = torch.eye(k, device=G.device, dtype=G.dtype).unsqueeze(0).expand(B, -1, -1)

    # 3) Zero out the diagonal by masking
    #    Off-diagonal = G * (1 - I)
    off_diag = G * (1.0 - I)

    # 4) Sum up the off-diagonals
    #    If you prefer a squared penalty, do off_diag.pow(2).sum()
    return torch.norm(off_diag, dim=[-1, -2]).mean()


# LOSS FUNCTIONS 

def EnergyLoss(eigvecs, adj, weights=None):
    # adj: SparseTensor in COO format on CUDA
    device = adj.device
    N = adj.size(0)
    # assert(False)
    # 1) sum to get a dense degree vector
    deg_vec = torch.sparse.sum(adj, dim=1).to_dense()      # [N]

    # 2) build sparse diagonal: indices = [[0,1,2,…],[0,1,2,…]]
    D = torch.diag(deg_vec)
    # 3) sparse-sparse subtraction (yields a sparse result)
    L = D - adj

    # 4) if you want energy in dense form, densify L first
    
    # L_dense = L.to_dense()
    
    
    num_eigenvectors = eigvecs.shape[-1]

    if weights == None:
        weights = torch.ones(num_eigenvectors).to(device)


    # build diagonal of weights
    diag_weights=torch.diag(weights)

    # print(L.is_sparse)
    # print(eigvecs.is_sparse)
    # print(diag_weights.is_sparse)
    # print(L.shape)
    # print(eigvecs.shape)
    # print(diag_weights.shape)

    energy = torch.trace(
        eigvecs.transpose(-2, -1) @ L @ eigvecs @ diag_weights
    )


    # print(adj.size)
    # print("eigvecs", eigvecs.shape)
    # print("L", L_dense.shape)

    return energy


def SupervisedEigenvalueLoss(eigvecs_pred, edge_index, eigvals_gt, batch, max_nodes):
    """
    eigvecs_pred: [total_nodes, num_eigvals]
    edge_index:   [2, total_edges]
    eigvals_gt:   [batch_size * max_nodes]  # ground‑truth eigenvalues, padded per graph
    batch:        [total_nodes]              # graph assignment for each node
    max_nodes:    int                        # padded node count per graph
    """
    device = eigvecs_pred.device
    B = int(batch.max().item()) + 1
    num_eigvals = eigvecs_pred.size(-1)

    # 1) build batched adjacency → [B, N, N]
    A_dense = to_dense_adj(edge_index, batch, max_num_nodes=max_nodes).to(device)
    
    # 2) compute Laplacian per graph: L = D − A
    deg = A_dense.sum(dim=-1)                          # [B, N]
    L_dense = torch.diag_embed(deg) - A_dense           # [B, N, N]

    # 3) pack predicted e‑vecs → [B, N, num_eigvals]
    E_dense, mask = to_dense_batch(eigvecs_pred, batch, max_num_nodes=max_nodes)

    # 4) grab GT eigenvalues per graph
    #    reshape to [B, N], then take first num_eigvals columns
    eigvals_dense = eigvals_gt.view(B, max_nodes)      # [B, N]
    eigs_for_loss = eigvals_dense[:, :num_eigvals]     # [B, num_eigvals]

    # 5) L * E  and  E * diag(eigs)
    lap_evec   = torch.bmm(L_dense, E_dense)           # [B, N, num_eigvals]
    evecs_diag = E_dense * eigs_for_loss.unsqueeze(1)  # [B, N, num_eigvals]

    # 6) mask out padded rows, then Frobenius‐norm per graph
    diff = (lap_evec - evecs_diag) * mask.unsqueeze(-1).type_as(E_dense) # [B, N, num_eigvals]
    eigval_loss = torch.norm(diff.view(B, -1), p='fro', dim=1).mean()

    # print("first", E_dense.transpose(-2, -1).unsqueeze(-2).shape)
    # print("second", diff.transpose(-2, -1).unsqueeze(-1).shape)
    energy_loss_mat =  torch.matmul(E_dense.transpose(-2,-1).unsqueeze(-2), diff.transpose(-2, -1).unsqueeze(-1)) 
    # print("result", energy_loss_mat.shape)
    energy_loss = energy_loss_mat.mean()
    # [B, num_eigvals, 1, N] x [B, num_eigvals, N, 1] -> [B, num_eigvals, 1, 1]
    

    return energy_loss, eigval_loss

def SupervisedEigenvalueLoss2(eigvecs_pred, eigvals_gt, batch, laplacian, true_eigvecs = None):
    """
    eigvecs_pred: [total_nodes, num_eigvals]
    edge_index:   [2, total_edges]
    eigvals_gt:   [total_nodes * num_eigvals] 
    batch:        [total_nodes]              # graph assignment for each node
    max_nodes:    int                        # padded node count per graph
    """

    # 

    num_eigvecs = eigvecs_pred.shape[-1]

    
    # ORTHOGONALIZE + NORMALIZE 
    N = eigvecs_pred.shape[0]
    
    U = eigvecs_pred # [total_nodes, num_eigvals] 
    print("U", U.shape)
    U = normalize_by_batch(U, batch)
    # print("U after normal", U.shape)
    ortho_loss = ortho_loss_by_batch(U, batch) 
    U = orthogonalize_by_batch(U, batch)
        # U = normalize_by_batch(U, batch)

    # print("U after ortho", U.shape)

    # print("SANITY CHECK ORTHO LOSS SHOULD BE ZERO:", ortho_loss)
   
    # LOADING THE LAPLACIAN 
    
    L = laplacian
    
    graphsizes = torch.bincount(batch)
    
    LU = L @ U  


    if true_eigvecs != None:
        # print(true_eigvecs.shape)
        # print(L.shape) 
        # print(eigvals_gt.shape)
        print("SANITY CHECK SHOULD BE ZERO:", (L @ true_eigvecs - true_eigvecs * eigvals_gt.squeeze()).sum())
    # print(eigvals_gt.shape)

    # print("U", U.shape)
    Ulbda = U * eigvals_gt.squeeze()
    # print("eigvals gt max", eigvals_gt.max())
    diff = LU - Ulbda # graphs with num_nodes < num_eigvals will have some nan eigval entries, exclude these from computation

    batched_U, _ = to_dense_batch(U, batch) # [B, max_nodes, num_eigvals]
    batched_diff, _ = to_dense_batch(diff, batch) # [B, max_nodes, num_eigvals]
    eigval_loss_mat = torch.norm(batched_diff, dim=1) # norm of each diff eigvec 
    eigval_loss_by_idx = eigval_loss_mat.mean(dim=0)

    eigval_loss = eigval_loss_mat.mean()

    energy_loss_mat =  torch.matmul(batched_U.transpose(-2,-1).unsqueeze(-2), batched_diff.transpose(-2, -1).unsqueeze(-1)).squeeze()
    
    
    energy_loss_mat = torch.abs(energy_loss_mat)

    energy_loss_by_idx = energy_loss_mat.mean(dim=0)
    energy_loss = energy_loss_mat.mean()
    


    return energy_loss, eigval_loss, ortho_loss / num_eigvecs, energy_loss_by_idx, eigval_loss_by_idx

def plot_loss_history(loss_hist, path):
    epochs = len(loss_hist)
    fig = plt.figure()
    ax = fig.add_subplot(1, 1, 1)
    ax.plot(range(1, epochs + 1), loss_hist, color='tab:orange', label='training')
    ax.set(xlabel='Epoch', ylabel='Loss', title=f"Loss History")
    ax.legend()
    fig.savefig(path)



def OrthogonalityLoss(eigvecs):
    """
    eigvecs:  a [N × k] (or [..., N, k]) dense tensor on CUDA
    returns:  sum of all pairwise dot-products between distinct columns
    """
    # 1) Gram matrix G [k × k]
    G = eigvecs.transpose(-2, -1) @ eigvecs    # shape (..., k, k)
    # 2) Identity of size k on the same device & dtype
    k = G.size(-1)
    I = torch.eye(k, device=G.device, dtype=G.dtype)

    # 3) Zero out the diagonal by masking
    #    Off-diagonal = G * (1 - I)
    off_diag = G * (1.0 - I)

    # 4) Sum up the off-diagonals
    #    If you prefer a squared penalty, do off_diag.pow(2).sum()
    return torch.norm(off_diag)