import math
import torch
from torch.optim import Optimizer

@torch.no_grad()
def row_normalize_gradient(
    grad: torch.Tensor,
    eps: float = 1e-8,
) -> torch.Tensor:
    """
    grad: [B, D] tensor
    mode: "standardize" or "diag"
    """
    # grad = grad.contiguous()
    normalizer = grad.pow(2).mean(dim=1)  # Std along rows
    whitened = grad / normalizer.unsqueeze(1).sqrt().add(eps)  # Standardize: subtract mean, divide by std

    return whitened

# Fast Row Orthognal STandardization SGD

class SRON(Optimizer):
    def __init__(
        self,
        lr=1e-3,
        wd=0.01,
        sgd_params=None,
        adamw_params=None,
        momentum=0.0,
        nesterov=True,
        adam_betas=(0.9, 0.95),
        adam_eps=1e-8,
        scale=1.0
    ):
        defaults = dict(
            lr=lr,
            wd=wd,
            momentum=momentum,
            nesterov=nesterov,
            scale=scale,
            adam_betas=adam_betas,
            adam_eps=adam_eps,
        )

        params = list(sgd_params) + (list(adamw_params) if adamw_params else [])

        super().__init__(params, defaults)

        for p in sgd_params:
            assert p.ndim >= 2, f"Expected 2D+ param, got {p.ndim}D"
            self.state[p]["use_white"] = True
        for p in adamw_params:
            self.state[p]["use_white"] = False

    @torch.no_grad()
    def step(self, closure=None):
        loss = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()

        for group in self.param_groups:
            lr = group["lr"]
            wd = group["wd"]
            momentum = group["momentum"]
            nesterov = group["nesterov"]
            beta1, beta2 = group["adam_betas"]
            eps = group["adam_eps"]
            scale = group["scale"]
            use_momentum = momentum > 0.0

            for p in group["params"]:
                g = p.grad
                if g is None:
                    continue
                state = self.state[p]

                if state["use_white"]:
                    if g.ndim > 2:
                        g = g.view(g.size(0), -1)
                    
                    g_white = row_normalize_gradient(
                        grad=g,
                    )

                    if use_momentum:
                        if "momentum_buffer" not in state:
                            state["momentum_buffer"] = torch.zeros_like(g_white)
                        buf = state["momentum_buffer"]
                        buf.mul_(momentum).add_(g_white, alpha=1.0 - momentum)

                        if nesterov:
                            g_update = g_white.add(buf, alpha=momentum)
                        else:
                            g_update = buf
                    else:
                        g_update = g_white

                    if wd > 0.0:
                        p.data.mul_(1 - lr * scale * wd)
                    p.data.add_(g_update, alpha=-lr * scale)

                else:
                    # Adam-style update
                    if "step" not in state:
                        state["step"] = 0
                        state["exp_avg"] = torch.zeros_like(g)
                        state["exp_avg_sq"] = torch.zeros_like(g)
                    state["step"] += 1
                    exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"]

                    exp_avg.mul_(beta1).add_(g, alpha=1 - beta1)
                    exp_avg_sq.mul_(beta2).addcmul_(g, g, value=1 - beta2)

                    denom = exp_avg_sq.sqrt().add_(eps)
                    bias_correction1 = 1 - beta1 ** state["step"]
                    bias_correction2 = 1 - beta2 ** state["step"]
                    step_size = lr * math.sqrt(bias_correction2) / bias_correction1

                    if wd > 0.0:
                        p.data.mul_(1 - lr * wd)
                    p.data.addcdiv_(exp_avg, denom, value=-step_size)

        return loss

