import torch
import numpy as np
from muon_batch import BatchedMuon 
from muon_batch_VR import BatchedMuon_VR

def f_clean(x):
    return 0.5 * torch.sum(x ** 2, dim=(1, 2))  

def pareto_noise_like(x, alpha=1.5, symmetric=True, eps=1e-8):
    U = torch.clamp(torch.rand_like(x), min=eps)
    noise = U.pow(-1 / alpha)
    if symmetric:
        signs = torch.randint_like(x, low=0, high=2, dtype=torch.float32) * 2 - 1
        noise *= signs
    return noise

def f_noisy(x):
    # Added noise
    noise = torch.randn_like(x) # Normal noise
    #noise = pareto_noise_like(x, alpha=1.5, symmetric=True) # Pareto noise
    loss = f_clean(x) + torch.sum(noise * x, dim=(1, 2))  
    return loss, noise

def main():
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    d = 1
    #d = 30
    k = 100000
    T = 100

    # Initialize x
    x = (torch.ones(k, d, d, device=device) * 1.5).requires_grad_()
    
    # Learning rate
    lr = 1e-1
    
    # Weight decay
    wd = 1.0
    
    # Clipping
    cl = 1.0

    # Optimizer
    optimizer = BatchedMuon([x], lr=lr, momentum=0.9, weight_decay=wd, backend_steps=5)
    #optimizer = BatchedMuon_VR([x], lr=lr, momentum=0.9, weight_decay=wd, backend_steps=5)

    prev_noise = None
    prev_x = None
    avg = torch.zeros(k, device=device)
    avg_grad_norms = torch.zeros(k, T, device=device)

    for step in range(T):
        loss, noise = f_noisy(x)
        loss.backward(torch.ones_like(loss))

        if prev_x is not None and prev_noise is not None:
            optimizer.prev_grads = {x: prev_x.detach() + noise.detach()}
        else:
            optimizer.prev_grads = None

        optimizer.curr_grads = {x: x.grad.detach().clone()}

        with torch.no_grad():
            grad_norms = x.norm(p='fro', dim=(1, 2)) 
            avg = (avg * step + grad_norms) / (step + 1)
            avg_grad_norms[:, step] = avg

        prev_x = x.detach().clone()
        prev_noise = noise.detach().clone()

        # Optional clipping
        # with torch.no_grad():
        #     grad = x.grad 
        #     norm = grad.norm(p='fro', dim=(1, 2), keepdim=True)  
        #     max_norm = 1.0  # or any clip threshold 
        #     scale = torch.clamp(max_norm / (norm + 1e-6), max=1.0)  
        #     grad.mul_(scale)  # in-place rescaling

        optimizer.step()
        optimizer.zero_grad()

    avg_grad_norms_cpu = avg_grad_norms.cpu().numpy()
    final_vals = avg_grad_norms_cpu[:, -1]
    sorted_indices = np.argsort(final_vals)
    median_index = sorted_indices[100 // 2]
    median_run = final_vals[median_index]

    np.savez("Muon_synthetic.npz", avg_norms=avg_grad_norms.cpu().numpy())

if __name__ == "__main__":
    main()