# Copyright 2023 solo-learn development team.

# Permission is hereby granted, free of charge, to any person obtaining a copy of
# this software and associated documentation files (the "Software"), to deal in
# the Software without restriction, including without limitation the rights to use,
# copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the
# Software, and to permit persons to whom the Software is furnished to do so,
# subject to the following conditions:

# The above copyright notice and this permission notice shall be included in all copies
# or substantial portions of the Software.

# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED,
# INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR
# PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE
# FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR
# OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
# DEALINGS IN THE SOFTWARE.

import math

import torch
from solo.losses.radialvicreg import (
    covariance_loss,
    invariance_loss,
    variance_loss,
)
from solo.utils.misc import gather


def m_spacings_estimator(x, N, m, epsilon):
    """Calculates the marginal entropies of a D dimensional random variable."""
    x_sorted, _ = torch.sort(x, dim=0)
    spacings = x_sorted[m:] - x_sorted[: N - m]
    spacings = spacings * (N + 1) / m
    marginal_ents = torch.log(spacings + epsilon).sum(dim=0) / (N - m)
    return marginal_ents


def entropy_and_hypercovariance_loss(z1, z2, epsilon):
    """Computes the entropy and hypercovariance terms of the E2MC loss."""
    N, D = z1.size()
    m = round(math.sqrt(N))

    # First view
    x1_hyper = torch.sigmoid(z1)
    ent1 = m_spacings_estimator(x1_hyper, N, m, epsilon)
    x1_hyper_centered = x1_hyper - x1_hyper.mean(dim=0)
    cov_x1_hyper = (x1_hyper_centered.T @ x1_hyper_centered) / (N - 1)

    # Second view
    x2_hyper = torch.sigmoid(z2)
    ent2 = m_spacings_estimator(x2_hyper, N, m, epsilon)
    x2_hyper_centered = x2_hyper - x2_hyper.mean(dim=0)
    cov_x2_hyper = (x2_hyper_centered.T @ x2_hyper_centered) / (N - 1)

    # Entropy loss (we want to maximize entropy, so minimize negative entropy)
    ent_loss = (ent1.mean() + ent2.mean()) / 2

    # Hypercovariance loss
    diag = torch.eye(D, device=z1.device)
    hypercov_loss = (cov_x1_hyper[~diag.bool()].pow_(2).sum() / D) + (
        cov_x2_hyper[~diag.bool()].pow_(2).sum() / D
    )

    return ent_loss, hypercov_loss


def vicreg_e2mc_loss_func(
    z1: torch.Tensor,
    z2: torch.Tensor,
    sim_loss_weight: float = 25.0,
    var_loss_weight: float = 25.0,
    cov_loss_weight: float = 1.0,
    ent_loss_weight: float = 1000.0,
    hypercov_loss_weight: float = 100.0,
    epsilon: float = 1e-7,
):
    """Computes the total VICReg-E2MC loss.

    Args:
        z1 (torch.Tensor): NxD Tensor containing projected features from view 1.
        z2 (torch.Tensor): NxD Tensor containing projected features from view 2.
        sim_loss_weight (float): weight of the invariance term.
        var_loss_weight (float): weight of the variance term.
        cov_loss_weight (float): weight of the covariance term.
        ent_loss_weight (float): weight of the entropy term.
        hypercov_loss_weight (float): weight of the hypercovariance term.
        epsilon (float): small value for numerical stability in entropy calculation.

    Returns:
        torch.Tensor: VICReg-E2MC loss.
    """

    sim_loss = invariance_loss(z1, z2)

    # gather representations from all GPUs
    z1_gathered = gather(z1)
    z2_gathered = gather(z2)

    var_loss = variance_loss(z1_gathered, z2_gathered)
    cov_loss = covariance_loss(z1_gathered, z2_gathered)

    ent_loss, hyper_cov_loss = entropy_and_hypercovariance_loss(z1_gathered, z2_gathered, epsilon)

    total_loss = (
        sim_loss_weight * sim_loss
        + var_loss_weight * var_loss
        + cov_loss_weight * cov_loss
        - ent_loss_weight * ent_loss
        + hypercov_loss_weight * hyper_cov_loss
    )

    return total_loss, sim_loss, var_loss, cov_loss, ent_loss, hyper_cov_loss 