import math
import torch
import torch.nn.functional as F
from torch.optim import Optimizer
from torch.cuda.amp import autocast
from torch import Tensor

import math
import torch

@torch.no_grad()

def row_normalize_gradient(
    grad: torch.Tensor,
    eps: float = 1e-8,
) -> torch.Tensor:
    """
    grad: [B, D] tensor
    mode: "standardize" or "diag"
    """
    # grad = grad.contiguous()
    if grad.ndim == 1:
        normalizer = grad.pow(2).mean(dim=0)
        whitened = grad / normalizer.sqrt().add(eps)
    else:
        normalizer = grad.pow(2).mean(dim=1)  # Std along rows
        whitened = grad / normalizer.unsqueeze(1).sqrt().add(eps)  # Standardize: subtract mean, divide by std

    return whitened

from torch.optim.optimizer import Optimizer

class SRONSGD(Optimizer):
    def __init__(
        self,
        params,
        lr=1e-3,
        wd=0.01,
        momentum=0.0,
        nesterov=True,
    ):
        defaults = dict(
            params=params,
            lr=lr,
            wd=wd,
            momentum=momentum,
            nesterov=nesterov,
        )
        if momentum > 0.0:
            print("Enable momentum")
        
        super().__init__(params, defaults)

    @torch.no_grad()
    def step(self, closure=None):
        loss = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()

        for group in self.param_groups:
            lr = group["lr"]
            wd = group["wd"]
            momentum = group["momentum"]
            nesterov = group["nesterov"]
            use_momentum = momentum > 0.0

            for p in group["params"]:
                g = p.grad
                if g is None:
                    continue
                state = self.state[p]

                if g.ndim > 2:
                    g = g.view(g.size(0), -1)
                
                g_white = row_normalize_gradient(
                    grad=g,
                )

                if use_momentum:
                    if "momentum_buffer" not in state:
                        state["momentum_buffer"] = torch.zeros_like(g_white)
                    buf = state["momentum_buffer"]
                    buf.mul_(momentum).add_(g_white, alpha=1.0 - momentum)

                    if nesterov:
                        g_update = g_white.add(buf, alpha=momentum)
                    else:
                        g_update = buf
                else:
                    g_update = g_white

                if wd > 0.0:
                    p.data.mul_(1 - lr * wd)
                p.data.add_(g_update, alpha=-lr)

        return loss

