from transformers import TrainerCallback
import matplotlib.pyplot as plt

class GradientNormCallback(TrainerCallback):
    def __init__(self):
        self.gradient_norms = []

    def on_training_step_end(self, args, state, control, **kwargs):
        model = kwargs['model']
        total_norm = 0.0
        for param in model.parameters():
            if param.grad is not None:
                param_norm = param.grad.data.norm(2).item()
                total_norm += param_norm ** 2
        total_norm = total_norm ** 0.5
        
        self.gradient_norms.append(total_norm)

    def get_mean_and_reset(self):
        if len(self.gradient_norms) == 0:
            return 0.0
        mean = sum(self.gradient_norms) / len(self.gradient_norms)
        self.gradient_norms = []
        return mean