# On the Convergence of Muon and Beyond-Supplementary Material

We introduce MuonMVR, which applies Variance-Reduction to Muon, and provide a detailed theoretical analysis of its convergence.
The core idea of MuonMVR is to reduce gradient variance by leveraging information from the previous training step.
This requires a specific training loop structure: the gradient of the previous batch must be calculated before the optimizer performs `step()` on the gradient of the current batch.
This operation is handled by the `optimizer.update_last_grad()` method.
In practice, for low computational cost and low memory requirements, the existing and widely used approximate version, MVR1, is sufficient.

# Stochastic Gradient Estimation
The mathematical formulas behind the different modes of the MuonMVR optimizer. Let $X_t$ be the model parameters at step $t$, and $\xi_t$ be the data batch used at step $t$.
The gradient of the loss function is denoted as $\nabla f(X_t;\xi_t)$.

**Standard Muon (EMA)**
The baseline optimizer uses a standard Exponential Moving Average (EMA) of the gradients. It serves as the foundation for the various variance-reduced variants.
$$\mathbf{M}_t = \beta_t \mathbf{M}_{t-1} + (1-\beta_t)\nabla f(\mathbf{X}_t;\xi_t)$$

**MVR1: Single-Batch Variance Reduction**
This is a precise variance reduction method. Its correction term is based on the difference between the gradient of the current batch (using the current parameters $X_t$) and the gradient of the previous batch (using the parameters from the previous step $X_{t-1}$). This corresponds to the `is_approx=True` mode in the practical implementation.

$$\mathbf{M}_t = \beta_t \mathbf{M}_{t-1} + (1-\beta_t)\nabla f(\mathbf{X}_t;\xi_t) + \gamma\cdot \beta_t \cdot (\nabla f(\mathbf{X}_t;\xi_t)-\nabla f(\mathbf{X}_{t-1};\xi_{t-1}))$$

By setting $\gamma = 1-\beta_t, \mathbf{M}_t / {1-\beta_t}$, this is equivalent to the practical versions in ([Muon](https://github.com/KellerJordan/Muon) and [Moonlight-Muon](https://github.com/MoonshotAI/Moonlight)):
$$
\begin{aligned}
 \mathbf{C}_{t} &=\mu \mathbf{C}_{t-1} + \nabla f(\mathbf{X}_{t};\xi_t)\\
 \mathbf{M}_t &= \mu \mathbf{C}_{t} + \nabla f(\mathbf{X}_{t};\xi_t)\\
\end{aligned}
$$

**MVR2: Dual-Batch Variance Reduction (Standard MVR)**
This is a standard practice discussed in theory, where the variance reduction term is calculated on the same data batch $\xi_t$ but using two different model states (the current $X_t$ and the previous $X_{t−1}$).

$$\mathbf{M}_t = \beta_t \mathbf{M}_{t-1} + (1-\beta_t)\nabla f(\mathbf{X}_t;\xi_t) + \gamma\cdot \beta_t \cdot (\nabla f(\mathbf{X}_t;\xi_t)-\nabla f(\mathbf{X}_{t-1};\xi_t))$$

**MVR3: Dual-Batch Variance Reduction (Approximate MVR)**
This is another version that corresponds to the `is_approx=False` mode. Its correction term uses the difference between the gradient of the current batch (with current parameters $X_t$) and the gradient of the previous batch $\xi_{t−1}$ (also using the current parameters $X_t$).

$$\mathbf{M}_t = \beta_t \mathbf{M}_{t-1} + (1-\beta_t)\nabla f(\mathbf{X}_t;\xi_t) + \gamma\cdot \beta_t \cdot (\nabla f(\mathbf{X}_t;\xi_t)-\nabla f(\mathbf{X}_{t};\xi_{t-1}))$$

Because MVR2 and MVR3 require extra forward and backward passes and storage of additional data or model states, the existing and widely used approximate version, MVR1, is sufficient in practice for low computational cost and low memory requirements.

# Usage
### MVR3 Usage
1.  **Standard Training (No Gradient Accumulation)**
    This is the most basic usage. In each iteration, first calculate the gradient for the previous batch and call `update_last_grad()`, then calculate the gradient for the current batch and call `step()`.
    ```python
    # Initialize the optimizer (use the exact version or a high-precision approximate version)
    optimizer = MuonMVR(model.parameters(), lr=1e-3, is_approx=False)
    
    # Training loop
    previous_X, previous_Y = None, None
    for epoch in range(epochs):
        for X, Y in data_loader:
            if previous_X is not None:
                # Calculate the gradient for the previous batch
                logits, loss = model(previous_X, previous_Y)
                loss.backward()
                optimizer.update_last_grad()
                optimizer.zero_grad(set_to_none=True)
            
            # Process the current batch
            logits, loss = model(X, Y)
            loss.backward()
            optimizer.step()
            optimizer.zero_grad(set_to_none=True)
            
            # Store the current batch for the next iteration
            previous_X, previous_Y = X.clone(), Y.clone()
    ```
2.  **Training with Gradient Accumulation**
    To use gradient accumulation, the losses for both the current and previous batches must be scaled by dividing by the number of accumulation steps. `optimizer.step()` is only called after the required number of gradients has been accumulated.
    ```python
    # Initialize the optimizer
    optimizer = MuonMVR(model.parameters(), lr=1e-3, is_approx=False)
    
    # Training loop
    previous_X, previous_Y = None, None
    accum_steps = 4  # Number of gradient accumulation steps
    
    for epoch in range(epochs):
        for i, (X, Y) in enumerate(data_loader):
            # Process the current batch
            logits, loss = model(X, Y)
            # Scale the loss to average the gradients
            loss = loss / accum_steps
            loss.backward()
            
            if previous_X is not None:
                # Calculate the gradient for the previous batch
                prev_logits, prev_loss = model(previous_X, previous_Y)
                prev_loss = prev_loss / accum_steps
                prev_loss.backward()
                optimizer.update_last_grad()
                # Clear the gradients for the previous batch, keeping the gradients for the current batch
                optimizer.zero_grad(set_to_none=True) 
            
            # Update parameters when enough gradients have been accumulated
            if (i + 1) % accum_steps == 0:
                optimizer.step()
                optimizer.zero_grad(set_to_none=True)
                previous_X, previous_Y = None, None # Reset for the next accumulation cycle
            else:
                previous_X, previous_Y = X.clone(), Y.clone()
    ```
### MVR2 Usage
3.  **Optimizer Modes**
    MuonMVR can be initialized in different modes to trade off between precision and computational cost.

    **Exact Variance Reduction (`is_approx=False`)**
    To achieve the most precise variance reduction, you must manually manage the model state. Before calculating the gradient for the previous batch, you need to load the model state from the previous iteration. This ensures that the gradient is computed with the correct model weights.
    ```python
    optimizer = MuonMVR(model.parameters(), lr=1e-3, is_approx=False)
    old_state_dict = {}
    
    for batch_idx, (inputs, targets) in enumerate(trainloader):
        # Store the current model state
        cur_state_dict = {k: v.data.clone() for k, v in net.state_dict().items()}
    
        if old_state_dict:
            # Load the previous model state to compute the old gradient
            net.load_state_dict(old_state_dict)
            optimizer.zero_grad()
            outputs = net(inputs)
            loss = criterion(outputs, targets)
            loss.backward()
            optimizer.update_last_grad()
    
        # Restore the current model state to compute the new gradient
        net.load_state_dict(cur_state_dict)
        old_state_dict = {k: v.data.clone() for k, v in cur_state_dict.items()}
        
        # Standard forward/backward pass and step
        optimizer.zero_grad()
        outputs = net(inputs)
        loss = criterion(outputs, targets)
        loss.backward()
        optimizer.step()
    ```

### MVR1 Usage
**Approximate Version (`is_approx=True`)**
This mode uses a more computationally efficient approximation. Its training loop structure is the same as the standard training example.
```python
# Initialize the optimizer and enable the approximate mode
optimizer = MuonMVR(model.parameters(), lr=1e-3, is_approx=True)
```

# CITE

We referenced the original author's implementation of [Muon](https://github.com/KellerJordan/Muon) as well as the implementation from [Moonlight-Muon](https://github.com/MoonshotAI/Moonlight), and adopted the gradient clipping strategy from [MARS](https://github.com/AGI-Arena/MARS) to further control gradient noise. We thank them for their open-source contributions!