import torch
import torch.nn.functional as F

from pytorch_metric_learning import losses, reducers
from typing import Literal


def orthogonality_loss(
    vectors: torch.Tensor,
    eps: float = 1e-6,
    style: Literal["cosine_offdiag", "mse_identity"] = "cosine_offdiag",
) -> torch.Tensor:
    """
    Orthogonality penalty.

    - cosine_offdiag: sum of off-diagonal cosine similarities of L2-normalized vectors
    - mse_identity: mean squared error between Gram matrix and identity
    """
    if style == "cosine_offdiag":
        normalized_vectors = F.normalize(vectors, p=2, dim=-1, eps=eps)
        cosine_similarity_matrix = torch.matmul(
            normalized_vectors, normalized_vectors.t()
        )
        diagonal_mask = torch.eye(
            vectors.shape[0], dtype=torch.bool, device=vectors.device
        )
        off_diagonal_similarities = cosine_similarity_matrix[~diagonal_mask]
        loss = off_diagonal_similarities.sum()
        return loss.to(dtype=vectors.dtype)
    elif style == "mse_identity":
        k = vectors.shape[0]
        gram = vectors @ vectors.t()
        eye = torch.eye(k, device=vectors.device, dtype=vectors.dtype)
        # MSE over all entries, normalized by k^2 for scale invariance
        loss = F.mse_loss(gram, eye, reduction="sum") / (k * k)
        return loss
    else:
        raise ValueError(f"Unknown orthogonality style: {style}")


def magnitude_loss(activation_shifts: torch.Tensor, p=2, q=1) -> torch.Tensor:
    """
    Encourages the activation shifts to have a large magnitude.
    The loss is the negative of the sum L2 norm of the shifts, so minimizing
    this loss maximizes the magnitude.

    Args:
        activation_shifts: A tensor of shape (batch_size, n_vectors, hidden_dim).

    Returns:
        The magnitude loss.
    """
    return (
        -activation_shifts.norm(dim=-1).pow(p).sum().pow(1 / q)
        / activation_shifts.shape[0]
    )


def diversity_loss(
    activation_shifts: torch.Tensor,
    vector_idxs: torch.Tensor,
    loss_type: str = "supcon",
    temperature: float = 0.1,
) -> torch.Tensor:
    """
    Encourages different steering vectors to produce diverse activation shifts
    for the same input, using a contrastive loss (NT-Xent).

    Args:
        activation_shifts: A tensor of shape (batch_size, n_vectors, hidden_dim).
        vector_idxs: The indices of the steering vectors.
        loss_type: The type of loss to use.
        temperature: The temperature parameter for the contrastive loss.

    Returns:
        The diversity loss.
    """
    # Normalize activation_shifts along the last dimension before contrastive loss
    normed_shifts = F.normalize(activation_shifts, p=2, dim=-1)

    if loss_type == "supcon":
        loss_fn = losses.SupConLoss(temperature=temperature)
    elif loss_type == "circle":
        loss_fn = losses.CircleLoss(m=0.4, gamma=80)
    elif loss_type == "multisimilarity":
        loss_fn = losses.MultiSimilarityLoss(alpha=2.0, beta=50.0, base=0.5)
    elif loss_type == "ntxent":
        loss_fn = losses.NTXentLoss(temperature=temperature)
    else:
        raise ValueError(f"Unknown loss type: {loss_type}")

    loss = loss_fn(normed_shifts, vector_idxs)
    return loss.to(dtype=activation_shifts.dtype)
