# Copyright 2023 solo-learn development team.

# Permission is hereby granted, free of charge, to any person obtaining a copy of
# this software and associated documentation files (the "Software"), to deal in
# the Software without restriction, including without limitation the rights to use,
# copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the
# Software, and to permit persons to whom the Software is furnished to do so,
# subject to the following conditions:

# The above copyright notice and this permission notice shall be included in all copies
# or substantial portions of the Software.

# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED,
# INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR
# PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE
# FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR
# OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
# DEALINGS IN THE SOFTWARE.

import logging
from functools import partial
from typing import Any, Callable, Dict, List, Sequence, Tuple, Union

import lightning.pytorch as pl
import omegaconf
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim.lr_scheduler import MultiStepLR
try:
    import wandb  # optional
except Exception:
    wandb = None

from solo.backbones import (
    convnext_base,
    convnext_large,
    convnext_small,
    convnext_tiny,
    poolformer_m36,
    poolformer_m48,
    poolformer_s12,
    poolformer_s24,
    poolformer_s36,
    resnet18,
    resnet50,
    swin_base,
    swin_large,
    swin_small,
    swin_tiny,
    vit_base,
    vit_large,
    vit_small,
    vit_small_linear,
    vit_tiny,
    wide_resnet28w2,
    wide_resnet28w8,
)
from solo.utils.knn import WeightedKNNClassifier
from solo.utils.lars import LARS
from solo.utils.lr_scheduler import LinearWarmupCosineAnnealingLR
from solo.utils.metrics import (
    accuracy_at_k,
    weighted_mean,
    batch_sparsity_metric,
    embedding_sparsity_metric,
    active_feature_fraction,
    count_avg_nonzero_elements_per_dimension,
    count_avg_nonzero_elements_per_sample,
)
from solo.utils.misc import omegaconf_select, remove_bias_and_norm_from_weight_decay
from solo.utils.momentum import MomentumUpdater, initialize_momentum_params


def static_lr(
    get_lr: Callable,
    param_group_indexes: Sequence[int],
    lrs_to_replace: Sequence[float],
):
    lrs = get_lr()
    for idx, lr in zip(param_group_indexes, lrs_to_replace):
        lrs[idx] = lr
    return lrs


class BaseMethod(pl.LightningModule):
    _BACKBONES = {
        "resnet18": resnet18,
        "resnet50": resnet50,
        "vit_tiny": vit_tiny,
        "vit_small": vit_small,
        "vit_small_linear": vit_small_linear,
        "vit_base": vit_base,
        "vit_large": vit_large,
        "swin_tiny": swin_tiny,
        "swin_small": swin_small,
        "swin_base": swin_base,
        "swin_large": swin_large,
        "poolformer_s12": poolformer_s12,
        "poolformer_s24": poolformer_s24,
        "poolformer_s36": poolformer_s36,
        "poolformer_m36": poolformer_m36,
        "poolformer_m48": poolformer_m48,
        "convnext_tiny": convnext_tiny,
        "convnext_small": convnext_small,
        "convnext_base": convnext_base,
        "convnext_large": convnext_large,
        "wide_resnet28w2": wide_resnet28w2,
        "wide_resnet28w8": wide_resnet28w8,
    }
    _OPTIMIZERS = {
        "sgd": torch.optim.SGD,
        "lars": LARS,
        "adam": torch.optim.Adam,
        "adamw": torch.optim.AdamW,
    }
    _SCHEDULERS = [
        "reduce",
        "warmup_cosine",
        "step",
        "exponential",
        "none",
    ]

    def __init__(self, cfg: omegaconf.DictConfig):
        """Base model that implements all basic operations for all self-supervised methods.
        It adds shared arguments, extract basic learnable parameters, creates optimizers
        and schedulers, implements basic training_step for any number of crops,
        trains the online classifier and implements validation_step.

        .. note:: Cfg defaults are set in init by calling `cfg = add_and_assert_specific_cfg(cfg)`

        Cfg basic structure:
            backbone:
                name (str): architecture of the base backbone.
                kwargs (dict): extra backbone kwargs.
            data:
                dataset (str): name of the dataset.
                num_classes (int): number of classes.
            max_epochs (int): number of training epochs.

            backbone_params (dict): dict containing extra backbone args, namely:
                #! only for resnet
                zero_init_residual (bool): change the initialization of the resnet backbone.
                #! only for vit
                patch_size (int): size of the patches for ViT.
            optimizer:
                name (str): name of the optimizer.
                batch_size (int): number of samples in the batch.
                lr (float): learning rate.
                weight_decay (float): weight decay for optimizer.
                classifier_lr (float): learning rate for the online linear classifier.
                kwargs (Dict): extra named arguments for the optimizer.
            scheduler:
                name (str): name of the scheduler.
                min_lr (float): minimum learning rate for warmup scheduler. Defaults to 0.0.
                warmup_start_lr (float): initial learning rate for warmup scheduler.
                    Defaults to 0.00003.
                warmup_epochs (float): number of warmup epochs. Defaults to 10.
                lr_decay_steps (Sequence, optional): steps to decay the learning rate if
                    scheduler is step. Defaults to None.
                interval (str): interval to update the lr scheduler. Defaults to 'step'.
            knn_eval:
                enabled (bool): enables online knn evaluation while training.
                k (int): the number of neighbors to use for knn.
            performance:
                disable_channel_last (bool). Disables channel last conversion operation which
                speeds up training considerably. Defaults to False.
                https://pytorch.org/tutorials/intermediate/memory_format_tutorial.html#converting-existing-models
            accumulate_grad_batches (Union[int, None]): number of batches for gradient accumulation.
            num_large_crops (int): number of big crops.
            num_small_crops (int): number of small crops .

        .. note::
            When using distributed data parallel, the batch size and the number of workers are
            specified on a per process basis. Therefore, the total batch size (number of workers)
            is calculated as the product of the number of GPUs with the batch size (number of
            workers).

        .. note::
            The learning rate (base, min and warmup) is automatically scaled linearly
            if using gradient accumulation.

        .. note::
            For CIFAR10/100, the first convolutional and maxpooling layers of the ResNet backbone
            are slightly adjusted to handle lower resolution images (32x32 instead of 224x224).

        """

        super().__init__()

        # add default values and assert that config has the basic needed settings
        cfg = self.add_and_assert_specific_cfg(cfg)

        self.cfg: omegaconf.DictConfig = cfg

        ##############################
        # Backbone
        self.backbone_args: Dict[str, Any] = cfg.backbone.kwargs
        assert cfg.backbone.name in BaseMethod._BACKBONES
        self.base_model: Callable = self._BACKBONES[cfg.backbone.name]
        self.backbone_name: str = cfg.backbone.name
        # initialize backbone
        kwargs = self.backbone_args.copy()

        method: str = cfg.method
        self.backbone: nn.Module = self.base_model(method, **kwargs)
        if self.backbone_name.startswith("resnet"):
            self.features_dim: int = self.backbone.inplanes
            # remove fc layer
            self.backbone.fc = nn.Identity()
            cifar = cfg.data.dataset in ["cifar10", "cifar100"]
            if cifar:
                self.backbone.conv1 = nn.Conv2d(
                    3, 64, kernel_size=3, stride=1, padding=2, bias=False
                )
                self.backbone.maxpool = nn.Identity()
        else:
            self.features_dim: int = self.backbone.num_features
        ##############################

        # online linear classifier
        self.num_classes: int = cfg.data.num_classes
        self.classifier: nn.Module = nn.Linear(self.features_dim, self.num_classes)

        # Expressive MLP Probes (encoder and projector) - November 23, 2025. Copied from solo learn guassinization
        self.mlp_probe_enabled: bool = omegaconf_select(cfg, "mlp_probe.enabled", False) 
        self.mlp_probe_layers: int = int(omegaconf_select(cfg, "mlp_probe.num_layers", 3)) 
        self.mlp_probe_encoder = None
        self.mlp_probe_projector = None
        if self.mlp_probe_enabled:
            self.mlp_probe_encoder = self._build_mlp_classifier(
                self.features_dim, self.num_classes, self.mlp_probe_layers
            )

            # If projector output dim is known in cfg, prebuild projector MLP probe, else lazy-init later
            proj_out_dim_cfg = omegaconf_select(cfg, "method_kwargs.proj_output_dim", None)
            if cfg.method == "dino": # so dino works properly as well
                proj_out_dim_cfg = omegaconf_select(cfg, "method_kwargs.num_prototypes", proj_out_dim_cfg)

            if proj_out_dim_cfg is not None:
                self.mlp_probe_projector = self._build_mlp_classifier(int(proj_out_dim_cfg), self.num_classes, self.mlp_probe_layers)

        # Optional Projector Classifier
        self.add_projector_classifier = omegaconf_select(cfg, "method_kwargs.add_projector_classifier", False)
        self.projector_classifier = None
        if self.add_projector_classifier:
            proj_output_dim = omegaconf_select(cfg, "method_kwargs.proj_output_dim", None)
            assert (
                proj_output_dim is not None
            ), "method_kwargs.proj_output_dim must be set when add_projector_classifier is True."
    
            self.projector_classifier = nn.Linear(proj_output_dim, self.num_classes)

        # l1_penalty for the encoder and projector classifier
        self.probe_l1_penalty_weight: float = omegaconf_select(cfg, "method_kwargs.probe_l1_penalty_weight", 0.0)
        
        # training related
        self.max_epochs: int = cfg.max_epochs
        self.accumulate_grad_batches: Union[int, None] = cfg.accumulate_grad_batches

        # optimizer related
        self.optimizer: str = cfg.optimizer.name
        self.batch_size: int = cfg.optimizer.batch_size
        self.lr: float = cfg.optimizer.lr
        self.weight_decay: float = cfg.optimizer.weight_decay
        self.classifier_lr: float = cfg.optimizer.classifier_lr
        self.extra_optimizer_args: Dict[str, Any] = cfg.optimizer.kwargs
        self.exclude_bias_n_norm_wd: bool = cfg.optimizer.exclude_bias_n_norm_wd

        # scheduler related
        self.scheduler: str = cfg.scheduler.name
        self.lr_decay_steps: Union[List[int], None] = cfg.scheduler.lr_decay_steps
        self.min_lr: float = cfg.scheduler.min_lr
        self.warmup_start_lr: float = cfg.scheduler.warmup_start_lr
        self.warmup_epochs: int = cfg.scheduler.warmup_epochs
        self.scheduler_interval: str = cfg.scheduler.interval
        assert self.scheduler_interval in ["step", "epoch"]
        if self.scheduler_interval == "step":
            logging.warning(
                f"Using scheduler_interval={self.scheduler_interval} might generate "
                "issues when resuming a checkpoint."
            )

        # if accumulating gradient then scale lr
        if self.accumulate_grad_batches:
            self.lr = self.lr * self.accumulate_grad_batches
            self.classifier_lr = self.classifier_lr * self.accumulate_grad_batches
            self.min_lr = self.min_lr * self.accumulate_grad_batches
            self.warmup_start_lr = self.warmup_start_lr * self.accumulate_grad_batches

        # data-related
        self.num_large_crops: int = cfg.data.num_large_crops
        self.num_small_crops: int = cfg.data.num_small_crops
        self.num_crops: int = self.num_large_crops + self.num_small_crops
        # turn on multicrop if there are small crops
        self.multicrop: bool = self.num_small_crops != 0

        # knn online evaluation
        self.knn_eval: bool = cfg.knn_eval.enabled
        self.knn_k: int = cfg.knn_eval.k
        if self.knn_eval:
            self.knn = WeightedKNNClassifier(k=self.knn_k, distance_fx=cfg.knn_eval.distance_func)

        # for performance
        self.no_channel_last = cfg.performance.disable_channel_last

        # keep track of validation metrics
        self.validation_step_outputs = []

        # logging interval
        self.logging_interval = omegaconf_select(cfg, "method_kwargs.logging_interval", 1) # it's opt in -- 1 by deafult. 
        self.active_feature_threshold = omegaconf_select(cfg, "method_kwargs.active_feature_threshold", 1e-3)

    @staticmethod
    def add_and_assert_specific_cfg(cfg: omegaconf.DictConfig) -> omegaconf.DictConfig:
        """Adds method specific default values/checks for config.

        Args:
            cfg (omegaconf.DictConfig): DictConfig object.

        Returns:
            omegaconf.DictConfig: same as the argument, used to avoid errors.
        """

        # default for extra backbone kwargs (use pytorch's default if not available)
        cfg.backbone.kwargs = omegaconf_select(cfg, "backbone.kwargs", {})

        # default parameters for optimizer
        cfg.optimizer.exclude_bias_n_norm_wd = omegaconf_select(
            cfg, "optimizer.exclude_bias_n_norm_wd", False
        )
        # default for extra optimizer kwargs (use pytorch's default if not available)
        cfg.optimizer.kwargs = omegaconf_select(cfg, "optimizer.kwargs", {})

        # default for acc grad batches
        cfg.accumulate_grad_batches = omegaconf_select(cfg, "accumulate_grad_batches", 1)

        # default parameters for the scheduler
        cfg.scheduler.lr_decay_steps = omegaconf_select(cfg, "scheduler.lr_decay_steps", None)
        cfg.scheduler.min_lr = omegaconf_select(cfg, "scheduler.min_lr", 0.0)
        cfg.scheduler.warmup_start_lr = omegaconf_select(cfg, "scheduler.warmup_start_lr", 3e-5)
        cfg.scheduler.warmup_epochs = omegaconf_select(cfg, "scheduler.warmup_epochs", 10)
        cfg.scheduler.interval = omegaconf_select(cfg, "scheduler.interval", "step")

        # default parameters for knn eval
        cfg.knn_eval = omegaconf_select(cfg, "knn_eval", {})
        cfg.knn_eval.enabled = omegaconf_select(cfg, "knn_eval.enabled", False)
        cfg.knn_eval.k = omegaconf_select(cfg, "knn_eval.k", 20)
        cfg.knn_eval.distance_func = omegaconf_select(cfg, "knn_eval.distance_func", "euclidean")

        # default parameters for performance optimization
        cfg.performance = omegaconf_select(cfg, "performance", {})
        cfg.performance.disable_channel_last = omegaconf_select(
            cfg, "performance.disable_channel_last", False
        )

        # default empty parameters for method-specific kwargs
        cfg.method_kwargs = omegaconf_select(cfg, "method_kwargs", {})

        cfg.method_kwargs.add_projector_classifier = omegaconf_select(
            cfg, "method_kwargs.add_projector_classifier", False
        )
        # method_kwargs.proj_output_dim is expected to be set by the method if add_projector_classifier is true.
        
        # default parameters for MLP probe
        cfg.mlp_probe = omegaconf_select(cfg, "mlp_probe", {})
        cfg.mlp_probe.enabled = omegaconf_select(cfg, "mlp_probe.enabled", False)
        cfg.mlp_probe.num_layers = omegaconf_select(cfg, "mlp_probe.num_layers", 3)

        # logging interval
        cfg.method_kwargs.logging_interval = omegaconf_select(cfg, "method_kwargs.logging_interval", 1)

        cfg.method_kwargs.active_feature_threshold = omegaconf_select(
            cfg,
            "method_kwargs.active_feature_threshold",
            1e-3,
        )

        return cfg

    @property
    def learnable_params(self) -> List[Dict[str, Any]]:
        """Defines learnable parameters for the base class.

        Returns:
            List[Dict[str, Any]]:
                list of dicts containing learnable parameters and possible settings.
        """

        params = [
            {"name": "backbone", "params": self.backbone.parameters()},
            {
                "name": "classifier",
                "params": self.classifier.parameters(),
                "lr": self.classifier_lr,
                "weight_decay": 0,
            },
        ]
        if self.mlp_probe_enabled and self.mlp_probe_encoder is not None:
            params.append({
                "name": "mlp_probe_encoder",
                "params": self.mlp_probe_encoder.parameters(),
                "lr": self.classifier_lr,
                "weight_decay": 0,
            })
         # @TODO: if and when we do continued-pretraining, cpt, we should uptdate/copy this function 
        
        if self.add_projector_classifier and self.projector_classifier is not None:
            params.append(
                {
                    "name": "projector_classifier",
                    "params": self.projector_classifier.parameters(),
                    "lr": self.classifier_lr,
                    "weight_decay": 0,
                }
            )
        
        if self.mlp_probe_enabled and self.mlp_probe_projector is not None:
            params.append({
                "name": "mlp_probe_projector",
                "params": self.mlp_probe_projector.parameters(),
                "lr": self.classifier_lr,
                "weight_decay": 0,
            })

        return params

    def configure_optimizers(self) -> Tuple[List, List]:
        """Collects learnable parameters and configures the optimizer and learning rate scheduler.

        Returns:
            Tuple[List, List]: two lists containing the optimizer and the scheduler.
        """

        learnable_params = self.learnable_params

        # exclude bias and norm from weight decay
        if self.exclude_bias_n_norm_wd:
            learnable_params = remove_bias_and_norm_from_weight_decay(learnable_params)

        # indexes of parameters without lr scheduler
        idxs_no_scheduler = [i for i, m in enumerate(learnable_params) if m.pop("static_lr", False)]

        assert self.optimizer in self._OPTIMIZERS
        optimizer = self._OPTIMIZERS[self.optimizer]

        # create optimizer
        optimizer = optimizer(
            learnable_params,
            lr=self.lr,
            weight_decay=self.weight_decay,
            **self.extra_optimizer_args,
        )

        if self.scheduler.lower() == "none":
            return optimizer

        if self.scheduler == "warmup_cosine":
            max_warmup_steps = (
                self.warmup_epochs * (self.trainer.estimated_stepping_batches / self.max_epochs)
                if self.scheduler_interval == "step"
                else self.warmup_epochs
            )
            max_scheduler_steps = (
                self.trainer.estimated_stepping_batches
                if self.scheduler_interval == "step"
                else self.max_epochs
            )
            scheduler = {
                "scheduler": LinearWarmupCosineAnnealingLR(
                    optimizer,
                    warmup_epochs=max_warmup_steps,
                    max_epochs=max_scheduler_steps,
                    warmup_start_lr=self.warmup_start_lr if self.warmup_epochs > 0 else self.lr,
                    eta_min=self.min_lr,
                ),
                "interval": self.scheduler_interval,
                "frequency": 1,
            }
        elif self.scheduler == "step":
            scheduler = MultiStepLR(optimizer, self.lr_decay_steps)
        else:
            raise ValueError(f"{self.scheduler} not in (warmup_cosine, cosine, step)")

        if idxs_no_scheduler:
            partial_fn = partial(
                static_lr,
                get_lr=scheduler["scheduler"].get_lr
                if isinstance(scheduler, dict)
                else scheduler.get_lr,
                param_group_indexes=idxs_no_scheduler,
                lrs_to_replace=[self.lr] * len(idxs_no_scheduler),
            )
            if isinstance(scheduler, dict):
                scheduler["scheduler"].get_lr = partial_fn
            else:
                scheduler.get_lr = partial_fn

        return [optimizer], [scheduler]

    def optimizer_zero_grad(self, epoch, batch_idx, optimizer, *_):
        """
        This improves performance marginally. It should be fine
        since we are not affected by any of the downsides descrited in
        https://pytorch.org/docs/stable/generated/torch.optim.Optimizer.zero_grad.html#torch.optim.Optimizer.zero_grad

        Implemented as in here
        https://lightning.ai/docs/pytorch/latest/advanced/speed.html?highlight=set%20grads%20none
        """
        try:
            optimizer.zero_grad(set_to_none=True)
        except:
            optimizer.zero_grad()

    def forward(self, X) -> Dict:
        """Basic forward method. Children methods should call this function,
        modify the ouputs (without deleting anything) and return it.

        Args:
            X (torch.Tensor): batch of images in tensor format.

        Returns:
            Dict: dict of logits and features.
        """

        if not self.no_channel_last:
            X = X.to(memory_format=torch.channels_last)
        feats = self.backbone(X)
        logits = self.classifier(feats.detach())
        return {"logits": logits, "feats": feats}

    def multicrop_forward(self, X: torch.Tensor) -> Dict[str, Any]:
        """Basic multicrop forward method that performs the forward pass
        for the multicrop views. Children classes can override this method to
        add new outputs but should still call this function. Make sure
        that this method and its overrides always return a dict.

        Args:
            X (torch.Tensor): batch of images in tensor format.

        Returns:
            Dict: dict of features.
        """

        if not self.no_channel_last:
            X = X.to(memory_format=torch.channels_last)
        feats = self.backbone(X)
        return {"feats": feats}

    def _projector_classifier_step(self, z: torch.Tensor, targets: torch.Tensor) -> Dict:
        """Helper to compute projector classifier loss and acc.
        Assumes z is the output of a projector for a single view
        IMPORTANT: if x and x augmented than send it twice and average in the SSL method class
        """
        if self.projector_classifier is None:
            return {}

        # Detach z to ensure gradients only flow to the projector_classifier, not further back.
        projector_logits = self.projector_classifier(z.detach())

        # cross entropy loss
        loss = F.cross_entropy(projector_logits, targets, ignore_index=-1)

        # adding l1 penalty to the loss
        l1_penalty = self.projector_classifier.weight.abs().sum()
        loss += self.probe_l1_penalty_weight * l1_penalty

        top_k_max = min(5, projector_logits.size(1))
        if self.cfg.data.dataset != "celeba":
            acc1, acc5 = accuracy_at_k(projector_logits, targets, top_k=(1, top_k_max))
        else:
            acc1 = torch.tensor(0.0, device=self.device)
            acc5 = torch.tensor(0.0, device=self.device)

        return {"proj_loss": loss, "proj_acc1": acc1, "proj_acc5": acc5}

    def _build_mlp_classifier(self, in_dim: int, num_classes: int, num_layers: int) -> nn.Sequential:
        """Builds an MLP classifier with num_layers linear layers (>=1).
        Hidden size: min(int(1.5 * in_dim), 8192). Uses ReLU and BatchNorm between hidden layers.
        """
        num_layers = max(1, int(num_layers))
        hidden_dim = int(min(max(1, round(1.5 * in_dim)), 8192))  # hidden dim capped at 8192 and 1.5 * input; revisit if needed

        layers: List[nn.Module] = []
        if num_layers == 1:
            layers.append(nn.Linear(in_dim, num_classes))
        else:
            # first hidden
            layers.extend([nn.Linear(in_dim, hidden_dim), nn.BatchNorm1d(hidden_dim), nn.ReLU()])
            # middle hidden layers (num_layers - 2)
            for _ in range(num_layers - 2):
                layers.extend([nn.Linear(hidden_dim, hidden_dim), nn.BatchNorm1d(hidden_dim), nn.ReLU()])
            # final layer
            layers.append(nn.Linear(hidden_dim, num_classes))
        return nn.Sequential(*layers)

    def _base_shared_step(self, X: torch.Tensor, targets: torch.Tensor) -> Dict:
        """Forwards a batch of images X and computes the classification loss, the logits, the
        features, acc@1 and acc@5.

        Args:
            X (torch.Tensor): batch of images in tensor format.
            targets (torch.Tensor): batch of labels for X.

        Returns:
            Dict: dict containing the classification loss, logits, features, acc@1 and acc@5.
        """

        out = self(X)
        logits = out["logits"]
        feats = out["feats"]
        
        # encoder classifier loss
        encoder_classifier_loss = F.cross_entropy(logits, targets, ignore_index=-1)

        # l1 penalty for encoder classifier
        l1_penalty = self.classifier.weight.abs().sum()
        encoder_classifier_loss += self.probe_l1_penalty_weight * l1_penalty
        
        # handle when the number of classes is smaller than 5
        top_k_max = min(5, logits.size(1))
        if self.cfg.data.dataset != "celeba":
            encoder_acc1, encoder_acc5 = accuracy_at_k(logits, targets, top_k=(1, top_k_max))
        else:
            encoder_acc1 = torch.tensor(0.0, device=self.device)
            encoder_acc5 = torch.tensor(0.0, device=self.device)

        out.update({"loss": encoder_classifier_loss, "acc1": encoder_acc1, "acc5": encoder_acc5})

        # MLP probe on encoder features (optional)
        if self.mlp_probe_enabled and self.mlp_probe_encoder is not None:
            mlp_logits = self.mlp_probe_encoder(feats.detach())
            mlp_loss = F.cross_entropy(mlp_logits, targets, ignore_index=-1)
            top_k_max_mlp = min(5, mlp_logits.size(1))
            if self.cfg.data.dataset != "celeba":
                mlp_acc1, mlp_acc5 = accuracy_at_k(mlp_logits, targets, top_k=(1, top_k_max_mlp))
            else:
                mlp_acc1 = torch.tensor(0.0, device=self.device)
                mlp_acc5 = torch.tensor(0.0, device=self.device)
            out.update({"mlp_loss": mlp_loss, "mlp_acc1": mlp_acc1, "mlp_acc5": mlp_acc5})

            # train the probe jointly with the online linear classifier
            out["loss"] = out["loss"] + out["mlp_loss"]

        # Metrics for the projector classifier (if provided by overridden forward and enabled)
        if "projector_logits" in out and self.projector_classifier is not None:

            # get projector logits
            projector_logits = out["projector_logits"]

            proj_loss = F.cross_entropy(projector_logits, targets, ignore_index=-1)
            proj_top_k = min(5, projector_logits.size(1))
            if self.cfg.data.dataset != "celeba":
                proj_acc1, proj_acc5 = accuracy_at_k(projector_logits, targets, top_k=(1, proj_top_k))
            else:
                proj_acc1 = torch.tensor(0.0, device=self.device)
                proj_acc5 = torch.tensor(0.0, device=self.device)
            out.update({
                "proj_loss": proj_loss, 
                "proj_acc1": proj_acc1, 
                "proj_acc5": proj_acc5
            })
        # now let's try MLP projector    
        # Optional MLP probe on projector features 'z' (requires child method to supply 'z')
        if self.mlp_probe_enabled and "z" in out and isinstance(out["z"], torch.Tensor):
            z_tensor = out["z"].detach()
            # Lazy initialize projector MLP probe on first use
            if self.mlp_probe_projector is None:
                self.mlp_probe_projector = self._build_mlp_classifier(z_tensor.size(1), self.num_classes, self.mlp_probe_layers)
            
            proj_mlp_logits = self.mlp_probe_projector(z_tensor)
            proj_mlp_loss = F.cross_entropy(proj_mlp_logits, targets, ignore_index=-1)
            top_k_max_p = min(5, proj_mlp_logits.size(1))
            if self.cfg.data.dataset != "celeba":
                proj_mlp_acc1, proj_mlp_acc5 = accuracy_at_k(proj_mlp_logits, targets, top_k=(1, top_k_max_p))
            else:
                proj_mlp_acc1 = torch.tensor(0.0, device=self.device)
                proj_mlp_acc5 = torch.tensor(0.0, device=self.device)
            out.update({
                "proj_mlp_loss": proj_mlp_loss,
                "proj_mlp_acc1": proj_mlp_acc1,
                "proj_mlp_acc5": proj_mlp_acc5,
            })
            out["loss"] = out["loss"] + proj_mlp_loss

        return out

    def _log_sparsity_metrics(self, z1, z2, feats1, feats2):
        with torch.no_grad():
            # logging per-dimension variance of encoder & projector outputs
            feats1_avg_per_dim_var = feats1.var(dim=0).mean()
            feats2_avg_per_dim_var = feats2.var(dim=0).mean()

            z1_avg_per_dim_var = z1.var(dim=0).mean()
            z2_avg_per_dim_var = z2.var(dim=0).mean()

            self.log("encoder_metrics/train_feats1_avg_per_dim_var", feats1_avg_per_dim_var, on_epoch=True, sync_dist=True)
            self.log("encoder_metrics/train_feats2_avg_per_dim_var", feats2_avg_per_dim_var, on_epoch=True, sync_dist=True)

            self.log("projector_metrics/train_z1_avg_per_dim_var", z1_avg_per_dim_var, on_epoch=True, sync_dist=True)
            self.log("projector_metrics/train_z2_avg_per_dim_var", z2_avg_per_dim_var, on_epoch=True, sync_dist=True)

            # Encoder | Logging batch sparsity metrics
            feat1_batch_sparse_max, feat1_batch_sparse_mean, feat1_batch_sparse_min = batch_sparsity_metric(feats1)
            feat2_batch_sparse_max, feat2_batch_sparse_mean, feat2_batch_sparse_min = batch_sparsity_metric(feats2)
            self.log("encoder_metrics/feat1_encoder_batch_sparse_max", feat1_batch_sparse_max, on_epoch=True, sync_dist=True)
            self.log("encoder_metrics/feat1_encoder_batch_sparse_mean", feat1_batch_sparse_mean, on_epoch=True, sync_dist=True)
            self.log("encoder_metrics/feat1_encoder_batch_sparse_min", feat1_batch_sparse_min, on_epoch=True, sync_dist=True)
            self.log("encoder_metrics/feat2_encoder_batch_sparse_max", feat2_batch_sparse_max, on_epoch=True, sync_dist=True)
            self.log("encoder_metrics/feat2_encoder_batch_sparse_mean", feat2_batch_sparse_mean, on_epoch=True, sync_dist=True)
            self.log("encoder_metrics/feat2_encoder_batch_sparse_min", feat2_batch_sparse_min, on_epoch=True, sync_dist=True)

            # Encoder | Logging embedding sparsity metrics
            feat1_embed_sparse_max, feat1_embed_sparse_mean, feat1_embed_sparse_min = embedding_sparsity_metric(feats1)
            feat2_embed_sparse_max, feat2_embed_sparse_mean, feat2_embed_sparse_min = embedding_sparsity_metric(feats2)
            self.log("encoder_metrics/feat1_encoder_embed_sparse_max", feat1_embed_sparse_max, on_epoch=True, sync_dist=True)
            self.log("encoder_metrics/feat1_encoder_embed_sparse_mean", feat1_embed_sparse_mean, on_epoch=True, sync_dist=True)
            self.log("encoder_metrics/feat1_encoder_embed_sparse_min", feat1_embed_sparse_min, on_epoch=True, sync_dist=True)
            self.log("encoder_metrics/feat2_encoder_embed_sparse_max", feat2_embed_sparse_max, on_epoch=True, sync_dist=True)
            self.log("encoder_metrics/feat2_encoder_embed_sparse_mean", feat2_embed_sparse_mean, on_epoch=True, sync_dist=True)
            self.log("encoder_metrics/feat2_encoder_embed_sparse_min", feat2_embed_sparse_min, on_epoch=True, sync_dist=True)
            
            # Encoder | Logging active feature fraction
            feat1_active_feature_fraction = active_feature_fraction(feats1, self.active_feature_threshold)
            feat2_active_feature_fraction = active_feature_fraction(feats2, self.active_feature_threshold)
            self.log("encoder_metrics/feat1_encoder_active_feature_fraction", feat1_active_feature_fraction, on_epoch=True, sync_dist=True)
            self.log("encoder_metrics/feat2_encoder_active_feature_fraction", feat2_active_feature_fraction, on_epoch=True, sync_dist=True)

            # Encoder | Logging avg nonzero elements per dimension
            feat1_avg_nonzero_elements_per_dimension = count_avg_nonzero_elements_per_dimension(feats1)
            feat2_avg_nonzero_elements_per_dimension = count_avg_nonzero_elements_per_dimension(feats2)
            self.log("encoder_metrics/feat1_encoder_avg_nonzero_elements_per_dimension", feat1_avg_nonzero_elements_per_dimension, on_epoch=True, sync_dist=True)
            self.log("encoder_metrics/feat2_encoder_avg_nonzero_elements_per_dimension", feat2_avg_nonzero_elements_per_dimension, on_epoch=True, sync_dist=True)

            # Encoder | Logging avg nonzero elements per sample
            feat1_avg_nonzero_elements_per_sample = count_avg_nonzero_elements_per_sample(feats1)
            feat2_avg_nonzero_elements_per_sample = count_avg_nonzero_elements_per_sample(feats2)
            self.log("encoder_metrics/feat1_encoder_avg_nonzero_elements_per_sample", feat1_avg_nonzero_elements_per_sample, on_epoch=True, sync_dist=True)
            self.log("encoder_metrics/feat2_encoder_avg_nonzero_elements_per_sample", feat2_avg_nonzero_elements_per_sample, on_epoch=True, sync_dist=True)

            # Projector | Logging batch sparsity metrics
            z1_batch_sparse_max, z1_batch_sparse_mean, z1_batch_sparse_min = batch_sparsity_metric(z1)
            z2_batch_sparse_max, z2_batch_sparse_mean, z2_batch_sparse_min = batch_sparsity_metric(z2)
            self.log("projector_metrics/z1_projector_batch_sparse_max", z1_batch_sparse_max, on_epoch=True, sync_dist=True)
            self.log("projector_metrics/z1_projector_batch_sparse_mean", z1_batch_sparse_mean, on_epoch=True, sync_dist=True)
            self.log("projector_metrics/z1_projector_batch_sparse_min", z1_batch_sparse_min, on_epoch=True, sync_dist=True)
            self.log("projector_metrics/z2_projector_batch_sparse_max", z2_batch_sparse_max, on_epoch=True, sync_dist=True)
            self.log("projector_metrics/z2_projector_batch_sparse_mean", z2_batch_sparse_mean, on_epoch=True, sync_dist=True)
            self.log("projector_metrics/z2_projector_batch_sparse_min", z2_batch_sparse_min, on_epoch=True, sync_dist=True)

            # Projector | Logging embedding sparsity metrics
            z1_embed_sparse_max, z1_embed_sparse_mean, z1_embed_sparse_min = embedding_sparsity_metric(z1)
            z2_embed_sparse_max, z2_embed_sparse_mean, z2_embed_sparse_min = embedding_sparsity_metric(z2)
            self.log("projector_metrics/z1_projector_embed_sparse_max", z1_embed_sparse_max, on_epoch=True, sync_dist=True)
            self.log("projector_metrics/z1_projector_embed_sparse_mean", z1_embed_sparse_mean, on_epoch=True, sync_dist=True)
            self.log("projector_metrics/z1_projector_embed_sparse_min", z1_embed_sparse_min, on_epoch=True, sync_dist=True)
            self.log("projector_metrics/z2_projector_embed_sparse_max", z2_embed_sparse_max, on_epoch=True, sync_dist=True)
            self.log("projector_metrics/z2_projector_embed_sparse_mean", z2_embed_sparse_mean, on_epoch=True, sync_dist=True)
            self.log("projector_metrics/z2_projector_embed_sparse_min", z2_embed_sparse_min, on_epoch=True, sync_dist=True)
            
            # Projector | Logging active feature fraction
            z1_active_feature_fraction = active_feature_fraction(z1, self.active_feature_threshold)
            z2_active_feature_fraction = active_feature_fraction(z2, self.active_feature_threshold)
            self.log("projector_metrics/z1_projector_active_feature_fraction", z1_active_feature_fraction, on_epoch=True, sync_dist=True)
            self.log("projector_metrics/z2_projector_active_feature_fraction", z2_active_feature_fraction, on_epoch=True, sync_dist=True)

            # Projector | Logging avg nonzero elements per dimension
            z1_avg_nonzero_elements_per_dimension = count_avg_nonzero_elements_per_dimension(z1)
            z2_avg_nonzero_elements_per_dimension = count_avg_nonzero_elements_per_dimension(z2)
            self.log("projector_metrics/z1_projector_avg_nonzero_elements_per_dimension", z1_avg_nonzero_elements_per_dimension, on_epoch=True, sync_dist=True)
            self.log("projector_metrics/z2_projector_avg_nonzero_elements_per_dimension", z2_avg_nonzero_elements_per_dimension, on_epoch=True, sync_dist=True)

            # Projector | Logging avg nonzero elements per sample
            z1_avg_nonzero_elements_per_sample = count_avg_nonzero_elements_per_sample(z1)
            z2_avg_nonzero_elements_per_sample = count_avg_nonzero_elements_per_sample(z2)
            self.log("projector_metrics/z1_projector_avg_nonzero_elements_per_sample", z1_avg_nonzero_elements_per_sample, on_epoch=True, sync_dist=True)
            self.log("projector_metrics/z2_projector_avg_nonzero_elements_per_sample", z2_avg_nonzero_elements_per_sample, on_epoch=True, sync_dist=True)

    def base_training_step(self, X: torch.Tensor, targets: torch.Tensor) -> Dict:
        """Allows user to re-write how the forward step behaves for the training_step.
        Should always return a dict containing, at least, "loss", "acc1" and "acc5".
        Defaults to _base_shared_step

        Args:
            X (torch.Tensor): batch of images in tensor format.
            targets (torch.Tensor): batch of labels for X.

        Returns:
            Dict: dict containing the classification loss, logits, features, acc@1 and acc@5.
        """

        return self._base_shared_step(X, targets)

    def training_step(self, batch: List[Any], batch_idx: int) -> Dict[str, Any]:
        """Training step for pytorch lightning. It does all the shared operations, such as
        forwarding the crops, computing logits and computing statistics.

        Args:
            batch (List[Any]): a batch of data in the format of [img_indexes, [X], Y], where
                [X] is a list of size self.num_crops containing batches of images.
            batch_idx (int): index of the batch.

        Returns:
            Dict[str, Any]: dict with the classification loss, features and logits.
        """

        _, X, targets = batch

        X = [X] if isinstance(X, torch.Tensor) else X

        # check that we received the desired number of crops
        assert len(X) == self.num_crops

        outs = [self.base_training_step(x, targets) for x in X[: self.num_large_crops]]
        outs = {k: [out[k] for out in outs] for k in outs[0].keys()}

        if self.multicrop:
            multicrop_outs = [self.multicrop_forward(x) for x in X[self.num_large_crops :]]
            for k in multicrop_outs[0].keys():
                outs[k] = outs.get(k, []) + [out[k] for out in multicrop_outs]

        # loss and stats
        outs["loss"] = sum(outs["loss"]) / self.num_large_crops
        outs["acc1"] = sum(outs["acc1"]) / self.num_large_crops
        outs["acc5"] = sum(outs["acc5"]) / self.num_large_crops

        metrics = {
            "train_class_loss": outs["loss"],
            "train_acc1": outs["acc1"],
            "train_acc5": outs["acc5"],
        }

        if self.global_step % self.logging_interval == 0:
            self.log_dict(metrics, on_epoch=True, sync_dist=True)

        if self.knn_eval:
            targets = targets.repeat(self.num_large_crops)
            mask = targets != -1
            self.knn(
                train_features=torch.cat(outs["feats"][: self.num_large_crops])[mask].detach(),
                train_targets=targets[mask],
            )

        return outs

    def base_validation_step(self, X: torch.Tensor, targets: torch.Tensor) -> Dict:
        """Allows user to re-write how the forward step behaves for the validation_step.
        Should always return a dict containing, at least, "loss", "acc1" and "acc5".
        Defaults to _base_shared_step

        Args:
            X (torch.Tensor): batch of images in tensor format.
            targets (torch.Tensor): batch of labels for X.

        Returns:
            Dict: dict containing the classification loss, logits, features, acc@1 and acc@5.
        """

        return self._base_shared_step(X, targets)

    def validation_step(
        self,
        batch: List[torch.Tensor],
        batch_idx: int,
        dataloader_idx: int = None,
        update_validation_step_outputs: bool = True,
    ) -> Dict[str, Any]:
        """Validation step for pytorch lightning. It does all the shared operations, such as
        forwarding a batch of images, computing logits and computing metrics.

        Args:
            batch (List[torch.Tensor]):a batch of data in the format of [img_indexes, X, Y].
            batch_idx (int): index of the batch.
            update_validation_step_outputs (bool): whether or not to append the
                metrics to validation_step_outputs

        Returns:
            Dict[str, Any]: dict with the batch_size (used for averaging), the classification loss
                and accuracies.
        """

        X, targets = batch
        batch_size = targets.size(0)

        out = self.base_validation_step(X, targets)

        if self.cfg.data.dataset == "celeba" and "z" in out and "feats" in out:
            z = out["z"]
            feats = out["feats"]
            z1 = z[0] if isinstance(z, (list, tuple)) else z
            f1 = feats[0] if isinstance(feats, (list, tuple)) else feats
            self._log_sparsity_metrics(z1, z1, f1, f1)

        if self.knn_eval and not self.trainer.sanity_checking:
            self.knn(test_features=out.pop("feats").detach(), test_targets=targets.detach())

        metrics = {
            "batch_size": batch_size,
            "val_loss": out["loss"],
            "val_acc1": out["acc1"],
            "val_acc5": out["acc5"],
        }
        if "proj_loss" in out:
            metrics.update({
                "val_proj_loss": out["proj_loss"],
                "val_proj_acc1": out["proj_acc1"],
                "val_proj_acc5": out["proj_acc5"],
            })
        
        # Add MLP probe validation metrics if available
        if "mlp_acc1" in out and "mlp_acc5" in out:
            metrics.update({
                "val_mlp_acc1": out["mlp_acc1"],
                "val_mlp_acc5": out["mlp_acc5"],
            })
        if "proj_mlp_acc1" in out and "proj_mlp_acc5" in out:
            metrics.update({
                "val_proj_mlp_acc1": out["proj_mlp_acc1"],
                "val_proj_mlp_acc5": out["proj_mlp_acc5"],
            })

        if update_validation_step_outputs:
            self.validation_step_outputs.append(metrics)
        return metrics

    def on_validation_epoch_end(self):
        """Averages the losses and accuracies of all the validation batches.
        This is needed because the last batch can be smaller than the others,
        slightly skewing the metrics.
        """

        val_loss = weighted_mean(self.validation_step_outputs, "val_loss", "batch_size")
        val_acc1 = weighted_mean(self.validation_step_outputs, "val_acc1", "batch_size")
        val_acc5 = weighted_mean(self.validation_step_outputs, "val_acc5", "batch_size")

        log = {"val_loss": val_loss, "val_acc1": val_acc1, "val_acc5": val_acc5}

        if self.validation_step_outputs and "val_proj_loss" in self.validation_step_outputs[0]:
            val_proj_loss = weighted_mean(self.validation_step_outputs, "val_proj_loss", "batch_size")

            val_proj_acc1 = weighted_mean(self.validation_step_outputs, "val_proj_acc1", "batch_size")
            val_proj_acc5 = weighted_mean(self.validation_step_outputs, "val_proj_acc5", "batch_size")
            log.update({
                "val_proj_loss": val_proj_loss,
                "val_proj_acc1": val_proj_acc1,
                "val_proj_acc5": val_proj_acc5,
            })

        # Aggregate MLP probe metrics if present
        if self.validation_step_outputs and "val_mlp_acc1" in self.validation_step_outputs[0]:
            val_mlp_acc1 = weighted_mean(self.validation_step_outputs, "val_mlp_acc1", "batch_size")
            val_mlp_acc5 = weighted_mean(self.validation_step_outputs, "val_mlp_acc5", "batch_size")
            log.update({
                "val_mlp_acc1": val_mlp_acc1,
                "val_mlp_acc5": val_mlp_acc5,
            })
        if self.validation_step_outputs and "val_proj_mlp_acc1" in self.validation_step_outputs[0]:
            val_proj_mlp_acc1 = weighted_mean(self.validation_step_outputs, "val_proj_mlp_acc1", "batch_size")
            val_proj_mlp_acc5 = weighted_mean(self.validation_step_outputs, "val_proj_mlp_acc5", "batch_size")
            log.update({
                "val_proj_mlp_acc1": val_proj_mlp_acc1,
                "val_proj_mlp_acc5": val_proj_mlp_acc5,
            })

        if self.knn_eval and not self.trainer.sanity_checking:
            val_knn_acc1, val_knn_acc5 = self.knn.compute()
            log.update({"val_knn_acc1": val_knn_acc1, "val_knn_acc5": val_knn_acc5})

        self.log_dict(log, sync_dist=True)

        self.validation_step_outputs.clear()


class BaseMomentumMethod(BaseMethod):
    def __init__(
        self,
        cfg: omegaconf.DictConfig,
    ):
        """Base momentum model that implements all basic operations for all self-supervised methods
        that use a momentum backbone. It adds shared momentum arguments, adds basic learnable
        parameters, implements basic training and validation steps for the momentum backbone and
        classifier. Also implements momentum update using exponential moving average and cosine
        annealing of the weighting decrease coefficient.

        Extra cfg settings:
            momentum:
                base_tau (float): base value of the weighting decrease coefficient in [0,1].
                final_tau (float): final value of the weighting decrease coefficient in [0,1].
                classifier (bool): whether or not to train a classifier on top of the
                    momentum backbone.
        """

        super().__init__(cfg)

        # initialize momentum backbone
        kwargs = self.backbone_args.copy()

        method: str = cfg.method
        self.momentum_backbone: nn.Module = self.base_model(method, **kwargs)
        if self.backbone_name.startswith("resnet"):
            # remove fc layer
            self.momentum_backbone.fc = nn.Identity()
            cifar = cfg.data.dataset in ["cifar10", "cifar100"]
            if cifar:
                self.momentum_backbone.conv1 = nn.Conv2d(
                    3, 64, kernel_size=3, stride=1, padding=2, bias=False
                )
                self.momentum_backbone.maxpool = nn.Identity()

        initialize_momentum_params(self.backbone, self.momentum_backbone)

        # momentum classifier
        if cfg.momentum.classifier:
            self.momentum_classifier: Any = nn.Linear(self.features_dim, self.num_classes)
        else:
            self.momentum_classifier = None

        # momentum updater
        self.momentum_updater = MomentumUpdater(cfg.momentum.base_tau, cfg.momentum.final_tau)

    @property
    def learnable_params(self) -> List[Dict[str, Any]]:
        """Adds momentum classifier parameters to the parameters of the base class.

        Returns:
            List[Dict[str, Any]]:
                list of dicts containing learnable parameters and possible settings.
        """

        momentum_learnable_parameters = []
        if self.momentum_classifier is not None:
            momentum_learnable_parameters.append(
                {
                    "name": "momentum_classifier",
                    "params": self.momentum_classifier.parameters(),
                    "lr": self.classifier_lr,
                    "weight_decay": 0,
                }
            )
        return super().learnable_params + momentum_learnable_parameters

    @property
    def momentum_pairs(self) -> List[Tuple[Any, Any]]:
        """Defines base momentum pairs that will be updated using exponential moving average.

        Returns:
            List[Tuple[Any, Any]]: list of momentum pairs (two element tuples).
        """

        return [(self.backbone, self.momentum_backbone)]

    @staticmethod
    def add_and_assert_specific_cfg(cfg: omegaconf.DictConfig) -> omegaconf.DictConfig:
        """Adds method specific default values/checks for config.

        Args:
            cfg (omegaconf.DictConfig): DictConfig object.

        Returns:
            omegaconf.DictConfig: same as the argument, used to avoid errors.
        """

        cfg = super(BaseMomentumMethod, BaseMomentumMethod).add_and_assert_specific_cfg(cfg)

        cfg.momentum.base_tau = omegaconf_select(cfg, "momentum.base_tau", 0.99)
        cfg.momentum.final_tau = omegaconf_select(cfg, "momentum.final_tau", 1.0)
        cfg.momentum.classifier = omegaconf_select(cfg, "momentum.classifier", False)

        return cfg

    def on_train_start(self):
        """Resets the step counter at the beginning of training."""
        self.last_step = 0

    @torch.no_grad()
    def momentum_forward(self, X: torch.Tensor) -> Dict[str, Any]:
        """Momentum forward method. Children methods should call this function,
        modify the ouputs (without deleting anything) and return it.

        Args:
            X (torch.Tensor): batch of images in tensor format.

        Returns:
            Dict: dict of logits and features.
        """

        if not self.no_channel_last:
            X = X.to(memory_format=torch.channels_last)
        feats = self.momentum_backbone(X)
        return {"feats": feats}

    def _shared_step_momentum(self, X: torch.Tensor, targets: torch.Tensor) -> Dict[str, Any]:
        """Forwards a batch of images X in the momentum backbone and optionally computes the
        classification loss, the logits, the features, acc@1 and acc@5 for of momentum classifier.

        Args:
            X (torch.Tensor): batch of images in tensor format.
            targets (torch.Tensor): batch of labels for X.

        Returns:
            Dict[str, Any]:
                a dict containing the classification loss, logits, features, acc@1 and
                acc@5 of the momentum backbone / classifier.
        """

        out = self.momentum_forward(X)

        if self.momentum_classifier is not None:
            feats = out["feats"]
            logits = self.momentum_classifier(feats)

            loss = F.cross_entropy(logits, targets, ignore_index=-1)
            if self.cfg.data.dataset != "celeba":
                acc1, acc5 = accuracy_at_k(logits, targets, top_k=(1, 5))
            else:
                acc1 = torch.tensor(0.0, device=self.device)
                acc5 = torch.tensor(0.0, device=self.device)
            out.update({"logits": logits, "loss": loss, "acc1": acc1, "acc5": acc5})

        return out

    def training_step(self, batch: List[Any], batch_idx: int) -> Dict[str, Any]:
        """Training step for pytorch lightning. It performs all the shared operations for the
        momentum backbone and classifier, such as forwarding the crops in the momentum backbone
        and classifier, and computing statistics.
        Args:
            batch (List[Any]): a batch of data in the format of [img_indexes, [X], Y], where
                [X] is a list of size self.num_crops containing batches of images.
            batch_idx (int): index of the batch.

        Returns:
            Dict[str, Any]: a dict with the features of the momentum backbone and the classification
                loss and logits of the momentum classifier.
        """

        outs = super().training_step(batch, batch_idx)

        _, X, targets = batch
        X = [X] if isinstance(X, torch.Tensor) else X

        # remove small crops
        X = X[: self.num_large_crops]

        momentum_outs = [self._shared_step_momentum(x, targets) for x in X]
        momentum_outs = {
            "momentum_" + k: [out[k] for out in momentum_outs] for k in momentum_outs[0].keys()
        }

        if self.momentum_classifier is not None:
            # momentum loss and stats
            momentum_outs["momentum_loss"] = (
                sum(momentum_outs["momentum_loss"]) / self.num_large_crops
            )
            momentum_outs["momentum_acc1"] = (
                sum(momentum_outs["momentum_acc1"]) / self.num_large_crops
            )
            momentum_outs["momentum_acc5"] = (
                sum(momentum_outs["momentum_acc5"]) / self.num_large_crops
            )

            metrics = {
                "train_momentum_class_loss": momentum_outs["momentum_loss"],
                "train_momentum_acc1": momentum_outs["momentum_acc1"],
                "train_momentum_acc5": momentum_outs["momentum_acc5"],
            }
            self.log_dict(metrics, on_epoch=True, sync_dist=True)

            # adds the momentum classifier loss together with the general loss
            outs["loss"] += momentum_outs["momentum_loss"]

        outs.update(momentum_outs)
        return outs

    def on_train_batch_end(self, outputs: Dict[str, Any], batch: Sequence[Any], batch_idx: int):
        """Performs the momentum update of momentum pairs using exponential moving average at the
        end of the current training step if an optimizer step was performed.

        Args:
            outputs (Dict[str, Any]): the outputs of the training step.
            batch (Sequence[Any]): a batch of data in the format of [img_indexes, [X], Y], where
                [X] is a list of size self.num_crops containing batches of images.
            batch_idx (int): index of the batch.
        """

        if self.trainer.global_step > self.last_step:
            # update momentum backbone and projector
            momentum_pairs = self.momentum_pairs
            for mp in momentum_pairs:
                self.momentum_updater.update(*mp)
            # log tau momentum
            self.log("tau", self.momentum_updater.cur_tau)
            # update tau
            self.momentum_updater.update_tau(
                cur_step=self.trainer.global_step,
                max_steps=self.trainer.estimated_stepping_batches,
            )
        self.last_step = self.trainer.global_step

    def validation_step(
        self,
        batch: List[torch.Tensor],
        batch_idx: int,
        dataloader_idx: int = None,
        update_validation_step_outputs: bool = True,
    ) -> Tuple[Dict[str, Any], Dict[str, Any]]:
        """Validation step for pytorch lightning. It performs all the shared operations for the
        momentum backbone and classifier, such as forwarding a batch of images in the momentum
        backbone and classifier and computing statistics.

        Args:
            batch (List[torch.Tensor]): a batch of data in the format of [X, Y].
            batch_idx (int): index of the batch.
            update_validation_step_outputs (bool): whether or not to append the
                metrics to validation_step_outputs

        Returns:
            Tuple(Dict[str, Any], Dict[str, Any]): tuple of dicts containing the batch_size (used
                for averaging), the classification loss and accuracies for both the online and the
                momentum classifiers.
        """

        metrics = super().validation_step(batch, batch_idx, update_validation_step_outputs=False)

        X, targets = batch

        out = self._shared_step_momentum(X, targets)

        if self.momentum_classifier is not None:
            metrics.update(
                {
                    "momentum_val_loss": out["loss"],
                    "momentum_val_acc1": out["acc1"],
                    "momentum_val_acc5": out["acc5"],
                }
            )

        if update_validation_step_outputs:
            self.validation_step_outputs.append(metrics)

        return metrics

    def on_validation_epoch_end(self):
        """Averages the losses and accuracies of the momentum backbone / classifier for all the
        validation batches. This is needed because the last batch can be smaller than the others,
        slightly skewing the metrics.
        """

        # base method metrics
        val_loss = weighted_mean(self.validation_step_outputs, "val_loss", "batch_size")
        val_acc1 = weighted_mean(self.validation_step_outputs, "val_acc1", "batch_size")
        val_acc5 = weighted_mean(self.validation_step_outputs, "val_acc5", "batch_size")

        log = {"val_loss": val_loss, "val_acc1": val_acc1, "val_acc5": val_acc5}

        if self.knn_eval and not self.trainer.sanity_checking:
            val_knn_acc1, val_knn_acc5 = self.knn.compute()
            log.update({"val_knn_acc1": val_knn_acc1, "val_knn_acc5": val_knn_acc5})

        self.log_dict(log, sync_dist=True)

        # momentum method metrics
        if self.momentum_classifier is not None:
            val_loss = weighted_mean(
                self.validation_step_outputs, "momentum_val_loss", "batch_size"
            )
            val_acc1 = weighted_mean(
                self.validation_step_outputs, "momentum_val_acc1", "batch_size"
            )
            val_acc5 = weighted_mean(
                self.validation_step_outputs, "momentum_val_acc5", "batch_size"
            )

            log = {
                "momentum_val_loss": val_loss,
                "momentum_val_acc1": val_acc1,
                "momentum_val_acc5": val_acc5,
            }
            self.log_dict(log, sync_dist=True)

        self.validation_step_outputs.clear()
