# Copyright 2023 solo-learn development team.

# Permission is hereby granted, free of charge, to any person obtaining a copy of
# this software and associated documentation files (the "Software"), to deal in
# the Software without restriction, including without limitation the rights to use,
# copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the
# Software, and to permit persons to whom the Software is furnished to do so,
# subject to the following conditions:

# The above copyright notice and this permission notice shall be included in all copies
# or substantial portions of the Software.

# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED,
# INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR
# PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE
# FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR
# OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
# DEALINGS IN THE SOFTWARE.

import torch
import torch.nn.functional as F
from solo.utils.misc import gather



# Notation - so we have p for prediction, z for projection and than we an online one that is back proped and than the other one is just an EMA of online so slowly moving trailing average

# Original BYOL loss function
def byol_similarity_loss_func(p: torch.Tensor, z: torch.Tensor, simplified: bool = True) -> torch.Tensor:
    """Computes BYOL's similarity loss given batch of predicted features p and projected momentum features z.

    Args:
        p (torch.Tensor): NxD Tensor containing predicted features from view 1
        z (torch.Tensor): NxD Tensor containing projected momentum features from view 2
        simplified (bool): faster computation, but with same result. Defaults to True.

    Returns:
        torch.Tensor: BYOL's similarity loss.
    """
    if simplified:
        p_norm = F.normalize(p, dim=-1)
        z_norm = F.normalize(z.detach(), dim=-1)
        return 2 - 2 * (p_norm * z_norm).sum(dim=-1).mean()

    p = F.normalize(p, dim=-1)
    z = F.normalize(z.detach(), dim=-1) # original BYOL detaches momentum features
    return 2 - 2 * (p * z).sum(dim=1).mean() # cosine similarity - 2 -  2* so that when p,z have same direction, loss = 0, when orthgonal loss = 2, if oppostite - loss = 4

# During training we should expect the similarity loss to be around 1 since random vectors are orthgonal and slowly decrease to 0
# that's kind of what we see too - it starts at 0.7 and drops to 0.23 as expected - that means that the EMA is roughly in the same direction as the online model


# Copied from solo/losses/vicreg.py (and radialvicreg.py) for variance and covariance
def variance_loss(z1: torch.Tensor, z2: torch.Tensor) -> torch.Tensor:
    """Computes variance loss given batch of projected features z1 from view 1 and
    projected features z2 from view 2. Operates on ONLINE projections only.

    Args:
        z1 (torch.Tensor): NxD Tensor containing projected features from view 1.
        z2 (torch.Tensor): NxD Tensor containing projected features from view 2.

    Returns:
        torch.Tensor: variance regularization loss.
    """
    eps = 1e-4 # for stability so no div by 0
    # Applies to gathered tensors if called after gather
    std_z1 = torch.sqrt(z1.var(dim=0) + eps) # this is the std dev of z1
    std_z2 = torch.sqrt(z2.var(dim=0) + eps)
    std_loss = torch.mean(F.relu(1 - std_z1)) + torch.mean(F.relu(1 - std_z2))
    return std_loss


def covariance_loss(z1: torch.Tensor, z2: torch.Tensor) -> torch.Tensor:
    """Computes covariance loss given batch of projected features z1 from view 1 and
    projected features z2 from view 2. Operates on ONLINE projections only.

    Args:
        z1 (torch.Tensor): NxD Tensor containing projected features from view 1.
        z2 (torch.Tensor): NxD Tensor containing projected features from view 2.

    Returns:
        torch.Tensor: covariance regularization loss.
    """
    N, D = z1.size()
    # Applies to gathered tensors if called after gather
    z1_centered = z1 - z1.mean(dim=0) # this is the centered version of z1
    z2_centered = z2 - z2.mean(dim=0) # this is the centered version of z2
    cov_z1 = (z1_centered.T @ z1_centered) / (N - 1) # this is the covariance matrix of z1
    cov_z2 = (z2_centered.T @ z2_centered) / (N - 1) # this is the covariance matrix of z2

    diag = torch.eye(D, device=z1.device) # this is the diagonal matrix of the covariance matrix - we don't want to compute the loss on the diagonal
    cov_loss = cov_z1[~diag.bool()].pow_(2).sum() / D + cov_z2[~diag.bool()].pow_(2).sum() / D # this is the off diagonal covariance loss 
    return cov_loss

# Copied from solo/losses/radialvicreg.py for radial loss
def _empirical_minimal_chi_nll_without_constant(d: int, device: torch.device) -> torch.Tensor:
    """
    Computes the minimal achievable Negative Log-Likelihood (NLL) loss
    when ||z||_2 follows a Chi distribution with `d` degrees of freedom.
    Constant part of the NLL loss.
    Args:
        d (int): Dimension of the feature vector (degrees of freedom).
        device (torch.device): Device of the tensor.
    Returns:
        torch.Tensor: Minimal NLL loss value (scalar tensor).
    """
    if d <= 1: # Avoid log(0) or log of negative for d=1. Chi distribution requires d > 0.
        # For d=1, sqrt(d-1) is 0, log(0) is -inf. Maximize likelihood means NLL is -inf, so target is effectively 0.
        # This is a practical choice to avoid NaN/Inf issues.
        # A proper handling might involve a limit or a different formulation for d=1.
        # For d < 1, this formula is not directly applicable.
        # Given typical feature dimensions are >> 1, this case is rare but handled.
        return torch.tensor(0.0, dtype=torch.float32, device=device)

    d_minus_1 = torch.tensor(d - 1.0, dtype=torch.float32, device=device)
    # Original formula: -(d-1) * log(sqrt(d-1)) + 0.5 * (d-1)
    # simplified: -0.5 * (d-1) * log(d-1) + 0.5 * (d-1)
    #             = 0.5 * (d-1) * (1 - log(d-1))
    # Using original to match provided snippet structure
    log_sqrt_d_minus_1 = torch.log(torch.sqrt(d_minus_1))
    return -(d_minus_1) * log_sqrt_d_minus_1 + 0.5 * d_minus_1


def chi2_radial_nll_loss(z1: torch.Tensor, z2: torch.Tensor, eps: float = 1e-6) -> torch.Tensor:
    """Computes the sum of Chi2 NLL losses for online projections z1 and z2. we dont use EMA here.
    Args:
        z1 (torch.Tensor): NxD Tensor containing projected features from view 1.
        z2 (torch.Tensor): NxD Tensor containing projected features from view 2.
        eps (float): Small epsilon to clamp norms for stability.
    Returns:
        torch.Tensor: Sum of NLL losses for r1 and r2.
    """
    _, D = z1.size()
    device = z1.device

    r1 = torch.norm(z1, dim=1)
    r1_safe = torch.clamp(r1, min=eps) # ensures the min value is eps so no log(0)
    r2 = torch.norm(z2, dim=1)
    r2_safe = torch.clamp(r2, min=eps)

    # Constant offset for NLL to be zero when r ~ chi(D)
    # Must be computed with the correct device
    min_nll_offset = _empirical_minimal_chi_nll_without_constant(D, device=device)

    # NLL for r1: 0.5 * r1^2 - (D-1) * log(r1) - offset
    nll_r1 = 0.5 * r1_safe.pow(2) - (D - 1) * torch.log(r1_safe)
    loss_r1 = nll_r1.mean() - min_nll_offset
    
    # NLL for r2: 0.5 * r2^2 - (D-1) * log(r2) - offset
    nll_r2 = 0.5 * r2_safe.pow(2) - (D - 1) * torch.log(r2_safe)
    loss_r2 = nll_r2.mean() - min_nll_offset

    return loss_r1 + loss_r2 # this is the total loss from the two views


def uniform_loss(x: torch.Tensor, t: int = 2) -> torch.Tensor:
    """Computes the uniform loss for a batch of features x.
    Args:
        x (torch.Tensor): NxD Tensor of features.
        t (int): Temperature parameter.
    Returns:
        torch.Tensor: Uniform loss. 
    
    Uniform loss is used to ensure that the features are uniformly distributed on the unit sphere.
    """
    x_norm = F.normalize(x, dim=1)  # Normalize to unit sphere
    # pdist computes pairwise distances, result is 1D vector of N*(N-1)/2 elements
    # For N=1, pdist is empty. Add eps to avoid log(0).
    if x_norm.size(0) <= 1:
        return torch.tensor(0.0, device=x_norm.device)
    return torch.pdist(x_norm, p=2).pow(2).mul(-t).exp().mean().log() # pdist computes pairwise distances, result is 1D vector of N*(N-1)/2 elements - we than square it, multiply by -t, exponentiate, take the mean and take the log - the lower it is the more uniform the features are


def anisotropy_loss(x: torch.Tensor) -> torch.Tensor:
    """Computes the anisotropy of a batch of features.
    0 means isotropic (average cosine similarity ~ 0).
    1 means collapsed (all vectors point the same way, average cosine similarity ~ 1).

    Args:
        x (torch.Tensor): NxD Tensor of features.

    Returns:
        torch.Tensor: Anisotropy value.
    """
    N, _ = x.shape
    if N <= 1:
        return torch.tensor(0.0, device=x.device, dtype=x.dtype)

    x_norm = F.normalize(x, dim=1)
    cosine_sim_matrix = x_norm @ x_norm.T

    mask = ~torch.eye(N, dtype=torch.bool, device=x.device)
    
    off_diagonal_cosine_sim = cosine_sim_matrix[mask]

    if off_diagonal_cosine_sim.numel() == 0:
        return torch.tensor(0.0, device=x.device, dtype=x.dtype)
        
    mean_cosine_sim = off_diagonal_cosine_sim.mean()
    return mean_cosine_sim


def radial_byol_loss_suite(
    p1: torch.Tensor,  # Prediction from view 1
    p2: torch.Tensor,  # Prediction from view 2
    z1_online: torch.Tensor,  # Online projection from view 1
    z2_online: torch.Tensor,  # Online projection from view 2
    z1_momentum: torch.Tensor,  # Momentum projection from view 1
    z2_momentum: torch.Tensor,  # Momentum projection from view 2
) -> tuple:
    """
    Computes the full suite of losses for RadialBYOL.
    - BYOL's similarity loss (prediction vs momentum projection of other view).
    - Variance loss on online projections.
    - Covariance loss on online projections.
    - Radial Chi2 NLL loss on online projections.

    Args:
        p1: Prediction from online model, view 1.
        p2: Prediction from online model, view 2.
        z1_online: Online projection, view 1.
        z2_online: Online projection, view 2.
        z1_momentum: Momentum projection, view 1.
        z2_momentum: Momentum projection, view 2.

    Returns:
        A tuple containing:
            (byol_sim_loss, var_loss, cov_loss, radial_loss)
    """
    # BYOL similarity loss
    # Compare p1 with z2_momentum (detached) and p2 with z1_momentum (detached)
    sim_loss_1 = byol_similarity_loss_func(p1, z2_momentum.detach())
    sim_loss_2 = byol_similarity_loss_func(p2, z1_momentum.detach())
    total_sim_loss = (sim_loss_1 + sim_loss_2) * 0.5 # Average over the two comparisons
    # when sim loss = 0, then p1,z2_momentum are the same direction and p2,z1_momentum are the same direction when it equals 1 they are orthgonal and when it equals 2 they are opposite directions

    ### DURING DISTRIBUTED TRAINING DDP - THIS MAYBE WRONG - LOOK INTO IT @ ask Yillun
    # # Gather features from all processes
    # z1_online = gather(z1_online)
    # z2_online = gather(z2_online)
    # z1_momentum = gather(z1_momentum)
    # z2_momentum = gather(z2_momentum) ### MAYBE NEED TO DO SOMETHING LIKE THAT FOR DDP

    var_loss_val = variance_loss(z1_online, z2_online)
    cov_loss_val = covariance_loss(z1_online, z2_online)
    
    # Radial loss on online projections
    radial_loss_val = chi2_radial_nll_loss(z1_online, z2_online)

    return total_sim_loss, var_loss_val, cov_loss_val, radial_loss_val
