import math
import torch
from torch.optim import Optimizer

class GradientStabilizer:
    """

    Notes:
      * Bias-correction is applied only for *linear* EMA (standard 1 - gamma^step).
    """

    def __init__(
        self,
        gamma1: float = 0.7,     # AdaGN EMA coef for ||g||
        gamma2: float = 0.999,     # AdaGN EMA coef for ||g||^2
        eps: float = 1e-12,
        bias_correction: bool = True,       # applies to linear EMA only
        use_optimizer_step: bool = True,
    ):
        self.gamma1 = float(gamma1)
        self.gamma2 = float(gamma2)
        self.eps = eps
        self.bias_correction = bias_correction
        self.use_optimizer_step = use_optimizer_step

        print(
            "gamma1, gamma2", self.gamma1, self.gamma2,
        )

        # per-parameter state
        self.state = {}

    def reset(self):
        self.state.clear()

    def _get_state(self, key: int, device, dtype=torch.float32):
        st = self.state.get(key)
        if st is None:
            st = {
                # linear stores
                "mnorm": torch.zeros((), device=device, dtype=dtype),   # E[||g||]
                "vnorm": torch.zeros((), device=device, dtype=dtype),   # E[||g||^2]
                "step": 0,
            }
            self.state[key] = st
        return st

    @staticmethod
    def _ema_linear_(accum: torch.Tensor, new: torch.Tensor, gamma: float):
        accum.mul_(gamma).add_(new * (1.0 - gamma))
        return accum

    def _maybe_bias_correct(self, val: torch.Tensor, gamma: float, step: int):
        """
        For linear EMA: val_hat = val / (1 - gamma**step).
        """
        if not self.bias_correction:
            return val
        bc = 1.0 - (gamma ** step)
        return val / bc


    @torch.no_grad()
    def _scale_param(self, p: torch.nn.Parameter, opt_state=None):
        if p.grad is None:
            return
        g = p.grad
        if g.is_sparse:
            return

        key = id(p)
        st = self._get_state(key, device=g.device, dtype=torch.float32)

        g_view = g.view(-1) if g.is_contiguous() else g.contiguous().view(-1)
        # g_view = g.reshape(-1) 

        # step from optimizer.state[p]['step'] if available
        if self.use_optimizer_step and opt_state is not None and p in opt_state:
            opt_st = opt_state[p]
            step = int(opt_st.get("step", 0)) + 1
            if step <= 0:
                step = 1
        else:
            st["step"] += 1
            step = int(st["step"])
        
        # ---------- 2) AdaGN: normalize ||g|| back to typical range ----------
        g_norm = g_view.norm(p=2).to(torch.float32)

        if not torch.isfinite(g_norm) or g_norm <= 0:
            return  # skip AdaGN for this param this step

        # Update E[||g||] and E[||g||^2] in chosen domains

        self._ema_linear_(st["mnorm"], g_norm, self.gamma1)
        m_est = self._maybe_bias_correct(st["mnorm"], self.gamma1, step)

        g2 = (g_norm ** 2)
        self._ema_linear_(st["vnorm"], g2, self.gamma2)
        v_est = self._maybe_bias_correct(st["vnorm"], self.gamma2, step)

        # target_over_sigma = m / sqrt(v + eps)
        denom = v_est.sqrt().add_(self.eps)
        target_over_sigma = (m_est / denom).to(g_view.dtype)

        # Normalize current gradient to unit norm then scale
        g_norm_f = g_norm.to(g_view.dtype)
        g_norm_f = torch.clamp(g_norm_f, min=self.eps)
        g_view.div_(g_norm_f).mul_(target_over_sigma)

    @torch.no_grad()
    def __call__(self, obj, opt_state=None):
        if isinstance(obj, Optimizer):
            optimizer = obj
            param_groups = optimizer.param_groups
            if opt_state is None:
                opt_state = optimizer.state
            for group in param_groups:
                for p in group["params"]:
                    self._scale_param(p, opt_state=opt_state)
            return

        if isinstance(obj, (list, tuple)) and len(obj) > 0 and isinstance(obj[0], dict):
            for group in obj:
                for p in group["params"]:
                    self._scale_param(p, opt_state=opt_state)
            return

        for p in obj:
            self._scale_param(p, opt_state=opt_state)


class GSWrapper:
    def __init__(self, optimizer: Optimizer, **scale_kwargs):
        self.optimizer = optimizer
        self.scaler = GradientStabilizer(**scale_kwargs)

    @property
    def param_groups(self):
        return self.optimizer.param_groups

    @property
    def state(self):
        return self.optimizer.state

    def zero_grad(self, set_to_none: bool = False):
        self.optimizer.zero_grad(set_to_none=set_to_none)

    def step(self, closure=None):
        # Apply AdaClip + AdaGN before base optimizer.step()
        self.scaler(self.optimizer)
        if closure is not None:
            return self.optimizer.step(closure)
        else:
            return self.optimizer.step()

    def state_dict(self):
        return {"optimizer": self.optimizer.state_dict(), "scaler_state": self.scaler.state}

    def load_state_dict(self, state_dict):
        self.optimizer.load_state_dict(state_dict["optimizer"])
        self.scaler.state = state_dict.get("scaler_state", {})
"""
Usage: optimizer=GSWrapper(optimizer, gamma1=0.6,gamma2=0.999)
"""
