# Copyright 2023 solo-learn development team.

# Permission is hereby granted, free of charge, to any person obtaining a copy of
# this software and associated documentation files (the "Software"), to deal in
# the Software without restriction, including without limitation the rights to use,
# copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the
# Software, and to permit persons to whom the Software is furnished to do so,
# subject to the following conditions:

# The above copyright notice and this permission notice shall be included in all copies
# or substantial portions of the Software.

# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED,
# INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR
# PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE
# FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR
# OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
# DEALINGS IN THE SOFTWARE.

from typing import Any, Dict, List, Sequence
import logging
import math
import numpy as np
import matplotlib.pyplot as plt
from math import lgamma

import omegaconf
import torch
import torch.nn as nn
from solo.losses.radialvicreg import radial_vicreg_loss_func, uniform_loss, \
    invariance_loss, variance_loss, covariance_loss, chi2_radial_nll_loss, batch_sparsity_metric, embedding_sparsity_metric, anisotropy_loss, chi2_radial_nll_loss_for_lightning_logging, m_spacing_entropy_loss, w1_radial_loss_two_views, w1_radial_loss_single
from solo.methods.base import BaseMethod
from solo.utils.misc import omegaconf_select, gather
from solo.utils.metrics import compute_balanced_accuracy


class RadialVICReg(BaseMethod):
    def __init__(self, cfg: omegaconf.DictConfig):
        """Implements VICReg (https://arxiv.org/abs/2105.04906) with Radial Gaussianization

        Extra cfg settings:
            method_kwargs:
                proj_output_dim (int): number of dimensions of the projected features.
                proj_hidden_dim (int): number of neurons in the hidden layers of the projector.
                sim_loss_weight (float): weight of the invariance term.
                var_loss_weight (float): weight of the variance term.
                cov_loss_weight (float): weight of the covariance term.
                radial_loss_weight (float): weight of the radial term.
                radial_ent_loss_weight (float): weight of the radial entropy term.
                pre_projector_radial_loss_weight (float): weight of the pre-projector radial term.
                pre_projector_w1_loss_weight (float): weight of the pre-projector W1 radial term.
                w1_radial_loss_weight (float): weight of the W1 radial term.
                w1_target_num_samples (int): number of samples to draw for W1 radial loss.
                lambda_strategy (str): "standard" or "self_tune" for loss weighting strategy.
                also,
                _w1_active (bool): whether to use W1 radial loss. NOT INTENDED TO BE SET BY USERS.

        """

        # proj_output_dim is needed by projector_classifier in BaseMethod if enabled.
        # Ensure it's available on cfg.method_kwargs if projector_classifier is true.
        if omegaconf_select(cfg, "method_kwargs.add_projector_classifier", False):
            assert not omegaconf.OmegaConf.is_missing(
                cfg, "method_kwargs.proj_output_dim"
            ), "RadialVICReg: method_kwargs.proj_output_dim must be set if add_projector_classifier is True."

        super().__init__(cfg)

        self.sim_loss_weight: float = cfg.method_kwargs.sim_loss_weight
        self.var_loss_weight: float = cfg.method_kwargs.var_loss_weight
        self.cov_loss_weight: float = cfg.method_kwargs.cov_loss_weight
        self.radial_loss_weight: float = cfg.method_kwargs.radial_loss_weight
        self.radial_ent_loss_weight: float = cfg.method_kwargs.radial_ent_loss_weight
        self.pre_projector_radial_loss_weight: float = (
            cfg.method_kwargs.pre_projector_radial_loss_weight
        )
        self.projector_type: str = cfg.method_kwargs.projector_type
        self.lambda_strategy: str = cfg.method_kwargs.lambda_strategy
        self.adaptive_scalar: bool = cfg.method_kwargs.adaptive_scalar
        self.w1_radial_loss_weight: float = omegaconf_select(cfg, "method_kwargs.w1_radial_loss_weight", 0.0)
        self.pre_projector_w1_loss_weight: float = omegaconf_select(cfg, "method_kwargs.pre_projector_w1_loss_weight", 0.0)
        self.w1_target_num_samples: int = omegaconf_select(cfg, "method_kwargs.w1_target_num_samples", 512)
        self._w1_active: bool = (self.w1_radial_loss_weight > 0.0) or (self.pre_projector_w1_loss_weight > 0.0) # we wouldn't set this up ourself ever 

        proj_hidden_dim: int = cfg.method_kwargs.proj_hidden_dim
        proj_output_dim: int = cfg.method_kwargs.proj_output_dim
        
        # projector
        if self.projector_type == "identity":
            self.projector = nn.Identity() # it's just identity will always be 512 x 512 for resnet18
            if hasattr(self, "projector_classifier") and self.projector_classifier is not None:
                self.projector_classifier = nn.Linear(self.features_dim, self.num_classes)
        elif self.projector_type == "mlp5":
            self.projector = nn.Sequential(
                nn.Linear(self.features_dim, proj_hidden_dim),
                nn.BatchNorm1d(proj_hidden_dim),
                nn.ReLU(),
                nn.Linear(proj_hidden_dim, proj_hidden_dim),
                nn.BatchNorm1d(proj_hidden_dim),
                nn.ReLU(),
                nn.Linear(proj_hidden_dim, proj_output_dim),
                nn.BatchNorm1d(proj_output_dim),
                nn.ReLU(),
                nn.Linear(proj_output_dim, proj_output_dim),
                nn.BatchNorm1d(proj_output_dim),
                nn.ReLU(),
                nn.Linear(proj_output_dim, proj_output_dim),
            )
        elif self.projector_type == "mlp4":
            self.projector = nn.Sequential(
                nn.Linear(self.features_dim, proj_hidden_dim),
                nn.BatchNorm1d(proj_hidden_dim),
                nn.ReLU(),
                nn.Linear(proj_hidden_dim, proj_hidden_dim),
                nn.BatchNorm1d(proj_hidden_dim),
                nn.ReLU(),
                nn.Linear(proj_hidden_dim, proj_output_dim),
                nn.BatchNorm1d(proj_output_dim),
                nn.ReLU(),
                nn.Linear(proj_output_dim, proj_output_dim),
            )
        elif self.projector_type == "mlp3" or self.projector_type == "mlp":
            self.projector = nn.Sequential(
                nn.Linear(self.features_dim, proj_hidden_dim),
                nn.BatchNorm1d(proj_hidden_dim),
                nn.ReLU(),
                nn.Linear(proj_hidden_dim, proj_hidden_dim),
                nn.BatchNorm1d(proj_hidden_dim),
                nn.ReLU(),
                nn.Linear(proj_hidden_dim, proj_output_dim),
            )
        elif self.projector_type == "mlp3_with_one_more_relu":
            self.projector = nn.Sequential(
                nn.Linear(self.features_dim, proj_hidden_dim),
                nn.BatchNorm1d(proj_hidden_dim),
                nn.ReLU(),
                nn.Linear(proj_hidden_dim, proj_hidden_dim),
                nn.BatchNorm1d(proj_hidden_dim),
                nn.ReLU(),
                nn.Linear(proj_hidden_dim, proj_output_dim),
                nn.ReLU(),
            )
        elif self.projector_type == "mlp2":
            self.projector = nn.Sequential(
                nn.Linear(self.features_dim, proj_hidden_dim),
                nn.BatchNorm1d(proj_hidden_dim),
                nn.ReLU(),
                nn.Linear(proj_hidden_dim, proj_output_dim),
            )
        elif self.projector_type == "mlp1":
            self.projector = nn.Sequential(
                nn.Linear(self.features_dim, proj_output_dim),
            )
        else:
            raise ValueError

        # Radius histogram configuration moved to BaseMethod

    @staticmethod
    def add_and_assert_specific_cfg(cfg: omegaconf.DictConfig) -> omegaconf.DictConfig:
        """Adds method specific default values/checks for config.

        Args:
            cfg (omegaconf.DictConfig): DictConfig object.

        Returns:
            omegaconf.DictConfig: same as the argument, used to avoid errors.
        """

        cfg = super(RadialVICReg, RadialVICReg).add_and_assert_specific_cfg(cfg)

        assert not omegaconf.OmegaConf.is_missing(cfg, "method_kwargs.proj_output_dim")
        assert not omegaconf.OmegaConf.is_missing(cfg, "method_kwargs.proj_hidden_dim")

        # Ensure projector_type is valid and has a default value of "mlp3".
        projector_type = omegaconf_select(cfg, "method_kwargs.projector_type", "mlp3")
        assert projector_type in ["identity", "mlp", "mlp1", "mlp2", "mlp3", "mlp3_with_one_more_relu", "mlp4", "mlp5"], f"Invalid projector_type: {projector_type}. Must be one of ['identity', 'mlp', 'mlp1', 'mlp2', 'mlp3', 'mlp3_with_one_more_relu', 'mlp4', 'mlp5']"
        cfg.method_kwargs.projector_type = projector_type
        
        cfg.method_kwargs.sim_loss_weight = omegaconf_select(
            cfg,
            "method_kwargs.sim_loss_weight",
            25.0,
        )
        cfg.method_kwargs.var_loss_weight = omegaconf_select(
            cfg,
            "method_kwargs.var_loss_weight",
            25.0,
        )
        cfg.method_kwargs.cov_loss_weight = omegaconf_select(
            cfg,
            "method_kwargs.cov_loss_weight",
            1.0,
        )
        cfg.method_kwargs.radial_loss_weight = omegaconf_select(
            cfg,
            "method_kwargs.radial_loss_weight",
            0.0,
        )
        cfg.method_kwargs.radial_ent_loss_weight = omegaconf_select(
            cfg,
            "method_kwargs.radial_ent_loss_weight",
            0.0,
        )
        cfg.method_kwargs.pre_projector_radial_loss_weight = omegaconf_select(
            cfg,
            "method_kwargs.pre_projector_radial_loss_weight",
            0.0,
        )

        # Wasserstein-1 radial objective (optional)
        cfg.method_kwargs.w1_radial_loss_weight = omegaconf_select(
            cfg, "method_kwargs.w1_radial_loss_weight", 0.0
        )
        cfg.method_kwargs.pre_projector_w1_loss_weight = omegaconf_select(
            cfg, "method_kwargs.pre_projector_w1_loss_weight", 0.0
        ) # 
        cfg.method_kwargs.w1_target_num_samples = omegaconf_select(
            cfg, "method_kwargs.w1_target_num_samples", 512
        )

        # Enforce mutual exclusivity between CE/Entropy-based radial terms and W1 terms
        has_w1 = (
            cfg.method_kwargs.w1_radial_loss_weight > 0.0
            or cfg.method_kwargs.pre_projector_w1_loss_weight > 0.0
        )
        has_ce = (
            cfg.method_kwargs.radial_loss_weight > 0.0
            or cfg.method_kwargs.radial_ent_loss_weight > 0.0
            or cfg.method_kwargs.pre_projector_radial_loss_weight > 0.0
        )
        assert not (
            has_w1 and has_ce
        ), "W1 radial weights and CE/Entropy radial weights are mutually exclusive; set only one family non-zero."

        # Adaptive scalar: rescales embeddings to have norm sqrt(D-1)
        cfg.method_kwargs.adaptive_scalar = omegaconf_select(
            cfg, "method_kwargs.adaptive_scalar", False
        )
        if cfg.method_kwargs.adaptive_scalar:
            assert (
                cfg.method_kwargs.radial_loss_weight == 0.0
            ), "If adaptive_scalar is True, radial_loss_weight must be 0."

        # Lambda strategy: 'standard' for fixed weights, 'self_tune' for scaling by sim_loss
        cfg.method_kwargs.lambda_strategy = omegaconf_select(
            cfg, "method_kwargs.lambda_strategy", "standard"
        )
        assert cfg.method_kwargs.lambda_strategy in ["standard", "self_tune"]

        return cfg

    @property
    def learnable_params(self) -> List[dict]:
        """Adds projector parameters to the parent's learnable parameters.

        Returns:
            List[dict]: list of learnable parameters.
        """

        extra_learnable_params = [{"name": "projector", "params": self.projector.parameters()}]
        return super().learnable_params + extra_learnable_params

    def forward(self, X: torch.Tensor) -> Dict[str, Any]:
        """Performs the forward pass of the backbone, projector, and both classifiers (if enabled).  
        Args:
            X (torch.Tensor): a batch of images in the tensor format.

        Returns:
            Dict[str, Any]: a dict containing the outputs of the parent and the projected features.
        """

        out = super().forward(X)

        # to help with a cleaner computational graph
        if self.training:
            with torch.no_grad():
                # log embedding layer feature norm - this is the encoder
                backbone_feature_norm_mean = torch.norm(out["feats"], dim=1).mean() 
                backbone_feature_norm_var = torch.norm(out["feats"], dim=1).var() 

                self.log("encoder_norm/train_backbone_feature_norm_mean", backbone_feature_norm_mean, on_epoch=True, sync_dist=True)
                self.log("encoder_norm/train_backbone_feature_norm_var", backbone_feature_norm_var, on_epoch=True, sync_dist=True)

                # Logging uniform loss
                backbone_feature_unif_loss = uniform_loss(out["feats"])
                self.log("unif_loss/train_backbone_feature_unif_loss_embed", backbone_feature_unif_loss, on_epoch=True, sync_dist=True)

                # Logging anisotropy loss for encoder features
                backbone_feature_anisotropy_loss = anisotropy_loss(out["feats"])
                self.log("anisotropy_loss/train_encoder_feature_anisotropy_loss_embed", backbone_feature_anisotropy_loss, on_epoch=True, sync_dist=True)

        z = self.projector(out["feats"])
        out.update({"z": z})
        
        # If the projector classifier exists (initialized in BaseMethod), call it.  # self.projector_classifier is from BaseMethod
        if self.projector_classifier is not None:
            # Pass z.detach() to ensure gradients for this classifier only update itself,
            # not the projector or backbone through this path.

            if self.dataset_name == "3dshapes":
                projector_logits = [projector_classifier_i(z.detach()) for projector_classifier_i in self.projector_classifier]
            else:
                projector_logits = self.projector_classifier(z.detach())
        
            out.update({"projector_logits": projector_logits})

        return out

    def validation_step(self, batch: Sequence[Any], batch_idx: int):
        """Validation step for RadialVICReg."""
        X, targets = batch
        out = self.base_validation_step(X, targets)

        batch_size = targets.size(0)

        # Log norms and loss components for validation
        if not self.trainer.sanity_checking:
            feats = out["feats"]
            z = self.projector(feats)

            # Log metrics for encoder features
            gathered_feats = gather(feats.contiguous())
            # Always log W1 and CE/Entropy/ KL for encoder
            val_w1_enc = w1_radial_loss_single(gathered_feats, target_num_samples=self.w1_target_num_samples)
            self.log("encoder/val_backbone_w1_radial_loss", val_w1_enc, on_epoch=True, sync_dist=True)
            val_radial_loss_encoder = chi2_radial_nll_loss_for_lightning_logging(gathered_feats)
            self.log("encoder/val_backbone_radial_loss", val_radial_loss_encoder, on_epoch=True, sync_dist=True)
            # KL for encoder (single set)
            r_enc = torch.norm(gathered_feats, dim=1)
            val_entropy_enc = m_spacing_entropy_loss(r_enc)
            val_kl_enc = val_radial_loss_encoder - val_entropy_enc
            self.log("encoder/val_backbone_ent_loss", val_entropy_enc, on_epoch=True, sync_dist=True)
            self.log("encoder/val_backbone_kl_loss", val_kl_enc, on_epoch=True, sync_dist=True)

            # Log metrics for projector features
            gathered_z = gather(z.contiguous())
            # Always log W1 and CE/Entropy/KL for projector
            val_w1_proj = w1_radial_loss_single(gathered_z, target_num_samples=self.w1_target_num_samples)
            self.log("projector/val_w1_radial_loss", val_w1_proj, on_epoch=True, sync_dist=True)
            val_radial_loss_projector = chi2_radial_nll_loss_for_lightning_logging(gathered_z)
            self.log("projector/val_radial_loss", val_radial_loss_projector, on_epoch=True, sync_dist=True)
            # KL for projector (single set)
            r_proj = torch.norm(gathered_z, dim=1)
            val_entropy_proj = m_spacing_entropy_loss(r_proj)
            val_kl_proj = val_radial_loss_projector - val_entropy_proj
            self.log("projector/val_ent_loss", val_entropy_proj, on_epoch=True, sync_dist=True)
            self.log("projector/val_kl_loss", val_kl_proj, on_epoch=True, sync_dist=True)

            # Log norms for projector features
            z_norm_mean = torch.norm(z, dim=1).mean()
            z_norm_var = torch.norm(z, dim=1).var()
            self.log("projector_norm/val_z_norm_mean", z_norm_mean, on_epoch=True, sync_dist=True)
            self.log("projector_norm/val_z_norm_var", z_norm_var, on_epoch=True, sync_dist=True)

            # Collect radii for histogram snapshots via BaseMethod helper
            self._radius_hist_collect(gathered_feats=gathered_feats, gathered_z=gathered_z)

            # Radius histogram collection moved to BaseMethod

        # Ensure KNN test features/targets are populated (parent not called here)
        if self.knn_eval and not self.trainer.sanity_checking:
            if self.dataset_name == "CelebA":
                raise NotImplementedError("double check or implement knn eval for celebA")
            if self.dataset_name == "3dshapes":
                raise NotImplementedError("double check or implement knn eval for 3dshapes")

            # Encoder KNN (use feats from base_validation_step)
            if self.knn_encoder is not None and "feats" in out:
                try:
                    self.knn_encoder(
                        test_features=out["feats"].detach().cpu(),
                        test_targets=targets.detach().cpu(),
                    )
                except Exception as e:
                    logging.warning(f"Skipping encoder KNN update in validation_step due to error: {e}")

            # Projector KNN (use the 'z' we computed above)
            if self.knn_projector is not None:
                try:
                    # z is computed above regardless of dataset; ensure size matches targets
                    if z.size(0) == targets.detach().size(0):
                        self.knn_projector(
                            test_features=z.detach().cpu(),
                            test_targets=targets.detach().cpu(),
                        )
                except Exception as e:
                    logging.warning(f"Skipping projector KNN update in validation_step due to error: {e}")

        if self.dataset_name == "CelebA":
            out_for_logging = {f"val_{k}": v for k, v in out.items() if "metrics" in k}

            ### Balance_Acc ###
            for outcome_type in self.balanced_acc_quantities_order:
                self.validation_step_outputs_for_balanced_acc["separate_metrics"]["encoder"][
                    outcome_type
                ] += out[f"encoder_{outcome_type}"].clone()
                
                if "proj_loss" in out:
                    self.validation_step_outputs_for_balanced_acc["separate_metrics"]["proj"][
                        outcome_type
                    ] += out[f"proj_{outcome_type}"].clone()

            encoder_balanced_acc = compute_balanced_accuracy(
                self.validation_step_outputs_for_balanced_acc["separate_metrics"]["encoder"]["tps"],
                self.validation_step_outputs_for_balanced_acc["separate_metrics"]["encoder"]["fns"],
                self.validation_step_outputs_for_balanced_acc["separate_metrics"]["encoder"]["tns"],
                self.validation_step_outputs_for_balanced_acc["separate_metrics"]["encoder"]["fps"],
            )

            proj_balanced_acc = compute_balanced_accuracy(
                self.validation_step_outputs_for_balanced_acc["separate_metrics"]["proj"]["tps"],
                self.validation_step_outputs_for_balanced_acc["separate_metrics"]["proj"]["fns"],
                self.validation_step_outputs_for_balanced_acc["separate_metrics"]["proj"]["tns"],
                self.validation_step_outputs_for_balanced_acc["separate_metrics"]["proj"]["fps"],
            )
            
            encoder_balanced_acc_dict = dict(zip([f"val_separate_metrics/encoder_{k}" for k in self.dataset_attr_names], encoder_balanced_acc))
            proj_balanced_acc_dict = dict(zip([f"val_separate_metrics/proj_{k}" for k in self.dataset_attr_names], proj_balanced_acc))

            
            metrics = {
                "batch_size": batch_size,
                "val_loss": out["loss"],
                **out_for_logging,
                **encoder_balanced_acc_dict,
                **proj_balanced_acc_dict,
            }
            if "proj_loss" in out:
                metrics.update({
                    "val_proj_loss": out.get("proj_loss"),
                })
        elif self.dataset_name == "3dshapes":
            outs_for_logging_encoder = {f"val_encoder_separate_metrics/{k}": out['encoder_separate_metrics'][k] for k in out['encoder_separate_metrics']}
            outs_for_logging_proj = {f"val_proj_separate_metrics/{k}": out['proj_separate_metrics'][k] for k in out['proj_separate_metrics']}
                    
            # Replicate the logic from BaseMethod's validation_step to prepare metrics for logging
            metrics = {
                "batch_size": batch_size,
                "val_loss": out.get("loss"),
                **outs_for_logging_encoder,
                **outs_for_logging_proj,
            }
            if "proj_loss" in out:
                metrics.update({
                    "val_proj_loss": out.get("proj_loss"),
                })
        else:
            # Replicate the logic from BaseMethod's validation_step to prepare metrics for logging
            metrics = {
                "batch_size": batch_size,
                "val_loss": out.get("loss"),
                "val_acc1": out.get("acc1"),
                "val_acc5": out.get("acc5"),
            }
            if "proj_loss" in out:
                metrics.update({
                    "val_proj_loss": out.get("proj_loss"),
                    "val_proj_acc1": out.get("proj_acc1"),
                    "val_proj_acc5": out.get("proj_acc5"),
                })

            # Include MLP probe validation metrics when available
            if "mlp_acc1" in out and "mlp_acc5" in out:
                metrics.update({
                    "val_mlp_acc1": out.get("mlp_acc1"),
                    "val_mlp_acc5": out.get("mlp_acc5"),
                })
            if "proj_mlp_acc1" in out and "proj_mlp_acc5" in out:
                metrics.update({
                    "val_proj_mlp_acc1": out.get("proj_mlp_acc1"),
                    "val_proj_mlp_acc5": out.get("proj_mlp_acc5"),
                })


        self.validation_step_outputs.append(metrics)
        return metrics

    def training_step(self, batch: Sequence[Any], batch_idx: int) -> torch.Tensor:
        """Training step for RadialVICReg reusing BaseMethod training step.

        Args:
            batch (Sequence[Any]): a batch of data in the format of [img_indexes, [X], Y], where
                [X] is a list of size num_crops containing batches of images.
            batch_idx (int): index of the batch.

        Returns:
            torch.Tensor: total loss composed of RadialVICReg loss and classification loss.
        """

        out_base = super().training_step(batch, batch_idx)
        class_loss = out_base["loss"]
        # Backbone features (encoder output)
        feats1_b, feats2_b = out_base["feats"]

        if self.adaptive_scalar:
            # Rescale ENCODER embeddings to have norm sqrt(D-1)
            _, D = feats1_b.size()  # encoder feature dimension
            target_norm = torch.sqrt(torch.tensor(D - 1.0, device=feats1_b.device))

            eps = 1e-6

            feats1_b_norm = torch.norm(feats1_b, dim=1, keepdim=True)
            feats1_b = feats1_b * (target_norm / (feats1_b_norm + eps))

            feats2_b_norm = torch.norm(feats2_b, dim=1, keepdim=True)
            feats2_b = feats2_b * (target_norm / (feats2_b_norm + eps))
        
        # Gather features across GPUs for losses that depend on batch statistics.
        gathered_feats1_b = gather(feats1_b.contiguous())
        gathered_feats2_b = gather(feats2_b.contiguous())

        # Calculate pre-projector radial losses (encoder) both ways for logging; pick the optimized term based on config
        radial_loss_encoder_w1 = w1_radial_loss_two_views(
            gathered_feats1_b, gathered_feats2_b, target_num_samples=self.w1_target_num_samples
        )
        radial_loss_encoder_ce = chi2_radial_nll_loss(gathered_feats1_b, gathered_feats2_b)
        if self._w1_active and self.pre_projector_w1_loss_weight > 0:
            radial_loss_encoder = radial_loss_encoder_w1
            enc_is_w1 = True
        else:
            radial_loss_encoder = radial_loss_encoder_ce
            enc_is_w1 = False
        
        # These are for logging only and do not contribute to the gradient of the main loss.
        # Invariance loss on backbone features
        with torch.no_grad():
            sim_loss_encoder = invariance_loss(feats1_b, feats2_b)
            self.log("encoder/train_backbone_sim_loss", sim_loss_encoder, on_epoch=True, sync_dist=True)

            var_loss_encoder = variance_loss(gathered_feats1_b, gathered_feats2_b)
            self.log("encoder/train_backbone_var_loss", var_loss_encoder / self.features_dim, on_epoch=True, sync_dist=True)

            cov_loss_encoder = covariance_loss(gathered_feats1_b, gathered_feats2_b)
            self.log("encoder/train_backbone_cov_loss", cov_loss_encoder, on_epoch=True, sync_dist=True)
    
            # Always log both W1 and CE/Entropy for encoder
            self.log("encoder/train_backbone_w1_radial_loss", radial_loss_encoder_w1, on_epoch=True, sync_dist=True)
            self.log("encoder/train_backbone_radial_loss", radial_loss_encoder_ce, on_epoch=True, sync_dist=True)
            # Encoder KL (two-view): CE - H, unscaled
            r1_enc = torch.norm(gathered_feats1_b, dim=1)
            r2_enc = torch.norm(gathered_feats2_b, dim=1)
            entropy_enc = m_spacing_entropy_loss(r1_enc) + m_spacing_entropy_loss(r2_enc)
            kl_enc = radial_loss_encoder_ce - entropy_enc
            self.log("encoder/train_backbone_ent_loss", entropy_enc, on_epoch=True, sync_dist=True)
            self.log("encoder/train_backbone_kl_loss", kl_enc, on_epoch=True, sync_dist=True)

            # log sparsity metrics
            # IMP: renaming sparsity_metric to batch_sparsity_metric and introduceing embedding_sparsity_metric
            feat1_batch_sparse_max, feat1_batch_sparse_mean, feat1_batch_sparse_min = batch_sparsity_metric(gathered_feats1_b)
            feat2_batch_sparse_max, feat2_batch_sparse_mean, feat2_batch_sparse_min = batch_sparsity_metric(gathered_feats2_b)
            self.log("batch_sparsity_metric/feat1_encoder_batch_sparse_max", feat1_batch_sparse_max, on_epoch=True, sync_dist=True)
            self.log("batch_sparsity_metric/feat1_encoder_batch_sparse_mean", feat1_batch_sparse_mean, on_epoch=True, sync_dist=True)
            self.log("batch_sparsity_metric/feat1_encoder_batch_sparse_min", feat1_batch_sparse_min, on_epoch=True, sync_dist=True)
            self.log("batch_sparsity_metric/feat2_encoder_batch_sparse_max", feat2_batch_sparse_max, on_epoch=True, sync_dist=True)
            self.log("batch_sparsity_metric/feat2_encoder_batch_sparse_mean", feat2_batch_sparse_mean, on_epoch=True, sync_dist=True)
            self.log("batch_sparsity_metric/feat2_encoder_batch_sparse_min", feat2_batch_sparse_min, on_epoch=True, sync_dist=True)

            feat1_embed_sparse_max, feat1_embed_sparse_mean, feat1_embed_sparse_min = embedding_sparsity_metric(gathered_feats1_b)
            feat2_embed_sparse_max, feat2_embed_sparse_mean, feat2_embed_sparse_min = embedding_sparsity_metric(gathered_feats2_b)
            self.log("embedding_sparsity_metric/feat1_encoder_embed_sparse_max", feat1_embed_sparse_max, on_epoch=True, sync_dist=True)
            self.log("embedding_sparsity_metric/feat1_encoder_embed_sparse_mean", feat1_embed_sparse_mean, on_epoch=True, sync_dist=True)
            self.log("embedding_sparsity_metric/feat1_encoder_embed_sparse_min", feat1_embed_sparse_min, on_epoch=True, sync_dist=True)
            self.log("embedding_sparsity_metric/feat2_encoder_embed_sparse_max", feat2_embed_sparse_max, on_epoch=True, sync_dist=True)
            self.log("embedding_sparsity_metric/feat2_encoder_embed_sparse_mean", feat2_embed_sparse_mean, on_epoch=True, sync_dist=True)
            self.log("embedding_sparsity_metric/feat2_encoder_embed_sparse_min", feat2_embed_sparse_min, on_epoch=True, sync_dist=True)

        # use embeddings from the base class instead of redundant computation
        z1, z2 = out_base["z"]

        # if self.adaptive_scalar:
        #     # Rescale embeddings to have norm sqrt(D-1)
        #     _, D = z1.size() #z1.size is batchsize x feature dimension
        #     target_norm = torch.sqrt(torch.tensor(D - 1.0, device=z1.device))

        #     z1_norm = torch.norm(z1, dim=1, keepdim=True)
        #     z2_norm = torch.norm(z2, dim=1, keepdim=True)

        #     # Add a small epsilon to avoid division by zero
        #     eps = 1e-6
        #     z1 = z1 * (target_norm / (z1_norm + eps))
        #     z2 = z2 * (target_norm / (z2_norm + eps))

        # radial loss on projector features (z1, z2): compute both diagnostics and pick the optimized term
        gz1, gz2 = gather(z1.contiguous()), gather(z2.contiguous())
        # Diagnostics
        radial_loss_projector_w1 = w1_radial_loss_two_views(gz1, gz2, target_num_samples=self.w1_target_num_samples)
        _, sim_loss_projector, var_loss_projector, cov_loss_projector, radial_loss_projector_ce, radial_ent_loss_projector, kl_loss_projector = radial_vicreg_loss_func(
            z1, z2,
            sim_loss_weight=self.sim_loss_weight,
            var_loss_weight=self.var_loss_weight,
            cov_loss_weight=self.cov_loss_weight,
            radial_loss_weight=self.radial_loss_weight,
            radial_ent_loss_weight=self.radial_ent_loss_weight,
        )
        # Choose optimization objective
        if self._w1_active and self.w1_radial_loss_weight > 0:
            radial_loss_projector = radial_loss_projector_w1
            main_loss_value = (
                self.sim_loss_weight * sim_loss_projector
                + self.var_loss_weight * var_loss_projector
                + self.cov_loss_weight * cov_loss_projector
                + self.w1_radial_loss_weight * radial_loss_projector
            )
            # For logging consistency, keep KL fields defined
            kl_loss_projector = kl_loss_projector
        else:
            radial_loss_projector = radial_loss_projector_ce
            main_loss_value = (
                self.sim_loss_weight * sim_loss_projector
                + self.var_loss_weight * var_loss_projector
                + self.cov_loss_weight * cov_loss_projector
                + self.radial_loss_weight * (radial_loss_projector / z1.size(1))
                - self.radial_ent_loss_weight * radial_ent_loss_projector
            )

        # Initialize total loss.
        total_loss = class_loss

        # Apply selftune strategy if enabled
        if self.lambda_strategy == "self_tune":
            # Detach sim_loss for scaling factor in 'self_tune' mode - we are scaling by using the similarity loss as the main vector magnitude
            sim_loss_detached = sim_loss_projector.detach()
            
            # Recalculate auxiliary losses with self-tuned weights
            w_var = (self.var_loss_weight * sim_loss_detached) / (var_loss_projector.detach() + 1e-6)
            w_cov = (self.cov_loss_weight * sim_loss_detached) / (cov_loss_projector.detach() + 1e-6)
            if self._w1_active and self.w1_radial_loss_weight > 0:
                w_w1 = (self.w1_radial_loss_weight * sim_loss_detached) / (radial_loss_projector.detach() + 1e-6)
                w_rad = torch.tensor(0.0, device=self.device)
                w_rad_ent = torch.tensor(0.0, device=self.device)
            else:
                w_rad = (self.radial_loss_weight * sim_loss_detached) / (radial_loss_projector.detach() + 1e-6)
                w_rad_ent = (self.radial_ent_loss_weight * sim_loss_detached) / (radial_ent_loss_projector.detach() + 1e-6)

            w_pre_rad = torch.tensor(0.0, device=self.device)
            if self._w1_active and self.pre_projector_w1_loss_weight > 0:
                # Self-tuned weight for encoder W1 term
                w_pre_rad = (self.pre_projector_w1_loss_weight * sim_loss_detached) / (radial_loss_encoder_w1.detach() + 1e-6)
            elif (not self._w1_active) and self.pre_projector_radial_loss_weight > 0:
                # Self-tuned weight for encoder CE term (dimension-scaled elsewhere)
                w_pre_rad = (self.pre_projector_radial_loss_weight * sim_loss_detached) / (radial_loss_encoder_ce.detach() + 1e-6)
            
            # Recompute total loss with self-tuned weights
            if self._w1_active and self.w1_radial_loss_weight > 0:
                total_loss += (
                    self.sim_loss_weight * sim_loss_projector
                    + w_var * var_loss_projector
                    + w_cov * cov_loss_projector
                    + w_w1 * radial_loss_projector_w1
                    + w_pre_rad * (radial_loss_encoder_w1)
                )
            else:
                total_loss += (
                    self.sim_loss_weight * sim_loss_projector
                    + w_var * var_loss_projector
                    + w_cov * cov_loss_projector
                    + w_rad * (radial_loss_projector_ce / z1.size(1))
                    - w_rad_ent * radial_ent_loss_projector
                    + w_pre_rad * (radial_loss_encoder_ce / feats1_b.size(1))
                )
        else:
            # Standard loss calculation with fixed weights
            if self._w1_active and self.w1_radial_loss_weight > 0:
                total_loss += main_loss_value + self.pre_projector_w1_loss_weight * radial_loss_encoder_w1
            else:
                total_loss += main_loss_value + self.pre_projector_radial_loss_weight * (radial_loss_encoder_ce / feats1_b.size(1)) # CE is scaled by 1/D; entropy already handled in main_loss_value

        # Logging norms for projector features
        with torch.no_grad():
            # log sparsity metrics
            z1_batch_sparse_max, z1_batch_sparse_mean, z1_batch_sparse_min = batch_sparsity_metric(z1)
            z2_batch_sparse_max, z2_batch_sparse_mean, z2_batch_sparse_min = batch_sparsity_metric(z2)
            self.log("batch_sparsity_metric/z1_projector_batch_sparse_max", z1_batch_sparse_max, on_epoch=True, sync_dist=True)
            self.log("batch_sparsity_metric/z1_projector_batch_sparse_mean", z1_batch_sparse_mean, on_epoch=True, sync_dist=True)
            self.log("batch_sparsity_metric/z1_projector_batch_sparse_min", z1_batch_sparse_min, on_epoch=True, sync_dist=True)
            self.log("batch_sparsity_metric/z2_projector_batch_sparse_max", z2_batch_sparse_max, on_epoch=True, sync_dist=True)
            self.log("batch_sparsity_metric/z2_projector_batch_sparse_mean", z2_batch_sparse_mean, on_epoch=True, sync_dist=True)
            self.log("batch_sparsity_metric/z2_projector_batch_sparse_min", z2_batch_sparse_min, on_epoch=True, sync_dist=True)

            z1_embed_sparse_max, z1_embed_sparse_mean, z1_embed_sparse_min = embedding_sparsity_metric(z1)
            z2_embed_sparse_max, z2_embed_sparse_mean, z2_embed_sparse_min = embedding_sparsity_metric(z2)
            self.log("embedding_sparsity_metric/z1_projector_embed_sparse_max", z1_embed_sparse_max, on_epoch=True, sync_dist=True)
            self.log("embedding_sparsity_metric/z1_projector_embed_sparse_mean", z1_embed_sparse_mean, on_epoch=True, sync_dist=True)
            self.log("embedding_sparsity_metric/z1_projector_embed_sparse_min", z1_embed_sparse_min, on_epoch=True, sync_dist=True)
            self.log("embedding_sparsity_metric/z2_projector_embed_sparse_max", z2_embed_sparse_max, on_epoch=True, sync_dist=True)
            self.log("embedding_sparsity_metric/z2_projector_embed_sparse_mean", z2_embed_sparse_mean, on_epoch=True, sync_dist=True)
            self.log("embedding_sparsity_metric/z2_projector_embed_sparse_min", z2_embed_sparse_min, on_epoch=True, sync_dist=True)

            z1_norm_mean = torch.norm(z1, dim=1).mean()
            z2_norm_mean = torch.norm(z2, dim=1).mean()
            self.log("projector_norm/train_z1_norm_mean", z1_norm_mean, on_epoch=True, sync_dist=True)
            self.log("projector_norm/train_z2_norm_mean", z2_norm_mean, on_epoch=True, sync_dist=True)

            z1_norm_var = torch.norm(z1, dim=1).var()
            z2_norm_var = torch.norm(z2, dim=1).var()
            self.log("projector_norm/train_z1_norm_var", z1_norm_var, on_epoch=True, sync_dist=True)
            self.log("projector_norm/train_z2_norm_var", z2_norm_var, on_epoch=True, sync_dist=True)

            # Logging uniform loss for projector features
            z1_unif_loss = uniform_loss(z1)
            z2_unif_loss = uniform_loss(z2)
            self.log("unif_loss/train_z1_unif_loss_projector", z1_unif_loss, on_epoch=True, sync_dist=True)
            self.log("unif_loss/train_z2_unif_loss_projector", z2_unif_loss, on_epoch=True, sync_dist=True)

            # Logging anisotropy loss for projector features
            z1_anisotropy_loss = anisotropy_loss(z1)
            z2_anisotropy_loss = anisotropy_loss(z2)
            self.log("anisotropy_loss/train_z1_anisotropy_loss_projector", z1_anisotropy_loss, on_epoch=True, sync_dist=True)
            self.log("anisotropy_loss/train_z2_anisotropy_loss_projector", z2_anisotropy_loss, on_epoch=True, sync_dist=True)

        # Logging main loss components (from projector features) - using original names for backward compatibility
        self.log("train_radial_vicreg_loss", main_loss_value if self.lambda_strategy == "standard" else total_loss, on_epoch=True, sync_dist=True)
        self.log("train_sim_loss", sim_loss_projector, on_epoch=True, sync_dist=True)
        self.log("train_var_loss", var_loss_projector / z1.size(1), on_epoch=True, sync_dist=True)
        self.log("train_cov_loss", cov_loss_projector, on_epoch=True, sync_dist=True)
        # Always log both W1 and CE diagnostics for projector
        self.log("train_w1_radial_loss", radial_loss_projector_w1, on_epoch=True, sync_dist=True)
        self.log("train_radial_loss", radial_loss_projector_ce, on_epoch=True, sync_dist=True)
        self.log("train_radial_ent_loss", radial_ent_loss_projector, on_epoch=True, sync_dist=True)
        self.log("train_kl_loss", kl_loss_projector, on_epoch=True, sync_dist=True)
        self.log(
            "train_pre_projector_radial_loss",
            radial_loss_encoder_ce,
            on_epoch=True,
            sync_dist=True,
        )
        self.log(
            "train_pre_projector_w1_radial_loss",
            radial_loss_encoder_w1,
            on_epoch=True,
            sync_dist=True,
        )

        #  Projector Classifier Training
        projector_class_loss = torch.tensor(0.0, device=self.device)
        if self.projector_classifier is not None:
            _, _, targets = batch # Get targets for classification loss
            # _projector_classifier_step is a helper in BaseMethod
            # It expects a single view's z and targets
            proj_metrics1 = self._projector_classifier_step(z1, targets)
            proj_metrics2 = self._projector_classifier_step(z2, targets)

            if proj_metrics1 and proj_metrics2: # Check if metrics were computed
                projector_class_loss = (proj_metrics1["proj_loss"] + proj_metrics2["proj_loss"]) / 2
                self.log("train_proj_loss", projector_class_loss, on_epoch=True, sync_dist=True)

                if self.dataset_name == "CelebA":
                    pass # no logging in the training step
                elif self.dataset_name == "3dshapes":
                    pass # no logging in the training step
                else:
                    proj_acc1 = (proj_metrics1["proj_acc1"] + proj_metrics2["proj_acc1"]) / 2
                    proj_acc5 = (proj_metrics1["proj_acc5"] + proj_metrics2["proj_acc5"]) / 2

                    self.log("train_proj_acc1", proj_acc1, on_epoch=True, sync_dist=True)
                    self.log("train_proj_acc5", proj_acc5, on_epoch=True, sync_dist=True)
        
        return total_loss + projector_class_loss

    # Radius histogram hooks and plotting moved to BaseMethod
