# Adapted from https://github.com/AGI-Arena/MARS/tree/main
import torch

def batched_zeropower_newtonschulz5(G, steps=10, eps=1e-7):
    assert G.ndim == 3 and G.size(1) == G.size(2), "G must be [k, d, d] with square matrices"
    a, b, c = (3.4445, -4.7750, 2.0315)
    X = G / (G.norm(dim=(1, 2), keepdim=True) + eps)
    # X = X.to(dtype=torch.bfloat16)
    X_t = X.transpose(-2, -1)

    for _ in range(steps):
        A = torch.matmul(X, X_t)
        B = torch.matmul(A, X)
        X = a * X + b * B + c * torch.matmul(A, B)
        X_t = X.transpose(-2, -1)

    return X

def batched_zeropower_newtonschulz5_rect(G, steps=10, eps=1e-7):
    """
    Applies the Newton-Schulz quintic iteration to batched rectangular matrices [k, d1, d2].
    """
    assert G.ndim == 3, "Expected [k, d1, d2]"
    a, b, c = (3.4445, -4.7750, 2.0315)

    G = G / (G.norm(dim=(1, 2), keepdim=True) + eps)  
    transposed = G.size(1) > G.size(2) 

    if transposed:
        G = G.transpose(1, 2) 

    X = G 

    for _ in range(steps):
        A = torch.matmul(X, X.transpose(-2, -1))  # [k, d1, d1]
        B = torch.matmul(A, X)                    # [k, d1, d2]
        X = a * X + b * B + c * torch.matmul(A, B)

    if transposed:
        X = X.transpose(1, 2)

    return X


class BatchedMuon_VR(torch.optim.Optimizer):
    def __init__(self, params, lr=0.02, momentum=0.95, backend_steps=5, weight_decay=0.0):
        defaults = dict(lr=lr, momentum=momentum, backend_steps=backend_steps, weight_decay=weight_decay)
        super().__init__(params, defaults)

        self.prev_grads = None
        self.curr_grads = None

    def step(self):
        for group in self.param_groups:
            lr = group['lr']
            weight_decay = group['weight_decay']
            momentum = group['momentum']
            backend_steps = group['backend_steps']

            for p in group['params']:
                if p.grad is None:
                    continue
                assert p.ndim == 3, "BatchedMuon only supports 3D tensors [k, d, d]"
                g = p.grad.detach()
                state = self.state[p]

                if 'momentum_buffer' not in state:
                    state['momentum_buffer'] = torch.zeros_like(p.data)
                buf = state['momentum_buffer']

                buf.mul_(momentum).add_(g)

                if self.prev_grads is not None and self.curr_grads is not None:
                    prev_g = self.prev_grads.get(p, None)
                    curr_g = self.curr_grads.get(p, None)
                    if prev_g is not None and curr_g is not None:
                        vr_term = curr_g - prev_g
                        buf.add_(vr_term, alpha= 0.1*(momentum / (1 - momentum)))

                g_ortho = batched_zeropower_newtonschulz5_rect(buf, steps=backend_steps)
                scale = (p.size(1) / p.size(2)) ** 0.5 if p.size(1) > p.size(2) else 1.0
                g_ortho = g_ortho * scale

                p.data.mul_(1.0 - lr * weight_decay).add_(g_ortho, alpha=-lr)

                state['momentum_buffer'] = buf