# Copyright 2023 solo-learn development team.

# Permission is hereby granted, free of charge, to any person obtaining a copy of
# this software and associated documentation files (the "Software"), to deal in
# the Software without restriction, including without limitation the rights to use,
# copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the
# Software, and to permit persons to whom the Software is furnished to do so,
# subject to the following conditions:

# The above copyright notice and this permission notice shall be included in all copies
# or substantial portions of the Software.

# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED,
# INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR
# PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE
# FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR
# OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
# DEALINGS IN THE SOFTWARE.

import logging
from functools import partial
from typing import Any, Callable, Dict, List, Sequence, Tuple, Union
import numpy as np
import math

import lightning.pytorch as pl
import omegaconf
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim.lr_scheduler import MultiStepLR

from solo.backbones import (
    convnext_base,
    convnext_large,
    convnext_small,
    convnext_tiny,
    poolformer_m36,
    poolformer_m48,
    poolformer_s12,
    poolformer_s24,
    poolformer_s36,
    resnet18,
    resnet18_with_final_linear,
    resnet18_from_scratch,
    resnet50,
    swin_base,
    swin_large,
    swin_small,
    swin_tiny,
    vit_base,
    vit_large,
    vit_small,
    vit_tiny,
    wide_resnet28w2,
    wide_resnet28w8,
)
from solo.utils.knn import WeightedKNNClassifier
from solo.utils.lars import LARS
from solo.utils.lr_scheduler import LinearWarmupCosineAnnealingLR
from solo.utils.metrics import accuracy_at_k, weighted_mean, multi_label_metrics, compute_balanced_accuracy, multi_attribute_ce_loss
from solo.utils.misc import omegaconf_select, remove_bias_and_norm_from_weight_decay, gather
from solo.utils.momentum import MomentumUpdater, initialize_momentum_params
from solo.losses.radialvicreg import w1_distance_to_chi


def static_lr(
    get_lr: Callable,
    param_group_indexes: Sequence[int],
    lrs_to_replace: Sequence[float],
):
    lrs = get_lr()
    for idx, lr in zip(param_group_indexes, lrs_to_replace):
        lrs[idx] = lr
    return lrs


class BaseMethod(pl.LightningModule):
    _BACKBONES = {
        "resnet18": resnet18,
        "resnet18_with_final_linear": resnet18_with_final_linear,
        "resnet18_from_scratch": resnet18_from_scratch,
        "resnet50": resnet50,
        "vit_tiny": vit_tiny,
        "vit_small": vit_small,
        "vit_base": vit_base,
        "vit_large": vit_large,
        "swin_tiny": swin_tiny,
        "swin_small": swin_small,
        "swin_base": swin_base,
        "swin_large": swin_large,
        "poolformer_s12": poolformer_s12,
        "poolformer_s24": poolformer_s24,
        "poolformer_s36": poolformer_s36,
        "poolformer_m36": poolformer_m36,
        "poolformer_m48": poolformer_m48,
        "convnext_tiny": convnext_tiny,
        "convnext_small": convnext_small,
        "convnext_base": convnext_base,
        "convnext_large": convnext_large,
        "wide_resnet28w2": wide_resnet28w2,
        "wide_resnet28w8": wide_resnet28w8,
    }
    _OPTIMIZERS = {
        "sgd": torch.optim.SGD,
        "lars": LARS,
        "adam": torch.optim.Adam,
        "adamw": torch.optim.AdamW,
    }
    _SCHEDULERS = [
        "reduce",
        "warmup_cosine",
        "step",
        "exponential",
        "none",
    ]

    def __init__(self, cfg: omegaconf.DictConfig):
        """Base model that implements all basic operations for all self-supervised methods.
        It adds shared arguments, extract basic learnable parameters, creates optimizers
        and schedulers, implements basic training_step for any number of crops,
        trains the online classifier and implements validation_step.

        .. note:: Cfg defaults are set in init by calling `cfg = add_and_assert_specific_cfg(cfg)`

        Cfg basic structure:
            backbone:
                name (str): architecture of the base backbone.
                kwargs (dict): extra backbone kwargs.
            data:
                dataset (str): name of the dataset.
                num_classes (int): number of classes.
            max_epochs (int): number of training epochs.

            backbone_params (dict): dict containing extra backbone args, namely:
                #! only for resnet
                zero_init_residual (bool): change the initialization of the resnet backbone.
                #! only for vit
                patch_size (int): size of the patches for ViT.
            optimizer:
                name (str): name of the optimizer.
                batch_size (int): number of samples in the batch.
                lr (float): learning rate.
                weight_decay (float): weight decay for optimizer.
                classifier_lr (float): learning rate for the online linear classifier.
                kwargs (Dict): extra named arguments for the optimizer.
            scheduler:
                name (str): name of the scheduler.
                min_lr (float): minimum learning rate for warmup scheduler. Defaults to 0.0.
                warmup_start_lr (float): initial learning rate for warmup scheduler.
                    Defaults to 0.00003.
                warmup_epochs (float): number of warmup epochs. Defaults to 10.
                lr_decay_steps (Sequence, optional): steps to decay the learning rate if
                    scheduler is step. Defaults to None.
                interval (str): interval to update the lr scheduler. Defaults to 'step'.
            knn_eval:
                enabled (bool): enables online knn evaluation while training.
                k (int): the number of neighbors to use for knn.
            performance:
                disable_channel_last (bool). Disables channel last conversion operation which
                speeds up training considerably. Defaults to False.
                https://pytorch.org/tutorials/intermediate/memory_format_tutorial.html#converting-existing-models
            accumulate_grad_batches (Union[int, None]): number of batches for gradient accumulation.
            num_large_crops (int): number of big crops.
            num_small_crops (int): number of small crops .

        .. note::
            When using distributed data parallel, the batch size and the number of workers are
            specified on a per process basis. Therefore, the total batch size (number of workers)
            is calculated as the product of the number of GPUs with the batch size (number of
            workers).

        .. note::
            The learning rate (base, min and warmup) is automatically scaled linearly
            if using gradient accumulation.

        .. note::
            For CIFAR10/100, the first convolutional and maxpooling layers of the ResNet backbone
            are slightly adjusted to handle lower resolution images (32x32 instead of 224x224).

        """

        super().__init__()

        # add default values and assert that config has the basic needed settings
        cfg = self.add_and_assert_specific_cfg(cfg)

        self.cfg: omegaconf.DictConfig = cfg

        ##############################
        # Backbone
        self.backbone_args: Dict[str, Any] = cfg.backbone.kwargs
        assert cfg.backbone.name in BaseMethod._BACKBONES
        self.base_model: Callable = self._BACKBONES[cfg.backbone.name]
        self.backbone_name: str = cfg.backbone.name
        # initialize backbone
        kwargs = self.backbone_args.copy()

        method: str = cfg.method
        self.backbone: nn.Module = self.base_model(method, **kwargs)
        if self.backbone_name.startswith("resnet"):
            self.features_dim: int = self.backbone.inplanes
            # remove fc layer
            self.backbone.fc = nn.Identity()
            cifar = cfg.data.dataset in ["cifar10", "cifar100"]
            if cifar:
                self.backbone.conv1 = nn.Conv2d(
                    3, 64, kernel_size=3, stride=1, padding=2, bias=False
                )
                self.backbone.maxpool = nn.Identity()
        else:
            self.features_dim: int = self.backbone.num_features # would be 512 for resnet18 and 192 for vit_tiny
        ##############################

        # dataset name attribute
        self.dataset_name = cfg.data.dataset
        
        if self.dataset_name == "CelebA":
            self.dataset_attr_names = ['5_o_Clock_Shadow', 'Arched_Eyebrows', 'Attractive', 'Bags_Under_Eyes', 'Bald', 'Bangs', 'Big_Lips', 'Big_Nose', 'Black_Hair', 'Blond_Hair', 'Blurry', 'Brown_Hair', 'Bushy_Eyebrows', 'Chubby', 'Double_Chin', 'Eyeglasses', 'Goatee', 'Gray_Hair', 'Heavy_Makeup', 'High_Cheekbones', 'Male', 'Mouth_Slightly_Open', 'Mustache', 'Narrow_Eyes', 'No_Beard', 'Oval_Face', 'Pale_Skin', 'Pointy_Nose', 'Receding_Hairline', 'Rosy_Cheeks', 'Sideburns', 'Smiling', 'Straight_Hair', 'Wavy_Hair', 'Wearing_Earrings', 'Wearing_Hat', 'Wearing_Lipstick', 'Wearing_Necklace', 'Wearing_Necktie', 'Young']
        elif self.dataset_name == "3dshapes":
            self.dataset_attr_names = ["floor_hue", "wall_hue", "object_hue", "scale", "shape", "orientation"]
        else:
            self.dataset_attr_names = None
        
        # online linear classifier
        self.num_classes: int = cfg.data.num_classes
        
        if self.dataset_name == "3dshapes":
            self.classifier = nn.ModuleList([
                nn.Linear(self.features_dim, 10),  # floor hue
                nn.Linear(self.features_dim, 10),  # wall hue
                nn.Linear(self.features_dim, 10),  # object hue
                nn.Linear(self.features_dim, 8),   # scale
                nn.Linear(self.features_dim, 4),   # shape
                nn.Linear(self.features_dim, 15),  # orientation
            ])
        else:
            self.classifier: nn.Module = nn.Linear(self.features_dim, self.num_classes)

        # Expressive MLP Probes (encoder and projector) - August 20, 2025
        self.mlp_probe_enabled: bool = omegaconf_select(cfg, "mlp_probe.enabled", True) # deafult to true - so we alawyas have MLP probes enabled
        self.mlp_probe_layers: int = int(omegaconf_select(cfg, "mlp_probe.num_layers", 3)) # default to 3 layers
        self.mlp_probe_encoder = None
        self.mlp_probe_projector = None
        if self.mlp_probe_enabled:
            if self.dataset_name == "3dshapes":
                # One probe per attribute - consider using one probe for all attributes
                self.mlp_probe_encoder = nn.ModuleList([
                    self._build_mlp_classifier(self.features_dim, out_dim, self.mlp_probe_layers)
                    for out_dim in [10, 10, 10, 8, 4, 15]
                ])
            else:
                self.mlp_probe_encoder = self._build_mlp_classifier(
                    self.features_dim, self.num_classes, self.mlp_probe_layers
                )

            # If projector output dim is known in cfg, prebuild projector MLP probe, else lazy-init later
            proj_out_dim_cfg = omegaconf_select(cfg, "method_kwargs.proj_output_dim", None)
            if proj_out_dim_cfg is not None:
                if self.dataset_name == "3dshapes":
                    self.mlp_probe_projector = nn.ModuleList([
                        self._build_mlp_classifier(int(proj_out_dim_cfg), out_dim, self.mlp_probe_layers)
                        for out_dim in [10, 10, 10, 8, 4, 15]
                    ])
                else:
                    self.mlp_probe_projector = self._build_mlp_classifier(int(proj_out_dim_cfg), self.num_classes, self.mlp_probe_layers)

        # Optional Projector Classifier
        self.add_projector_classifier = omegaconf_select(cfg, "method_kwargs.add_projector_classifier", False)
        self.projector_classifier = None
        if self.add_projector_classifier:
            proj_output_dim = omegaconf_select(cfg, "method_kwargs.proj_output_dim", None)
            assert (
                proj_output_dim is not None
            ), "method_kwargs.proj_output_dim must be set when add_projector_classifier is True."
    
            if self.dataset_name == "3dshapes":
                self.projector_classifier = nn.ModuleList([
                    nn.Linear(proj_output_dim, 10),  # floor hue
                    nn.Linear(proj_output_dim, 10),  # wall hue
                    nn.Linear(proj_output_dim, 10),  # object hue
                    nn.Linear(proj_output_dim, 8),   # scale
                    nn.Linear(proj_output_dim, 4),   # shape
                    nn.Linear(proj_output_dim, 15),  # orientation
                ])
            else:
                self.projector_classifier = nn.Linear(proj_output_dim, self.num_classes)

        # training related
        self.max_epochs: int = cfg.max_epochs
        self.accumulate_grad_batches: Union[int, None] = cfg.accumulate_grad_batches

        # optimizer related
        self.optimizer: str = cfg.optimizer.name
        self.batch_size: int = cfg.optimizer.batch_size
        self.lr: float = cfg.optimizer.lr
        self.weight_decay: float = cfg.optimizer.weight_decay
        self.classifier_lr: float = cfg.optimizer.classifier_lr
        self.extra_optimizer_args: Dict[str, Any] = cfg.optimizer.kwargs
        self.exclude_bias_n_norm_wd: bool = cfg.optimizer.exclude_bias_n_norm_wd

        # scheduler related
        self.scheduler: str = cfg.scheduler.name
        self.lr_decay_steps: Union[List[int], None] = cfg.scheduler.lr_decay_steps
        self.min_lr: float = cfg.scheduler.min_lr
        self.warmup_start_lr: float = cfg.scheduler.warmup_start_lr
        self.warmup_epochs: int = cfg.scheduler.warmup_epochs
        self.scheduler_interval: str = cfg.scheduler.interval
        assert self.scheduler_interval in ["step", "epoch"]
        if self.scheduler_interval == "step":
            logging.warn(
                f"Using scheduler_interval={self.scheduler_interval} might generate "
                "issues when resuming a checkpoint."
            )

        # if accumulating gradient then scale lr
        if self.accumulate_grad_batches:
            self.lr = self.lr * self.accumulate_grad_batches
            self.classifier_lr = self.classifier_lr * self.accumulate_grad_batches
            self.min_lr = self.min_lr * self.accumulate_grad_batches
            self.warmup_start_lr = self.warmup_start_lr * self.accumulate_grad_batches

        # data-related
        self.num_large_crops: int = cfg.data.num_large_crops
        self.num_small_crops: int = cfg.data.num_small_crops
        self.num_crops: int = self.num_large_crops + self.num_small_crops
        # turn on multicrop if there are small crops
        self.multicrop: bool = self.num_small_crops != 0

        # knn online evaluation
        self.knn_eval: bool = cfg.knn_eval.enabled
        self.knn_k: int = cfg.knn_eval.k
        self.knn_encoder = None
        self.knn_projector = None
        if self.knn_eval:
            knn_params = {
                "k": self.knn_k,
                "distance_fx": cfg.knn_eval.distance_func,
                "max_distance_matrix_size": cfg.knn_eval.max_distance_matrix_size, # 5e5
            }
            self.knn_encoder = WeightedKNNClassifier(**knn_params)
            # Initialize projector KNN; it will only be used if projector features are available
            self.knn_projector = WeightedKNNClassifier(**knn_params)

        # ---------------------------------------------
        # Radius histogram configuration (encoder & projector)
        # ---------------------------------------------
        self.radius_hist_enabled: bool = omegaconf_select(cfg, "radius_hist.enabled", False)
        self._hist_max_snapshots: int = omegaconf_select(cfg, "radius_hist.max_snapshots", 50)
        total_epochs: int = self.max_epochs if hasattr(self, "max_epochs") else 0
        num_snaps: int = min(self._hist_max_snapshots, total_epochs) if total_epochs > 0 else 0
        if num_snaps > 0:
            idxs = np.linspace(0, total_epochs - 1, num_snaps, dtype=int).tolist()
            self._hist_snapshot_epochs = set(int(i) for i in idxs)
        else:
            self._hist_snapshot_epochs = set()
        # buffers used within a validation epoch
        self._val_encoder_radii_buffer: List[torch.Tensor] = []
        self._val_projector_radii_buffer: List[torch.Tensor] = []
        self._val_current_encoder_D: int = 0
        self._val_current_projector_D: int = 0
        # snapshots saved across training
        self._hist_encoder_radii_snapshots: List[tuple] = []  # (epoch, radii_np, D)
        self._hist_projector_radii_snapshots: List[tuple] = []  # (epoch, radii_np, D)

        # for performance
        self.no_channel_last = cfg.performance.disable_channel_last

        # keep track of training metrics
        if self.dataset_name == "CelebA":
            self.training_step_outputs_for_balanced_acc = {}
            self.validation_step_outputs_for_balanced_acc = {}
            self.balanced_acc_quantities_order = ["tps", "fns", "tns", "fps"]
        else:
            self.training_step_outputs_for_balanced_acc = {}
            self.validation_step_outputs_for_balanced_acc = {}
            self.balanced_acc_quantities_order = None

        # keep track of validation metrics
        self.validation_step_outputs = []

    @staticmethod
    def add_and_assert_specific_cfg(cfg: omegaconf.DictConfig) -> omegaconf.DictConfig:
        """Adds method specific default values/checks for config.

        Args:
            cfg (omegaconf.DictConfig): DictConfig object.

        Returns:
            omegaconf.DictConfig: same as the argument, used to avoid errors.
        """

        # default for extra backbone kwargs (use pytorch's default if not available)
        cfg.backbone.kwargs = omegaconf_select(cfg, "backbone.kwargs", {})

        # default parameters for optimizer
        cfg.optimizer.exclude_bias_n_norm_wd = omegaconf_select(
            cfg, "optimizer.exclude_bias_n_norm_wd", False
        )
        # default for extra optimizer kwargs (use pytorch's default if not available)
        cfg.optimizer.kwargs = omegaconf_select(cfg, "optimizer.kwargs", {})

        # default for acc grad batches
        cfg.accumulate_grad_batches = omegaconf_select(cfg, "accumulate_grad_batches", 1)

        # default parameters for the scheduler
        cfg.scheduler.lr_decay_steps = omegaconf_select(cfg, "scheduler.lr_decay_steps", None)
        cfg.scheduler.min_lr = omegaconf_select(cfg, "scheduler.min_lr", 0.0)
        cfg.scheduler.warmup_start_lr = omegaconf_select(cfg, "scheduler.warmup_start_lr", 3e-5)
        cfg.scheduler.warmup_epochs = omegaconf_select(cfg, "scheduler.warmup_epochs", 10)
        cfg.scheduler.interval = omegaconf_select(cfg, "scheduler.interval", "step")

        # default parameters for knn eval
        cfg.knn_eval = omegaconf_select(cfg, "knn_eval", {})
        cfg.knn_eval.enabled = omegaconf_select(cfg, "knn_eval.enabled", False)
        cfg.knn_eval.k = omegaconf_select(cfg, "knn_eval.k", 20)
        cfg.knn_eval.distance_func = omegaconf_select(cfg, "knn_eval.distance_func", "cosine")
        cfg.knn_eval.max_distance_matrix_size = omegaconf_select(
            cfg, "knn_eval.max_distance_matrix_size", int(1e6)
        )

        # default parameters for radius histogram plotting
        cfg.radius_hist = omegaconf_select(cfg, "radius_hist", {})
        cfg.radius_hist.enabled = omegaconf_select(cfg, "radius_hist.enabled", False)
        cfg.radius_hist.max_snapshots = omegaconf_select(cfg, "radius_hist.max_snapshots", 50)

        # default parameters for MLP probe
        cfg.mlp_probe = omegaconf_select(cfg, "mlp_probe", {})
        cfg.mlp_probe.enabled = omegaconf_select(cfg, "mlp_probe.enabled", True)
        cfg.mlp_probe.num_layers = omegaconf_select(cfg, "mlp_probe.num_layers", 3)

        # default parameters for performance optimization
        cfg.performance = omegaconf_select(cfg, "performance", {})
        cfg.performance.disable_channel_last = omegaconf_select(
            cfg, "performance.disable_channel_last", False
        )

        # default empty parameters for method-specific kwargs
        cfg.method_kwargs = omegaconf_select(cfg, "method_kwargs", {})

        cfg.method_kwargs.add_projector_classifier = omegaconf_select(
            cfg, "method_kwargs.add_projector_classifier", False
        )
        # method_kwargs.proj_output_dim is expected to be set by the method if add_projector_classifier is true.
      

        return cfg

    @property
    def learnable_params(self) -> List[Dict[str, Any]]:
        """Defines learnable parameters for the base class.

        Returns:
            List[Dict[str, Any]]:
                list of dicts containing learnable parameters and possible settings.
        """

        params = [
            {"name": "backbone", "params": self.backbone.parameters()},
            {
                "name": "classifier",
                "params": self.classifier.parameters(),
                "lr": self.classifier_lr,
                "weight_decay": 0,
            },
        ]
        if self.mlp_probe_enabled and self.mlp_probe_encoder is not None:
            params.append({
                "name": "mlp_probe_encoder",
                "params": self.mlp_probe_encoder.parameters(),
                "lr": self.classifier_lr,
                "weight_decay": 0,
            })
        if self.add_projector_classifier and self.projector_classifier is not None:
            params.append(
                {
                    "name": "projector_classifier",
                    "params": self.projector_classifier.parameters(),
                    "lr": self.classifier_lr,
                    "weight_decay": 0,
                }
            )
        if self.mlp_probe_enabled and self.mlp_probe_projector is not None:
            params.append({
                "name": "mlp_probe_projector",
                "params": self.mlp_probe_projector.parameters(),
                "lr": self.classifier_lr,
                "weight_decay": 0,
            })
        return params

    def configure_optimizers(self) -> Tuple[List, List]:
        """Collects learnable parameters and configures the optimizer and learning rate scheduler.

        Returns:
            Tuple[List, List]: two lists containing the optimizer and the scheduler.
        """

        learnable_params = self.learnable_params

        # exclude bias and norm from weight decay
        if self.exclude_bias_n_norm_wd:
            learnable_params = remove_bias_and_norm_from_weight_decay(learnable_params)

        # indexes of parameters without lr scheduler
        idxs_no_scheduler = [i for i, m in enumerate(learnable_params) if m.pop("static_lr", False)]

        assert self.optimizer in self._OPTIMIZERS
        optimizer = self._OPTIMIZERS[self.optimizer]

        # create optimizer
        optimizer = optimizer(
            learnable_params,
            lr=self.lr,
            weight_decay=self.weight_decay,
            **self.extra_optimizer_args,
        )

        if self.scheduler.lower() == "none":
            return [optimizer], []

        if self._trainer is None:
            # returning optimizer only, scheduler will be created later by the real trainer
            return [optimizer], []

        if self.scheduler == "warmup_cosine":
            max_warmup_steps = (
                self.warmup_epochs * (self.trainer.estimated_stepping_batches / self.max_epochs)
                if self.scheduler_interval == "step"
                else self.warmup_epochs
            )
            max_scheduler_steps = (
                self.trainer.estimated_stepping_batches
                if self.scheduler_interval == "step"
                else self.max_epochs
            )
            scheduler = {
                "scheduler": LinearWarmupCosineAnnealingLR(
                    optimizer,
                    warmup_epochs=max_warmup_steps,
                    max_epochs=max_scheduler_steps,
                    warmup_start_lr=self.warmup_start_lr if self.warmup_epochs > 0 else self.lr,
                    eta_min=self.min_lr,
                ),
                "interval": self.scheduler_interval,
                "frequency": 1,
            }
        elif self.scheduler == "step":
            scheduler = MultiStepLR(optimizer, self.lr_decay_steps)
        else:
            raise ValueError(f"{self.scheduler} not in (warmup_cosine, cosine, step)")

        if idxs_no_scheduler:
            partial_fn = partial(
                static_lr,
                get_lr=scheduler["scheduler"].get_lr
                if isinstance(scheduler, dict)
                else scheduler.get_lr,
                param_group_indexes=idxs_no_scheduler,
                lrs_to_replace=[self.lr] * len(idxs_no_scheduler),
            )
            if isinstance(scheduler, dict):
                scheduler["scheduler"].get_lr = partial_fn
            else:
                scheduler.get_lr = partial_fn

        return [optimizer], [scheduler]

    def optimizer_zero_grad(self, epoch, batch_idx, optimizer, *_):
        """
        This improves performance marginally. It should be fine
        since we are not affected by any of the downsides descrited in
        https://pytorch.org/docs/stable/generated/torch.optim.Optimizer.zero_grad.html#torch.optim.Optimizer.zero_grad

        Implemented as in here
        https://lightning.ai/docs/pytorch/latest/advanced/speed.html?highlight=set%20grads%20none
        """
        try:
            optimizer.zero_grad(set_to_none=True)
        except:
            optimizer.zero_grad()

    def forward(self, X) -> Dict:
        """Basic forward method. Children methods should call this function,
        modify the ouputs (without deleting anything) and return it.
        This base forward compute features and main classifier logits.
        Subclasses will override this to add projector outputs and projector_classifier_logits if enabled.

        Args:
            X (torch.Tensor): batch of images in tensor format.

        Returns:
            Dict: dict of logits and features.
        """

        if not self.no_channel_last:
            X = X.to(memory_format=torch.channels_last)
        feats = self.backbone(X)

        if self.dataset_name == "3dshapes":
            logits = [classifier_i(feats.detach()) for classifier_i in self.classifier]
        else:
            logits = self.classifier(feats.detach())
        
        return {"logits": logits, "feats": feats}

    def multicrop_forward(self, X: torch.tensor) -> Dict[str, Any]:
        """Basic multicrop forward method that performs the forward pass
        for the multicrop views. Children classes can override this method to
        add new outputs but should still call this function. Make sure
        that this method and its overrides always return a dict.

        Args:
            X (torch.Tensor): batch of images in tensor format.

        Returns:
            Dict: dict of features.
        """

        if not self.no_channel_last:
            X = X.to(memory_format=torch.channels_last)
        feats = self.backbone(X)
        return {"feats": feats}

    def _projector_classifier_step(self, z: torch.Tensor, targets: torch.Tensor) -> Dict:
        """Helper to compute projector classifier loss and acc.
        Assumes z is the output of a projector for a single view
        IMPORTANT: if x and x augmented than send it twice and average in the SSL method class
        """
        if self.projector_classifier is None:
            return {}

        # Detach z to ensure gradients only flow to the projector_classifier, not further back.
        if self.dataset_name == "3dshapes":
            projector_logits = [projector_classifier_i(z.detach()) for projector_classifier_i in self.projector_classifier]
        else:
            projector_logits = self.projector_classifier(z.detach())

        if self.dataset_name == "CelebA":
            loss = F.binary_cross_entropy_with_logits(projector_logits, targets.float())

            # no logging in the training step
            # proj_exact_match_ratio, proj_hamming_score, proj_average_jaccard_index, separate_metrics = multi_label_metrics(projector_logits, targets, self.dataset_attr_names, nn_type="proj")

            return {
                "proj_loss": loss, 
                # "aggregate_metrics/proj_exact_match_ratio": proj_exact_match_ratio, 
                # "aggregate_metrics/proj_hamming_score": proj_hamming_score,
                # "aggregate_metrics/proj_average_jaccard_index": proj_average_jaccard_index,
                # **separate_metrics
            }
        elif self.dataset_name == "3dshapes":
            loss, _ = multi_attribute_ce_loss(projector_logits, targets, self.dataset_attr_names, nn_type="proj")
            return {
                "proj_loss": loss, 
            }
        else:

            loss = F.cross_entropy(projector_logits, targets, ignore_index=-1)
            top_k_max = min(5, projector_logits.size(1))
            acc1, acc5 = accuracy_at_k(projector_logits, targets, top_k=(1, top_k_max))

            return {"proj_loss": loss, "proj_acc1": acc1, "proj_acc5": acc5}

    def _build_mlp_classifier(self, in_dim: int, num_classes: int, num_layers: int) -> nn.Sequential:
        """Builds an MLP classifier with num_layers linear layers (>=1).
        Hidden size: min(int(1.5 * in_dim), 8192). Uses ReLU and BatchNorm between hidden layers.
        """
        num_layers = max(1, int(num_layers))
        hidden_dim = int(min(max(1, round(1.5 * in_dim)), 8192)) # hidden dimension is at most 8192 and 1.5 * input dimension

        layers: List[nn.Module] = []
        if num_layers == 1:
            layers.append(nn.Linear(in_dim, num_classes))
        else:
            # first hidden
            layers.extend([nn.Linear(in_dim, hidden_dim), nn.BatchNorm1d(hidden_dim), nn.ReLU()])
            # middle hidden layers (num_layers - 2)
            for _ in range(num_layers - 2):
                layers.extend([nn.Linear(hidden_dim, hidden_dim), nn.BatchNorm1d(hidden_dim), nn.ReLU()])
            # final layer
            layers.append(nn.Linear(hidden_dim, num_classes))
        return nn.Sequential(*layers)

    def _base_shared_step(self, X: torch.Tensor, targets: torch.Tensor) -> Dict:
        """Forwards a batch of images X and computes the classification loss, the logits, the
        features, acc@1 and acc@5.

        Args:
            X (torch.Tensor): batch of images in tensor format.
            targets (torch.Tensor): batch of labels for X.

        Returns:
            Dict: dict containing the classification loss, logits, features, acc@1 and acc@5.
        """

        out = self(X) # This will call the potentially overridden forward method of the subclass
        
        # Pull common tensors
        feats = out.get("feats")
        # Metrics for the main encoder classifier (always present)
        logits = out["logits"]
        
        if self.dataset_name == "CelebA":
            encoder_classifier_loss = F.binary_cross_entropy_with_logits(logits, targets.float())
            encoder_exact_match_ratio, encoder_hamming_score, encoder_average_jaccard_index, separate_metrics = multi_label_metrics(logits, targets, self.dataset_attr_names, nn_type="encoder")

            # Update initial output dictionary with encoder classifier metrics
            out.update({
                "loss": encoder_classifier_loss, 
                "aggregate_metrics/encoder_exact_match_ratio": encoder_exact_match_ratio, 
                "aggregate_metrics/encoder_hamming_score": encoder_hamming_score,
                "aggregate_metrics/encoder_average_jaccard_index": encoder_average_jaccard_index,
                **separate_metrics
            })
        elif self.dataset_name == "3dshapes":
            encoder_classifier_loss, encoder_separate_metrics = multi_attribute_ce_loss(logits, targets, self.dataset_attr_names, nn_type="encoder")
                        
            # Update initial output dictionary with encoder classifier metrics
            out.update({
                "loss": encoder_classifier_loss, 
                **{"encoder_separate_metrics": encoder_separate_metrics},
                # **separate_metrics,
            })
        else:
            encoder_classifier_loss = F.cross_entropy(logits, targets, ignore_index=-1)
    
            top_k_max = min(5, logits.size(1))
            encoder_acc1, encoder_acc5 = accuracy_at_k(logits, targets, top_k=(1, top_k_max))

            # Update initial output dictionary with encoder classifier metrics
            out.update({
                "loss": encoder_classifier_loss, 
                "acc1": encoder_acc1, 
                "acc5": encoder_acc5
            })

        # MLP probe on encoder features (optional)
        if self.mlp_probe_enabled and self.mlp_probe_encoder is not None:
            if self.dataset_name == "CelebA":
                mlp_logits = self.mlp_probe_encoder(feats.detach())
                mlp_loss = F.binary_cross_entropy_with_logits(mlp_logits, targets.float())
                out.update({"mlp_loss": mlp_loss})
            elif self.dataset_name == "3dshapes":
                mlp_logits = [probe(feats.detach()) for probe in self.mlp_probe_encoder]
                mlp_loss, _ = multi_attribute_ce_loss(mlp_logits, targets, self.dataset_attr_names, nn_type="mlp")
                out.update({"mlp_loss": mlp_loss})
            else:
                mlp_logits = self.mlp_probe_encoder(feats.detach())
                mlp_loss = F.cross_entropy(mlp_logits, targets, ignore_index=-1)
                top_k_max_mlp = min(5, mlp_logits.size(1))
                mlp_acc1, mlp_acc5 = accuracy_at_k(mlp_logits, targets, top_k=(1, top_k_max_mlp))
                out.update({"mlp_loss": mlp_loss, "mlp_acc1": mlp_acc1, "mlp_acc5": mlp_acc5})

            # train the probe jointly with the online linear classifier
            out["loss"] = out["loss"] + out["mlp_loss"]

        # Metrics for the projector classifier (if provided by overridden forward and enabled)
        if "projector_logits" in out and self.projector_classifier is not None:

            # get projector logits
            projector_logits = out["projector_logits"]

            if self.dataset_name == "CelebA":
                proj_classifier_loss = F.binary_cross_entropy_with_logits(projector_logits, targets.float())
                proj_exact_match_ratio, proj_hamming_score, proj_average_jaccard_index, separate_metrics = multi_label_metrics(projector_logits, targets, self.dataset_attr_names, nn_type="proj")

                # Update initial output dictionary with proj classifier metrics
                out.update({
                    "proj_loss": proj_classifier_loss, 
                    "aggregate_metrics/proj_exact_match_ratio": proj_exact_match_ratio, 
                    "aggregate_metrics/proj_hamming_score": proj_hamming_score,
                    "aggregate_metrics/proj_average_jaccard_index": proj_average_jaccard_index,
                    **separate_metrics
                })
            elif self.dataset_name == "3dshapes":
                proj_classifier_loss, proj_separate_metrics = multi_attribute_ce_loss(projector_logits, targets, self.dataset_attr_names, nn_type="proj")
                out.update({
                    "proj_loss": proj_classifier_loss, 
                    **{"proj_separate_metrics": proj_separate_metrics},
                    # **separate_metrics,
                })

            else:
                proj_loss = F.cross_entropy(projector_logits, targets, ignore_index=-1)
                proj_top_k = min(5, projector_logits.size(1))
                proj_acc1, proj_acc5 = accuracy_at_k(projector_logits, targets, top_k=(1, proj_top_k))
                out.update({
                    "proj_loss": proj_loss, 
                    "proj_acc1": proj_acc1, 
                    "proj_acc5": proj_acc5
                })
                
        # Optional MLP probe on projector features 'z' (requires child method to supply 'z')
        if self.mlp_probe_enabled and "z" in out and isinstance(out["z"], torch.Tensor):
            z_tensor = out["z"].detach()
            # Lazy initialize projector MLP probe on first use
            if self.mlp_probe_projector is None:
                if self.dataset_name == "3dshapes":
                    self.mlp_probe_projector = nn.ModuleList([
                        self._build_mlp_classifier(z_tensor.size(1), out_dim, self.mlp_probe_layers)
                        for out_dim in [10, 10, 10, 8, 4, 15]
                    ])
                else:
                    self.mlp_probe_projector = self._build_mlp_classifier(z_tensor.size(1), self.num_classes, self.mlp_probe_layers)
            # Ensure optimizer sees newly created params
            # (Only effective for fresh Trainer/first configure_optimizers call; safe in practice for long runs)

            if self.dataset_name == "CelebA":
                proj_mlp_logits = self.mlp_probe_projector(z_tensor)
                proj_mlp_loss = F.binary_cross_entropy_with_logits(proj_mlp_logits, targets.float())
                out.update({"proj_mlp_loss": proj_mlp_loss})
                out["loss"] = out["loss"] + proj_mlp_loss
            elif self.dataset_name == "3dshapes":
                proj_mlp_logits = [probe(z_tensor) for probe in self.mlp_probe_projector]
                proj_mlp_loss, _ = multi_attribute_ce_loss(proj_mlp_logits, targets, self.dataset_attr_names, nn_type="proj_mlp")
                out.update({"proj_mlp_loss": proj_mlp_loss})
                out["loss"] = out["loss"] + proj_mlp_loss
            else:
                proj_mlp_logits = self.mlp_probe_projector(z_tensor)
                proj_mlp_loss = F.cross_entropy(proj_mlp_logits, targets, ignore_index=-1)
                top_k_max_p = min(5, proj_mlp_logits.size(1))
                proj_mlp_acc1, proj_mlp_acc5 = accuracy_at_k(proj_mlp_logits, targets, top_k=(1, top_k_max_p))
                out.update({
                    "proj_mlp_loss": proj_mlp_loss,
                    "proj_mlp_acc1": proj_mlp_acc1,
                    "proj_mlp_acc5": proj_mlp_acc5,
                })
                out["loss"] = out["loss"] + proj_mlp_loss

        return out

    def base_training_step(self, X: torch.Tensor, targets: torch.Tensor) -> Dict:
        """Allows user to re-write how the forward step behaves for the training_step.
        Should always return a dict containing, at least, "loss", "acc1" and "acc5".
        Defaults to _base_shared_step

        Args:
            X (torch.Tensor): batch of images in tensor format.
            targets (torch.Tensor): batch of labels for X.

        Returns:
            Dict: dict containing the classification loss, logits, features, acc@1 and acc@5.
        """

        return self._base_shared_step(X, targets)

    def training_step(self, batch: List[Any], batch_idx: int) -> Dict[str, Any]:
        """Training step for pytorch lightning. It does all the shared operations, such as
        forwarding the crops, computing logits and computing statistics.

        Args:
            batch (List[Any]): a batch of data in the format of [img_indexes, [X], Y], where
                [X] is a list of size self.num_crops containing batches of images.
            batch_idx (int): index of the batch.

        Returns:
            Dict[str, Any]: dict with the classification loss, features and logits.
        """

        _, X, targets = batch

        X = [X] if isinstance(X, torch.Tensor) else X

        # check that we received the desired number of crops
        assert len(X) == self.num_crops

        outs = [self.base_training_step(x, targets) for x in X[: self.num_large_crops]]
        outs = {k: [out[k] for out in outs] for k in outs[0].keys()}

        if self.multicrop:
            multicrop_outs = [self.multicrop_forward(x) for x in X[self.num_large_crops :]]
            for k in multicrop_outs[0].keys():
                outs[k] = outs.get(k, []) + [out[k] for out in multicrop_outs]

        # loss and stats
        outs["loss"] = sum(outs["loss"]) / self.num_large_crops

        if self.dataset_name == "CelebA":
            outs_for_logging = {f"train_{k}": sum(v) / self.num_large_crops for k, v in outs.items() if "metrics" in k}

            ### Balance_Acc ###
            for outcome_type in self.balanced_acc_quantities_order:
                assert len(outs[f'encoder_{outcome_type}']) == 2 # TODO: need to implement the computation for more than two views
                self.training_step_outputs_for_balanced_acc["separate_metrics"]["encoder"][outcome_type] += (outs[f'encoder_{outcome_type}'][0]+outs[f'encoder_{outcome_type}'][1])/2
                self.training_step_outputs_for_balanced_acc["separate_metrics"]["proj"][outcome_type] += (outs[f'proj_{outcome_type}'][0]+outs[f'proj_{outcome_type}'][1])/2

            encoder_balanced_acc = compute_balanced_accuracy(
                self.training_step_outputs_for_balanced_acc["separate_metrics"]["encoder"]["tps"],
                self.training_step_outputs_for_balanced_acc["separate_metrics"]["encoder"]["fns"],
                self.training_step_outputs_for_balanced_acc["separate_metrics"]["encoder"]["tns"],
                self.training_step_outputs_for_balanced_acc["separate_metrics"]["encoder"]["fps"],
            )

            proj_balanced_acc = compute_balanced_accuracy(
                self.training_step_outputs_for_balanced_acc["separate_metrics"]["proj"]["tps"],
                self.training_step_outputs_for_balanced_acc["separate_metrics"]["proj"]["fns"],
                self.training_step_outputs_for_balanced_acc["separate_metrics"]["proj"]["tns"],
                self.training_step_outputs_for_balanced_acc["separate_metrics"]["proj"]["fps"],
            )
            
            encoder_balanced_acc_dict = dict(zip([f"train_separate_metrics/encoder_{k}" for k in self.dataset_attr_names], encoder_balanced_acc))
            proj_balanced_acc_dict = dict(zip([f"train_separate_metrics/proj_{k}" for k in self.dataset_attr_names], proj_balanced_acc))

            metrics = {
                "train_class_loss": outs["loss"],
                **outs_for_logging,
                **encoder_balanced_acc_dict,
                **proj_balanced_acc_dict,
            }
        elif self.dataset_name == "3dshapes":
            outs_for_logging_encoder = {f"train_encoder_separate_metrics/{k}": sum(d[k] for d in outs['encoder_separate_metrics']) / len(outs['encoder_separate_metrics']) for k in outs['encoder_separate_metrics'][0]}
            outs_for_logging_proj = {f"train_proj_separate_metrics/{k}": sum(d[k] for d in outs['proj_separate_metrics']) / len(outs['proj_separate_metrics']) for k in outs['proj_separate_metrics'][0]}
                        
            metrics = {
                "train_class_loss": outs["loss"],
                **outs_for_logging_encoder,
                **outs_for_logging_proj,
            }
        else:
            outs["acc1"] = sum(outs["acc1"]) / self.num_large_crops
            outs["acc5"] = sum(outs["acc5"]) / self.num_large_crops

            metrics = {
                "train_class_loss": outs["loss"],
                "train_acc1": outs["acc1"],
                "train_acc5": outs["acc5"],
            }

        self.log_dict(metrics, on_epoch=True, sync_dist=True)

        if self.knn_eval:
            if self.dataset_name == "CelebA":
                raise NotImplementedError("double check or implement knn eval for celebA")

            if self.dataset_name == "3dshapes":
                raise NotImplementedError("double check or implement knn eval for 3dshapes")
            targets_repeated = targets.repeat(self.num_large_crops)
            mask = targets_repeated != -1

            # Encoder KNN
            if self.knn_encoder is not None:
                encoder_feats_cat = torch.cat(outs["feats"][: self.num_large_crops])[mask].detach().cpu()
                if encoder_feats_cat.size(0) == targets_repeated[mask].size(0):
                    self.knn_encoder(
                        train_features=encoder_feats_cat,
                        train_targets=targets_repeated[mask],
                    )

            # Projector KNN - if projector features 'z' are available in outs
            if "z" in outs and self.knn_projector is not None:
                # outs["z"] should be a list of tensors, similar to outs["feats"]
                projector_feats_list = outs.get("z", [])[: self.num_large_crops]
                # Ensure all elements in projector_feats_list are tensors and the list is not empty
                if projector_feats_list and all(isinstance(pf, torch.Tensor) for pf in projector_feats_list):
                    try:
                        projector_feats_cat = torch.cat(projector_feats_list)[mask].detach().cpu()
                        if projector_feats_cat.size(0) == targets_repeated[mask].size(0): # Check batch size consistency
                             self.knn_projector(
                                train_features=projector_feats_cat,
                                train_targets=targets_repeated[mask],
                            )
                    except Exception as e:
                        logging.warning(f"Skipping projector KNN update in training_step due to error: {e}. We expect the features to be called 'z'-- if 'z' in outs. Projector feats list length: {len(projector_feats_list)}")
        return outs

    def base_validation_step(self, X: torch.Tensor, targets: torch.Tensor) -> Dict:
        """Allows user to re-write how the forward step behaves for the validation_step.
        Should always return a dict containing, at least, "loss", "acc1" and "acc5".
        Defaults to _base_shared_step

        Args:
            X (torch.Tensor): batch of images in tensor format.
            targets (torch.Tensor): batch of labels for X.

        Returns:
            Dict: dict containing the classification loss, logits, features, acc@1 and acc@5.
        """

        return self._base_shared_step(X, targets)

    def validation_step(
        self,
        batch: List[torch.Tensor],
        batch_idx: int,
        dataloader_idx: int = None,
        update_validation_step_outputs: bool = True,
    ) -> Dict[str, Any]:
        """Validation step for pytorch lightning. It does all the shared operations, such as
        forwarding a batch of images, computing logits and computing metrics.

        Args:
            batch (List[torch.Tensor]):a batch of data in the format of [img_indexes, X, Y].
            batch_idx (int): index of the batch.
            update_validation_step_outputs (bool): whether or not to append the
                metrics to validation_step_outputs

        Returns:
            Dict[str, Any]: dict with the batch_size (used for averaging), the classification loss
                and accuracies.
        """

        X, targets = batch
        batch_size = targets.size(0)

        out = self.base_validation_step(X, targets)

        if self.knn_eval and not self.trainer.sanity_checking:
            if self.dataset_name == "CelebA":
                raise NotImplementedError("double check or implement knn eval for celebA")

            if self.dataset_name == "3dshapes":
                raise NotImplementedError("double check or implement knn eval for 3dshapes")

            # Encoder KNN
            if self.knn_encoder is not None:
                self.knn_encoder(test_features=out["feats"].detach().cpu(), test_targets=targets.detach().cpu())

            # Projector KNN - if projector features 'z' are available
            if "z" in out and out["z"] is not None and self.knn_projector is not None:
                try:
                    if out["z"].size(0) == targets.detach().size(0): # Check batch size consistency
                        self.knn_projector(test_features=out["z"].detach().cpu(), test_targets=targets.detach().cpu())
                except Exception as e:
                    logging.warning(f"Skipping projector KNN update in validation_step due to error: {e}")

        # Collect radii for histogram snapshots (encoder and projector)
        if self.radius_hist_enabled and not self.trainer.sanity_checking:
            try:
                feats = out["feats"]
                gathered_feats = gather(feats.contiguous())
                if self.current_epoch in getattr(self, "_hist_snapshot_epochs", set()):
                    r_enc = torch.norm(gathered_feats, dim=1)
                    self._val_encoder_radii_buffer.append(r_enc.detach().cpu())
                    if self._val_current_encoder_D == 0:
                        self._val_current_encoder_D = gathered_feats.size(1)
            except Exception:
                pass

            # Projector features (if available)
            try:
                if "z" in out and out["z"] is not None and isinstance(out["z"], torch.Tensor):
                    gathered_z = gather(out["z"].contiguous())
                    if self.current_epoch in getattr(self, "_hist_snapshot_epochs", set()):
                        r_proj = torch.norm(gathered_z, dim=1)
                        self._val_projector_radii_buffer.append(r_proj.detach().cpu())
                        if self._val_current_projector_D == 0:
                            self._val_current_projector_D = gathered_z.size(1)
            except Exception:
                pass


        if self.dataset_name == "CelebA":
            out_for_logging = {f"val_{k}": v for k, v in out.items() if "metrics" in k}


            ### Balance_Acc ###
            for outcome_type in self.balanced_acc_quantities_order:
                # self.validation_step_outputs_for_balanced_acc["separate_metrics"]["encoder"][outcome_type] += (outs[f'encoder_{outcome_type}'][0]+outs[f'encoder_{outcome_type}'][1])/2
                # self.validation_step_outputs_for_balanced_acc["separate_metrics"]["proj"][outcome_type] += (outs[f'proj_{outcome_type}'][0]+outs[f'proj_{outcome_type}'][1])/2
                self.validation_step_outputs_for_balanced_acc["separate_metrics"]["encoder"][outcome_type] += outs[f'encoder_{outcome_type}'].clone()
                self.validation_step_outputs_for_balanced_acc["separate_metrics"]["proj"][outcome_type] += outs[f'proj_{outcome_type}'].clone()

            encoder_balanced_acc = compute_balanced_accuracy(
                self.validation_step_outputs_for_balanced_acc["separate_metrics"]["encoder"]["tps"],
                self.validation_step_outputs_for_balanced_acc["separate_metrics"]["encoder"]["fns"],
                self.validation_step_outputs_for_balanced_acc["separate_metrics"]["encoder"]["tns"],
                self.validation_step_outputs_for_balanced_acc["separate_metrics"]["encoder"]["fps"],
            )

            proj_balanced_acc = compute_balanced_accuracy(
                self.validation_step_outputs_for_balanced_acc["separate_metrics"]["proj"]["tps"],
                self.validation_step_outputs_for_balanced_acc["separate_metrics"]["proj"]["fns"],
                self.validation_step_outputs_for_balanced_acc["separate_metrics"]["proj"]["tns"],
                self.validation_step_outputs_for_balanced_acc["separate_metrics"]["proj"]["fps"],
            )
            
            encoder_balanced_acc_dict = dict(zip([f"val_separate_metrics/encoder_{k}" for k in self.dataset_attr_names], encoder_balanced_acc))
            proj_balanced_acc_dict = dict(zip([f"val_separate_metrics/proj_{k}" for k in self.dataset_attr_names], proj_balanced_acc))

            
            metrics = {
                "batch_size": batch_size,
                "val_loss": out["loss"],
                **out_for_logging,
                **encoder_balanced_acc_dict,
                **proj_balanced_acc_dict,
            }
            if "proj_loss" in out:
                metrics.update({
                    "val_proj_loss": out["proj_loss"],
                    # "aggregate_metrics/val_proj_hamming_score": out["aggregate_metrics/proj_hamming_score"],
                    # "aggregate_metrics/val_proj_exact_match_ratio": out["aggregate_metrics/proj_exact_match_ratio"],
                    # "aggregate_metrics/val_proj_average_jaccard_index": out["aggregate_metrics/proj_average_jaccard_index"],
                })
        elif self.dataset_name == "3dshapes":
            raise NotImplementedError
            
            # # TODO: double check
            # outs_for_logging = {f"val_separate_metrics/{k}": sum(d[k] for d in outs['separate_metrics']) / len(outs['separate_metrics']) for k in outs['separate_metrics'][0]}

            # metrics = {
            #     "batch_size": batch_size,
            #     "val_loss": out["loss"],
            #     **outs_for_logging,
            # }
            # if "proj_loss" in out:
            #     metrics.update({
            #         "val_proj_loss": out["proj_loss"],
            #     })
        else:
            metrics = {
                "batch_size": batch_size,
                "val_loss": out["loss"],
                "val_acc1": out["acc1"],
                "val_acc5": out["acc5"],
            }
            if "proj_loss" in out:
                metrics.update({
                    "val_proj_loss": out["proj_loss"],
                    "val_proj_acc1": out["proj_acc1"],
                    "val_proj_acc5": out["proj_acc5"],
                })
            # Add MLP probe validation metrics if available
            if "mlp_acc1" in out and "mlp_acc5" in out:
                metrics.update({
                    "val_mlp_acc1": out["mlp_acc1"],
                    "val_mlp_acc5": out["mlp_acc5"],
                })
            if "proj_mlp_acc1" in out and "proj_mlp_acc5" in out:
                metrics.update({
                    "val_proj_mlp_acc1": out["proj_mlp_acc1"],
                    "val_proj_mlp_acc5": out["proj_mlp_acc5"],
                })

        if update_validation_step_outputs:
            self.validation_step_outputs.append(metrics)
        return metrics

    def on_train_epoch_start(self):
        """Prepares the training metrics dict for logging."""
        if self.dataset_name == "CelebA": # to help 
            self.training_step_outputs_for_balanced_acc = {
                "separate_metrics": {
                    "encoder": {
                        k: torch.zeros(len(self.dataset_attr_names), device=self.device)
                        for k in ["tps", "fns", "tns", "fps"]
                    },
                    "proj": {
                        k: torch.zeros(len(self.dataset_attr_names), device=self.device)
                        for k in ["tps", "fns", "tns", "fps"]
                    },
                }
            }

    def on_train_epoch_end(self):
        if self.dataset_name == "CelebA":  # added this under a CelebA if statement
            self.training_step_outputs_for_balanced_acc.clear()

    def on_validation_epoch_start(self):
        """Prepares the validation metrics dict for logging."""
        if self.dataset_name == "CelebA":
            self.validation_step_outputs_for_balanced_acc = {
                "separate_metrics": {
                    "encoder": {
                        k: torch.zeros(len(self.dataset_attr_names), device=self.device)
                        for k in ["tps", "fns", "tns", "fps"]
                    },
                    "proj": {
                        k: torch.zeros(len(self.dataset_attr_names), device=self.device)
                        for k in ["tps", "fns", "tns", "fps"]
                    },
                }
            }
        # Reset radius histogram buffers for this validation epoch
        if self.radius_hist_enabled:
            self._val_encoder_radii_buffer = []
            self._val_projector_radii_buffer = []
            self._val_current_encoder_D = 0
            self._val_current_projector_D = 0

    def on_validation_epoch_end(self):
        """Averages the losses and accuracies of all the validation batches.
        This is needed because the last batch can be smaller than the others,
        slightly skewing the metrics.
        """        

        val_loss = weighted_mean(self.validation_step_outputs, "val_loss", "batch_size")
        
        if self.dataset_name == "CelebA":
            new_validation_step_outputs = {key: weighted_mean(self.validation_step_outputs, key, "batch_size") for key in self.validation_step_outputs[0] if "metrics" in key}

            log = {
                "val_loss": val_loss, 
                **new_validation_step_outputs,
                }
        elif self.dataset_name == "3dshapes":
            new_validation_step_outputs = {key: weighted_mean(self.validation_step_outputs, key, "batch_size") for key in self.validation_step_outputs[0] if "separate_metrics" in key}
            log = {
                "val_loss": val_loss,
                **new_validation_step_outputs,
                }
        else:
            val_acc1 = weighted_mean(self.validation_step_outputs, "val_acc1", "batch_size")
            val_acc5 = weighted_mean(self.validation_step_outputs, "val_acc5", "batch_size")

            log = {"val_loss": val_loss, "val_acc1": val_acc1, "val_acc5": val_acc5}

        if self.validation_step_outputs and "val_proj_loss" in self.validation_step_outputs[0]:
            val_proj_loss = weighted_mean(self.validation_step_outputs, "val_proj_loss", "batch_size")

            if self.dataset_name == "CelebA":
                log.update({
                    "val_proj_loss": val_proj_loss,
                })
            elif self.dataset_name == "3dshapes":
                log.update({
                    "val_proj_loss": val_proj_loss,
                })
            else:
                val_proj_acc1 = weighted_mean(self.validation_step_outputs, "val_proj_acc1", "batch_size")
                val_proj_acc5 = weighted_mean(self.validation_step_outputs, "val_proj_acc5", "batch_size")
                log.update({
                    "val_proj_loss": val_proj_loss,
                    "val_proj_acc1": val_proj_acc1,
                    "val_proj_acc5": val_proj_acc5,
                })

        # Aggregate MLP probe metrics if present
        if self.validation_step_outputs and "val_mlp_acc1" in self.validation_step_outputs[0]:
            val_mlp_acc1 = weighted_mean(self.validation_step_outputs, "val_mlp_acc1", "batch_size")
            val_mlp_acc5 = weighted_mean(self.validation_step_outputs, "val_mlp_acc5", "batch_size")
            log.update({
                "val_mlp_acc1": val_mlp_acc1,
                "val_mlp_acc5": val_mlp_acc5,
            })
        if self.validation_step_outputs and "val_proj_mlp_acc1" in self.validation_step_outputs[0]:
            val_proj_mlp_acc1 = weighted_mean(self.validation_step_outputs, "val_proj_mlp_acc1", "batch_size")
            val_proj_mlp_acc5 = weighted_mean(self.validation_step_outputs, "val_proj_mlp_acc5", "batch_size")
            log.update({
                "val_proj_mlp_acc1": val_proj_mlp_acc1,
                "val_proj_mlp_acc5": val_proj_mlp_acc5,
            })

        if self.knn_eval and not self.trainer.sanity_checking:
            if self.dataset_name == "CelebA":
                raise NotImplementedError
                
            if self.dataset_name == "3dshapes":
                raise NotImplementedError

            if self.knn_encoder is not None and (self.knn_encoder.train_features or self.knn_encoder.test_features): # Check if it has data
                val_knn_encoder_acc1, val_knn_encoder_acc5 = self.knn_encoder.compute()
                log.update({
                    "val_knn_encoder_acc1": float(val_knn_encoder_acc1),
                    "val_knn_encoder_acc5": float(val_knn_encoder_acc5),
                })

            if self.knn_projector is not None and (self.knn_projector.train_features or self.knn_projector.test_features): # Check if it has data
                # This implies that projector features ('z') were present and collected
                val_knn_projector_acc1, val_knn_projector_acc5 = self.knn_projector.compute()
                log.update({
                    "val_knn_projector_acc1": float(val_knn_projector_acc1),
                    "val_knn_projector_acc5": float(val_knn_projector_acc5),
                })

        self.log_dict(log, sync_dist=True)

        self.validation_step_outputs.clear()

        if self.dataset_name == "CelebA":
            self.validation_step_outputs_for_balanced_acc.clear()

        # Save radius histogram snapshots if requested
        self._radius_hist_on_validation_epoch_end_common()

    # ------------------------------
    # Radius histogram helpers
    # ------------------------------
    def _radius_hist_on_validation_epoch_end_common(self):
        if not self.radius_hist_enabled:
            return
        try:
            if getattr(self, "_hist_snapshot_epochs", None) and self.current_epoch in self._hist_snapshot_epochs:
                # Encoder snapshots
                if self._val_encoder_radii_buffer:
                    try:
                        radii_epoch_enc = torch.cat(self._val_encoder_radii_buffer, dim=0).numpy()
                        D_epoch_enc = self._val_current_encoder_D if self._val_current_encoder_D > 0 else self.features_dim
                        self._hist_encoder_radii_snapshots.append((int(self.current_epoch), radii_epoch_enc, int(D_epoch_enc)))
                    except Exception:
                        pass
                # Projector snapshots
                if self._val_projector_radii_buffer:
                    try:
                        radii_epoch_proj = torch.cat(self._val_projector_radii_buffer, dim=0).numpy()
                        # If projector not used, fallback to features_dim to avoid errors
                        D_epoch_proj = self._val_current_projector_D if self._val_current_projector_D > 0 else self.features_dim
                        self._hist_projector_radii_snapshots.append((int(self.current_epoch), radii_epoch_proj, int(D_epoch_proj)))
                    except Exception:
                        pass
        except Exception:
            pass

    def _radius_hist_collect(self, gathered_feats: torch.Tensor = None, gathered_z: torch.Tensor = None):
        """Collect encoder/projector radii during validation for selected epochs.
        Call this from method-specific validation_step overrides after you compute gathered tensors.
        """
        if not self.radius_hist_enabled or self.trainer.sanity_checking:
            return
        try:
            if self.current_epoch in getattr(self, "_hist_snapshot_epochs", set()):
                if gathered_feats is not None:
                    try:
                        r_enc = torch.norm(gathered_feats, dim=1)
                        self._val_encoder_radii_buffer.append(r_enc.detach().cpu())
                        if self._val_current_encoder_D == 0:
                            self._val_current_encoder_D = gathered_feats.size(1)
                    except Exception:
                        pass
                if gathered_z is not None:
                    try:
                        r_proj = torch.norm(gathered_z, dim=1)
                        self._val_projector_radii_buffer.append(r_proj.detach().cpu())
                        if self._val_current_projector_D == 0:
                            self._val_current_projector_D = gathered_z.size(1)
                    except Exception:
                        pass
        except Exception:
            pass

    @staticmethod
    def _chi_pdf_np(r: np.ndarray, d: int) -> np.ndarray:
        """Robust Chi(df=d) PDF via SciPy if available, otherwise zeros."""
        r_pos = np.maximum(r, 0.0)
        try:
            from scipy.stats import chi as _chi
            pdf = _chi.pdf(r_pos, df=d)
            pdf = np.where(np.isfinite(pdf), pdf, 0.0)
            return pdf
        except Exception:
            return np.zeros_like(r_pos)

    def _plot_radius_hist_grid_common(self, snaps: List[tuple], hist_color: str, chi_color: str, title_prefix: str):
        if not snaps:
            return None
        # Use non-interactive backend in case of headless envs
        try:
            import matplotlib
            matplotlib.use("Agg", force=True)
        except Exception:
            pass
        import matplotlib.pyplot as plt

        n = len(snaps)
        cols = min(10, n)
        rows = math.ceil(n / cols)
        fig, axs = plt.subplots(rows, cols, figsize=(cols * 4, rows * 3))
        if rows == 1 and cols == 1:
            axs = np.array([[axs]])
        elif rows == 1:
            axs = np.array([axs])
        axs_flat = axs.flatten()

        for i, (epoch_i, radii_i, D_i) in enumerate(snaps):
            ax = axs_flat[i]
            try:
                p97 = float(np.percentile(radii_i, 97.5)) if radii_i.size > 0 else float(D_i)
            except Exception:
                p97 = float(D_i)
            radius_max = min(float(D_i), p97 + 20.0) if np.isfinite(p97) else float(D_i)
            radius_max = max(radius_max, 1e-6)
            ax.hist(radii_i, bins=50, density=True, alpha=0.6, range=(0.0, radius_max), color=hist_color, label="Empirical")

            try:
                r = np.linspace(0.0, radius_max, 800)
                pdf = self._chi_pdf_np(r, D_i)
                # Compute W1 distance for legend (best-effort)
                try:
                    import torch as _torch
                    r_tensor = _torch.from_numpy(radii_i.reshape(-1)).to(dtype=_torch.float32)
                    w1_dist = w1_distance_to_chi(r_tensor, int(D_i)).item()
                    legend_label = f"Chi(df={D_i})\nW1={w1_dist:.2f}"
                except Exception:
                    legend_label = f"Chi(df={D_i})"
                if np.any(np.isfinite(pdf)):
                    ax.plot(r, pdf, '-', color=chi_color, lw=2, label=legend_label)
            except Exception:
                pass
            ax.set_title(f"{title_prefix} Epoch {epoch_i}")
            ax.set_xlabel("Radius")
            ax.set_ylabel("Density")
            ax.grid(True)
            ax.legend(fontsize=8)

        for j in range(n, rows * cols):
            axs_flat[j].axis('off')

        plt.tight_layout()
        return fig

    def on_fit_end(self):
        # Only plot if enabled
        if not self.radius_hist_enabled:
            return
        encoder_hist_color = "#4B8384"   # Green-ish for encoder
        projector_hist_color = "#A1656D" # Red-ish for projector
        chi_line_color = "#364377"       # Dark blue for Chi PDF

        fig_enc = self._plot_radius_hist_grid_common(self._hist_encoder_radii_snapshots, encoder_hist_color, chi_line_color, "Encoder")
        fig_proj = self._plot_radius_hist_grid_common(self._hist_projector_radii_snapshots, projector_hist_color, chi_line_color, "Projector")
        if fig_enc is None and fig_proj is None:
            return
        try:
            # Prefer logging as an image if wandb is active
            if hasattr(self.logger, "experiment") and self.logger.experiment is not None:
                try:
                    import wandb
                    if fig_enc is not None:
                        self.logger.experiment.log({"encoder_radius_histograms": wandb.Image(fig_enc)})
                    if fig_proj is not None:
                        self.logger.experiment.log({"projector_radius_histograms": wandb.Image(fig_proj)})

                    # Additionally log the first and last snapshot radii arrays and metadata (epoch, dim, W1)
                    # Encoder snapshots
                    if self._hist_encoder_radii_snapshots:
                        try:
                            # First snapshot
                            first_epoch_enc, first_radii_enc, first_D_enc = self._hist_encoder_radii_snapshots[0]
                            try:
                                import torch as _torch
                                w1_enc_first = w1_distance_to_chi(_torch.from_numpy(first_radii_enc).to(dtype=_torch.float32), int(first_D_enc)).item()
                            except Exception:
                                w1_enc_first = None
                            self.logger.experiment.summary["initial_encoder_radii_epoch"] = int(first_epoch_enc)
                            self.logger.experiment.summary["initial_encoder_radii_dim"] = int(first_D_enc)
                            if w1_enc_first is not None:
                                self.logger.experiment.summary["initial_encoder_radii_w1_dist"] = float(w1_enc_first)
                            # Last snapshot
                            last_epoch_enc, last_radii_enc, last_D_enc = self._hist_encoder_radii_snapshots[-1]
                            try:
                                import torch as _torch
                                w1_enc = w1_distance_to_chi(_torch.from_numpy(last_radii_enc).to(dtype=_torch.float32), int(last_D_enc)).item()
                            except Exception:
                                w1_enc = None
                            self.logger.experiment.summary["final_encoder_radii_epoch"] = int(last_epoch_enc)
                            self.logger.experiment.summary["final_encoder_radii_dim"] = int(last_D_enc)
                            if w1_enc is not None:
                                self.logger.experiment.summary["final_encoder_radii_w1_dist"] = float(w1_enc)

                            # Single table with both initial and final radii
                            try:
                                max_len_enc = max(len(first_radii_enc), len(last_radii_enc))
                                enc_rows = []
                                for i in range(max_len_enc):
                                    r_init = float(first_radii_enc[i]) if i < len(first_radii_enc) else None
                                    r_last = float(last_radii_enc[i]) if i < len(last_radii_enc) else None
                                    enc_rows.append([r_init, r_last])
                                enc_table_both = wandb.Table(data=enc_rows, columns=["radius_initial", "radius_final"])
                                self.logger.experiment.log({"encoder_radii_values": enc_table_both})
                            except Exception:
                                pass
                        except Exception:
                            pass

                    # Projector snapshots
                    if self._hist_projector_radii_snapshots:
                        try:
                            # First snapshot
                            first_epoch_proj, first_radii_proj, first_D_proj = self._hist_projector_radii_snapshots[0]
                            try:
                                import torch as _torch
                                w1_proj_first = w1_distance_to_chi(_torch.from_numpy(first_radii_proj).to(dtype=_torch.float32), int(first_D_proj)).item()
                            except Exception:
                                w1_proj_first = None
                            self.logger.experiment.summary["initial_projector_radii_epoch"] = int(first_epoch_proj)
                            self.logger.experiment.summary["initial_projector_radii_dim"] = int(first_D_proj)
                            if w1_proj_first is not None:
                                self.logger.experiment.summary["initial_projector_radii_w1_dist"] = float(w1_proj_first)
                            # Last snapshot
                            last_epoch_proj, last_radii_proj, last_D_proj = self._hist_projector_radii_snapshots[-1]
                            try:
                                import torch as _torch
                                w1_proj = w1_distance_to_chi(_torch.from_numpy(last_radii_proj).to(dtype=_torch.float32), int(last_D_proj)).item()
                            except Exception:
                                w1_proj = None
                            self.logger.experiment.summary["final_projector_radii_epoch"] = int(last_epoch_proj)
                            self.logger.experiment.summary["final_projector_radii_dim"] = int(last_D_proj)
                            if w1_proj is not None:
                                self.logger.experiment.summary["final_projector_radii_w1_dist"] = float(w1_proj)

                            # Single table with both initial and final radii
                            try:
                                max_len_proj = max(len(first_radii_proj), len(last_radii_proj))
                                proj_rows = []
                                for i in range(max_len_proj):
                                    r_init = float(first_radii_proj[i]) if i < len(first_radii_proj) else None
                                    r_last = float(last_radii_proj[i]) if i < len(last_radii_proj) else None
                                    proj_rows.append([r_init, r_last])
                                proj_table_both = wandb.Table(data=proj_rows, columns=["radius_initial", "radius_final"])
                                self.logger.experiment.log({"projector_radii_values": proj_table_both})
                            except Exception:
                                pass
                        except Exception:
                            pass
                except Exception as e:
                    print(f"Error logging radius histogram to wandb: {e}")
                    # Fallback: save to disk and log path strings
                    if fig_enc is not None:
                        fig_path_enc = f"encoder_radius_histograms_epoch_grid.png"
                        fig_enc.savefig(fig_path_enc, dpi=200)
                        self.logger.experiment.log({"encoder_radius_histograms_path": fig_path_enc})
                    if fig_proj is not None:
                        fig_path_proj = f"projector_radius_histograms_epoch_grid.png"
                        fig_proj.savefig(fig_path_proj, dpi=200)
                        self.logger.experiment.log({"projector_radius_histograms_path": fig_path_proj})
        finally:
            try:
                import matplotlib.pyplot as plt
                if fig_enc is not None:
                    plt.close(fig_enc)
                if fig_proj is not None:
                    plt.close(fig_proj)
            except Exception:
                pass


class BaseMomentumMethod(BaseMethod):
    def __init__(
        self,
        cfg: omegaconf.DictConfig,
    ):
        """Base momentum model that implements all basic operations for all self-supervised methods
        that use a momentum backbone. It adds shared momentum arguments, adds basic learnable
        parameters, implements basic training and validation steps for the momentum backbone and
        classifier. Also implements momentum update using exponential moving average and cosine
        annealing of the weighting decrease coefficient.

        Extra cfg settings:
            momentum:
                base_tau (float): base value of the weighting decrease coefficient in [0,1].
                final_tau (float): final value of the weighting decrease coefficient in [0,1].
                classifier (bool): whether or not to train a classifier on top of the
                    momentum backbone.
        """

        super().__init__(cfg)

        # initialize momentum backbone
        kwargs = self.backbone_args.copy()

        method: str = cfg.method
        self.momentum_backbone: nn.Module = self.base_model(method, **kwargs)
        if self.backbone_name.startswith("resnet"):
            # remove fc layer
            self.momentum_backbone.fc = nn.Identity()
            cifar = cfg.data.dataset in ["cifar10", "cifar100"]
            if cifar:
                self.momentum_backbone.conv1 = nn.Conv2d(
                    3, 64, kernel_size=3, stride=1, padding=2, bias=False
                )
                self.momentum_backbone.maxpool = nn.Identity()

        initialize_momentum_params(self.backbone, self.momentum_backbone)

        # momentum classifier
        if cfg.momentum.classifier:
            self.momentum_classifier: Any = nn.Linear(self.features_dim, self.num_classes)
        else:
            self.momentum_classifier = None

        # momentum updater
        self.momentum_updater = MomentumUpdater(cfg.momentum.base_tau, cfg.momentum.final_tau)

    @property
    def learnable_params(self) -> List[Dict[str, Any]]:
        """Adds momentum classifier parameters to the parameters of the base class.

        Returns:
            List[Dict[str, Any]]:
                list of dicts containing learnable parameters and possible settings.
        """

        momentum_learnable_parameters = []
        if self.momentum_classifier is not None:
            momentum_learnable_parameters.append(
                {
                    "name": "momentum_classifier",
                    "params": self.momentum_classifier.parameters(),
                    "lr": self.classifier_lr,
                    "weight_decay": 0,
                }
            )
        return super().learnable_params + momentum_learnable_parameters

    @property
    def momentum_pairs(self) -> List[Tuple[Any, Any]]:
        """Defines base momentum pairs that will be updated using exponential moving average.

        Returns:
            List[Tuple[Any, Any]]: list of momentum pairs (two element tuples).
        """

        return [(self.backbone, self.momentum_backbone)]

    @staticmethod
    def add_and_assert_specific_cfg(cfg: omegaconf.DictConfig) -> omegaconf.DictConfig:
        """Adds method specific default values/checks for config.

        Args:
            cfg (omegaconf.DictConfig): DictConfig object.

        Returns:
            omegaconf.DictConfig: same as the argument, used to avoid errors.
        """

        cfg = super(BaseMomentumMethod, BaseMomentumMethod).add_and_assert_specific_cfg(cfg)

        cfg.momentum.base_tau = omegaconf_select(cfg, "momentum.base_tau", 0.99)
        cfg.momentum.final_tau = omegaconf_select(cfg, "momentum.final_tau", 1.0)
        cfg.momentum.classifier = omegaconf_select(cfg, "momentum.classifier", False)

        return cfg

    def on_train_start(self):
        """Resets the step counter at the beginning of training."""
        self.last_step = 0

    @torch.no_grad()
    def momentum_forward(self, X: torch.Tensor) -> Dict[str, Any]:
        """Momentum forward method. Children methods should call this function,
        modify the ouputs (without deleting anything) and return it.

        Args:
            X (torch.Tensor): batch of images in tensor format.

        Returns:
            Dict: dict of logits and features.
        """

        if not self.no_channel_last:
            X = X.to(memory_format=torch.channels_last)
        feats = self.momentum_backbone(X)
        return {"feats": feats}

    def _shared_step_momentum(self, X: torch.Tensor, targets: torch.Tensor) -> Dict[str, Any]:
        """Forwards a batch of images X in the momentum backbone and optionally computes the
        classification loss, the logits, the features, acc@1 and acc@5 for of momentum classifier.

        Args:
            X (torch.Tensor): batch of images in tensor format.
            targets (torch.Tensor): batch of labels for X.

        Returns:
            Dict[str, Any]:
                a dict containing the classification loss, logits, features, acc@1 and
                acc@5 of the momentum backbone / classifier.
        """

        out = self.momentum_forward(X)

        if self.momentum_classifier is not None:
            feats = out["feats"]
            logits = self.momentum_classifier(feats)

            loss = F.cross_entropy(logits, targets, ignore_index=-1)
            acc1, acc5 = accuracy_at_k(logits, targets, top_k=(1, 5))
            out.update({"logits": logits, "loss": loss, "acc1": acc1, "acc5": acc5})

        return out

    def training_step(self, batch: List[Any], batch_idx: int) -> Dict[str, Any]:
        """Training step for pytorch lightning. It performs all the shared operations for the
        momentum backbone and classifier, such as forwarding the crops in the momentum backbone
        and classifier, and computing statistics.
        Args:
            batch (List[Any]): a batch of data in the format of [img_indexes, [X], Y], where
                [X] is a list of size self.num_crops containing batches of images.
            batch_idx (int): index of the batch.

        Returns:
            Dict[str, Any]: a dict with the features of the momentum backbone and the classification
                loss and logits of the momentum classifier.
        """

        outs = super().training_step(batch, batch_idx)

        _, X, targets = batch
        X = [X] if isinstance(X, torch.Tensor) else X

        # remove small crops
        X = X[: self.num_large_crops]

        momentum_outs = [self._shared_step_momentum(x, targets) for x in X]
        momentum_outs = {
            "momentum_" + k: [out[k] for out in momentum_outs] for k in momentum_outs[0].keys()
        }

        if self.momentum_classifier is not None:
            # momentum loss and stats
            momentum_outs["momentum_loss"] = (
                sum(momentum_outs["momentum_loss"]) / self.num_large_crops
            )
            momentum_outs["momentum_acc1"] = (
                sum(momentum_outs["momentum_acc1"]) / self.num_large_crops
            )
            momentum_outs["momentum_acc5"] = (
                sum(momentum_outs["momentum_acc5"]) / self.num_large_crops
            )

            metrics = {
                "train_momentum_class_loss": momentum_outs["momentum_loss"],
                "train_momentum_acc1": momentum_outs["momentum_acc1"],
                "train_momentum_acc5": momentum_outs["momentum_acc5"],
            }
            self.log_dict(metrics, on_epoch=True, sync_dist=True)

            # adds the momentum classifier loss together with the general loss
            outs["loss"] += momentum_outs["momentum_loss"]

        outs.update(momentum_outs)
        return outs

    def on_train_batch_end(self, outputs: Dict[str, Any], batch: Sequence[Any], batch_idx: int):
        """Performs the momentum update of momentum pairs using exponential moving average at the
        end of the current training step if an optimizer step was performed.

        Args:
            outputs (Dict[str, Any]): the outputs of the training step.
            batch (Sequence[Any]): a batch of data in the format of [img_indexes, [X], Y], where
                [X] is a list of size self.num_crops containing batches of images.
            batch_idx (int): index of the batch.
        """

        if self.trainer.global_step > self.last_step:
            # update momentum backbone and projector
            momentum_pairs = self.momentum_pairs
            for mp in momentum_pairs:
                self.momentum_updater.update(*mp)
            # log tau momentum
            self.log("tau", self.momentum_updater.cur_tau)
            # update tau
            self.momentum_updater.update_tau(
                cur_step=self.trainer.global_step,
                max_steps=self.trainer.estimated_stepping_batches,
            )
        self.last_step = self.trainer.global_step

    def validation_step(
        self,
        batch: List[torch.Tensor],
        batch_idx: int,
        dataloader_idx: int = None,
        update_validation_step_outputs: bool = True,
    ) -> Tuple[Dict[str, Any], Dict[str, Any]]:
        """Validation step for pytorch lightning. It performs all the shared operations for the
        momentum backbone and classifier, such as forwarding a batch of images in the momentum
        backbone and classifier and computing statistics.

        Args:
            batch (List[torch.Tensor]): a batch of data in the format of [X, Y].
            batch_idx (int): index of the batch.
            update_validation_step_outputs (bool): whether or not to append the
                metrics to validation_step_outputs

        Returns:
            Tuple(Dict[str, Any], Dict[str, Any]): tuple of dicts containing the batch_size (used
                for averaging), the classification loss and accuracies for both the online and the
                momentum classifiers.
        """

        metrics = super().validation_step(batch, batch_idx, update_validation_step_outputs=False)

        X, targets = batch

        out = self._shared_step_momentum(X, targets)

        if self.momentum_classifier is not None:
            metrics.update(
                {
                    "momentum_val_loss": out["loss"],
                    "momentum_val_acc1": out["acc1"],
                    "momentum_val_acc5": out["acc5"],
                }
            )

        if update_validation_step_outputs:
            self.validation_step_outputs.append(metrics)

        return metrics

    def on_validation_epoch_end(self):
        """Averages the losses and accuracies of the momentum backbone / classifier for all the
        validation batches. This is needed because the last batch can be smaller than the others,
        slightly skewing the metrics.
        """

        # base method metrics
        val_loss = weighted_mean(self.validation_step_outputs, "val_loss", "batch_size")
        val_acc1 = weighted_mean(self.validation_step_outputs, "val_acc1", "batch_size")
        val_acc5 = weighted_mean(self.validation_step_outputs, "val_acc5", "batch_size")

        log = {"val_loss": val_loss, "val_acc1": val_acc1, "val_acc5": val_acc5}
        
        # Check if projector metrics exist and log them
        if self.validation_step_outputs and "val_proj_loss" in self.validation_step_outputs[0]:
            val_proj_loss = weighted_mean(self.validation_step_outputs, "val_proj_loss", "batch_size")
            val_proj_acc1 = weighted_mean(self.validation_step_outputs, "val_proj_acc1", "batch_size")
            val_proj_acc5 = weighted_mean(self.validation_step_outputs, "val_proj_acc5", "batch_size")
            log.update({
                "val_proj_loss": val_proj_loss,
                "val_proj_acc1": val_proj_acc1,
                "val_proj_acc5": val_proj_acc5,
            })

        if self.knn_eval and not self.trainer.sanity_checking:
            if self.knn_encoder is not None and (self.knn_encoder.train_features or self.knn_encoder.test_features): # Check if it has data
                val_knn_encoder_acc1, val_knn_encoder_acc5 = self.knn_encoder.compute()
                log.update({
                    "val_knn_encoder_acc1": float(val_knn_encoder_acc1),
                    "val_knn_encoder_acc5": float(val_knn_encoder_acc5),
                })

            if self.knn_projector is not None and (self.knn_projector.train_features or self.knn_projector.test_features): # Check if it has data
                # This implies that projector features ('z') were present and collected
                val_knn_projector_acc1, val_knn_projector_acc5 = self.knn_projector.compute()
                log.update({
                    "val_knn_projector_acc1": float(val_knn_projector_acc1),
                    "val_knn_projector_acc5": float(val_knn_projector_acc5),
                })

        self.log_dict(log, sync_dist=True)

        # momentum method metrics
        if self.momentum_classifier is not None:
            val_loss = weighted_mean(
                self.validation_step_outputs, "momentum_val_loss", "batch_size"
            )
            val_acc1 = weighted_mean(
                self.validation_step_outputs, "momentum_val_acc1", "batch_size"
            )
            val_acc5 = weighted_mean(
                self.validation_step_outputs, "momentum_val_acc5", "batch_size"
            )

            log = {
                "momentum_val_loss": val_loss,
                "momentum_val_acc1": val_acc1,
                "momentum_val_acc5": val_acc5,
            }
            self.log_dict(log, sync_dist=True)

        self.validation_step_outputs.clear()
