import torch
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
import seaborn as sns

sns.set_style("whitegrid")
plt.rcParams.update({
    'font.size': 20,
    'figure.figsize': (12, 8),
    'axes.titlesize': 20,
    'axes.labelsize': 20,
    'xtick.labelsize': 16,
    'ytick.labelsize': 16
})

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

def hard_thresholding(arr, k):
    top_k_indices = torch.topk(torch.abs(arr), k).indices
    thresholded_arr = torch.zeros_like(arr)
    thresholded_arr[top_k_indices] = arr[top_k_indices]
    return thresholded_arr

n = 100
d = 100
k = 30
eta = 1e-5
mu = 1e-4
q = 20

T = 2000
m = 5
iter_max = 5000


torch.manual_seed(42)
base_x_train = 10 * torch.rand((n, d))
base_w_true = hard_thresholding(torch.rand(d), k)
base_y_train = base_x_train @ base_w_true

def loss_function(x, y, w, Delta):

    noise = Delta * torch.rand(1, device=device)
    return torch.mean((x @ w - y) ** 2) + noise

def ZO_gradient(x, y, w, q, mu, Delta):

    w_ZO = torch.zeros_like(w)
    d = w.shape[0]
    
    for _ in range(q):
        u = torch.randn(d, device=device)
        u = u / torch.norm(u)
        

        loss_plus = loss_function(x, y, w + mu * u, Delta)

        loss_orig = loss_function(x, y, w, Delta)
        
        w_ZO += (loss_plus - loss_orig) * u
    
    return w_ZO * d / (q * mu)

def SZOHT(eta, iter_max, q, k, mu, Delta, experiment_seed):

    torch.manual_seed(experiment_seed)
    np.random.seed(experiment_seed)
    
    x_train = base_x_train.clone().to(device)
    y_train = base_y_train.clone().to(device)
    w_true = base_w_true.clone().to(device)
    
    loss_list = []
    w = hard_thresholding(10 * torch.ones(d, device=device), k)
    
    for t in range(iter_max):

        with torch.no_grad():
            current_loss = torch.mean((x_train @ w - y_train) ** 2).item()
        loss_list.append(current_loss)
        

        sample_i = torch.randint(0, n, (1,)).item()
        

        grad_est = ZO_gradient(x_train[sample_i], y_train[sample_i], w, q, mu, Delta)
        w = w - eta * grad_est
        

        w = hard_thresholding(w, k)
    
    return loss_list

def VR_SZHT(eta, T, m, q, k, mu, Delta, experiment_seed):
    torch.manual_seed(experiment_seed)
    np.random.seed(experiment_seed)
    
    x_train = base_x_train.clone().to(device)
    y_train = base_y_train.clone().to(device)
    w_true = base_w_true.clone().to(device)
    
    loss_list = []
    w = hard_thresholding(10 * torch.ones(d, device=device), k)
    
    total_iterations = T * m
    
    for t in range(T):
        w_last = w.clone()
        
        fg = ZO_gradient(x_train, y_train, w_last, q, mu, Delta)
        
        for r in range(m):
            iter_idx = t * m + r
            
            with torch.no_grad():
                current_loss = torch.mean((x_train @ w - y_train) ** 2).item()
            loss_list.append(current_loss)
            
            sample_i = torch.randint(0, n, (1,)).item()
            
            grad_current = ZO_gradient(
                x_train[sample_i], y_train[sample_i], w, q, mu, Delta
            )
            grad_last = ZO_gradient(
                x_train[sample_i], y_train[sample_i], w_last, q, mu, Delta
            )
            
            w = w - eta * (grad_current - grad_last + fg)
            w = hard_thresholding(w, k)
    
    return loss_list

def SAGA(eta, iter_max, q, k, mu, Delta, experiment_seed):
    torch.manual_seed(experiment_seed)
    np.random.seed(experiment_seed)
    
    x_train = base_x_train.clone().to(device)
    y_train = base_y_train.clone().to(device)
    w_true = base_w_true.clone().to(device)
    
    loss_list = []
    w = hard_thresholding(10 * torch.ones(d, device=device), k)
    
    gradient_memory = torch.zeros((n, d), device=device)
    gradient_avg = torch.mean(gradient_memory, dim=0)
    
    for t in range(iter_max):
        with torch.no_grad():
            current_loss = torch.mean((x_train @ w - y_train) ** 2).item()
        loss_list.append(current_loss)
        
        sample_i = torch.randint(0, n, (1,)).item()
        
        current_gradient = ZO_gradient(
            x_train[sample_i], y_train[sample_i], w, q, mu, Delta
        )
        
        w = w - eta * (current_gradient - gradient_memory[sample_i] + gradient_avg)
        
        gradient_avg += (current_gradient - gradient_memory[sample_i]) / n
        gradient_memory[sample_i] = current_gradient.clone()
        
        w = hard_thresholding(w, k)
    
    return loss_list

noise_bound_list = [0, 10, 20, 30, 40]
num_experiments = 5
iter_max = 5000
seed = [12, 24, 36, 37, 42]

all_results = {}

for noise_bound in noise_bound_list:
    print(f"\nRunning experiments for noise bound: {noise_bound}")
    
    noise_results = []
    
    for exp_idx in tqdm(range(num_experiments), desc=f"Noise Δ={noise_bound}"):
        # loss_history = SZOHT(eta, iter_max, q, k, mu, noise_bound, experiment_seed=seed[exp_idx])
        loss_history = VR_SZHT(eta, T, m, q, k, mu, noise_bound, experiment_seed=seed[exp_idx])
        # loss_history = SAGA(eta, iter_max, q, k, mu, noise_bound, experiment_seed=seed[exp_idx])
        noise_results.append(loss_history)
    
    noise_results = np.array(noise_results)
    all_results[noise_bound] = noise_results

for i, noise_bound in enumerate(noise_bound_list):
    results = all_results[noise_bound]
    
    mean_loss = np.mean(results, axis=0)
    std_loss = np.std(results, axis=0)
    
    # iterations = np.arange(iter_max)
    iterations = np.arange(T*m)

    plt.plot(iterations, mean_loss, label=f'Δ={noise_bound}', 
             linewidth=2, alpha=0.8)
    
    plt.fill_between(iterations, 
                    mean_loss - std_loss, 
                    mean_loss + std_loss, 
                    alpha=0.2)

plt.xlabel('Iteration', fontsize=20)
plt.ylabel('Loss', fontsize=20)

plt.legend(title='Noise Level', fontsize=20, title_fontsize=20)
plt.grid(True, which="both", ls="--", alpha=0.3)

plt.tight_layout()
# plt.savefig('SZOHT.png', dpi=300, bbox_inches='tight')
plt.savefig('VR_SZHT.png', dpi=300, bbox_inches='tight')
# plt.savefig('SAGA.png', dpi=300, bbox_inches='tight')
plt.show()