import torch
import numpy as np
from lion import Lion
from lion_VR import Lion_VR
import torch.optim as optim
from torch.nn.utils import clip_grad_norm_

def f_clean(x):
    return 0.5 * torch.sum(x ** 2, dim=1)  

def f_noisy(x):
    # Added Noise
    noise = torch.randn_like(x) # Normal noise
    #noise = pareto_noise_like(x, alpha=2.5, symmetric=True) # Pareto noise
    if torch.isnan(noise).any():
        print("Warning: NaNs detected in noise")
    if torch.isinf(noise).any():
        print("Warning: inf detected in noise")
    return f_clean(x) + torch.sum(noise * x, dim=1), noise  

def pareto_noise_like(x, alpha=2.5, symmetric=True, eps=1e-8):
    U = torch.clamp(torch.rand_like(x), min=eps)  
    noise = U.pow(-1 / alpha)
    if symmetric:
        signs = torch.randint_like(x, low=0, high=2, dtype=torch.float32) * 2 - 1
        noise *= signs
    return noise

def f_noisy_pareto(x):
    noise = pareto_noise_like(x, alpha=1.5, symmetric=False)
    return f_clean(x) + torch.sum(noise * x, dim=1)

def main():
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    d = 1
    #d = 1000
    k = 100000
    T = 100
                
    # Initialize x
    x = (torch.ones(k, d, device=device) * 1.5).requires_grad_()
    
    # Learning rate
    lr = 1e-1
    
    # Weight decay
    wd = 1.0
    
    # Clipping
    cl = 1.0

    # Optimizer
    optimizer = Lion([x], betas=(0.8, 0.9), lr=lr, weight_decay=wd)
    #optimizer = Lion_VR([x], betas=(0.8, 0.9), lr=lr, weight_decay=wd) 

    avg_grad_norms = torch.zeros(k, T, device=device)
    avg = torch.zeros(k, device=device)

    prev_noise = None
    prev_x = None

    for step in range(T):
        loss, noise = f_noisy(x)  
        loss.sum().backward()

        if prev_x is not None and prev_noise is not None:
            prev_grad = prev_x.detach() + noise.detach()
        else:
            prev_grad = None

        with torch.no_grad():
            grad_norms = x.norm(dim=1) 
            avg = (avg * step + grad_norms) / (step + 1)
            avg_grad_norms[:, step] = avg

        prev_x = x.detach().clone()
        prev_noise = noise.detach().clone()

        optimizer.curr_grads = {
            p: p.grad.detach().clone()
            for group in optimizer.param_groups
            for p in group['params']
            if p.grad is not None
        }

        optimizer.prev_grads = prev_grad

        # Optional clipping
        # grads = x.grad
        # norms = grads.norm(p=2, dim=1, keepdim=True)  # shape (k, 1)
        # max_norm = cl
        # scale = torch.clamp(max_norm / (norms + 1e-6), max=1.0)
        # grads.mul_(scale)

        optimizer.step()
        optimizer.zero_grad()

    avg_grad_norms_cpu = avg_grad_norms.cpu().numpy()
    final_vals = avg_grad_norms_cpu[:, -1]
    sorted_indices = np.argsort(final_vals)
    median_index = sorted_indices[100 // 2]
    median_run = final_vals[median_index]
    
    np.savez("Lion_synthetic.npz", avg_norms=avg_grad_norms_cpu)

if __name__ == '__main__':
    main()
