# Copyright 2023 solo-learn development team.

# Permission is hereby granted, free of charge, to any person obtaining a copy of
# this software and associated documentation files (the "Software"), to deal in
# the Software without restriction, including without limitation the rights to use,
# copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the
# Software, and to permit persons to whom the Software is furnished to do so,
# subject to the following conditions:

# The above copyright notice and this permission notice shall be included in all copies
# or substantial portions of the Software.

# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED,
# INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR
# PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE
# FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR
# OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
# DEALINGS IN THE SOFTWARE.

import logging
from typing import Any, Callable, Dict, List, Tuple, Union

import omegaconf
import pytorch_lightning as pl
import torch
import torch.nn as nn
import torch.nn.functional as F
from solo.utils.lars import LARS
from solo.utils.lr_scheduler import LinearWarmupCosineAnnealingLR
from solo.utils.metrics import accuracy_at_k, weighted_mean
from solo.utils.misc import (
    omegaconf_select,
    param_groups_layer_decay,
    remove_bias_and_norm_from_weight_decay,
)
from torch.optim.lr_scheduler import ExponentialLR, MultiStepLR, ReduceLROnPlateau
import numpy as np
import matplotlib.pyplot as plt
import pdb

def nesum(tensor):
    z1  = tensor.to(torch.float32) # NxD
    z1_mean = torch.mean(z1, dim=0)
    z1_std = torch.std(z1, dim=0)
    z1_normalized = (z1 - z1_mean) / z1_std
          
  
    cov_z1 = torch.corrcoef(z1.T)
    #pdb.set_trace()
    try:
        eigenvalues = torch.linalg.eigvalsh(cov_z1.to(torch.float32))
    except:
        #torch.linalg.eigvalsh((cov_z1 + torch.eye(cov_z1.size(0)).cuda()).to(torch.float32))
        pdb.set_trace()
    #U, eigenvalues, V = torch.linalg.svd(tensor, full_matrices=False)

            # 计算特征值最大值
    eigenvalue_max = torch.max(eigenvalues)

    eigenvalues = eigenvalues/eigenvalue_max

    return eigenvalues


def obtain_wight_for_con2d(conv2d,name):
    
    weights = conv2d.weight
    weights = weights.reshape(weights.shape[0],-1)
    plot_log_singular_values(weights.T,save_path='/extra_room/junlin/SSL/solo-learn/{}_singular.jpg'.format(name),name=name)

def plot_log_singular_values(tensor, save_path,name,idx=512):
    """
    绘制给定 tensor 的奇异值的对数图，并保存为 JPG 文件。

    参数:
        tensor (torch.Tensor): 输入的 NxD tensor。
        save_path (str): 图片保存的路径，包括文件名和扩展名。
    """
    tensor= tensor.reshape(tensor.shape[0],-1).double() #input x output

   
    eigenvalues = nesum(tensor)
    try:
        eigenvalues= eigenvalues.cpu().tolist()
    except:
        eigenvalues = eigenvalues.tolist()
    #idx = min(idx,len(eigenvalues))
    idx = len(eigenvalues)
    # eigenvalues = np.array(eigenvalues)
    # eigenvalues = list(np.clip(eigenvalues,a_min=0.001,a_max=1))
    
    # 绘图
    plt.figure()
    x = [i+1 for i in range(idx)]
    plt.plot(x,eigenvalues[::-1][:idx], marker='.',label=name,markersize=0,color='orange')  # 确保数据在 CPU 上，用于绘图

    line = [1/(i+1) for i in range(idx)]

    # line = np.array(line)
    # line = list(np.clip(line,a_min=0.001,a_max=1))

    plt.plot(x, line, marker='.',label='y=1/(x+1)',markersize=0,color = 'black') 
    
    plt.yscale('log')  # 设置纵轴为对数刻度
    plt.xscale('log')

    plt.yticks([1,0.1,0.01,0.001])
    
    plt.ylim(bottom=0.001,top=1.0)
    plt.xlim(1,idx)

  # 计算 1 到 idx 范围内的 10 的整数次方刻度
    log_ticks = [10**i for i in range(int(np.log10(idx)) + 1)]

    # 确保 idx 在刻度列表中
    if idx not in log_ticks:
        log_ticks.append(idx)

# 设置 x 轴的刻度和刻度标签
    plt.xticks(ticks=log_ticks, labels=[str(tick) for tick in log_ticks])
    # 设置 x 轴的刻度位置和标签
    #plt.xticks(ticks=locs, labels=labels)
    

    plt.xlabel('Eigenvalue rank')
    plt.ylabel('Normalised eigenvalue')
    plt.legend()
    # 保存图像
    plt.savefig(save_path, format='jpg')
    plt.close()  # 关闭绘图窗口，避免内存泄漏

def plot_log_group_values(group_tensor, save_path,group_name,conv=False):
    """
    绘制给定 tensor 的奇异值的对数图，并保存为 JPG 文件。

    参数:
        tensor (torch.Tensor): 输入的 NxD tensor。
        save_path (str): 图片保存的路径，包括文件名和扩展名。
    """
    colors = ['red','blue','green','purple','orange','brown','gold','grey']
    marker = ['o','s','^','v','<','>','x']
    nesums = []
    idx = 0
    plt.figure(figsize=(20, 10))
    for k in range(len(group_tensor)):
        tensor = group_tensor[k]
        if conv:
            tensor = tensor.reshape(tensor.shape[0],-1).T
        else:
        
            tensor= tensor.reshape(tensor.shape[0],-1) #input x output

   
        eigenvalues = nesum(tensor)
        try:
            eigenvalues= eigenvalues.cpu().tolist()
        except:
            eigenvalues = eigenvalues.tolist()
    #idx = min(idx,len(eigenvalues))
        
        t_idx = len(eigenvalues)
        idx= max(t_idx,idx)

    # 绘图
    
        x = [i+1 for i in range(t_idx)]
        if min(eigenvalues)>0:
            plt.plot(x,eigenvalues[::-1], marker='.',label='{}_{}'.format(group_name[k],len(eigenvalues)),markersize=0,color=colors[k],linewidth =2.0)  # 确保数据在 CPU 上，用于绘图
        else:
            plt.plot(x,eigenvalues[::-1], marker='.',label='{}_{}'.format(group_name[k],len(eigenvalues)),markersize=0,color=colors[k],linewidth =2.0,linestyle='dashed') 

    line = [1/(i+1) for i in range(idx)]
    x = [i+1 for i in range(idx)]

    # line = np.array(line)
    # line = list(np.clip(line,a_min=0.001,a_max=1))

    plt.plot(x, line, marker='.',label='y=1/(x+1)',markersize=0,color = 'black',linestyle='--',linewidth =2.0) 
    
    plt.yscale('log')  # 设置纵轴为对数刻度
    plt.xscale('log')

    plt.yticks([1,0.1,0.01,0.001],[1,0.1,0.01,0.001])
    
    plt.ylim(bottom=0.001,top=1.0)
    

    if idx>1000:
        plt.xticks([1,10,100,1000,idx],[1,10,100,1000,idx])

        plt.xlim(left=1,right=idx)
    elif idx>100:
        plt.xticks([1,10,100,idx],[1,10,100,idx])
    
        plt.xlim(left=1,right=idx)
    elif idx>10:
        plt.xticks([1,10,idx],[1,10,idx])
    
        plt.xlim(left=1,right=idx)


    

    plt.xlabel('Eigenvalue rank')
    plt.ylabel('Normalised eigenvalue')
    plt.legend()
    # 保存图像
    plt.grid(linestyle= ':',linewidth=1)
    plt.savefig(save_path, format='jpg')
    plt.close()  # 关闭绘图窗口，避免内存泄漏

class LinearModel(pl.LightningModule):
    _OPTIMIZERS = {
        "sgd": torch.optim.SGD,
        "lars": LARS,
        "adam": torch.optim.Adam,
        "adamw": torch.optim.AdamW,
    }
    _SCHEDULERS = [
        "reduce",
        "warmup_cosine",
        "step",
        "exponential",
        "none",
    ]

    def __init__(
        self,
        backbone: nn.Module,
        cfg: omegaconf.DictConfig,
        loss_func: Callable = None,
        mixup_func: Callable = None,
    ):
        """Implements linear and finetune evaluation.

        .. note:: Cfg defaults are set in init by calling `cfg = add_and_assert_specific_cfg(cfg)`

        backbone (nn.Module): backbone architecture for feature extraction.
        Cfg basic structure:
            data:
                num_classes (int): number of classes in the dataset.
            max_epochs (int): total number of epochs.

            optimizer:
                name (str): name of the optimizer.
                batch_size (int): number of samples in the batch.
                lr (float): learning rate.
                weight_decay (float): weight decay for optimizer.
                kwargs (Dict): extra named arguments for the optimizer.
            scheduler:
                name (str): name of the scheduler.
                min_lr (float): minimum learning rate for warmup scheduler. Defaults to 0.0.
                warmup_start_lr (float): initial learning rate for warmup scheduler.
                    Defaults to 0.00003.
                warmup_epochs (float): number of warmup epochs. Defaults to 10.
                lr_decay_steps (Sequence, optional): steps to decay the learning rate
                    if scheduler is step. Defaults to None.
                interval (str): interval to update the lr scheduler. Defaults to 'step'.

            finetune (bool): whether or not to finetune the backbone. Defaults to False.

            performance:
                disable_channel_last (bool). Disables channel last conversion operation which
                speeds up training considerably. Defaults to False.
                https://pytorch.org/tutorials/intermediate/memory_format_tutorial.html#converting-existing-models

        loss_func (Callable): loss function to use (for mixup, label smoothing or default).
        Defaults to None mixup_func (Callable, optional). function to convert data and targets
        with mixup/cutmix. Defaults to None.
        """

        super().__init__()

        # add default values and assert that config has the basic needed settings
        cfg = self.add_and_assert_specific_cfg(cfg)

        # backbone
        self.backbone = backbone
        if hasattr(self.backbone, "inplanes"):
            features_dim = self.backbone.inplanes
        else:
            features_dim = self.backbone.num_features

        # classifier
        self.classifier = nn.Linear(features_dim, cfg.data.num_classes)  # type: ignore

        # mixup/cutmix function
        self.mixup_func: Callable = mixup_func

        if loss_func is None:
            loss_func = nn.CrossEntropyLoss()
        self.loss_func = loss_func

        # training related
        self.max_epochs: int = cfg.max_epochs
        self.accumulate_grad_batches: Union[int, None] = cfg.accumulate_grad_batches

        # optimizer related
        self.optimizer: str = cfg.optimizer.name
        self.batch_size: int = cfg.optimizer.batch_size
        self.lr: float = cfg.optimizer.lr
        self.weight_decay: float = cfg.optimizer.weight_decay
        self.extra_optimizer_args: Dict[str, Any] = cfg.optimizer.kwargs
        self.exclude_bias_n_norm_wd: bool = cfg.optimizer.exclude_bias_n_norm_wd
        self.layer_decay: float = cfg.optimizer.layer_decay

        # scheduler related
        self.scheduler: str = cfg.scheduler.name
        self.lr_decay_steps: Union[List[int], None] = cfg.scheduler.lr_decay_steps
        self.min_lr: float = cfg.scheduler.min_lr
        self.warmup_start_lr: float = cfg.scheduler.warmup_start_lr
        self.warmup_epochs: int = cfg.scheduler.warmup_epochs
        self.scheduler_interval: str = cfg.scheduler.interval
        assert self.scheduler_interval in ["step", "epoch"]
        if self.scheduler_interval == "step":
            logging.warn(
                f"Using scheduler_interval={self.scheduler_interval} might generate "
                "issues when resuming a checkpoint."
            )

        # if finetuning the backbone
        self.finetune: bool = cfg.finetune

        # for performance
        self.no_channel_last = cfg.performance.disable_channel_last

        if not self.finetune:
            for param in self.backbone.parameters():
                param.requires_grad = False

        # keep track of validation metrics
        self.validation_step_outputs = []

    @staticmethod
    def add_and_assert_specific_cfg(cfg: omegaconf.DictConfig) -> omegaconf.DictConfig:
        """Adds method specific default values/checks for config.

        Args:
            cfg (omegaconf.DictConfig): DictConfig object.

        Returns:
            omegaconf.DictConfig: same as the argument, used to avoid errors.
        """

        # default parameters for optimizer
        cfg.optimizer.exclude_bias_n_norm_wd = omegaconf_select(
            cfg, "optimizer.exclude_bias_n_norm_wd", False
        )
        # default for extra optimizer kwargs (use pytorch's default if not available)
        cfg.optimizer.kwargs = omegaconf_select(cfg, "optimizer.kwargs", {})
        cfg.optimizer.layer_decay = omegaconf_select(cfg, "optimizer.layer_decay", 0.0)

        # whether or not to finetune the backbone
        cfg.finetune = omegaconf_select(cfg, "finetune", False)

        # default for acc grad batches
        cfg.accumulate_grad_batches = omegaconf_select(cfg, "accumulate_grad_batches", 1)

        # default parameters for the scheduler
        cfg.scheduler.lr_decay_steps = omegaconf_select(cfg, "scheduler.lr_decay_steps", None)
        cfg.scheduler.min_lr = omegaconf_select(cfg, "scheduler.min_lr", 0.0)
        cfg.scheduler.warmup_start_lr = omegaconf_select(cfg, "scheduler.warmup_start_lr", 3e-5)
        cfg.scheduler.warmup_epochs = omegaconf_select(cfg, "scheduler.warmup_epochs", 10)
        cfg.scheduler.interval = omegaconf_select(cfg, "scheduler.interval", "step")

        # default parameters for performance optimization
        cfg.performance = omegaconf_select(cfg, "performance", {})
        cfg.performance.disable_channel_last = omegaconf_select(
            cfg, "performance.disable_channel_last", False
        )

        return cfg

    def configure_optimizers(self) -> Tuple[List, List]:
        """Collects learnable parameters and configures the optimizer and learning rate scheduler.

        Returns:
            Tuple[List, List]: two lists containing the optimizer and the scheduler.
        """

        if self.layer_decay > 0:
            assert self.finetune, "Only with use layer weight decay with finetune on."
            msg = (
                "Method should implement no_weight_decay() that returns "
                "a set of parameter names to ignore from weight decay"
            )
            assert hasattr(self.backbone, "no_weight_decay"), msg

            learnable_params = param_groups_layer_decay(
                self.backbone,
                self.weight_decay,
                no_weight_decay_list=self.backbone.no_weight_decay(),
                layer_decay=self.layer_decay,
            )
            learnable_params.append({"name": "classifier", "params": self.classifier.parameters()})
        else:
            learnable_params = (
                self.classifier.parameters()
                if not self.finetune
                else [
                    {"name": "backbone", "params": self.backbone.parameters()},
                    {"name": "classifier", "params": self.classifier.parameters()},
                ]
            )

        # exclude bias and norm from weight decay
        if self.exclude_bias_n_norm_wd:
            learnable_params = remove_bias_and_norm_from_weight_decay(learnable_params)

        assert self.optimizer in self._OPTIMIZERS
        optimizer = self._OPTIMIZERS[self.optimizer]

        optimizer = optimizer(
            learnable_params,
            lr=self.lr,
            weight_decay=self.weight_decay,
            **self.extra_optimizer_args,
        )

        # select scheduler
        if self.scheduler == "none":
            return optimizer

        if self.scheduler == "warmup_cosine":
            max_warmup_steps = (
                self.warmup_epochs * (self.trainer.estimated_stepping_batches / self.max_epochs)
                if self.scheduler_interval == "step"
                else self.warmup_epochs
            )
            max_scheduler_steps = (
                self.trainer.estimated_stepping_batches
                if self.scheduler_interval == "step"
                else self.max_epochs
            )
            scheduler = {
                "scheduler": LinearWarmupCosineAnnealingLR(
                    optimizer,
                    warmup_epochs=max_warmup_steps,
                    max_epochs=max_scheduler_steps,
                    warmup_start_lr=self.warmup_start_lr if self.warmup_epochs > 0 else self.lr,
                    eta_min=self.min_lr,
                ),
                "interval": self.scheduler_interval,
                "frequency": 1,
            }
        elif self.scheduler == "reduce":
            scheduler = ReduceLROnPlateau(optimizer)
        elif self.scheduler == "step":
            scheduler = MultiStepLR(optimizer, self.lr_decay_steps, gamma=0.1)
        elif self.scheduler == "exponential":
            scheduler = ExponentialLR(optimizer, self.weight_decay)
        else:
            raise ValueError(
                f"{self.scheduler} not in (warmup_cosine, cosine, reduce, step, exponential)"
            )

        return [optimizer], [scheduler]

    def forward(self, X: torch.tensor) -> Dict[str, Any]:
        import pdb
        """Performs forward pass of the frozen backbone and the linear layer for evaluation.

        Args:
            X (torch.tensor): a batch of images in the tensor format.

        Returns:
            Dict[str, Any]: a dict containing features and logits.
        """

        if not self.no_channel_last:
            X = X.to(memory_format=torch.channels_last)
        

        with torch.set_grad_enabled(self.finetune):
            feats = self.backbone(X)
        # obtain_wight_for_con2d(self.backbone.conv1,name='ResNet50_Conv1')
        group_weights = [self.backbone.layer1[1].conv2.weight,self.backbone.layer2[1].conv2.weight,self.backbone.layer3[1].conv2.weight,self.backbone.layer4[1].conv2.weight]
        group_names = ['layer1','layer2','layer3','layer4']
        
        plot_log_group_values(group_tensor=group_weights,save_path='./group_weights_byol(none)_resnet18.png',group_name=group_names,conv=True)
        #pdb.set_trace()

        group_weights = []
        group_names = []
        
        group_names.append('cifar10_image')
        group_weights.append(X)
        
        name = 'vicreg_feats'
        group_names.append('representations')
        group_weights.append(feats)

        plot_log_group_values(group_tensor=group_weights,save_path='./group_features_byol(none)_resnet18.png',group_name=group_names,conv=False)
        #exit()
        with torch.set_grad_enabled(self.finetune):
            base_feat = self.backbone.conv1(X)
           
            
            base_feat = self.backbone.maxpool(self.backbone.relu(self.backbone.bn1(base_feat)))
          
            hidden = self.backbone.layer1[1].conv1(self.backbone.layer1[0](base_feat))
            base_feat = self.backbone.layer1(base_feat)
            # try:
            #     plot_log_group_values(group_tensor=[hidden],save_path='./group_features_vicreg_resnet18.png',group_name=['vicreg_hidden_feats(layer0)'],conv=False)
            # except:
            #     print('fail')
            #     pass
            # name = 'vicreg_hidden_feats(layer1)'
            # #plot_log_singular_values(base_feat.cpu(),save_path='/extra_room/junlin/SSL/solo-learn/{}_singular.jpg'.format(name),name=name)
            # hidden = self.backbone.layer2[1].conv1(self.backbone.layer2[0](base_feat))
            base_feat = self.backbone.layer2(base_feat)
            # try:
            #     plot_log_group_values(group_tensor=[hidden],save_path='./group_features_vicreg_resnet18.png',group_name=['vicreg_hidden_feats(layer0)'],conv=False)
            # except:
            #     print('fail')
            #     pass
            # name = 'vicreg_hidden_feats(layer2)'
          
            #plot_log_singular_values(base_feat.cpu(),save_path='/extra_room/junlin/SSL/solo-learn/{}_singular.jpg'.format(name),name=name)
            # hidden = self.backbone.layer3[1].conv1(self.backbone.layer3[0](base_feat))
            base_feat = self.backbone.layer3(base_feat)
            # try:
            #     plot_log_group_values(group_tensor=[hidden],save_path='./group_features_vicreg_resnet18.png',group_name=['vicreg_hidden_feats(layer0)'],conv=False)
            # except:
            #     print('fail')
            #     pass
            # name = 'vicreg_hidden_feats(layer3)'
            
            # group_names.append(name)
            # group_weights.append(base_feat)
            #plot_log_singular_values(base_feat.cpu(),save_path='/extra_room/junlin/SSL/solo-learn/{}_singular.jpg'.format(name),name=name)
            hidden =  self.backbone.layer4[1].conv1(self.backbone.layer4[0](base_feat))
            #pdb.set_trace()
            #base_feat = self.backbone.layer4(base_feat)
            # try:
            #     plot_log_group_values(group_tensor=[hidden],save_path='./group_features_vicreg_resnet18.png',group_name=['vicreg_hidden_feats(layer0)'],conv=False)
            # except:
            #     #pdb.set_trace()
            #     print('fail')
            #     pass
            name = 'byol(none)_hidden_feats(after layer4.baseblock1.conv1)'
            # exit()
            group_names.append(name)
            group_weights.append(hidden)
           # print(base_feat.shape)
            plot_log_group_values(group_tensor=group_weights,save_path='./group_features_byol(none)_resnet18.png',group_name=group_names,conv=False)
            #pdb.set_trace()
            exit()
            #plot_log_singular_values(base_feat.cpu(),save_path='/extra_room/junlin/SSL/solo-learn/{}_singular.jpg'.format(name),name=name)
            # hidden_feat_1 = self.backbone.layer4[0].conv1(base_feat)
            # hidden_feat_2 = self.backbone.layer4[1].conv1(self.backbone.layer4[0](base_feat))
        # name = 'vicreg_hidden_feats(layer4[0].conv1)'
        # plot_log_singular_values(hidden_feat_1.cpu(),save_path='/extra_room/junlin/SSL/solo-learn/{}_singular.jpg'.format(name),name=name)

        # name = 'vicreg_hidden_feats(layer4[1].conv1)'
        # plot_log_singular_values(hidden_feat_2.cpu(),save_path='/extra_room/junlin/SSL/solo-learn/{}_singular.jpg'.format(name),name=name)

        import pdb
        pdb.set_trace()

        logits = self.classifier(feats)
        return {"logits": logits, "feats": feats}

    def shared_step(
        self, batch: Tuple, batch_idx: int
    ) -> Tuple[int, torch.Tensor, torch.Tensor, torch.Tensor]:
        """Performs operations that are shared between the training nd validation steps.

        Args:
            batch (Tuple): a batch of images in the tensor format.
            batch_idx (int): the index of the batch.

        Returns:
            Tuple[int, torch.Tensor, torch.Tensor, torch.Tensor]:
                batch size, loss, accuracy @1 and accuracy @5.
        """

        X, target = batch

        metrics = {"batch_size": X.size(0)}
        if self.training and self.mixup_func is not None:
            X, target = self.mixup_func(X, target)
            out = self(X)["logits"]
            loss = self.loss_func(out, target)
            metrics.update({"loss": loss})
        else:
            out = self(X)["logits"]
            loss = F.cross_entropy(out, target)
            acc1, acc5 = accuracy_at_k(out, target, top_k=(1, 5))
            metrics.update({"loss": loss, "acc1": acc1, "acc5": acc5})

        return metrics

    def training_step(self, batch: torch.Tensor, batch_idx: int) -> torch.Tensor:
        """Performs the training step for the linear eval.

        Args:
            batch (torch.Tensor): a batch of images in the tensor format.
            batch_idx (int): the index of the batch.

        Returns:
            torch.Tensor: cross-entropy loss between the predictions and the ground truth.
        """

        # set backbone to eval mode
        if not self.finetune:
            self.backbone.eval()

        out = self.shared_step(batch, batch_idx)

        log = {"train_loss": out["loss"]}
        if self.mixup_func is None:
            log.update({"train_acc1": out["acc1"], "train_acc5": out["acc5"]})

        self.log_dict(log, on_epoch=True, sync_dist=True)
        return out["loss"]

    def validation_step(self, batch: torch.Tensor, batch_idx: int) -> Dict[str, Any]:
        """Performs the validation step for the linear eval.

        Args:
            batch (torch.Tensor): a batch of images in the tensor format.
            batch_idx (int): the index of the batch.

        Returns:
            Dict[str, Any]:
                dict with the batch_size (used for averaging),
                the classification loss and accuracies.
        """

        out = self.shared_step(batch, batch_idx)

        metrics = {
            "batch_size": out["batch_size"],
            "val_loss": out["loss"],
            "val_acc1": out["acc1"],
            "val_acc5": out["acc5"],
        }
        self.validation_step_outputs.append(metrics)
        return metrics

    def on_validation_epoch_end(self):
        """Averages the losses and accuracies of all the validation batches.
        This is needed because the last batch can be smaller than the others,
        slightly skewing the metrics.
        """

        val_loss = weighted_mean(self.validation_step_outputs, "val_loss", "batch_size")
        val_acc1 = weighted_mean(self.validation_step_outputs, "val_acc1", "batch_size")
        val_acc5 = weighted_mean(self.validation_step_outputs, "val_acc5", "batch_size")
        self.validation_step_outputs.clear()

        log = {"val_loss": val_loss, "val_acc1": val_acc1, "val_acc5": val_acc5}
        self.log_dict(log, sync_dist=True)
