import torch
import torch.nn.functional as F

def diversity_loss(model, coef=0.1, layer_name='fc_layer_1', **kwargs):
    W = getattr(model, layer_name).weight
    reg = _diversity_loss(W, **kwargs)
    return coef * reg


def _diversity_loss(W, tau=0.8, abs_sim=False, reduction='sum', eps=1e-8):
    # Normalize each row (weight vector) to unit norm
    W_norm = F.normalize(W, p=2, dim=1, eps=eps)

    # Compute cosine similarity matrix
    sim_matrix = torch.mm(W_norm, W_norm.t())  # (n, n)

    # Optionally use absolute similarity
    if abs_sim:
        sim_matrix = sim_matrix.abs()

    # Create a mask to exclude diagonal (self-similarity)
    n = W.shape[0]
    diag_mask = ~torch.eye(n, dtype=torch.bool, device=W.device)
    # Extract relevant similarities (off-diagonal)
    sim_vals = sim_matrix[diag_mask]

    # Compute penalty where similarity > tau
    penalty = torch.clamp(sim_vals - tau, min=0) ** 2

    # Reduction
    if reduction == 'sum':
        reg = penalty.sum()
    elif reduction == 'mean':
        reg = penalty.mean()
    else:
        raise ValueError("reduction must be 'sum' or 'mean'")
    
    return reg


def orthogonality_loss(model, coef=0.1, layer_name='fc_layer_1', **kwargs):
    W = getattr(model, layer_name).weight
    reg = _orthogonality_loss(W, **kwargs)
    return coef * reg


def _orthogonality_loss(W, reduction='sum'):
    # Compute Gram matrix (handles 2D and 3D/batched input)
    G = torch.matmul(W, W.transpose(-1, -2))  # (..., n, n)
    
    # Get diagonal elements (..., n)
    diag = torch.diagonal(G, dim1=-2, dim2=-1)
    
    # Create zero-diagonal Gram matrix
    # torch.diag_embed works for both 2D and batched 3D inputs
    G_off = G - torch.diag_embed(diag)
    
    # Compute squared Frobenius norm of off-diagonal elements
    penalty = torch.norm(G_off, p="fro", dim=(-2, -1)) ** 2  # Handles batch or scalar
    
    # Reduction
    if penalty.ndim == 0:
        reg = penalty
    else:
        if reduction == 'sum':
            reg = penalty.sum()
        elif reduction == 'mean':
            reg = penalty.mean()
        else:
            raise ValueError("reduction must be 'sum' or 'mean'")
    
    return reg


def fold_plane_alignment_loss(model, target_point, folding_point, coef=0.1, layer_name='fc_layer_1', new_param_index=-1, **kwargs):
    w = getattr(model, layer_name).weight[new_param_index]
    b = getattr(model, layer_name).bias[new_param_index]

    distance_target = torch.abs(torch.dot(w, target_point) + b) / w.norm()
    distance_fold = torch.abs(torch.dot(w, folding_point) + b) / w.norm()

    directional_loss = torch.abs(torch.dot(w, (target_point - folding_point))) / (w.norm() * (target_point - folding_point).norm())

    return coef * (distance_target + distance_fold + directional_loss)