import os
import math
import torch
import numpy as np
from datasets import load_dataset
from torch.utils.data import DataLoader, Dataset
from typing import Dict


# This code snippet is a modified version adapted from the following GitHub repository:
# https://github.com/KellerJordan/Muon/blob/master/muon.py
@torch.compile
def zeropower_via_newtonschulz5(G, steps):
    """
    Newton-Schulz iteration to compute the zeroth power / orthogonalization of G. We opt to use a
    quintic iteration whose coefficients are selected to maximize the slope at zero. For the purpose
    of minimizing steps, it turns out to be empirically effective to keep increasing the slope at
    zero even beyond the point where the iteration no longer converges all the way to one everywhere
    on the interval. This iteration therefore does not produce UV^T but rather something like US'V^T
    where S' is diagonal with S_{ii}' ~ Uniform(0.5, 1.5), which turns out not to hurt model
    performance at all relative to UV^T, where USV^T = G is the SVD.
    """
    assert len(G.shape) == 2
    a, b, c = (3.4445, -4.7750, 2.0315)
    X = G.bfloat16()
    if G.size(0) > G.size(1):
        X = X.T
    # Ensure spectral norm is at most 1
    X = X / (X.norm() + 1e-7)
    # Perform the NS iterations
    for _ in range(steps):
        A = X @ X.T
        B = (
            b * A + c * A @ A
        )  # adapted from suggestion by @jxbz, @leloykun, and @YouJiacheng
        X = a * X + B @ X

    if G.size(0) > G.size(1):
        X = X.T
    return X


# efficiently estimate the eigen value
def torch_randomized_svd(A, k=20, n_iter=1):
    # Convert to float32 for stability and compatibility
    A = A.float()

    batch_shape = A.shape[:-2]
    m, n = A.shape[-2:]
    assert k <= min(m, n), "k must be <= min(m, n)"
    device = A.device

    # Step 1: Sample a random projection matrix Q (float32)
    Q = torch.randn(*batch_shape, n, k, device=device, dtype=torch.float32)

    # Step 2: Project A to low-rank subspace
    Y = A @ Q  # (..., m, k)
    for _ in range(n_iter):
        Y = A @ (A.transpose(-2, -1) @ Y)

    # Step 3: Orthonormalize
    Q_batched, _ = torch.linalg.qr(Y)  # (..., m, k)

    # Step 4: Project original matrix into lower space
    B = Q_batched.transpose(-2, -1) @ A  # (..., k, n)

    # Step 5: SVD in low-rank space
    U_hat, S, Vh = torch.linalg.svd(B, full_matrices=False)  # (..., k, k), (k,), (k, n)

    return S  # or return U, S, Vh if needed

class LayerWiseAdaptiveGradientTracker:
    """
    Enhanced gradient tracker that allocates learning rates to each layer 
    based on the ratio of their gradient differences relative to all layers.
    """

    def __init__(self, alpha=0.9, warmup_steps=1000):
        """
        Args:
            alpha: EMA decay factor for gradient difference tracking
            warmup_steps: Number of steps before adaptive adjustment kicks in
        """
        self.alpha = alpha
        self.warmup_steps = warmup_steps

        # Tracking dictionaries
        self.prev_grads = {}  # Previous gradients for each parameter
        self.grad_diff_ema = {}  # EMA of gradient differences
        self.adaptive_lr_scales = {}  # Adaptive learning rate scales
        self.step_count = 0
        
        # Layer-wise allocation tracking
        self.layer_grad_diffs = {}  # Current gradient differences by layer
        self.layer_lr_allocation = {}  # Learning rate allocation by layer
        self.layer_n_dim = {}  # dimensionality of each layer

    def _calculate_layer_allocations(self, use_muon: bool) -> Dict[str, float]:
        """Calculate learning rate allocations for each layer based on gradient differences."""
        if not self.layer_grad_diffs:
            return {}
        
        grad_diffs = np.array(list(self.layer_grad_diffs.values()))
        
        temp_diffs = grad_diffs + 1e-8
       
        sig_min = np.min(temp_diffs)
        allocations = np.sqrt(sig_min/temp_diffs)

        
        layer_allocations = {}
        for param_name, allocation in zip(self.layer_grad_diffs.keys(), allocations):
            layer_allocations[param_name] = float(allocation)

        return layer_allocations


    def update_and_get_lr_scale(self, param_name: str, current_grad: torch.Tensor, step_count: int, use_muon: bool) -> float:
        """
        Update gradient tracking and return adaptive learning rate scale for a parameter.
        """
        # Initialize tracking for new parameters
        if param_name not in self.prev_grads:
            self.prev_grads[param_name] = None
            self.grad_diff_ema[param_name] = 0.0
            self.adaptive_lr_scales[param_name] = 1.0
            
        # During warmup, use base learning rate
        if step_count <= self.warmup_steps:
            self.prev_grads[param_name] = current_grad.clone().detach()
            return 1.0
            
        # Compute gradient difference if we have previous gradient
        if self.prev_grads[param_name] is not None:
            # Calculate gradient difference (using nuclear norm for 2D+ tensors)
            if current_grad.ndim >= 2 and use_muon:
                grad_diff = max(1, current_grad.size(-2) / current_grad.size(-1))**0.5 * torch_randomized_svd(current_grad - self.prev_grads[param_name], k=100).sum().item()

            else:
                if current_grad.ndim == 1:
                    grad_diff = math.sqrt(current_grad.size(0)) * torch.norm(current_grad - self.prev_grads[param_name], p=2).item()
                else:
                    d = current_grad.size(-1)
                    grad_diff = 1/d * torch.norm(current_grad - self.prev_grads[param_name], p=1).item()
            # Update EMA of gradient difference
            self.grad_diff_ema[param_name] = (
                self.alpha * self.grad_diff_ema[param_name] + 
                (1 - self.alpha) * grad_diff
            )
            
            
            # Use EMA for layer-wise tracking as well
            if param_name not in self.layer_grad_diffs:
                self.layer_grad_diffs[param_name] = grad_diff
            else:
                self.layer_grad_diffs[param_name] = (
                    self.alpha * self.layer_grad_diffs[param_name] + 
                    (1 - self.alpha) * grad_diff
                )
            
            # Calculate layer-wise allocations every few steps to avoid overhead
            if self.step_count % 10 == 0:  # Update allocations every 10 steps
                self.layer_lr_allocation = self._calculate_layer_allocations(use_muon)
            
            # Get learning rate scale based on layer allocation
            if param_name in self.layer_lr_allocation:
                lr_scale = self.layer_lr_allocation[param_name]
            else:
                lr_scale = 1.0
            
            self.adaptive_lr_scales[param_name] = lr_scale
        else:
            lr_scale = 1.0
            
        # Store current gradient for next iteration
        self.prev_grads[param_name] = current_grad.clone().detach()
        
        return lr_scale
    
    def get_lr_statistics(self) -> Dict[str, float]:
        """Get statistics about current learning rate scales."""
        if not self.adaptive_lr_scales:
            return {"mean": 1.0, "min": 1.0, "max": 1.0, "std": 0.0}
            
        scales = list(self.adaptive_lr_scales.values())
        return {
            "mean": np.mean(scales),
            "min": np.min(scales),
            "max": np.max(scales),
            "std": np.std(scales)
        }
    
    def get_layer_statistics(self) -> Dict[str, Dict[str, float]]:
        """Get detailed statistics about layer-wise allocations and gradient differences."""
        stats = {
            "layer_grad_diffs": dict(self.layer_grad_diffs),
            "layer_lr_allocations": dict(self.layer_lr_allocation)
        }
        return stats


class LANTON(torch.optim.Optimizer):
    """
    This optimzer implements the updates seperately for layer groups defined by the different norm
    """

    def __init__(
        self,
        lr=1e-3,
        wd=0.1,
        muon_params=None,
        momentum=0.95,
        nesterov=True,
        ns_steps=5,
        sign_params=None,
        beta=0.9,
        eps=1e-8,
        adaptive_warmup_steps=1000,
        scale1=3000.0,
        scale2=1.0
    ):

        defaults = dict(
            lr=lr,
            wd=wd,
            momentum=momentum,
            nesterov=nesterov,
            ns_steps=ns_steps,
            beta=beta,
            eps=eps,
            scale1=scale1,
            scale2=scale2
        )

        params = list(muon_params)
        sign_params = list(sign_params) if sign_params is not None else []
        params.extend(sign_params)
        self.step_count = 0
        super().__init__(params, defaults)
        # Sort parameters into those for which we will use Muon, and those for which we will not
        for p in muon_params:
            # Use Muon for every parameter in muon_params which is >= 2D and doesn't look like an embedding or head layer
            assert p.ndim == 2, p.ndim
            self.state[p]["use_muon"] = True            
        for p in sign_params:
            self.state[p]["use_muon"] = False
        self.adaptive_tracker_muon = LayerWiseAdaptiveGradientTracker(
            alpha=beta,
            warmup_steps=adaptive_warmup_steps,
        )  
        self.adaptive_tracker_sign = LayerWiseAdaptiveGradientTracker(
            alpha=beta,
            warmup_steps=adaptive_warmup_steps,
        )    
        self.step_count = 0
        self.scale1 = scale1
        self.scale2 = scale2

    def adjust_lr_for_muon(self, lr, param_shape):
        A, B = param_shape[:2]
        # We adjust the learning rate and weight decay based on the size of the parameter matrix
        # as describted in the d-muon paper
        adjusted_ratio = 0.2 * math.sqrt(max(A, B))
        adjusted_lr = lr * adjusted_ratio
        return adjusted_lr

    def step(self, closure=None):
        """Perform a single optimization step.
        Args:
            closure (Callable, optional): A closure that reevaluates the model
                and returns the loss.
        """
        loss = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()
        for group in self.param_groups:

            ############################
            #     Muon update          #
            ############################

            params = [p for p in group["params"] if self.state[p]["use_muon"]]
            # import pdb; pdb.set_trace()
            lr = group["lr"]
            wd = group["wd"]
            momentum = group["momentum"]

            # generate weight updates
            for i, p in enumerate(params):
                # sanity check
                g = p.grad
                if g is None:
                    continue
                if g.ndim > 2:
                    g = g.view(g.size(0), -1)
                assert g is not None

                # calc update
                state = self.state[p]
                if "momentum_buffer" not in state:
                    state["momentum_buffer"] = torch.zeros_like(g)
                buf = state["momentum_buffer"]
                buf.mul_(momentum).add_(g)
                if group["nesterov"]:
                    g = g.add(buf, alpha=momentum)
                else:
                    g = buf
                u = zeropower_via_newtonschulz5(g, steps=group["ns_steps"])

                u *= max(1, g.size(-2) / g.size(-1))**0.5 
                
                # scale update by maxtrix shape                
                muon_lr = self.adjust_lr_for_muon(lr, p.shape)

                # get layer-wise adaptive learning rate scale
                param_name =f"muon_param_{i}_{p.shape}"
                
                layer_wise_lr_scale = self.adaptive_tracker_muon.update_and_get_lr_scale(param_name, p.grad, self.step_count, use_muon=True)
                # 
                adaptive_lr = layer_wise_lr_scale * muon_lr

                # apply weight decay
                p.data.mul_(1 - lr * wd)

                # apply update
                p.data.add_(u, alpha=-adaptive_lr)

            ############################
            #       Sign update       #
            ############################

            params = [p for p in group["params"] if not self.state[p]["use_muon"]]
            lr = group["lr"]
            beta = group["beta"]
            eps = group["eps"]
            weight_decay = group["wd"]


            for i, p in enumerate(params):
                g = p.grad
                if g is None:
                    continue
                state = self.state[p]
                if "momentum_buffer" not in state:
                    state['momentum_buffer'] = torch.zeros_like(g)
                buf = state['momentum_buffer']
                buf.mul_(beta).add_(g, alpha=1-beta)
                g = buf

                if g.ndim == 2:
                    _, d_in = g.shape
                    update = self.scale1 * (1/d_in) * torch.sign(g) 
                elif g.ndim == 1:
                    update = self.scale2 * g/(torch.sqrt(torch.mean(g**2, dim=0, keepdim=True)) + eps)
                else:
                    continue
                param_name = f"sign_param_{i}_{p.shape}"
                # get layer-wise adaptive learning rate scale
                lr_scale = self.adaptive_tracker_sign.update_and_get_lr_scale(param_name, p.grad,  self.step_count, use_muon=False)
                adaptive_lr = lr * lr_scale
                # apply weight decay
                p.data.mul_(1 - adaptive_lr * weight_decay)
                p.data.add_(update, alpha=-adaptive_lr)

        self.step_count += 1
        return loss
