import torch
import torch.nn.functional as F

def l_cl(zi, zj, tau=0.1):
    """Base Contrastive Loss (Eq. 1)"""
    logits = torch.matmul(zi, zj.T) / tau
    labels = torch.arange(zi.size(0), device=zi.device)
    return F.cross_entropy(logits, labels)

def l_mv_cl(z_views, tau=0.1):
    """Multi-View Contrastive Loss (Eq. 2)"""
    view_names = list(z_views.keys())
    V = len(view_names)
    total_loss, pairs = 0, 0
    for i in range(V):
        for j in range(i + 1, V):
            total_loss += l_cl(z_views[view_names[i]], z_views[view_names[j]], tau)
            pairs += 1
    return total_loss / pairs if pairs > 0 else 0

def l_global(e_k, z_views, tau=0.1):
    """Global Alignment Loss (Eq. 3)"""
    total_loss = 0
    for v in z_views.values():
        total_loss += l_cl(e_k, v, tau)
    return total_loss / len(z_views)



def l_or(z_views):
    """
    Orthogonality Regularization (OR) Loss (Eq. 5).
    Penalizes cross-view redundancy by decorrelating representations [3, 4].
    """
    view_names = list(z_views.keys())
    V = len(view_names)
    N = z_views[view_names[0]].size(0) 
    
    loss = 0
    pairs = 0

    for i in range(V):
        # Center the first view's projections
        zi = z_views[view_names[i]]
        zi = zi - zi.mean(dim=0, keepdim=True)

        for j in range(i + 1, V):
            # Center the second view's projections
            zj = z_views[view_names[j]]
            zj = zj - zj.mean(dim=0, keepdim=True)

            # Compute Cross-covariance matrix C (Eq. 4) [3, 4]
            C = (zi.T @ zj) / N
            
            # Penalize the squared Frobenius norm of the cross-covariance [5, 6]
            loss += torch.linalg.matrix_norm(C, ord="fro") ** 2
            pairs += 1

    return loss / pairs if pairs > 0 else 0