# Copyright 2023 solo-learn development team.

# Permission is hereby granted, free of charge, to any person obtaining a copy of
# this software and associated documentation files (the "Software"), to deal in
# the Software without restriction, including without limitation the rights to use,
# copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the
# Software, and to permit persons to whom the Software is furnished to do so,
# subject to the following conditions:

# The above copyright notice and this permission notice shall be included in all copies
# or substantial portions of the Software.

# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED,
# INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR
# PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE
# FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR
# OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
# DEALINGS IN THE SOFTWARE.

from typing import Any, Dict, List, Sequence, Optional
from collections import defaultdict
import random

import omegaconf
import torch
import torch.nn as nn
import torch.nn.functional as F
from solo.losses.nat import hungarian_matching, nat_alignment_loss, nat_auxiliary_losses
from solo.methods.base import BaseMethod
from solo.utils.misc import omegaconf_select, gather
from solo.losses.radialbyol import variance_loss, covariance_loss, chi2_radial_nll_loss, uniform_loss, anisotropy_loss
from solo.losses.radialvicreg import batch_sparsity_metric, embedding_sparsity_metric
from torch.utils.data import DataLoader
import numpy as np


class NAT(BaseMethod):
    def __init__(self, cfg: omegaconf.DictConfig):
        """Implements NAT (Noise As Targets) from the paper 
        "Unsupervised Learning by Predicting Noise" (https://arxiv.org/abs/1704.05310).
        
        Aligns deep features to fixed noise targets while using Hungarian 
        matching for permutation-invariant assignment.

        Extra cfg settings:
            method_kwargs:
                proj_output_dim (int): number of dimensions of the projected features.
                proj_hidden_dim (int): number of neurons in the hidden layers of the projector.
                target_dim (int): dimension of the target representations.
                update_freq (int): frequency of assignment updates (in epochs).
                shuffle_freq (int): frequency of target shuffling (in epochs).
                variance_lambda (float): weight for variance loss - percentage when strategy is self tune other weise 
                optimize_variance (bool): whether to optimize for variance loss.
                covariance_lambda (float): weight for covariance loss.
                optimize_covariance (bool): whether to optimize for covariance loss.
                radial_lambda (float): weight for radial loss.
                optimize_radial (bool): whether to optimize for radial loss.
                lambda_strategy (str): 'standard' or 'self_tune'.
                add_projector_classifier (bool): whether to add projector classifier.
        """

        # Check projector classifier configuration before calling super().__init__
        if omegaconf_select(cfg, "method_kwargs.add_projector_classifier", False):
            assert not omegaconf.OmegaConf.is_missing(
                cfg, "method_kwargs.proj_output_dim"
            ), "NAT: method_kwargs.proj_output_dim must be set if add_projector_classifier is True."

        super().__init__(cfg)

        self.proj_output_dim: int = cfg.method_kwargs.proj_output_dim
        self.update_freq: int = cfg.method_kwargs.update_freq
        self.shuffle_freq: int = cfg.method_kwargs.shuffle_freq

        # Auxiliary loss parameters
        self.variance_lambda: float = cfg.method_kwargs.variance_lambda
        self.optimize_variance: bool = cfg.method_kwargs.optimize_variance
        self.covariance_lambda: float = cfg.method_kwargs.covariance_lambda
        self.optimize_covariance: bool = cfg.method_kwargs.optimize_covariance
        self.radial_lambda: float = cfg.method_kwargs.radial_lambda
        self.optimize_radial: bool = cfg.method_kwargs.optimize_radial
        self.lambda_strategy: str = cfg.method_kwargs.lambda_strategy

        # Projector network
        proj_hidden_dim: int = cfg.method_kwargs.proj_hidden_dim
        self.projector = nn.Sequential(
            nn.Linear(self.features_dim, proj_hidden_dim),
            nn.BatchNorm1d(proj_hidden_dim),
            nn.ReLU(),
            nn.Linear(proj_hidden_dim, proj_hidden_dim),
            nn.BatchNorm1d(proj_hidden_dim),
            nn.ReLU(),
            nn.Linear(proj_hidden_dim, self.proj_output_dim),
        )
        
        # For tracking assignments and shuffling during training
        self.register_buffer("assignment_epoch", torch.zeros(1, dtype=torch.long))
        self.register_buffer("shuffle_epoch", torch.zeros(1, dtype=torch.long))
        self.batch_size = cfg.optimizer.batch_size
        
        # Dictionary to store image index -> target assignment
        # Will be initialized in setup()
        self.index_to_target = {}
        
        # Cache for target tensor used in the current batch
        self._current_batch_targets = None

    def _init_targets(self, n_targets: int) -> torch.Tensor:
        """Initializes the target representations by sampling from a unit sphere.
        
        Args:
            n_targets (int): Number of targets to generate.
            
        Returns:
            torch.Tensor: Normalized target vectors on the unit sphere.
        """
        targets = torch.randn(n_targets, self.proj_output_dim)
        targets = F.normalize(targets, dim=1)
        return targets
        
    def setup(self, stage: Optional[str] = None) -> None:
        """Sets up the method for training.
        
        This is called by PyTorch Lightning after the model is initialized but before
        training starts. We use it to initialize our global target assignments.
        
        Args:
            stage (Optional[str]): either 'fit', 'validate', 'test', or 'predict'
        """
        super().setup(stage)
        
        if stage == "fit" or stage is None: # only initialize targets for training
            # Get the total dataset size
            # Different handling for DALI vs regular dataloaders
            try:
                if hasattr(self.trainer.datamodule, 'dataset_len'):
                    # DALI dataloader case
                    dataset_size = self.trainer.datamodule.dataset_len
                else:
                    # Regular dataloader case - our path
                    dataset_size = len(self.trainer.train_dataloader.dataset)
            except Exception as e:
                # Try to get dataset_size from config as a fallback
                dataset_size_cfg = omegaconf_select(self.cfg, "data.dataset_size", None)
                if dataset_size_cfg is not None and isinstance(dataset_size_cfg, int):
                    dataset_size = dataset_size_cfg
                    print(f"Using dataset_size from config: {dataset_size}")
                else:
                    # If not in config or not an int, raise the original error
                    raise ValueError(f"Couldn't determine dataset size for NAT. Exception: {str(e)}. Also, data.dataset_size not found or invalid in config.")
            
            # Initialize targets for the entire dataset
            all_targets = self._init_targets(dataset_size)
            all_targets = all_targets.to(self.device)
            
            # Initialize the global index to target mapping
            for idx in range(dataset_size):
                self.index_to_target[idx] = all_targets[idx].clone()
                
            print(f"Initialized {dataset_size} targets for NAT")
    
    def shuffle_targets(self):
        """Shuffles all target vectors randomly to avoid plateaus.
        
        This is a key component of the original NAT algorithm. Random shuffling helps
        avoid getting stuck in local minima during training.
        """
        # Get all indices and their corresponding targets
        all_indices = list(self.index_to_target.keys())
        all_targets = list(self.index_to_target.values())
        
        # Convert list of tensors to tensor
        target_tensor = torch.stack(all_targets, dim=0)
        
        # Create a random permutation
        perm = torch.randperm(len(all_indices))
        shuffled_targets = target_tensor[perm]
        
        # Update the assignments
        for i, idx in enumerate(all_indices):
            self.index_to_target[idx] = shuffled_targets[i].clone()
            
        print(f"Shuffled {len(all_indices)} target vectors")
        
    def get_targets_for_indices(self, indices: torch.Tensor) -> torch.Tensor:
        """Returns the target vectors for a batch of image indices.
        
        Args:
            indices (torch.Tensor): Batch of image indices.
            
        Returns:
            torch.Tensor: Batch of corresponding target vectors.
        """
        # Convert indices to CPU for dictionary lookup
        indices = indices.cpu().numpy()
        
        # Get targets for the current batch
        batch_targets = []
        for idx in indices:
            batch_targets.append(self.index_to_target[idx.item()])
            
        # Stack into a tensor
        return torch.stack(batch_targets, dim=0)

    @staticmethod
    def add_and_assert_specific_cfg(cfg: omegaconf.DictConfig) -> omegaconf.DictConfig:
        """Adds method specific default values/checks for config.

        Args:
            cfg (omegaconf.DictConfig): DictConfig object.

        Returns:
            omegaconf.DictConfig: same as the argument, used to avoid errors.
        """
        cfg = super(NAT, NAT).add_and_assert_specific_cfg(cfg)

        assert not omegaconf.OmegaConf.is_missing(cfg, "method_kwargs.proj_output_dim")
        assert not omegaconf.OmegaConf.is_missing(cfg, "method_kwargs.proj_hidden_dim")
        
        cfg.method_kwargs.update_freq = omegaconf_select(
            cfg, "method_kwargs.update_freq", 3
        )
        
        cfg.method_kwargs.shuffle_freq = omegaconf_select(
            cfg, "method_kwargs.shuffle_freq", 3
        )

        # Defaults for auxiliary losses
        cfg.method_kwargs.variance_lambda = omegaconf_select(cfg, "method_kwargs.variance_lambda", 0.0)
        cfg.method_kwargs.optimize_variance = omegaconf_select(cfg, "method_kwargs.optimize_variance", False)
        cfg.method_kwargs.covariance_lambda = omegaconf_select(cfg, "method_kwargs.covariance_lambda", 0.0)
        cfg.method_kwargs.optimize_covariance = omegaconf_select(cfg, "method_kwargs.optimize_covariance", False)
        cfg.method_kwargs.radial_lambda = omegaconf_select(cfg, "method_kwargs.radial_lambda", 0.0)
        cfg.method_kwargs.optimize_radial = omegaconf_select(cfg, "method_kwargs.optimize_radial", False)
        cfg.method_kwargs.lambda_strategy = omegaconf_select(cfg, "method_kwargs.lambda_strategy", "standard")
        assert cfg.method_kwargs.lambda_strategy in ["standard", "self_tune"]

        return cfg

    @property
    def learnable_params(self) -> List[Dict]:
        """Adds projector parameters to the parent's learnable parameters.

        Returns:
            List[Dict]: list of learnable parameters.
        """
        extra_learnable_params = [{"name": "projector", "params": self.projector.parameters()}]
        return super().learnable_params + extra_learnable_params

    def forward(self, X: torch.Tensor) -> Dict[str, Any]:
        """Performs the forward pass of the backbone and the projector.

        Args:
            X (torch.Tensor): a batch of images in the tensor format.

        Returns:
            Dict[str, Any]: a dict containing the outputs of the parent
                and the projected features.
        """
            
        out = super().forward(X)
        # Log encoder feature norms, uniform loss, and anisotropy loss
        if self.training and "feats" in out:
            backbone_feats = out["feats"]
            with torch.no_grad():
                if isinstance(backbone_feats, torch.Tensor):
                    backbone_feature_norm_mean = torch.norm(backbone_feats, dim=1).mean()
                    self.log("encoder_norm/train_encoder_feature_norm_mean", backbone_feature_norm_mean, on_epoch=True, sync_dist=True)
                    backbone_feature_norm_var = torch.norm(backbone_feats, dim=1).var()
                    self.log("encoder_norm/train_encoder_feature_norm_var", backbone_feature_norm_var, on_epoch=True, sync_dist=True)
                    
                    backbone_feature_unif_loss = uniform_loss(backbone_feats)
                    self.log("unif_loss/train_encoder_unif_loss", backbone_feature_unif_loss, on_epoch=True, sync_dist=True)
                    backbone_feature_anisotropy_loss = anisotropy_loss(backbone_feats) 
                    self.log("anisotropy_loss/train_encoder_anisotropy_loss", backbone_feature_anisotropy_loss, on_epoch=True, sync_dist=True)
                    
                    # Encoder sparsity metrics
                    enc_batch_sparse_max, enc_batch_sparse_mean, enc_batch_sparse_min = batch_sparsity_metric(backbone_feats)
                    self.log("batch_sparsity_metric/encoder_batch_sparse_max", enc_batch_sparse_max, on_epoch=True, sync_dist=True)
                    self.log("batch_sparsity_metric/encoder_batch_sparse_mean", enc_batch_sparse_mean, on_epoch=True, sync_dist=True)
                    self.log("batch_sparsity_metric/encoder_batch_sparse_min", enc_batch_sparse_min, on_epoch=True, sync_dist=True)

                    enc_embed_sparse_max, enc_embed_sparse_mean, enc_embed_sparse_min = embedding_sparsity_metric(backbone_feats)
                    self.log("embedding_sparsity_metric/encoder_embed_sparse_max", enc_embed_sparse_max, on_epoch=True, sync_dist=True)
                    self.log("embedding_sparsity_metric/encoder_embed_sparse_mean", enc_embed_sparse_mean, on_epoch=True, sync_dist=True)
                    self.log("embedding_sparsity_metric/encoder_embed_sparse_min", enc_embed_sparse_min, on_epoch=True, sync_dist=True)


        z = self.projector(out["feats"])
        out.update({"z": z})
        
        # If the projector classifier exists (initialized in BaseMethod), call it.
        if self.projector_classifier is not None:
            # Pass z.detach() to ensure gradients for this classifier only update itself,
            # not the projector or backbone through this path.
            projector_logits = self.projector_classifier(z.detach())
            out.update({"projector_logits": projector_logits})
            
        return out

    def training_step(self, batch: Sequence[Any], batch_idx: int) -> torch.Tensor:
        """Training step for NAT.

        Args:
            batch (Sequence[Any]): a batch of data in the format of [img_indexes, [X], Y], where
                [X] is a list of size num_crops containing batches of images.
            batch_idx (int): index of the batch.

        Returns:
            torch.Tensor: total loss.
        """
        
        # Get indices, inputs, and targets
        indices, X, targets = batch
        
        # Run standard training step to get features
        out = super().training_step(batch, batch_idx)
        class_loss = out["loss"]
        
        # Projected features from all crops (list of tensors)
        all_z_crops = out["z"]
        
        # Get the targets assigned to the current batch
        batch_targets = self.get_targets_for_indices(indices)
        curr_epoch = self.trainer.current_epoch

        # Initialize targets_for_loss with current batch targets
        targets_for_loss = batch_targets

        # Check if we need to update the assignment (every update_freq epochs)
        if curr_epoch > self.assignment_epoch.item() and curr_epoch % self.update_freq == 0:
            # Get features from first crop
            features = all_z_crops[0]  # Use first crop's features
            # Compute similarity matrix for assignment (cosine similarity)
            features_norm = F.normalize(features, dim=1)
            batch_targets_norm = F.normalize(batch_targets, dim=1)
            similarity = features_norm @ batch_targets_norm.t()
            # Compute optimal assignment using Hungarian algorithm
            perm_matrix = hungarian_matching(similarity)
            # Permute current batch targets for loss calculation
            new_batch_targets = perm_matrix @ batch_targets
            targets_for_loss = new_batch_targets.clone()

            # Update assignment epoch and log
            if batch_idx == 0:  # Only log once per epoch
                self.assignment_epoch.fill_(curr_epoch)
                self.log("nat_reassignment", float(curr_epoch), on_step=False, on_epoch=True, sync_dist=True)

        # Compute alignment loss for each crop using batch-local targets
        align_loss = 0
        for i, crop_z in enumerate(all_z_crops):
            align_loss += nat_alignment_loss(crop_z, targets_for_loss)
        align_loss /= len(all_z_crops)
        
        # Initialize auxiliary losses to 0 values on the correct device
        var_loss_val = torch.tensor(0.0, device=self.device)
        cov_loss_val = torch.tensor(0.0, device=self.device)
        radial_loss_val = torch.tensor(0.0, device=self.device)

        # Calculate auxiliary losses using the new function from solo.losses.nat
        if len(all_z_crops) >= 2:
            z1_online = all_z_crops[0]
            z2_online = all_z_crops[1]
            var_loss_val, cov_loss_val, radial_loss_val = nat_auxiliary_losses(z1_online, z2_online)
            
            # Log projector norms and uniform loss for these two views
            z1_norm_mean = torch.norm(z1_online, dim=1).mean()
            z2_norm_mean = torch.norm(z2_online, dim=1).mean()
            self.log("projector_norm/train_z1_norm_mean_projector", z1_norm_mean, on_epoch=True, sync_dist=True)
            self.log("projector_norm/train_z2_norm_mean_projector", z2_norm_mean, on_epoch=True, sync_dist=True)

            z1_norm_var = torch.norm(z1_online, dim=1).var()
            z2_norm_var = torch.norm(z2_online, dim=1).var()
            self.log("projector_norm/train_z1_norm_var_projector", z1_norm_var, on_epoch=True, sync_dist=True)
            self.log("projector_norm/train_z2_norm_var_projector", z2_norm_var, on_epoch=True, sync_dist=True)
            
            z1_unif_loss = uniform_loss(z1_online)
            self.log("unif_loss/train_z1_unif_loss_projector", z1_unif_loss, on_epoch=True, sync_dist=True)
            z2_unif_loss = uniform_loss(z2_online)
            self.log("unif_loss/train_z2_unif_loss_projector", z2_unif_loss, on_epoch=True, sync_dist=True)
            
            # Anisotropy loss for z1 and z2 projector outputs
            z1_anisotropy_loss = anisotropy_loss(z1_online)
            self.log("anisotropy_loss/train_z1_anisotropy_loss_projector", z1_anisotropy_loss, on_epoch=True, sync_dist=True)
            z2_anisotropy_loss = anisotropy_loss(z2_online)
            self.log("anisotropy_loss/train_z2_anisotropy_loss_projector", z2_anisotropy_loss, on_epoch=True, sync_dist=True)

            # Projector sparsity metrics for z1_online and z2_online
            z1_batch_sparse_max, z1_batch_sparse_mean, z1_batch_sparse_min = batch_sparsity_metric(z1_online)
            self.log("batch_sparsity_metric/z1_projector_batch_sparse_max", z1_batch_sparse_max, on_epoch=True, sync_dist=True)
            self.log("batch_sparsity_metric/z1_projector_batch_sparse_mean", z1_batch_sparse_mean, on_epoch=True, sync_dist=True)
            self.log("batch_sparsity_metric/z1_projector_batch_sparse_min", z1_batch_sparse_min, on_epoch=True, sync_dist=True)
            z1_embed_sparse_max, z1_embed_sparse_mean, z1_embed_sparse_min = embedding_sparsity_metric(z1_online)
            self.log("embedding_sparsity_metric/z1_projector_embed_sparse_max", z1_embed_sparse_max, on_epoch=True, sync_dist=True)
            self.log("embedding_sparsity_metric/z1_projector_embed_sparse_mean", z1_embed_sparse_mean, on_epoch=True, sync_dist=True)
            self.log("embedding_sparsity_metric/z1_projector_embed_sparse_min", z1_embed_sparse_min, on_epoch=True, sync_dist=True)

            z2_batch_sparse_max, z2_batch_sparse_mean, z2_batch_sparse_min = batch_sparsity_metric(z2_online)
            self.log("batch_sparsity_metric/z2_projector_batch_sparse_max", z2_batch_sparse_max, on_epoch=True, sync_dist=True)
            self.log("batch_sparsity_metric/z2_projector_batch_sparse_mean", z2_batch_sparse_mean, on_epoch=True, sync_dist=True)
            self.log("batch_sparsity_metric/z2_projector_batch_sparse_min", z2_batch_sparse_min, on_epoch=True, sync_dist=True)
            z2_embed_sparse_max, z2_embed_sparse_mean, z2_embed_sparse_min = embedding_sparsity_metric(z2_online)
            self.log("embedding_sparsity_metric/z2_projector_embed_sparse_max", z2_embed_sparse_max, on_epoch=True, sync_dist=True)
            self.log("embedding_sparsity_metric/z2_projector_embed_sparse_mean", z2_embed_sparse_mean, on_epoch=True, sync_dist=True)
            self.log("embedding_sparsity_metric/z2_projector_embed_sparse_min", z2_embed_sparse_min, on_epoch=True, sync_dist=True)

        elif len(all_z_crops) == 1: 
            z_online = all_z_crops[0]
            var_loss_val, cov_loss_val, radial_loss_val = nat_auxiliary_losses(z_online, None) # Pass None for z2_online

            z_norm_mean = torch.norm(z_online, dim=1).mean()
            self.log("projector_norm/train_z_norm_mean_projector", z_norm_mean, on_epoch=True, sync_dist=True)
            z_norm_var = torch.norm(z_online, dim=1).var()
            self.log("projector_norm/train_z_norm_var_projector", z_norm_var, on_epoch=True, sync_dist=True)
            
            z_unif_loss = uniform_loss(z_online)
            self.log("unif_loss/train_z_unif_loss_projector", z_unif_loss, on_epoch=True, sync_dist=True)
            # Anisotropy loss for single projector output z_online
            z_anisotropy_loss = anisotropy_loss(z_online)
            self.log("anisotropy_loss/train_z_anisotropy_loss_projector", z_anisotropy_loss, on_epoch=True, sync_dist=True)

            # Projector sparsity metrics for z_online (single crop)
            z_batch_sparse_max, z_batch_sparse_mean, z_batch_sparse_min = batch_sparsity_metric(z_online)
            self.log("batch_sparsity_metric/z_projector_batch_sparse_max", z_batch_sparse_max, on_epoch=True, sync_dist=True)
            self.log("batch_sparsity_metric/z_projector_batch_sparse_mean", z_batch_sparse_mean, on_epoch=True, sync_dist=True)
            self.log("batch_sparsity_metric/z_projector_batch_sparse_min", z_batch_sparse_min, on_epoch=True, sync_dist=True)
            z_embed_sparse_max, z_embed_sparse_mean, z_embed_sparse_min = embedding_sparsity_metric(z_online)
            self.log("embedding_sparsity_metric/z_projector_embed_sparse_max", z_embed_sparse_max, on_epoch=True, sync_dist=True)
            self.log("embedding_sparsity_metric/z_projector_embed_sparse_mean", z_embed_sparse_mean, on_epoch=True, sync_dist=True)
            self.log("embedding_sparsity_metric/z_projector_embed_sparse_min", z_embed_sparse_min, on_epoch=True, sync_dist=True)


        # Log auxiliary losses
        self.log("train_variance_loss", var_loss_val / all_z_crops[0].size(1) if len(all_z_crops) > 0 and all_z_crops[0].size(1) > 0 else var_loss_val, on_epoch=True, sync_dist=True)
        self.log("train_covariance_loss", cov_loss_val, on_epoch=True, sync_dist=True)
        self.log("train_radial_loss", radial_loss_val, on_epoch=True, sync_dist=True)

        # Total loss
        total_loss = align_loss + class_loss
        
        # Add auxiliary losses to total_loss if optimization is enabled
        sim_loss_detached_for_scaling = align_loss.detach() # Using align_loss as the "main SSL loss" for scaling

        if self.optimize_variance:
            if self.lambda_strategy == "standard":
                total_loss += self.variance_lambda * var_loss_val
            elif self.lambda_strategy == "self_tune":
                w_var = (self.variance_lambda * sim_loss_detached_for_scaling) / (var_loss_val.detach() + 1e-6)
                total_loss += w_var * var_loss_val
        
        if self.optimize_covariance:
            if self.lambda_strategy == "standard":
                total_loss += self.covariance_lambda * cov_loss_val
            elif self.lambda_strategy == "self_tune":
                w_cov = (self.covariance_lambda * sim_loss_detached_for_scaling) / (cov_loss_val.detach() + 1e-6)
                total_loss += w_cov * cov_loss_val

        if self.optimize_radial:
            if self.lambda_strategy == "standard":
                total_loss += self.radial_lambda * radial_loss_val
            elif self.lambda_strategy == "self_tune":
                w_rad = (self.radial_lambda * sim_loss_detached_for_scaling) / (radial_loss_val.detach() + 1e-6)
                total_loss += w_rad * radial_loss_val
        
        # Log final metrics
        metrics = {
            "train_nat_align_loss": align_loss,
            "train_total_loss": total_loss,
        }
        # Add class_loss to metrics if it's not zero (i.e., online classifier is active)
        if isinstance(class_loss, torch.Tensor) and class_loss.item() != 0:
            metrics["train_class_loss"] = class_loss
        elif isinstance(class_loss, (float, int)) and class_loss != 0:
            metrics["train_class_loss"] = class_loss
            
        self.log_dict(metrics, on_epoch=True, sync_dist=True)
        
        # Log original feature norm statistics for monitoring all crops
        # Ensure all_z_crops is not empty before attempting to stack and calculate mean/var
        if all_z_crops:
            avg_z_norm_mean_all_crops = torch.stack([torch.norm(crop_z, dim=1).mean() for crop_z in all_z_crops]).mean()
            avg_z_norm_var_all_crops = torch.stack([torch.norm(crop_z, dim=1).var() for crop_z in all_z_crops]).mean()
            self.log("projector_norm/train_avg_z_norm_mean_all_crops", avg_z_norm_mean_all_crops, on_epoch=True, sync_dist=True)
            self.log("projector_norm/train_avg_z_norm_var_all_crops", avg_z_norm_var_all_crops, on_epoch=True, sync_dist=True)

        # Projector Classifier Training
        projector_class_loss = torch.tensor(0.0, device=self.device)
        if self.projector_classifier is not None:
            # _projector_classifier_step is a helper in BaseMethod
            # It expects a single view's z and targets
            total_proj_loss = torch.tensor(0.0, device=self.device)
            # Initialize acc tensors as 1-dim to match proj_metrics shape ([1])
            total_proj_acc1 = torch.zeros(1, device=self.device)
            total_proj_acc5 = torch.zeros(1, device=self.device)
            
            for i, crop_z in enumerate(all_z_crops):
                proj_metrics = self._projector_classifier_step(crop_z, targets)
                if proj_metrics: # Check if metrics were computed
                    total_proj_loss += proj_metrics["proj_loss"]
                    total_proj_acc1 += proj_metrics["proj_acc1"]
                    total_proj_acc5 += proj_metrics["proj_acc5"]
            
            if len(all_z_crops) > 0:
                projector_class_loss = total_proj_loss / len(all_z_crops)
                avg_proj_acc1 = total_proj_acc1 / len(all_z_crops)
                avg_proj_acc5 = total_proj_acc5 / len(all_z_crops)

                self.log("train_proj_loss", projector_class_loss, on_epoch=True, sync_dist=True)
                self.log("train_proj_acc1", avg_proj_acc1, on_epoch=True, sync_dist=True)
                self.log("train_proj_acc5", avg_proj_acc5, on_epoch=True, sync_dist=True)

        # Add projector classifier loss to total loss
        total_loss += projector_class_loss

        return total_loss