# Copyright 2023 solo-learn development team.

# Permission is hereby granted, free of charge, to any person obtaining a copy of
# this software and associated documentation files (the "Software"), to deal in
# the Software without restriction, including without limitation the rights to use,
# copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the
# Software, and to permit persons to whom the Software is furnished to do so,
# subject to the following conditions:

# The above copyright notice and this permission notice shall be included in all copies
# or substantial portions of the Software.

# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED,
# INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR
# PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE
# FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR
# OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
# DEALINGS IN THE SOFTWARE.

import math
import torch
import torch.nn.functional as F
from solo.utils.misc import gather
import mpmath as mp

from torch.distributions.exponential import Exponential
from torch.distributions.multivariate_normal import MultivariateNormal
from torch.distributions.laplace import Laplace
from solo.utils.metrics import (
    batch_sparsity_metric,
    embedding_sparsity_metric,
    count_avg_nonzero_elements_per_dimension,
    count_avg_nonzero_elements_per_sample,
    active_feature_fraction,
)
# from ipdb import ipdb

# =========================
# VICReg-specific pieces --NOTE WE NEVER USE VARIANCE AND COVARIANCE LOSS FOR ANY ITER-DIST METHODS; 
# WE ONLY LOG THESE TERMS BUT NEVER OPTIMIZE FOR THEM. ITER-DIST METHODS USE SIM_LOSS AND ONE_D_DIST_LOSS ONLY.
# =========================
def invariance_loss(z1: torch.Tensor, z2: torch.Tensor) -> torch.Tensor:
    """Computes mse loss given batch of projected features z1 from view 1 and
    projected features z2 from view 2.

    Args:
        z1 (torch.Tensor): NxD Tensor containing projected features from view 1.
        z2 (torch.Tensor): NxD Tensor containing projected features from view 2.

    Returns:
        torch.Tensor: invariance loss (mean squared error).
    """

    return F.mse_loss(z1, z2)


def variance_loss(z1: torch.Tensor, z2: torch.Tensor) -> torch.Tensor:
    """Computes variance loss given batch of projected features z1 from view 1 and
    projected features z2 from view 2.

    Args:
        z1 (torch.Tensor): NxD Tensor containing projected features from view 1.
        z2 (torch.Tensor): NxD Tensor containing projected features from view 2.

    Returns:
        torch.Tensor: variance regularization loss.
    """

    eps = 1e-4
    std_z1 = torch.sqrt(z1.var(dim=0) + eps)
    std_z2 = torch.sqrt(z2.var(dim=0) + eps)
    std_loss = torch.mean(F.relu(1 - std_z1)) + torch.mean(F.relu(1 - std_z2))
    return std_loss


def covariance_loss(z1: torch.Tensor, z2: torch.Tensor) -> torch.Tensor:
    """Computes covariance loss given batch of projected features z1 from view 1 and
    projected features z2 from view 2.

    Args:
        z1 (torch.Tensor): NxD Tensor containing projected features from view 1.
        z2 (torch.Tensor): NxD Tensor containing projected features from view 2.

    Returns:
        torch.Tensor: covariance regularization loss.
    """

    N, D = z1.size()

    z1 = z1 - z1.mean(dim=0)
    z2 = z2 - z2.mean(dim=0)
    cov_z1 = (z1.T @ z1) / (N - 1)
    cov_z2 = (z2.T @ z2) / (N - 1)

    diag = torch.eye(D, device=z1.device)
    cov_loss = cov_z1[~diag.bool()].pow_(2).sum() / D + cov_z2[~diag.bool()].pow_(2).sum() / D
    return cov_loss




# =========================
# Gaussian-specific pieces
# =========================
def cvm_exact_loss_for_one_view(x):
    # x: (B, D)
    B, D = x.shape

    # sort each column independently
    x_sorted, _ = torch.sort(x, dim=0)  # (B, D)

    # compute target normal CDF per element
    F0 = 0.5 * (1 + torch.erf(x_sorted / math.sqrt(2.0)))  # (B, D)
    # ideal target quantiles (broadcast across D)
    i = torch.arange(1, B + 1, device=x.device, dtype=x.dtype).unsqueeze(1)  # (B, 1)
    target = (2 * i - 1) / (2.0 * B)  # (B, 1), will broadcast over D

    # compute CvM statistic per dimension, then average
    omega2 = (1.0 / (12.0 * B)) + torch.sum((F0 - target) ** 2, dim=0)  # (D,)
    return omega2.mean() # scalar

def cvm_exact_loss(z1, z2, orthogonal_transform):
    z1_transformed = z1 @ orthogonal_transform
    z2_transformed = z2 @ orthogonal_transform

    omega2_z1 = cvm_exact_loss_for_one_view(z1_transformed)
    omega2_z2 = cvm_exact_loss_for_one_view(z2_transformed)
    final_loss = (omega2_z1 + omega2_z2) / 2
    return final_loss

def jarque_bera_loss_for_one_view(x: torch.Tensor, eps: float = 1e-8) -> torch.Tensor:
    """
    x: (B, D) — compute JB per feature dimension and average
    """
    n = x.shape[0]
    mean = x.mean(dim=0)
    var = x.var(dim=0, unbiased=True)
    std = var.sqrt().clamp(min=1e-8)
    skewness = ((x - mean) / std).pow(3).mean(dim=0)
    kurtosis = ((x - mean) / std).pow(4).mean(dim=0)

    # Test statistics for each moment
    # 1. Mean: (mean^2) / (var / n) ~ chi2(1)
    stat_mean = (mean**2) / (var / n)
    # 2. Variance: ((var - 1)^2) / (2 / (n-1)) ~ chi2(1)
    stat_var = ((var - 1) ** 2) / (2 / (n - 1))
    # 3. Skewness and 4. Kurtosis: Jarque-Bera part
    stat_skew_kurt = n / 6 * (skewness**2 + 0.25 * (kurtosis - 3) ** 2)

    # Total statistic: sum of all four
    stat = stat_mean + stat_var + stat_skew_kurt

    # p-value for chi-squared with 4 degrees of freedom
    # CDF: 1 - gammainc(2, stat/2)
    # For 4 dof: 1 - (1 + stat/2 + (stat/2)**2/2) * exp(-stat/2)
    # p_value = 1 - (1 + stat / 2 + (stat / 2) ** 2 / 2) * torch.exp(-stat / 2)
    # moments = {"mean": mean, "var": var, "skewness": skewness, "kurtosis": kurtosis}
    return stat.mean()


def jarque_bera_loss(z1, z2, orthogonal_transform):
    z1_transformed = z1 @ orthogonal_transform.T
    z2_transformed = z2 @ orthogonal_transform.T

    jarque_bera_loss_z1 = jarque_bera_loss_for_one_view(z1_transformed)
    jarque_bera_loss_z2 = jarque_bera_loss_for_one_view(z2_transformed)

    final_loss = (jarque_bera_loss_z1 + jarque_bera_loss_z2) / 2
    return final_loss

# =========================
# SIGReg-specific pieces (LeJEPA)
# =========================
def sigreg_loss_for_one_view(x: torch.Tensor, A: torch.Tensor, knots: int = 17, t_max: float = 3.0) -> torch.Tensor:
    """
    Computes SIGReg loss for one view using an efficient quadrature implementation.
    Reference: LeJEPA: Provable and Scalable Self-Supervised Learning Without the Heuristics (arXiv:2511.08544)
    
    Args:
        x: (B, D) gathered features
        A: (K, D) orthogonal transform / projection matrix
        knots: number of integration points
        t_max: integration range [0, t_max] # usually 3.0
        
    Returns:
        torch.Tensor: SIGReg loss scalar
    """
    
    dev = x.device
    dtype = x.dtype
    
    # 1. Setup integration nodes/weights
    # Note -- I followed the min implementation from gh -- paper has bw tmin and tmax
    t = torch.linspace(0, t_max, knots, dtype=dtype, device=dev)
    dt = t_max / (knots - 1)
    # improved quadrature weights for [0, t_max] exploiting symmetry
    weights = torch.full((knots,), 2 * dt, dtype=dtype, device=dev)
    weights[0] = dt
    weights[-1] = dt
    
    # Theoretical CF for Standard Normal N(0, 1)
    window = torch.exp(-t.square() / 2.0)
    phi = window 
    weights = weights * window
    
    # 2. Project
    # x is (B, D). A is (K, D).
    # We want x @ A.T -> (B, K)
    projections = x @ A.T 
    
    # 3. Empirical CF
    # x_t: (B, K, T)
    x_t = projections.unsqueeze(-1) * t
    
    # Mean across batch (B) -> (K, T)
    re_ecf = x_t.cos().mean(dim=0)
    im_ecf = x_t.sin().mean(dim=0)
    
    # 4. Error
    # |ECF - CF|^2 = (Re(ECF) - Re(CF))^2 + (Im(ECF) - Im(CF))^2
    # Target CF is real (symmetric distribution). Im(CF) = 0.
    err = (re_ecf - phi).square() + im_ecf.square()
    
    # 5. Integrate and scale
    # err: (K, T) @ weights: (T,) -> (K,)
    integral = err @ weights
    
    # Scale by N (batch size) as per SIGReg definition
    N = x.size(0)
    loss_per_slice = integral * N
    
    return loss_per_slice.mean()

def sigreg_loss(z1, z2, orthogonal_transform):
    if isinstance(orthogonal_transform, list) and len(orthogonal_transform) == 2: # we copy this from sliced wasserstien -- there we think this would be useful for SVD
         loss_z1 = sigreg_loss_for_one_view(z1, orthogonal_transform[0])
         loss_z2 = sigreg_loss_for_one_view(z2, orthogonal_transform[1])
    elif isinstance(orthogonal_transform, torch.Tensor):
         loss_z1 = sigreg_loss_for_one_view(z1, orthogonal_transform)
         loss_z2 = sigreg_loss_for_one_view(z2, orthogonal_transform)
    else:
        raise ValueError("Invalid orthogonal_transform type for sigreg_loss")
    
    return (loss_z1 + loss_z2) / 2

# =========================
# Laplace-specific pieces
# =========================
def sample_symmetric_multivariate_laplace(d, Sigma, n_samples): # this is elliptical laplace
    '''
    Sample from the symmetric multivariate Laplace distribution in float32 to
    avoid half-precision cholesky issues on GPU, then return in Sigma's dtype.
    '''
    device = Sigma.device
    out_dtype = Sigma.dtype
    Sigma_f32 = Sigma.to(torch.float32)

    # sample w ~ Exp(1) in float32
    exp_dist = Exponential(rate=torch.tensor([1.0], device=device, dtype=torch.float32))
    w = exp_dist.sample((n_samples,))

    # sample x ~ N(0, Sigma) in float32
    mvn_dist = MultivariateNormal(loc=torch.zeros(d, device=device, dtype=torch.float32), covariance_matrix=Sigma_f32)
    x = mvn_dist.sample((n_samples,))

    # sample y (float32) and cast to output dtype
    y = torch.sqrt(w) * x
    return y.to(dtype=out_dtype)

# =========================
# Product Laplace-specific pieces
# =========================
# def sample_product_laplace(d, n_samples, chosed_sigma=None):
#     assert chosed_sigma is not None, "chosed_sigma must be provided"
    
#     # scale = 1 / torch.sqrt(torch.tensor(2.0))
    
#     laplace_dist = Laplace(loc=0.0, scale=chosed_sigma)
#     samples = laplace_dist.sample((n_samples, d)) # we sample each distribution independently
    
#     return samples




def sample_product_laplace(d, n_samples, device, dtype, chosed_sigma=None):
    '''
    Sample from the product Laplace distribution directly on the target device
    using torch.distributions.Laplace.
    '''
    if chosed_sigma is None:
        scale = torch.tensor(1.0 / math.sqrt(2.0), device=device, dtype=dtype)
    else:
        scale = torch.tensor(chosed_sigma, device=device, dtype=dtype)
    loc = torch.tensor(0.0, device=device, dtype=dtype)
    
    laplace_dist = Laplace(loc=loc, scale=scale)
    return laplace_dist.sample((n_samples, d))

# =========================
# Generalized Gaussian Distribution
# =========================
def determine_sigma_for_lp_dist(p):
    sigma = (math.gamma(1/p)**(1/2)) / ((p ** (1/p)) * (math.gamma(3/p)**(1/2)))
    return sigma

def sample_lp_distribution(shape, p, loc=0.0, scale=1.0):
    sign = torch.empty(shape).bernoulli_(0.5)
    sign = 2 * sign - 1
    gamma = torch.distributions.Gamma(concentration=1.0/p, rate=1.0)
    g = gamma.sample(shape)
    x = sign * (p * g).pow(1.0 / p)
    return loc + scale * x

def rectified_gengaus_mean_var_unified(p, mu, sigma):
    """
    Unified (non-piecewise) mean/var for Y = ReLU(X),
    X ~ GN_p(mu, sigma) with pdf ∝ exp(-|x-mu|^p / (p*sigma^p)).
    Uses sign(mu) + lower/upper incomplete gamma functions.
    """
    p = mp.mpf(p)
    mu = mp.mpf(mu)
    sigma = mp.mpf(sigma)
    if sigma <= 0:
        raise ValueError("sigma must be > 0")
    if p <= 0:
        raise ValueError("p must be > 0")

    sgn = mp.sign(mu)  # -1, 0, +1
    s1 = mp.mpf(1) / p
    s2 = mp.mpf(2) / p
    s3 = mp.mpf(3) / p

    t = (abs(mu) ** p) / (p * (sigma ** p))
    G1 = mp.gamma(s1)

    # lower incomplete gammas
    lower1 = mp.gammainc(s1, 0, t)         # γ(1/p, t)
    lower3 = mp.gammainc(s3, 0, t)         # γ(3/p, t)

    # upper incomplete gamma
    upper2 = mp.gammainc(s2, t, mp.inf)    # Γ(2/p, t)

    # unified coefficients
    A = (G1 + sgn * lower1) / G1
    B = upper2 / G1
    C = (mp.gamma(s3) + sgn * lower3) / G1

    p1 = p ** (mp.mpf(1) / p)
    p2 = p ** (mp.mpf(2) / p)

    EY  = mp.mpf('0.5') * (mu * A + p1 * sigma * B)
    EY2 = mp.mpf('0.5') * (mu**2 * A + 2*mu*p1*sigma*B + p2 * sigma**2 * C)
    VarY = EY2 - EY**2
    return float(EY), float(VarY)

def var_rectified_gengaus_regularized(p, mu, sigma):
    return rectified_gengaus_mean_var_unified(p, mu, sigma)[1]

def choose_sigma_for_unit_var(p, mu, target_var=1.0, rtol=1e-10, max_iter=2000):
    """
    Solve for sigma>0 such that Var(ReLU(X)) = target_var where X~GN_p(mu,sigma).
    Robust bisection on f(sigma)=Var- target_var.
    """
    p = mp.mpf(p); mu = mp.mpf(mu); target_var = mp.mpf(target_var)

    def f(sig):
        return var_rectified_gengaus_regularized(p, mu, sig) - target_var

    # --- bracket a root ---
    lo = mp.mpf('1e-8')
    hi = mp.mpf('1.0')
    flo = f(lo)
    fhi = f(hi)

    # Increase hi until sign change
    k = 0
    while flo * fhi > 0 and k < 200:
        hi *= 2
        fhi = f(hi)
        k += 1

    if flo * fhi > 0:
        raise RuntimeError("Failed to bracket a root for sigma. Try different initial range.")

    # --- bisection ---
    for _ in range(max_iter):
        mid = (lo + hi) / 2
        fmid = f(mid)

        if abs(fmid) <= rtol * (1 + abs(target_var)):
            return float(mid)

        if flo * fmid <= 0:
            hi, fhi = mid, fmid
        else:
            lo, flo = mid, fmid

    return float((lo + hi) / 2)


# =========================
# Generalized Gaussian Distribution
# =========================
def determine_sigma_for_lp_dist(p):
    sigma = (math.gamma(1/p)**(1/2)) / ((p ** (1/p)) * (math.gamma(3/p)**(1/2)))
    return sigma

def sample_lp_distribution(shape, p, loc=0.0, scale=1.0):
    device = "cpu"  # run on CPU; could be revisited for performance
    if isinstance(loc, torch.Tensor):
        device = loc.device
    elif isinstance(scale, torch.Tensor):
        device = scale.device
    
    sign = torch.empty(shape, device=device).bernoulli_(0.5)
    sign = 2 * sign - 1
    gamma = torch.distributions.Gamma(concentration=1.0/p, rate=1.0)
    g = gamma.sample(shape).to(device)
    x = sign * (p * g).pow(1.0 / p) # power 1/p will be numerically unstable so cpu makes sense
    return loc + scale * x


# ========================================================
# shared empirical cdf-based distribution matching losses
# ========================================================
def sliced_wasserstein_distance_for_one_view(features, P_directions, target_dist, mean_shift_scalar_for_rectified_gauss=0.0, p_norm_for_rectified_lp_distribution=1.0, chosed_sigma=None):
    """
    Computes the Sliced Wasserstein Distance between the features and a elliptical multivariate laplace distribution.
    
    Args:
        features (torch.Tensor): The high-dimensional features from your neural network, shape (B, D).
        num_projections (int): Number of random projections to use for approximation.

    Returns:
        torch.Tensor: The Sliced Wasserstein Distance loss.
    """
    # get the shape of the features
    B, D = features.shape

    # 2. Project the features and a target distribution (elliptical multivariate laplace)
    projected_features = torch.matmul(features, P_directions.T)

    # TODO: caveat: right now chosed_sigma is only implemented for rectified distributions and gauss/lp. For other distributions, we won't make any changes.
    if target_dist == "gauss":
        # We first multiply by sigma, and then shift by mu | e.g. gaussian_samples = torch.randn(B, D) * sigma + mu
        target_samples = torch.randn(B, D, device=features.device, dtype=features.dtype) * (chosed_sigma if chosed_sigma is not None else 1.0) + mean_shift_scalar_for_rectified_gauss
    elif target_dist == "rectified_gauss":
        # We first multiply by sigma, and then shift by mu, and then rectify | e.g. rec_gaussian_samples = torch.relu(torch.randn(B, D) * sigma + mu)
        target_samples = torch.relu(torch.randn(B, D, device=features.device, dtype=features.dtype) * (chosed_sigma if chosed_sigma is not None else 1.0) + mean_shift_scalar_for_rectified_gauss)
    elif target_dist == "laplace":
        target_samples = sample_symmetric_multivariate_laplace(
            D,
            torch.eye(D, device=features.device, dtype=torch.float32),
            n_samples=B,
        ).to(device=features.device, dtype=features.dtype)
    elif target_dist == "product_laplace":
        target_samples = sample_product_laplace(
            D,
            n_samples=B,
            device=features.device,
            dtype=features.dtype,
            chosed_sigma=chosed_sigma if chosed_sigma is not None else (1 / math.sqrt(2.0)), # we change this for product laplace
        ).to(device=features.device, dtype=features.dtype) + mean_shift_scalar_for_rectified_gauss
    elif target_dist == "rectified_product_laplace":
        # we first sample from product laplace with chosed_sigma, then shift by mu, and then rectify | e.g. rec_laplace_samples = torch.relu(sample_product_laplace(D, B, scale=sigma) + mu)
        target_samples = torch.relu(sample_product_laplace(
            D,
            n_samples=B,
            device=features.device,
            dtype=features.dtype,
            chosed_sigma=chosed_sigma if chosed_sigma is not None else (1 / math.sqrt(2.0)), # we change this for rectified product laplace
        ).to(device=features.device, dtype=features.dtype) + mean_shift_scalar_for_rectified_gauss)
    elif target_dist == "rectified_lp_distribution":
        # we sample from rectified lp distribution with chosed_sigma and mean shift and then rectify | e.g. samples = torch.relu(sample_lp_distribution((B, D), p=p, loc=mu, scale=sigma, device=device))
        assert chosed_sigma is not None, "chosed_sigma must be provided for rectified_lp_distribution"
        target_samples = torch.relu(sample_lp_distribution(
            shape=(B, D),
            p=p_norm_for_rectified_lp_distribution,
            loc=mean_shift_scalar_for_rectified_gauss,
            scale=chosed_sigma,
        ).to(device=features.device, dtype=features.dtype))
    elif target_dist == "lp_distribution":
        assert chosed_sigma is not None, "chosed_sigma must be provided for lp_distribution"
        target_samples = sample_lp_distribution(
            shape=(B, D),
            p=p_norm_for_rectified_lp_distribution,
            loc=mean_shift_scalar_for_rectified_gauss,
            scale=chosed_sigma,
        ).to(device=features.device, dtype=features.dtype)
    else:
        raise ValueError

    projected_targets = torch.matmul(target_samples, P_directions.T)
    
    # 3. Sort the projected samples
    sorted_features, _ = torch.sort(projected_features, dim=0)
    sorted_targets, _ = torch.sort(projected_targets, dim=0)

    # 4. Calculate the 1D Wasserstein-2 distance for each projection
    wasserstein_1d = torch.mean((sorted_features - sorted_targets)**2, dim=0)

    # 5. Average over all projections
    swd_loss = torch.mean(wasserstein_1d)

    return swd_loss

def sliced_wasserstein_distance(z1, z2, orthogonal_transform, target_dist, mean_shift_scalar_for_rectified_gauss=0.0, p_norm_for_rectified_lp_distribution=1.0, chosed_sigma=None):

    if type(orthogonal_transform) == list and len(orthogonal_transform) == 2: # speculation that this would be used for SVD
        swd_loss_z1 = sliced_wasserstein_distance_for_one_view(z1, orthogonal_transform[0], target_dist, mean_shift_scalar_for_rectified_gauss, p_norm_for_rectified_lp_distribution, chosed_sigma)
        swd_loss_z2 = sliced_wasserstein_distance_for_one_view(z2, orthogonal_transform[1], target_dist, mean_shift_scalar_for_rectified_gauss, p_norm_for_rectified_lp_distribution, chosed_sigma)        
    elif type(orthogonal_transform) == torch.Tensor: # and this would be used for random
        swd_loss_z1 = sliced_wasserstein_distance_for_one_view(z1, orthogonal_transform, target_dist, mean_shift_scalar_for_rectified_gauss, p_norm_for_rectified_lp_distribution, chosed_sigma)
        swd_loss_z2 = sliced_wasserstein_distance_for_one_view(z2, orthogonal_transform, target_dist, mean_shift_scalar_for_rectified_gauss, p_norm_for_rectified_lp_distribution, chosed_sigma)
    else:
        raise ValueError 
    
    final_loss = (swd_loss_z1 + swd_loss_z2) / 2
    return final_loss

# =========================
# unified loss functions
# =========================
def one_d_dist_loss(z1, z2, orthogonal_transform, target_dist, loss_choice, mean_shift_scalar_for_rectified_gauss=0.0, p_norm_for_rectified_lp_distribution=1.0, chosed_sigma=None):
    dict_of_distinct_loss_functions = {
        'gauss': {
            'cvm_exact_loss': cvm_exact_loss,
            'jarque_bera_loss': jarque_bera_loss,
            'sigreg_loss': sigreg_loss,
            'jarque_bera_loss': jarque_bera_loss, # also sliced wasserstein distance -- that is for both
        },
        'laplace': None,
        'product_laplace': None,
        'rectified_gauss': None,
        'rectified_product_laplace': None,
        'rectified_lp_distribution': None,
        'lp_distribution': None,
    }

    # shared empirical cdf-based distribution matching losses
    if loss_choice == "sliced_wasserstein_distance":
        return sliced_wasserstein_distance(z1, z2, orthogonal_transform, target_dist, mean_shift_scalar_for_rectified_gauss, p_norm_for_rectified_lp_distribution, chosed_sigma)
    # dinstinct loss functions for different target distributions
    else:
        # TODO: somewhat deprecated.
        return dict_of_distinct_loss_functions[target_dist][loss_choice](z1, z2, orthogonal_transform)

def iter_dist_loss_func(
    z1: torch.Tensor,
    z2: torch.Tensor,
    orthogonal_transform: torch.Tensor,
    target_distribution: str,
    sim_loss_weight: float,
    var_loss_weight: float,
    cov_loss_weight: float,
    one_d_dist_loss_weight: float,
    one_d_dist_loss_choice: str,
    mean_shift_scalar_for_rectified_gauss: float = 0.0,
    p_norm_for_rectified_lp_distribution: float = 1.0,
    chosed_sigma: float = None,
):
    # invariance loss (computed over the original features)
    sim_loss = invariance_loss(z1, z2)

    # copied from vicreg's implementation
    z1, z2 = gather(z1), gather(z2)

    # variance and covariance loss (computed over the original features)
    # TODO: should we compute these losses over the transformed features?
    # For var loss: we expect a target value under laplace assumptions; consider estimating with multiple features for accuracy.
    var_loss = variance_loss(z1, z2)
    cov_loss = covariance_loss(z1, z2)

    # marginal_dist_loss (computed over the transformed features)
    marginal_dist_loss = one_d_dist_loss(z1, z2, orthogonal_transform, target_distribution, one_d_dist_loss_choice, mean_shift_scalar_for_rectified_gauss, p_norm_for_rectified_lp_distribution, chosed_sigma)

    # total loss
    loss = (sim_loss_weight * sim_loss) + (var_loss_weight * var_loss) + (cov_loss_weight * cov_loss) + (one_d_dist_loss_weight * marginal_dist_loss)
    return loss, sim_loss, var_loss, cov_loss, marginal_dist_loss
