import torch
import torch.nn.functional as F
import numpy as np

def loss_diversity_from_S(S_assign_list, device=None, eps=1e-9):
    """
    L_div = sum_{ℓ} (1/|V^(ℓ)|) * sum_i H(row_i),
    where H(p) = - sum_k p_k log p_k.
    """
    L_div = 0.0
    for S in S_assign_list:
        # S may be np.ndarray; move to torch
        if isinstance(S, np.ndarray):
            S_t = torch.from_numpy(S)
        else:
            S_t = S
        if device is not None:
            S_t = S_t.to(device)
        S_t = S_t.clamp_min(eps)
        row_entropy = -(S_t * S_t.log()).sum(dim=1)  # [N_l]
        L_div = L_div + row_entropy.mean()
    return L_div

def loss_reconstruction_from_treeG(treeG, device=None):
    """
    L_rec = sum_{levels} || H^(ℓ) - U^(ℓ)^T (U^(ℓ) H^(ℓ)) ||_F^2,
    where U^(ℓ) is built from treeG[ℓ]['u'] (list of N_l vectors of length N_l).
    """
    L_rec = 0.0
    for lvl in range(len(treeG)):
        if 'u' not in treeG[lvl]:
            continue
        u_list = treeG[lvl]['u']
        # Some levels may store None (skip safely)
        if u_list is None or any(v is None for v in u_list):
            continue

        U_np = np.stack(u_list, axis=0)  # [N_l, N_l]
        H_np = treeG[lvl]['features']    # [N_l, D]

        U = torch.from_numpy(U_np.astype(np.float32))
        H = torch.from_numpy(H_np.astype(np.float32))
        if device is not None:
            U = U.to(device)
            H = H.to(device)

        H_hat = U.t() @ (U @ H)          # U^T U H
        # Frobenius norm squared
        L_rec = L_rec + F.mse_loss(H_hat, H, reduction='mean')
    L_rec = L_rec / len(treeG)
    return L_rec
