# Copyright 2023 solo-learn development team.

# Permission is hereby granted, free of charge, to any person obtaining a copy of
# this software and associated documentation files (the "Software"), to deal in
# the Software without restriction, including without limitation the rights to use,
# copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the
# Software, and to permit persons to whom the Software is furnished to do so,
# subject to the following conditions:

# The above copyright notice and this permission notice shall be included in all copies
# or substantial portions of the Software.

# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED,
# INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR
# PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE
# FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR
# OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
# DEALINGS IN THE SOFTWARE.

from typing import Any, Dict, List, Sequence, Tuple

import numpy as np
import omegaconf
import torch
import torch.nn as nn
import torch.nn.functional as F
from solo.losses.radialbyol import radial_byol_loss_suite, uniform_loss, anisotropy_loss # Updated import
from solo.losses.radialvicreg import batch_sparsity_metric, embedding_sparsity_metric, chi2_radial_nll_loss_for_lightning_logging # Added for sparsity metrics
from solo.methods.base import BaseMomentumMethod
from solo.utils.momentum import initialize_momentum_params
from solo.utils.misc import omegaconf_select # For default values
from solo.utils.misc import gather

class RadialBYOL(BaseMomentumMethod): # Renamed class
    def __init__(self, cfg: omegaconf.DictConfig):
        """Implements BYOL (https://arxiv.org/abs/2006.07733) with optional 
        radial, variance, and covariance losses for logging/optimization.

        Extra cfg settings:
            method_kwargs:
                proj_output_dim (int): number of dimensions of projected features.
                proj_hidden_dim (int): number of neurons of the hidden layers of the projector.
                pred_hidden_dim (int): number of neurons of the hidden layers of the predictor.
                projector_type (str): "mlp" or "identity".
                radial_lambda (float): lambda for the radial loss.
                optimize_radial (bool): whether to optimize for radial loss.
                variance_lambda (float): lambda for the variance loss.
                optimize_variance (bool): whether to optimize for variance loss.
                covariance_lambda (float): lambda for the covariance loss.
                optimize_covariance (bool): whether to optimize for covariance loss.
                lambda_strategy (str): "standard" or "self_tune".
                add_projector_classifier (bool): whether to add projector classifier.
        """

        # Check projector classifier configuration before calling super().__init__
        if omegaconf_select(cfg, "method_kwargs.add_projector_classifier", False):
            assert not omegaconf.OmegaConf.is_missing(
                cfg, "method_kwargs.proj_output_dim"
            ), "RadialBYOL: method_kwargs.proj_output_dim must be set if add_projector_classifier is True."

        super().__init__(cfg)

        # Load parameters from cfg
        proj_hidden_dim_cfg: int = cfg.method_kwargs.proj_hidden_dim
        proj_output_dim_cfg: int = cfg.method_kwargs.proj_output_dim
        pred_hidden_dim_cfg: int = cfg.method_kwargs.pred_hidden_dim
        self.projector_type: str = cfg.method_kwargs.projector_type # added projector type too

        # New hyperparameters for lambda strategy
        self.radial_lambda: float = cfg.method_kwargs.radial_lambda
        self.optimize_radial: bool = cfg.method_kwargs.optimize_radial
        self.variance_lambda: float = cfg.method_kwargs.variance_lambda
        self.optimize_variance: bool = cfg.method_kwargs.optimize_variance
        self.covariance_lambda: float = cfg.method_kwargs.covariance_lambda
        self.optimize_covariance: bool = cfg.method_kwargs.optimize_covariance
        self.lambda_strategy: str = cfg.method_kwargs.lambda_strategy

        # Determine actual output dimension of the projector based on type
        if self.projector_type == "identity":
            actual_projector_output_dim = self.features_dim
        else:  # mlp
            actual_projector_output_dim = proj_output_dim_cfg

        # Online projector # 
        if self.projector_type == "mlp":
            self.projector = nn.Sequential(
                nn.Linear(self.features_dim, proj_hidden_dim_cfg),
                nn.BatchNorm1d(proj_hidden_dim_cfg),
                nn.ReLU(),
                nn.Linear(proj_hidden_dim_cfg, actual_projector_output_dim),
            )
        else:  # identity
            self.projector = nn.Identity()

        # Momentum projector # this is the last layer of the momentum projector 
        if self.projector_type == "mlp":
            self.momentum_projector = nn.Sequential(
                nn.Linear(self.features_dim, proj_hidden_dim_cfg),
                nn.BatchNorm1d(proj_hidden_dim_cfg),
                nn.ReLU(),
                nn.Linear(proj_hidden_dim_cfg, actual_projector_output_dim),
            )
            initialize_momentum_params(self.projector, self.momentum_projector)
        else:  # identity
            self.momentum_projector = nn.Identity()

        # Predictor - its input must match the actual_projector_output_dim
        # The output of the predictor should also match actual_projector_output_dim for BYOL loss
        self.predictor = nn.Sequential(
            nn.Linear(actual_projector_output_dim, pred_hidden_dim_cfg),
            nn.BatchNorm1d(pred_hidden_dim_cfg),
            nn.ReLU(),
            nn.Linear(pred_hidden_dim_cfg, actual_projector_output_dim),
        )

    @staticmethod
    def add_and_assert_specific_cfg(cfg: omegaconf.DictConfig) -> omegaconf.DictConfig:
        cfg = super(RadialBYOL, RadialBYOL).add_and_assert_specific_cfg(cfg)

        assert not omegaconf.OmegaConf.is_missing(cfg, "method_kwargs.proj_hidden_dim")
        assert not omegaconf.OmegaConf.is_missing(cfg, "method_kwargs.proj_output_dim")
        assert not omegaconf.OmegaConf.is_missing(cfg, "method_kwargs.pred_hidden_dim")

        cfg.method_kwargs.projector_type = omegaconf_select(
            cfg, "method_kwargs.projector_type", "mlp"
        )
        assert cfg.method_kwargs.projector_type in ["mlp", "identity"]

        # Lambda hyperparameters and optimization flags
        cfg.method_kwargs.radial_lambda = omegaconf_select(
            cfg, "method_kwargs.radial_lambda", 0.0
        )
        cfg.method_kwargs.optimize_radial = omegaconf_select(
            cfg, "method_kwargs.optimize_radial", False
        )
        cfg.method_kwargs.variance_lambda = omegaconf_select(
            cfg, "method_kwargs.variance_lambda", 0.0
        )
        cfg.method_kwargs.optimize_variance = omegaconf_select(
            cfg, "method_kwargs.optimize_variance", False
        )
        cfg.method_kwargs.covariance_lambda = omegaconf_select(
            cfg, "method_kwargs.covariance_lambda", 0.0
        )
        cfg.method_kwargs.optimize_covariance = omegaconf_select(
            cfg, "method_kwargs.optimize_covariance", False
        )
        # Strategy: 'standard' for lambda weights on aux losses, 'self_tune' for scaling by sim_loss
        cfg.method_kwargs.lambda_strategy = omegaconf_select(
            cfg, "method_kwargs.lambda_strategy", "standard"
        )
        assert cfg.method_kwargs.lambda_strategy in ["standard", "self_tune"]
        return cfg

    @property
    def learnable_params(self) -> List[dict]:
        """Adds projector and predictor parameters to the parent's learnable parameters.

        Returns:
            List[dict]: list of learnable parameters.
        """
        extra_learnable_params = [
            {"name": "predictor", "params": self.predictor.parameters()},
        ]
        if self.projector_type == "mlp":
            extra_learnable_params.append({"name": "projector", "params": self.projector.parameters()})
        return super().learnable_params + extra_learnable_params

    @property
    def momentum_pairs(self) -> List[Tuple[Any, Any]]:
        """Adds (projector, momentum_projector) to the parent's momentum pairs.

        Returns:
            List[Tuple[Any, Any]]: list of momentum pairs.
        """
        extra_momentum_pairs = []
        if self.projector_type == "mlp":
             extra_momentum_pairs.append((self.projector, self.momentum_projector))
        return super().momentum_pairs + extra_momentum_pairs

    def forward(self, X: torch.Tensor, *args, **kwargs) -> Dict[str, Any]:
        """Performs forward pass of the online backbone, projector and predictor.

        Args:
            X (torch.Tensor): batch of images in tensor format.

        Returns:
            Dict[str, Any]: a dict containing the outputs of the parent and the projected features.
        """
        
        out = super().forward(X, *args, **kwargs)
        # Log encoder feature norms (before projector)
        if self.training and "feats" in out: # Check if feats exists, it should.
            backbone_feats = out["feats"]
            
            # use torch.no_grad() for a cleaner computational graph
            # Ensure backbone_feats is a tensor, not a list (super().forward might return list for multicrop)
            # However, for single tensor input X, feats should be a tensor.
            with torch.no_grad():
                if isinstance(backbone_feats, torch.Tensor):
                    backbone_feature_norm_mean = torch.norm(backbone_feats, dim=1).mean()
                    backbone_feature_norm_var = torch.norm(backbone_feats, dim=1).var()
                    backbone_feature_unif_loss = uniform_loss(backbone_feats)
                    self.log("encoder_norm/train_encoder_feature_norm_mean", backbone_feature_norm_mean, on_epoch=True, sync_dist=True)
                    self.log("encoder_norm/train_encoder_feature_norm_var", backbone_feature_norm_var, on_epoch=True, sync_dist=True)
                    self.log("unif_loss/train_encoder_unif_loss", backbone_feature_unif_loss, on_epoch=True, sync_dist=True)
                    backbone_feature_anisotropy_loss = anisotropy_loss(backbone_feats)
                    self.log("anisotropy_loss/train_encoder_anisotropy_loss", backbone_feature_anisotropy_loss, on_epoch=True, sync_dist=True)

                    # Encoder sparsity metrics
                    enc_batch_sparse_max, enc_batch_sparse_mean, enc_batch_sparse_min = batch_sparsity_metric(backbone_feats)
                    self.log("batch_sparsity_metric/encoder_batch_sparse_max", enc_batch_sparse_max, on_epoch=True, sync_dist=True)
                    self.log("batch_sparsity_metric/encoder_batch_sparse_mean", enc_batch_sparse_mean, on_epoch=True, sync_dist=True)
                    self.log("batch_sparsity_metric/encoder_batch_sparse_min", enc_batch_sparse_min, on_epoch=True, sync_dist=True)

                    enc_embed_sparse_max, enc_embed_sparse_mean, enc_embed_sparse_min = embedding_sparsity_metric(backbone_feats)
                    self.log("embedding_sparsity_metric/encoder_embed_sparse_max", enc_embed_sparse_max, on_epoch=True, sync_dist=True)
                    self.log("embedding_sparsity_metric/encoder_embed_sparse_mean", enc_embed_sparse_mean, on_epoch=True, sync_dist=True)
                    self.log("embedding_sparsity_metric/encoder_embed_sparse_min", enc_embed_sparse_min, on_epoch=True, sync_dist=True)

                elif isinstance(backbone_feats, list) and len(backbone_feats) > 0 and isinstance(backbone_feats[0], torch.Tensor):

                    raise NotImplementedError("Multicrop features are not implemented for radialbyol uniform loss - implement this in methods/radialbyol.py.")

        z = self.projector(out["feats"])
        p = self.predictor(z)
        out.update({"z": z, "p": p})
        
        # If the projector classifier exists (initialized in BaseMethod), call it.
        if self.projector_classifier is not None:
            # Pass z.detach() to ensure gradients for this classifier only update itself,
            # not the projector or backbone through this path.
            projector_logits = self.projector_classifier(z.detach())
            out.update({"projector_logits": projector_logits})
            
        return out # we are computing this for online only here

    def validation_step(self, batch: Sequence[Any], batch_idx: int):
        """Validation step for RadialBYOL."""
        X, targets = batch
        out = self.base_validation_step(X, targets)

        if not self.trainer.sanity_checking:
            feats = out["feats"]
            z = self.projector(feats)

            # Log metrics for encoder features
            gathered_feats = gather(feats)
            val_radial_loss_encoder = chi2_radial_nll_loss_for_lightning_logging(gathered_feats)
            self.log("encoder/val_backbone_radial_loss", val_radial_loss_encoder, on_epoch=True, sync_dist=True)

            # Log metrics for projector features
            gathered_z = gather(z)
            val_radial_loss_projector = chi2_radial_nll_loss_for_lightning_logging(gathered_z)
            self.log("projector/val_radial_loss", val_radial_loss_projector, on_epoch=True, sync_dist=True)

            # Log norms for projector features
            z_norm_mean = torch.norm(z, dim=1).mean()
            z_norm_var = torch.norm(z, dim=1).var()
            self.log("projector_norm/val_z_norm_mean", z_norm_mean, on_epoch=True, sync_dist=True)
            self.log("projector_norm/val_z_norm_var", z_norm_var, on_epoch=True, sync_dist=True)

        # Replicate the logic from BaseMethod's validation_step to prepare metrics for logging
        batch_size = targets.size(0)
        metrics = {
            "batch_size": batch_size,
            "val_loss": out.get("loss"),
            "val_acc1": out.get("acc1"),
            "val_acc5": out.get("acc5"),
        }
        if "proj_loss" in out:
            metrics.update({
                "val_proj_loss": out.get("proj_loss"),
                "val_proj_acc1": out.get("proj_acc1"),
                "val_proj_acc5": out.get("proj_acc5"),
            })

        self.validation_step_outputs.append(metrics)
        return metrics

    def multicrop_forward(self, X: torch.Tensor, *args, **kwargs) -> Dict[str, Any]:
        out = super().multicrop_forward(X, *args, **kwargs)
        # out["feats"] will be a list of feature tensors for each crop
        z = [self.projector(f) for f in out["feats"]]
        p = [self.predictor(cz) for cz in z]
        out.update({"z": z, "p": p})
        return out

    @torch.no_grad()
    def momentum_forward(self, X: torch.Tensor, *args, **kwargs) -> Dict:
        out = super().momentum_forward(X, *args, **kwargs)
        # out["feats"] can be a list if X was a list (multicrop)
        if isinstance(out["feats"], list):
            z = [self.momentum_projector(f) for f in out["feats"]]
        else:
            z = self.momentum_projector(out["feats"])
        out.update({"z": z}) # momentum_z is used in training_step
        return out

    def training_step(self, batch: Sequence[Any], batch_idx: int) -> torch.Tensor:
        out = super().training_step(batch, batch_idx) #This calls forward, multicrop_forward, momentum_forward
        class_loss = out["loss"] # Online linear classifier loss

        # Z_online, P_online, Z_momentum are lists of tensors, one for each crop/view
        # from the multicrop_forward and momentum_forward methods.
        # We assume at least two large crops are present for the main BYOL logic and auxiliary losses.
        Z_online = out["z"]       # Output of online projector for all crops
        P_online = out["p"]       # Output of online predictor for all crops
        Z_momentum = out["momentum_z"] # Output of momentum projector for all crops

        # Directly select the first two views for BYOL similarity and auxiliary losses.
        # This assumes num_large_crops >= 2 and they are the first elements.
        # If Z_online, P_online, or Z_momentum are not lists (e.g. only 1 crop total),
        # this would error. Standard BYOL / solo-learn setup ensures these are lists with >=2 elements
        # when num_large_crops is 2.
        try:
            z1_online, z2_online = Z_online[0], Z_online[1]
            p1, p2 = P_online[0], P_online[1]
            z1_momentum, z2_momentum = Z_momentum[0], Z_momentum[1]
        except IndexError:
            # This case should ideally not happen with correct multicrop dataloading for BYOL (num_large_crops=2)
            # If it does, it means fewer than 2 views were processed and put into Z_online/P_online/Z_momentum.
            # Log an error and return a zero loss or raise to halt.
            if self.trainer and self.trainer.is_global_zero:
                print("Error: RadialBYOL training_step received fewer than 2 views in Z_online, P_online, or Z_momentum. Check data pipeline and num_large_crops.")
            # Depending on desired strictness, could raise an error or return a dummy loss.
            # For now, returning class_loss to avoid crashing but this indicates a setup problem.
            return class_loss

        # --- Calculate all loss components ---
        # radial_byol_loss_suite calculates sim_loss between (p1, z2_momentum) and (p2, z1_momentum)
        # and aux losses on (z1_online, z2_online).
        sim_loss, var_loss, cov_loss, radial_loss = radial_byol_loss_suite(
            p1, p2, 
            z1_online, z2_online, 
            z1_momentum, z2_momentum
        )
        
        # --- Combine losses for optimization ---
        total_loss = sim_loss + class_loss  # Start with BYOL similarity and classification loss

        # Detach sim_loss for scaling factor in 'self_tune' mode
        sim_loss_detached = sim_loss.detach()
        # Auxiliary objectives: direction from aux_loss, magnitude per strategy
        if self.optimize_variance:
            if self.lambda_strategy == "standard":
                total_loss += self.variance_lambda * var_loss
            else:  # self_tune
                w_var = (self.variance_lambda * sim_loss_detached) / (var_loss.detach() + 1e-6)
                total_loss += w_var * var_loss
        if self.optimize_covariance:
            if self.lambda_strategy == "standard":
                total_loss += self.covariance_lambda * cov_loss
            else:  # self_tune
                w_cov = (self.covariance_lambda * sim_loss_detached) / (cov_loss.detach() + 1e-6)
                total_loss += w_cov * cov_loss
        if self.optimize_radial:
            if self.lambda_strategy == "standard":
                total_loss += self.radial_lambda * radial_loss
            else: # self_tune
                w_rad = (self.radial_lambda * sim_loss_detached) / (radial_loss.detach() + 1e-6)
                total_loss += w_rad * radial_loss

        # Encoder features (logged in forward, this is for projector features)
        # use torch.no_grad() for a cleaner computational graph
        with torch.no_grad():
            # Projector output norms (z1_online, z2_online)
            z1_norm_mean = torch.norm(z1_online, dim=1).mean()
            z2_norm_mean = torch.norm(z2_online, dim=1).mean()
            z1_norm_var = torch.norm(z1_online, dim=1).var()
            z2_norm_var = torch.norm(z2_online, dim=1).var()
            self.log("projector_norm/train_z1_online_norm_mean", z1_norm_mean, on_epoch=True, sync_dist=True)
            self.log("projector_norm/train_z2_online_norm_mean", z2_norm_mean, on_epoch=True, sync_dist=True)
            self.log("projector_norm/train_z1_online_norm_var", z1_norm_var, on_epoch=True, sync_dist=True)
            self.log("projector_norm/train_z2_online_norm_var", z2_norm_var, on_epoch=True, sync_dist=True)

            # Projector output uniform loss
            z1_unif_loss = uniform_loss(z1_online)
            z2_unif_loss = uniform_loss(z2_online)
            self.log("unif_loss/train_z1_online_unif_loss_projector", z1_unif_loss, on_epoch=True, sync_dist=True)
            self.log("unif_loss/train_z2_online_unif_loss_projector", z2_unif_loss, on_epoch=True, sync_dist=True)

            # Projector output anisotropy loss
            z1_anisotropy_loss = anisotropy_loss(z1_online)
            z2_anisotropy_loss = anisotropy_loss(z2_online)
            self.log("anisotropy_loss/train_z1_anisotropy_loss_projector", z1_anisotropy_loss, on_epoch=True, sync_dist=True)
            self.log("anisotropy_loss/train_z2_anisotropy_loss_projector", z2_anisotropy_loss, on_epoch=True, sync_dist=True)

            # Projector output sparsity metrics
            z1_batch_sparse_max, z1_batch_sparse_mean, z1_batch_sparse_min = batch_sparsity_metric(z1_online)
            self.log("batch_sparsity_metric/z1_online_projector_batch_sparse_max", z1_batch_sparse_max, on_epoch=True, sync_dist=True)
            self.log("batch_sparsity_metric/z1_online_projector_batch_sparse_mean", z1_batch_sparse_mean, on_epoch=True, sync_dist=True)
            self.log("batch_sparsity_metric/z1_online_projector_batch_sparse_min", z1_batch_sparse_min, on_epoch=True, sync_dist=True)

            z2_batch_sparse_max, z2_batch_sparse_mean, z2_batch_sparse_min = batch_sparsity_metric(z2_online)
            self.log("batch_sparsity_metric/z2_online_projector_batch_sparse_max", z2_batch_sparse_max, on_epoch=True, sync_dist=True)
            self.log("batch_sparsity_metric/z2_online_projector_batch_sparse_mean", z2_batch_sparse_mean, on_epoch=True, sync_dist=True)
            self.log("batch_sparsity_metric/z2_online_projector_batch_sparse_min", z2_batch_sparse_min, on_epoch=True, sync_dist=True)

            z1_embed_sparse_max, z1_embed_sparse_mean, z1_embed_sparse_min = embedding_sparsity_metric(z1_online)
            self.log("embedding_sparsity_metric/z1_online_projector_embed_sparse_max", z1_embed_sparse_max, on_epoch=True, sync_dist=True)
            self.log("embedding_sparsity_metric/z1_online_projector_embed_sparse_mean", z1_embed_sparse_mean, on_epoch=True, sync_dist=True)
            self.log("embedding_sparsity_metric/z1_online_projector_embed_sparse_min", z1_embed_sparse_min, on_epoch=True, sync_dist=True)

            z2_embed_sparse_max, z2_embed_sparse_mean, z2_embed_sparse_min = embedding_sparsity_metric(z2_online)
            self.log("embedding_sparsity_metric/z2_online_projector_embed_sparse_max", z2_embed_sparse_max, on_epoch=True, sync_dist=True)
            self.log("embedding_sparsity_metric/z2_online_projector_embed_sparse_mean", z2_embed_sparse_mean, on_epoch=True, sync_dist=True)
            self.log("embedding_sparsity_metric/z2_online_projector_embed_sparse_min", z2_embed_sparse_min, on_epoch=True, sync_dist=True)

        # Individual and total losses
        self.log("train_sim_loss", sim_loss, on_epoch=True, sync_dist=True)
        self.log("train_variance_loss", var_loss / z1_online.size(1), on_epoch=True, sync_dist=True)
        self.log("train_covariance_loss", cov_loss, on_epoch=True, sync_dist=True)
        self.log("train_radial_loss", radial_loss, on_epoch=True, sync_dist=True)
        if class_loss != 0:
             self.log("train_class_loss", class_loss, on_epoch=True, sync_dist=True)
        self.log("train_total_loss", total_loss, on_epoch=True, sync_dist=True)
        
        # Original BYOL also logs std of normalized online projections
        # z_std = F.normalize(torch.stack(Z_online[:self.num_large_crops]), dim=-1).std(dim=1).mean()
        # self.log("radial_gaussianization/train_online_z_std", z_std, on_epoch=True, sync_dist=True)
        if len(Z_online) >= self.num_large_crops and self.num_large_crops > 0:
            # use torch.no_grad() for a cleaner computational graph
            with torch.no_grad():
                stacked_z_online_large_crops = torch.stack(Z_online[:self.num_large_crops])
                z_std_online = F.normalize(stacked_z_online_large_crops, dim=-1).std(dim=1).mean()
                self.log("projector_norm/train_online_z_std", z_std_online, on_epoch=True, sync_dist=True)

        # Projector Classifier Training
        projector_class_loss = torch.tensor(0.0, device=self.device)
        if self.projector_classifier is not None:
            # Get targets for classification loss
            _, _, targets = batch
            # _projector_classifier_step is a helper in BaseMethod
            # It expects a single view's z and targets
            proj_metrics1 = self._projector_classifier_step(z1_online, targets)
            proj_metrics2 = self._projector_classifier_step(z2_online, targets)

            if proj_metrics1 and proj_metrics2: # Check if metrics were computed
                projector_class_loss = (proj_metrics1["proj_loss"] + proj_metrics2["proj_loss"]) / 2
                proj_acc1 = (proj_metrics1["proj_acc1"] + proj_metrics2["proj_acc1"]) / 2
                proj_acc5 = (proj_metrics1["proj_acc5"] + proj_metrics2["proj_acc5"]) / 2

                self.log("train_proj_loss", projector_class_loss, on_epoch=True, sync_dist=True)
                self.log("train_proj_acc1", proj_acc1, on_epoch=True, sync_dist=True)
                self.log("train_proj_acc5", proj_acc5, on_epoch=True, sync_dist=True)

        # Update total loss to include projector classifier loss
        total_loss += projector_class_loss

        return total_loss