"""Gradient monitoring callback for PyTorch Lightning training."""

import torch
import pytorch_lightning as pl
from pytorch_lightning.callbacks import Callback
import warnings
import typing as tp


class GradientMonitorCallback(Callback):
    """Full gradient monitoring with per-parameter tracking and anomaly detection.

    Features:
    - Per-parameter gradient norm, mean, std, max tracking
    - NaN/Inf detection with console warnings
    - Gradient histograms (WandB only)
    - Parameter group tracking (e.g., ssm, conv, proj)
    - Explosion alerts when norm > threshold

    Args:
        log_every_n_steps: How often to log gradient stats (default: 50)
        log_histograms: Whether to log gradient histograms to WandB (default: True)
        detect_anomalies: Whether to detect NaN/Inf/explosion (default: True)
        grad_norm_threshold: Threshold for gradient explosion alert (default: 100.0)
        parameter_groups: Dict mapping group names to parameter name patterns
            e.g., {'ssm': ['A_log', 'h_', 'q_'], 'conv': ['conv']}
    """

    def __init__(
        self,
        log_every_n_steps: int = 50,
        log_histograms: bool = True,
        detect_anomalies: bool = True,
        grad_norm_threshold: float = 100.0,
        parameter_groups: tp.Optional[tp.Dict[str, tp.List[str]]] = None,
    ):
        super().__init__()
        self.log_every_n_steps = log_every_n_steps
        self.log_histograms = log_histograms
        self.detect_anomalies = detect_anomalies
        self.grad_norm_threshold = grad_norm_threshold
        self.parameter_groups = parameter_groups or {}

    def on_after_backward(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None:
        """Called after loss.backward() but before optimizer.step()."""
        if trainer.global_step % self.log_every_n_steps != 0:
            return

        grad_stats = {}
        anomalies = []

        # Per-parameter tracking
        for name, param in pl_module.named_parameters():
            if param.grad is None:
                continue

            grad = param.grad
            grad_norm = grad.norm().item()

            # NaN/Inf detection
            if self.detect_anomalies:
                if torch.isnan(grad).any():
                    anomalies.append(f"{name}: NaN gradients")
                if torch.isinf(grad).any():
                    anomalies.append(f"{name}: Inf gradients")
                if grad_norm > self.grad_norm_threshold:
                    anomalies.append(f"{name}: grad_norm={grad_norm:.2f} > threshold={self.grad_norm_threshold}")

            # Log per-parameter stats
            grad_stats[f"grad_norm/{name}"] = grad_norm
            grad_stats[f"grad_mean/{name}"] = grad.mean().item()
            grad_stats[f"grad_std/{name}"] = grad.std().item()
            grad_stats[f"grad_max/{name}"] = grad.abs().max().item()

            # Histograms (WandB only)
            if self.log_histograms:
                self._log_histogram(trainer, name, grad)

        # Global gradient norm (without clipping)
        all_grads = [p.grad for p in pl_module.parameters() if p.grad is not None]
        if all_grads:
            total_norm = torch.stack([g.norm() for g in all_grads]).norm().item()
            grad_stats["grad_norm/global"] = total_norm

        # Parameter group norms
        for group_name, patterns in self.parameter_groups.items():
            group_grads = []
            for name, param in pl_module.named_parameters():
                if param.grad is not None and any(pat in name for pat in patterns):
                    group_grads.append(param.grad.norm())
            if group_grads:
                group_norm = torch.stack(group_grads).norm().item()
                grad_stats[f"grad_norm/group_{group_name}"] = group_norm

        # Log all stats
        pl_module.log_dict(grad_stats, on_step=True, on_epoch=False)

        # Alert on anomalies
        if anomalies:
            warning_msg = f"Step {trainer.global_step} gradient anomalies:\n" + "\n".join(anomalies)
            warnings.warn(warning_msg)
            pl_module.log("grad_anomaly_count", float(len(anomalies)), on_step=True, on_epoch=False)

    def _log_histogram(self, trainer: pl.Trainer, name: str, grad: torch.Tensor) -> None:
        """Log gradient histogram to WandB if available."""
        try:
            import wandb
            # Check if we have a WandB logger
            for logger in trainer.loggers:
                if hasattr(logger, 'experiment') and isinstance(logger.experiment, wandb.sdk.wandb_run.Run):
                    wandb.log({f"grad_hist/{name}": wandb.Histogram(grad.detach().cpu().numpy().flatten())})
                    break
        except ImportError:
            pass
        except Exception:
            pass  # Silently ignore histogram logging errors
