## Muon code from Moonlight
## https://github.com/MoonshotAI/Moonlight/blob/master/examples/toy_train.py

# This code snippet is a modified version adapted from the following GitHub repository:
# https://github.com/KellerJordan/Muon/blob/master/muon.py
import torch
import math

from .polar import PolarExpress


class Muon(torch.optim.Optimizer):
    """
    Muon - MomentUm Orthogonalized by Newton-schulz

    Muon internally runs standard SGD-momentum, and then performs an orthogonalization post-
    processing step, in which each 2D parameter's update is replaced with the nearest orthogonal
    matrix. To efficiently orthogonalize each update, we use a Newton-Schulz iteration, which has
    the advantage that it can be stably run in bfloat16 on the GPU.

    Some warnings:
    - We believe this optimizer is unlikely to work well for training with small batch size.
    - We believe it may not work well for finetuning pretrained models, but we haven't tested this.

    Arguments:
        muon_params: The parameters to be optimized by Muon.
        lr: The learning rate. The updates will have spectral norm of `lr`. (0.02 is a good default)
        momentum: The momentum used by the internal SGD. (0.95 is a good default)
        heavy_ball: Whether the momentum accumulation should be a moving average.
        nesterov: Whether to use Nesterov-style momentum in the internal SGD. (recommended)
        ns_steps: The number of Newton-Schulz iterations to run. (6 is probably always enough)
        lmo: Whether to use LMO instead variational viewpoint of gradient descent to derive
        update rule. If lmo=False, update is additionally scaled by the dual norm of the
        gradient.
        l2_prod_norm: Whether to use the L2 norm for the product space over layers
        instead of the max norm, which scales each layer's LR by the nuclear norm of the
        gradient.
        nuc_approx: How to approximate the gradient nuclear norm. Choices: [None, 'fro', 'past']
        rms_layer_norm: Whether to use the RMS norm the input/output space of each
        layer, which scale each layer's LR by sqrt(fan_out/fan_in).
        truncate_model: Lower bound of loss, if using a truncated model.
        adamw_params: The parameters to be optimized by AdamW. Any parameters in `muon_params` which are
        {0, 1}-D or are detected as being the embed or lm_head will be optimized by AdamW as well.
        adamw_lr: The learning rate for the internal AdamW.
        adamw_betas: The betas for the internal AdamW.
        adamw_eps: The epsilon for the internal AdamW.
        adamw_wd: The weight decay for the internal AdamW.
    """
    def __init__(self,
                 named_params,
                 lr=1e-3,
                 wd=0.1,
                 momentum=0.95,
                 heavy_ball=False,
                 nesterov=True,
                 ns_steps=5,
                 lmo=True,
                 l2_prod_norm=False,
                 nuc_approx=None,
                 rms_layer_norm=False,
                 truncate_model=None,
                 adamw_betas=(0.95, 0.95),
                 adamw_eps=1e-8):

        defaults = dict(
                lr=lr,
                wd=wd,
                momentum=momentum,
                heavy_ball=heavy_ball,
                nesterov=nesterov,
                ns_steps=ns_steps,
                lmo=lmo,
                l2_prod_norm=l2_prod_norm,
                nuc_approx=nuc_approx,
                rms_layer_norm=rms_layer_norm,
                truncate_model=truncate_model,
                adamw_betas=adamw_betas,
                adamw_eps=adamw_eps,
        )

        muon_params, muon_params_names = [], []
        adamw_params, adamw_params_names = [], []
        for name, p in named_params:
            if p.ndim >= 2 and not any(excluded in name for excluded in ["embeddings", "embed_tokens", "wte", "lm_head", "wpe"]):
                muon_params.append(p)
                muon_params_names.append(name)
            else:
                adamw_params.append(p)
                adamw_params_names.append(name)
        # print("EMBED TOKENS AND LM_HEAD SHOULD BE WITH ADAMW.")
        # print("Params trained with MUON : ", muon_params_names)
        # print("Params trained with ADAMW : ", adamw_params_names)
        params = list(muon_params)
        params.extend(adamw_params)
        super().__init__(params, defaults)

        # Sort parameters into those for which we will use Muon, and those for which we will not
        # Use Muon for every parameter in muon_params which is >= 2D and doesn't look like an embedding or head layer
        for p in muon_params:
                assert p.ndim == 2, p.ndim
                self.state[p]["use_muon"] = True
                
        for p in adamw_params:
                # Do not use Muon for parameters in adamw_params
                self.state[p]["use_muon"] = False

        self.use_truncation = truncate_model is not None
        if self.use_truncation:
            self.loss_model = None
        self.step_size_list = list()

        # Sanity check for options.
        if self.use_truncation and not heavy_ball:
            print("Using truncated models without heavy ball momentum. Does this make any sense?")

    def step(self, closure=None, loss=None):
        """Perform a single optimization step.
            Args:
            closure (Callable, optional): A closure that reevaluates the model
                and returns the loss.
            loss (torch.Tensor, optional): Tensor holding the loss of the current iteration.
        """

        if self.use_truncation:
            assert (closure is not None) or (loss is not None), "Either loss tensor or closure must be passed."
            assert (closure is None) or (loss is None), "Pass either the loss tensor or the closure, not both."

        if closure is not None:
            with torch.enable_grad():
                loss = closure()

        for group in self.param_groups:
            ############################
            #           Muon           #
            ############################

            params = [p for p in group["params"] if self.state[p]["use_muon"]]
            lr = group["lr"]
            wd = group["wd"]
            momentum = group["momentum"]
            heavy_ball = group["heavy_ball"]
            nesterov = group["nesterov"]
            lmo = group["lmo"]
            l2_prod_norm = group["l2_prod_norm"]
            nuc_approx = group["nuc_approx"]
            rms_layer_norm = group["rms_layer_norm"]
            truncate_model = group["truncate_model"]

            # initial pass over parameters to compute update direction and LR scalings.
            # Warning for the future: if we ever use more than one param group, these
            # scalings are not going to behave exactly right. Here we compute scaling
            # factors that depend on all layers of the network, so we assume that all
            # layers of the network are inside the current param group.
            layer_nuc_norms = None
            need_nuc_norms = (not lmo) or l2_prod_norm or self.use_truncation
            momentum_coeff = 1.0 - momentum if heavy_ball else 1.0
            current_loss_model = 0.0
            new_loss_model = 0.0
            for i, p in enumerate(params):

                # sanity check
                g = p.grad
                if g is None:
                    continue
                if g.ndim > 2:
                    g = g.view(g.size(0), -1)

                # calc momentum.
                state = self.state[p]
                if "momentum_buffer" not in state:
                    state["momentum_buffer"] = g.clone()
                buf = state["momentum_buffer"]
                buf.mul_(momentum).add_(g, alpha=momentum_coeff)

                # Compute inner product term of running average for truncated model.
                if self.use_truncation:
                    current_loss_model += torch.sum(torch.mul(p.data, p.grad.data))
                    new_loss_model += torch.sum(torch.mul(p.data, buf.data))

                # quit now if update doesn't depend on nuclear norm of layer gradients.
                if layer_nuc_norms is None:
                    layer_nuc_norms = torch.zeros(len(params), device=p.device)
                if not need_nuc_norms:
                    continue

                # Compute (or approximate) nuclear norms of each layer's gradient.
                if nuc_approx is None or (nuc_approx == "past" and "past_nuc" not in state):

                    # calc update.
                    if nesterov:
                        g = g.add(buf, alpha=momentum)
                    else:
                        g = buf
                    u = PolarExpress(g, steps=group["ns_steps"])

                    # If G = UDV^T, then nuc(G) = tr(G @ UV^T).
                    layer_nuc_norms[i] = torch.trace(g.bfloat16().T @ u)

                elif nuc_approx == "fro":
                    layer_nuc_norms[i] = torch.linalg.matrix_norm(g, ord="fro")
                elif nuc_approx == "past":
                    layer_nuc_norms[i] = state["past_nuc"]
                else:
                    raise NotImplementedError

                # Apply RMS scaling to nuclear norms.
                if rms_layer_norm:
                    fan_out, fan_in = p.shape[:2]
                    layer_nuc_norms[i] *= math.sqrt(fan_out / fan_in)

            # Compute dual norm of gradient, which is used to scale LR.
            if l2_prod_norm:
                global_dual_norm = torch.linalg.vector_norm(layer_nuc_norms, ord=2)
            else:
                global_dual_norm = torch.sum(layer_nuc_norms)
            global_dual_norm = float(global_dual_norm)

            # Update running average for truncated model and compute truncated lr.
            current_lr = lr
            if self.use_truncation:
                loss_model_update = loss.item() - current_loss_model.item()
                if self.loss_model is None:
                    self.loss_model = loss_model_update
                self.loss_model = momentum * self.loss_model + momentum_coeff * loss_model_update
                current_lr = min((self.loss_model - truncate_model + new_loss_model.item()) / global_dual_norm ** 2, lr)
            self.step_size_list.append(current_lr)

            # apply weight updates
            for i, p in enumerate(params):

                # sanity check
                g = p.grad
                if g is None:
                    continue
                if g.ndim > 2:
                    g = g.view(g.size(0), -1)

                # calc update. Note that we already computed and stored the momentum
                # term before, but we are re-computing the matrix sign. This is
                # suboptimal w.r.t.  time but doesn't use any extra memory. We can
                # always tweak this later.
                state = self.state[p]
                buf = state["momentum_buffer"]
                if nesterov:
                    g = g.add(buf, alpha=momentum)
                else:
                    g = buf
                u = PolarExpress(g, steps=group["ns_steps"])

                # Compute and store nuclear norm of u if necessary.
                if nuc_approx == "past":
                    if "past_nuc" not in state:
                        state["past_nuc"] = torch.zeros(1, device=p.device)
                    state["past_nuc"] = torch.trace(g.bfloat16().T @ u)

                # apply scaling factors to lr depending on steepest descent variations
                lr_scale = 1.0
                if lmo and not l2_prod_norm:
                    if rms_layer_norm:
                        fan_out, fan_in = p.shape[:2]
                        lr_scale = math.sqrt(fan_out / fan_in)
                if lmo and l2_prod_norm:
                    lr_scale = layer_nuc_norms[i] / global_dual_norm
                if not lmo and not l2_prod_norm:
                    lr_scale = global_dual_norm
                if not lmo and l2_prod_norm:
                    lr_scale = layer_nuc_norms[i]
                adjusted_lr = lr_scale * current_lr

                # apply weight decay
                p.data.mul_(1 - lr * wd)

                # apply update
                p.data.add_(u, alpha=-adjusted_lr)
                
            ############################
            #       AdamW backup       #
            ############################

            params = [p for p in group["params"] if not self.state[p]["use_muon"]]
            lr = group['lr']
            beta1, beta2 = group["adamw_betas"]
            eps = group["adamw_eps"]
            weight_decay = group["wd"]

            for p in params:
                g = p.grad
                if g is None:
                    continue
                state = self.state[p]
                if "step" not in state:
                    state["step"] = 0
                    state["moment1"] = torch.zeros_like(g)
                    state["moment2"] = torch.zeros_like(g)
                state["step"] += 1
                step = state["step"]
                buf1 = state["moment1"]
                buf2 = state["moment2"]
                buf1.lerp_(g, 1 - beta1)
                buf2.lerp_(g.square(), 1 - beta2)

                g = buf1 / (eps + buf2.sqrt())

                bias_correction1 = 1 - beta1**step
                bias_correction2 = 1 - beta2**step
                scale = bias_correction1 / bias_correction2**0.5
                p.data.mul_(1 - lr * weight_decay)
                p.data.add_(g, alpha=-lr / scale)
                    
        return loss



