import lightning.pytorch as pl
import torch
import numpy as np

from src.callbacks.auto_clip import AutoClip

class AutoClipCallback(pl.Callback):
    def __init__(self,
                 grad_clip_enabled:bool=True,
                 auto_clip_enabled:bool=True,
                 default_clip_value:float=np.inf,
                 percentile: float=0.9,
                 hist_len:int=100,
                 min_hist_len:int=10):
        super().__init__()
        self.auto_clip = AutoClip(
            grad_clip_enabled=grad_clip_enabled,
            auto_clip_enabled=auto_clip_enabled,
            default_clip_value=default_clip_value,
            auto_clip_percentile=percentile,
            auto_clip_hist_len=hist_len,
            auto_clip_min_len=min_hist_len
        )

    def on_before_optimizer_step(self, trainer, pl_module, optimizer):
        # Compute the total gradient norm for the entire network
        total_norm = 0.0
        for p in pl_module.parameters():
            if p.grad is not None:
                param_norm = p.grad.data.norm(2)
                total_norm += param_norm.item() ** 2
        total_norm = total_norm ** 0.5

        # Update the gradient norm history
        self.auto_clip.update_gradient_norm_history(total_norm)

        # Get the current clip value
        clip_value = self.auto_clip.get_clip_value()
        
        # Log the total gradient norm
        trainer.logger.log_metrics({"train/grad_norm/total": total_norm}, step=trainer.global_step)
        trainer.logger.log_metrics({"train/grad_norm/clip": clip_value}, step=trainer.global_step)


        # Apply gradient clipping
        if clip_value != np.inf:
            torch.nn.utils.clip_grad_norm_(pl_module.parameters(), clip_value)

