import lightning.pytorch as pl
import torch

class GradientNormLogger(pl.Callback):
    def on_before_optimizer_step(self, trainer, pl_module, optimizer):
        # Compute the total gradient norm for the entire network
        total_norm = 0.0
        for p in pl_module.parameters():
            if p.grad is not None:
                param_norm = p.grad.data.norm(2)
                total_norm += param_norm.item() ** 2
        total_norm = total_norm ** 0.5

        # Log the total gradient norm
        trainer.logger.log_metrics({"train/grad_norm/total": total_norm}, step=trainer.global_step)
