import torch

def frobenius_loss_shared(h_shared, inf_factor=1e6, orthogonal_weight=1.0):
    wtw = h_shared.T @ h_shared
    # look for nan and inf, and replace them with 0 and inf_factor*max
    # Handle NaN and Inf without in-place operations
    mask_nan = torch.isnan(wtw)
    mask_inf = wtw == float('inf')
    if mask_nan.any() or mask_inf.any():
        wtw_clean = torch.where(mask_nan, torch.zeros_like(wtw), wtw)
        max_val = wtw_clean.max() * inf_factor
        wtw_clean = torch.where(mask_inf, torch.ones_like(wtw) * max_val * inf_factor, wtw_clean)
        wtw_tri = torch.triu(wtw_clean, diagonal=1)
    else:
        wtw_tri = torch.triu(wtw, diagonal=1)
    orthogonal_loss = torch.norm(wtw_tri, p='fro') * orthogonal_weight
    # normalize by the number of active neurons (maximum possible connections)
    #orthogonal_loss_norm = orthogonal_loss / (active_neurons.shape[0] * (active_neurons.shape[0] - 1) / 2)
    return orthogonal_loss

def frobenius_loss_crossmodal(h1, h2, h1_ranks, h2_ranks, inf_factor=1e6, orthogonal_weight=1.0):
    wtw = (h1[:,:h1_ranks]).T @ (h2[:,:h2_ranks]) # will give a matrix of shape (d1, d2) so order does not matter for the final norm?
    # look for nan and inf, and replace them with 0 and inf_factor*max
    # Handle NaN and Inf without in-place operations
    mask_nan = torch.isnan(wtw)
    mask_inf = wtw == float('inf')
    if mask_nan.any() or mask_inf.any():
        wtw_clean = torch.where(mask_nan, torch.zeros_like(wtw), wtw)
        max_val = wtw_clean.max() * inf_factor
        wtw_clean = torch.where(mask_inf, torch.ones_like(wtw) * max_val * inf_factor, wtw_clean)
        wtw_tri = torch.triu(wtw_clean, diagonal=1)
    else:
        wtw_tri = torch.triu(wtw, diagonal=1)
    orthogonal_loss = (torch.norm(wtw_tri, p='fro') * orthogonal_weight) / (h1_ranks * h2_ranks)
    return orthogonal_loss