import deepspeed
import pytorch_lightning as pl
import torch
from pytorch_lightning import Callback


class DetectGradientBurst(Callback):

    def __init__(self, ):
        super().__init__()

    def on_after_backward(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule"):

        if trainer.training:

            if trainer.num_devices > 1:
                rank = trainer.local_rank
            else:
                rank = 0

            stats = {}
            for k, v in pl_module.model.named_parameters():

                if trainer.num_devices > 1:
                    grad_data = deepspeed.utils.safe_get_full_grad(v)
                else:
                    grad_data = v.grad

                if grad_data is not None:

                    if torch.isnan(grad_data).sum() > 0:
                        stats[f"gradNaN/rank_{rank}/{k}"] = 1
                        print(f"# NaN in grad rank {rank}: {k}")
                    if torch.isinf(grad_data).sum() > 0:
                        stats[f"gradInf/rank_{rank}/{k}"] = 1
                        print(f"# Inf in grad rank {rank} {k}")

            if trainer.loggers is not None:
                for logger in trainer.loggers:
                    logger.log_metrics(stats, step=trainer.global_step)
