import torch
import wandb
from lightning.pytorch.callbacks import Callback
from torch.linalg import vector_norm

class WeightMagnitudeCallback(Callback):
    @torch.no_grad()
    def on_before_backward(self, trainer, pl_module, loss):
        """Log the L2 norm of the weights."""
        if trainer.logger:
            weights = []
            for sequential in pl_module.layers:
                for layer in sequential:
                    # check if layer has weights
                    if hasattr(layer, 'weight'):
                        weights.append(layer.weight)

            
            # Calculate L2 norm for each layer's weights
            weight_norms = torch.stack([w.norm() for w in weights])
            
            # Log individual layer norms and overall statistics
            weight_norm_dict = {
                f"weight_norm/layer_{i}": weight_norms[i].cpu() 
                for i in range(len(weight_norms))
            }
            
            trainer.logger.experiment.log({
                "weight_norm/mean": weight_norms.mean().cpu(),
                "weight_norm/max": weight_norms.max().cpu(),
                "weight_norm/min": weight_norms.min().cpu(),
                **weight_norm_dict
            })
