import torch
import torch.nn.functional as F
from torch_geometric.graphgym.config import cfg
from torch_geometric.graphgym.register import register_loss
from torch_geometric.utils import to_dense_batch



def normalize_by_batch(x, batch):
    print(x.shape)
    print(batch.shape)
    num_groups = int(batch.max()) + 1
    norm_sq = x.new_zeros(num_groups, x.shape[-1]).index_add(0, batch, x.pow(2))
    eps   = 1e-5
    norms  = (norm_sq + eps).sqrt()
    x_norm = x / norms[batch]
    
    return x_norm

def orthogonalize_by_batch(x, batch):


    X_dense, mask = to_dense_batch(x, batch) # [B, m, k]

    Q_t, R = torch.linalg.qr(X_dense)  # [B, m, k]
    Q = Q_t       

    # 3) Un‐pack back to (m, k)
    x_orth = Q[mask]  

    return x_orth


def ortho_loss_by_batch(eigvecs, batch):
    """
    eigvecs:  a [N × k] (or [..., N, k]) dense tensor on CUDA
    returns:  sum of all pairwise dot-products between distinct columns
    """
    
    batched_eigvecs, mask = to_dense_batch(eigvecs, batch)
    B = batched_eigvecs.shape[0]
    # 1) Gram matrix G [k × k]
    G = batched_eigvecs.transpose(-2, -1) @ batched_eigvecs    # shape (..., k, k)
    print("G", G.shape)
    # 2) Identity of size k on the same device & dtype
    k = G.size(-1)

    I = torch.eye(k, device=G.device, dtype=G.dtype).unsqueeze(0).expand(B, -1, -1)

    # 3) Zero out the diagonal by masking
    #    Off-diagonal = G * (1 - I)
    off_diag = G * (1.0 - I)

    # 4) Sum up the off-diagonals
    #    If you prefer a squared penalty, do off_diag.pow(2).sum()
    if cfg.posenc_LapPE.squared_ortho:
        loss = off_diag.pow(2).sum()
    else:
        loss = torch.norm(off_diag, dim=[-1, -2]).mean()
    return loss

# Do NOT register, is only used as an add-on to the main loss during pretraining
# @register_loss("eigval_loss") 
def eigval_losses(eigvecs_pred, eigvals_gt, batch, laplacian, true_eigvecs = None):
    """
    eigvecs_pred: [total_nodes, num_eigvals]
    edge_index:   [2, total_edges]
    eigvals_gt:   [total_nodes * num_eigvals] 
    batch:        [total_nodes]              # graph assignment for each node
    max_nodes:    int                        # padded node count per graph
    """

    # 

    

    num_eigvecs = eigvecs_pred.shape[-1]
    if cfg.posenc_LapPE.eigen.skip_zero_freq:
        raise Exception("Orthogonality constraints not implemented for skip_zero_freq = True")
    
    
    # ORTHOGONALIZE + NORMALIZE 
    N = eigvecs_pred.shape[0]
    
    U = eigvecs_pred # [total_nodes, num_eigvals] 
    # print("U before", U)

    U = normalize_by_batch(U, batch)
    ortho_loss = ortho_loss_by_batch(U, batch) 
    # print("U after norm", U)
    if cfg.posenc_LapPE.forced_ortho:
        U = orthogonalize_by_batch(U, batch)
        # U = normalize_by_batch(U, batch)
    # print("U after", U)
    # print("U norm", U.norm(dim=1))

    # print("SANITY CHECK ORTHO LOSS SHOULD BE ZERO:", ortho_loss)
   
    # LOADING THE LAPLACIAN 
    
    L = laplacian
    
    graphsizes = torch.bincount(batch)
    # print("graphsizes min", graphsizes.min())
    # print("graphsizes max", graphsizes.max())
    
    # print("L", L.shape)
    # print("U", U.shape)
    LU = L @ U  


    if true_eigvecs != None:
        # print(true_eigvecs.shape)
        # print(L.shape) 
        # print(eigvals_gt.shape)
        print("SANITY CHECK SHOULD BE ZERO:", (L @ true_eigvecs - true_eigvecs * eigvals_gt.squeeze()).sum())

    Ulbda = U * eigvals_gt.squeeze()
    # print("eigvals gt max", eigvals_gt.max())
    diff = LU - Ulbda # graphs with num_nodes < num_eigvals will have some nan eigval entries, exclude these from computation

    batched_U, _ = to_dense_batch(U, batch) # [B, max_nodes, num_eigvals]
    batched_diff, _ = to_dense_batch(diff, batch) # [B, max_nodes, num_eigvals]
    if cfg.posenc_LapPE.squared_eigval:
        eigval_loss_mat = torch.pow(batched_diff, 2).sum(dim=1)
    else:
        eigval_loss_mat = torch.norm(batched_diff, dim=1) # norm of each diff eigvec 
    eigval_loss_by_idx = eigval_loss_mat.mean(dim=0)

    eigval_loss = eigval_loss_mat.mean()

    energy_loss_mat =  torch.matmul(batched_U.transpose(-2,-1).unsqueeze(-2), batched_diff.transpose(-2, -1).unsqueeze(-1)).squeeze()
    
    if cfg.posenc_LapPE.energy_abs:
        # print("using abs")
        energy_loss_mat = torch.abs(energy_loss_mat)

    energy_loss_by_idx = energy_loss_mat.mean(dim=0)
    energy_loss = energy_loss_mat.mean()
    


    return energy_loss, eigval_loss, ortho_loss / num_eigvecs, energy_loss_by_idx, eigval_loss_by_idx



