import pytorch_lightning as pl
from pytorch_lightning.callbacks import Callback
from typing import Dict
import math


class GradientDiagnosticsCallback(Callback):

    def __init__(self, log_every_n_steps: int = 100):
        super().__init__()
        self.log_every_n_steps = log_every_n_steps
        self._params_before_step = {}

    def on_before_optimizer_step(
        self,
        trainer: pl.Trainer,
        pl_module: pl.LightningModule,
        optimizer,
        optimizer_idx: int = 0,
    ):
        if trainer.global_step % self.log_every_n_steps != 0:
            return

        grad_norms = self._compute_gradient_norms_from_optimizer(optimizer)

        prefix = f"opt_{optimizer_idx}"
        pl_module.log(f"{prefix}/grad_norm_total", grad_norms["total"])
        pl_module.log(f"{prefix}/grad_norm_mean", grad_norms["mean"])
        pl_module.log(f"{prefix}/grad_norm_max", grad_norms["max"])

        for name, norm in grad_norms["per_layer"].items():
            log_name = f"{prefix}/grad_norm/{name.replace('.', '/')}"
            pl_module.log(log_name, norm)

        self._params_before_step[optimizer_idx] = {}
        for param_group in optimizer.param_groups:
            for param in param_group["params"]:
                if param.requires_grad:
                    self._params_before_step[optimizer_idx][id(param)] = {
                        "data": param.data.clone(),
                        "lr": param_group["lr"],
                    }

    def on_after_optimizer_step(
        self,
        trainer: pl.Trainer,
        pl_module: pl.LightningModule,
        optimizer,
        optimizer_idx: int = 0,
    ):
        if trainer.global_step % self.log_every_n_steps != 0:
            return

        if optimizer_idx not in self._params_before_step:
            return

        precond_norms = self._compute_preconditioned_norms_from_optimizer(
            optimizer, optimizer_idx, pl_module
        )

        prefix = f"opt_{optimizer_idx}"
        pl_module.log(f"{prefix}/precond_norm_total", precond_norms["total"])
        pl_module.log(f"{prefix}/precond_norm_mean", precond_norms["mean"])
        pl_module.log(f"{prefix}/precond_norm_max", precond_norms["max"])

        for name, norm in precond_norms["per_layer"].items():
            log_name = f"{prefix}/precond_norm/{name.replace('.', '/')}"
            pl_module.log(log_name, norm)

        del self._params_before_step[optimizer_idx]

    def _compute_gradient_norms_from_optimizer(self, optimizer) -> Dict:
        total_norm_sq = 0.0
        norms = []
        per_layer = {}

        for param_group in optimizer.param_groups:
            for param in param_group["params"]:
                if param.grad is None:
                    continue

                param_norm = param.grad.data.norm(2).item()
                total_norm_sq += param_norm**2
                norms.append(param_norm)

                param_name = self._find_param_name(param)
                if param_name:
                    per_layer[param_name] = param_norm

        total_norm = math.sqrt(total_norm_sq)
        mean_norm = sum(norms) / len(norms) if norms else 0.0
        max_norm = max(norms) if norms else 0.0

        return {
            "total": total_norm,
            "mean": mean_norm,
            "max": max_norm,
            "per_layer": per_layer,
        }

    def _compute_preconditioned_norms_from_optimizer(
        self, optimizer, optimizer_idx: int, pl_module: pl.LightningModule
    ) -> Dict:
        total_norm_sq = 0.0
        norms = []
        per_layer = {}

        saved_params = self._params_before_step[optimizer_idx]

        for param_group in optimizer.param_groups:
            for param in param_group["params"]:
                if not param.requires_grad:
                    continue

                param_id = id(param)
                if param_id not in saved_params:
                    continue

                saved_info = saved_params[param_id]
                lr = saved_info["lr"]

                if lr == 0:
                    continue

                param_update = saved_info["data"] - param.data
                precond_grad = param_update / lr

                param_norm = precond_grad.norm(2).item()
                total_norm_sq += param_norm**2
                norms.append(param_norm)

                param_name = self._find_param_name(param)
                if param_name is not None:
                    per_layer[param_name] = param_norm

        total_norm = math.sqrt(total_norm_sq)
        mean_norm = sum(norms) / len(norms) if norms else 0.0
        max_norm = max(norms) if norms else 0.0

        return {
            "total": total_norm,
            "mean": mean_norm,
            "max": max_norm,
            "per_layer": per_layer,
        }

    def _find_param_name(self, param) -> str | None:
        if not hasattr(self, "_param_to_name"):
            self._param_to_name = {}

        param_id = id(param)
        if param_id in self._param_to_name:
            return self._param_to_name[param_id]

        return None

    def on_train_start(self, trainer: pl.Trainer, pl_module: pl.LightningModule):
        self._param_to_name = {}
        for name, param in pl_module.named_parameters():
            self._param_to_name[id(param)] = name
