# Copyright 2023 solo-learn development team.

# Permission is hereby granted, free of charge, to any person obtaining a copy of
# this software and associated documentation files (the "Software"), to deal in
# the Software without restriction, including without limitation the rights to use,
# copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the
# Software, and to permit persons to whom the Software is furnished to do so,
# subject to the following conditions:

# The above copyright notice and this permission notice shall be included in all copies
# or substantial portions of the Software.

# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED,
# INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR
# PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE
# FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR
# OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
# DEALINGS IN THE SOFTWARE.

import torch
import torch.nn.functional as F
from torch import Tensor

# Import the individual loss components needed
from solo.losses.radialbyol import variance_loss as compute_variance_loss
from solo.losses.radialbyol import covariance_loss as compute_covariance_loss
from solo.losses.radialbyol import chi2_radial_nll_loss as compute_chi2_radial_nll_loss

def nat_alignment_loss(features: Tensor, targets: Tensor) -> Tensor:
    """
    Computes the NAT alignment loss between features and fixed target representations.
    
    Args:
        features (torch.Tensor): NxD Tensor containing the normalized features.
        targets (torch.Tensor): NxD Tensor containing the fixed target representations.
        
    Returns:
        torch.Tensor: The alignment loss.
    """
    # Make sure both inputs are normalized
    features_norm = F.normalize(features, dim=1)
    targets_norm = F.normalize(targets, dim=1)
    
    # Calculate squared L2 distance (equivalent to 2 - 2*cosine_similarity)
    return ((features_norm - targets_norm) ** 2).sum(dim=1).mean()

def hungarian_matching(similarity_matrix: Tensor) -> Tensor:
    """
    Computes the optimal matching between features and targets using
    the Hungarian algorithm.
    
    Args:
        similarity_matrix (torch.Tensor): NxN Tensor containing the similarity between
                                        features and target representations.
    
    Returns:
        torch.Tensor: NxN permutation matrix with binary entries.
    """
    # Convert to CPU for scipy implementation
    similarity_cpu = similarity_matrix.detach().cpu().numpy()
    
    # Use scipy's implementation of the Hungarian algorithm
    try:
        from scipy.optimize import linear_sum_assignment
        row_ind, col_ind = linear_sum_assignment(-similarity_cpu)
        
        # Create permutation matrix
        batch_size = similarity_matrix.size(0)
        perm_matrix = torch.zeros(batch_size, batch_size, device=similarity_matrix.device)
        perm_matrix[row_ind, col_ind] = 1.0
        
        return perm_matrix
    except ImportError:
        raise ImportError("scipy is required for hungarian matching. Please install it with pip install scipy.")

def nat_auxiliary_losses(z1_online: Tensor, z2_online: Tensor = None) -> tuple:
    """Computes auxiliary losses (variance, covariance, radial) for NAT.

    Args:
        z1_online (torch.Tensor): NxD Tensor of projected features from the first view.
        z2_online (torch.Tensor, optional): NxD Tensor of projected features from the second view.
            If None, calculates single-view versions of variance and radial loss (covariance is zero).

    Returns:
        tuple: (variance_loss_val, covariance_loss_val, radial_loss_val)
    """
    var_loss_val = torch.tensor(0.0, device=z1_online.device)
    cov_loss_val = torch.tensor(0.0, device=z1_online.device)
    radial_loss_val = torch.tensor(0.0, device=z1_online.device)

    if z2_online is not None:
        var_loss_val = compute_variance_loss(z1_online, z2_online)
        cov_loss_val = compute_covariance_loss(z1_online, z2_online)
        radial_loss_val = compute_chi2_radial_nll_loss(z1_online, z2_online)
    else:
        # Single view calculations (adapted from how RadialBYOL/VICReg might handle single terms)
        # The imported functions sum terms for z1 and z2. So for single view, we pass z1,z1 and halve.
        var_loss_val = compute_variance_loss(z1_online, z1_online) / 2.0
        # Covariance of a single view against itself (after centering) using the formula for two views would give 0
        # if we consider the off-diagonal terms. For a single view, typically covariance isn't applied in this context.
        # However, if we force compute_covariance_loss(z1,z1), it effectively computes sum of squared off-diagonals of cov(z1).
        # For simplicity and consistency with previous logic, we make it compute_covariance_loss(z1_online, z1_online) / 2.0.
        cov_loss_val = compute_covariance_loss(z1_online, z1_online) / 2.0 
        radial_loss_val = compute_chi2_radial_nll_loss(z1_online, z1_online) / 2.0
        
    return var_loss_val, cov_loss_val, radial_loss_val