import optuna
import torch
import numpy as np
import os

class ZOSA(torch.optim.Optimizer):
    def __init__(self, params, rho=0.05, epsilon=1e-3, m=4, lr=1e-3):
        defaults = dict(rho=rho, epsilon=epsilon, m=m, lr=lr)
        super(ZOSA, self).__init__(params, defaults)

    def step(self, closure):
        for group in self.param_groups:
            rho = group['rho']
            epsilon = group['epsilon']
            m = group['m']
            lr = group['lr']
            
            u_list = []
            for _ in range(m):
                torch.manual_seed(torch.randint(0, 10000, (1,)))
                u = [torch.randint(0, 2, p.size(), device=p.device, dtype=p.dtype) * 2 - 1 for p in group['params']]
                u_list.append(u)
            
            l0 = closure()
            
            li_list = []
            for u in u_list:
                for p, u_p in zip(group['params'], u):
                    p.data.add_(epsilon * u_p)
                li = closure()
                for p, u_p in zip(group['params'], u):
                    p.data.add_(-epsilon * u_p)
                li_list.append(li)
            
            g_t = [torch.zeros_like(p) for p in group['params']]
            for u, li in zip(u_list, li_list):
                for g, u_p in zip(g_t, u):
                    g.add_((li - l0) * u_p / (epsilon * m))
            
            li_tensor = torch.tensor(li_list, dtype=torch.float32)
            sigma_t = torch.std(li_tensor, unbiased=False)           

            epsilon_sam = [rho * g / sigma_t for g in g_t] if sigma_t > 0 else [torch.zeros_like(g) for g in g_t]
            for p, eps in zip(group['params'], epsilon_sam):
                p.data.add_(eps)
            
            l_pert = closure()
            
            u_pert_list = []
            for _ in range(m):
                torch.manual_seed(torch.randint(0, 10000, (1,)))
                u_pert = [torch.randint(0, 2, p.size(), device=p.device, dtype=p.dtype) * 2 - 1 for p in group['params']]
                u_pert_list.append(u_pert)
            
            li_pert_list = []
            for u_pert in u_pert_list:
                for p, u_p in zip(group['params'], u_pert):
                    p.data.add_(epsilon * u_p)
                li_pert = closure()
                for p, u_p in zip(group['params'], u_pert):
                    p.data.add_(-epsilon * u_p)
                li_pert_list.append(li_pert)
            
            g_pert = [torch.zeros_like(p) for p in group['params']]
            for u_pert, li_pert in zip(u_pert_list, li_pert_list):
                for g, u_p in zip(g_pert, u_pert):
                    g.add_((li_pert - l_pert) * u_p / (epsilon * m))
            
            for p, eps in zip(group['params'], epsilon_sam):
                p.data.add_(-eps)
            
            li_pert_tensor = torch.tensor(li_pert_list, dtype=torch.float32)
            sigma_t_pert = torch.std(li_pert_tensor, unbiased=False)
            adaptive_lr = lr / sigma_t_pert if sigma_t_pert > 0 else lr

            for p, g in zip(group['params'], g_pert):
                p.data.add_(-adaptive_lr * g)
                # p.data.add_(-lr * g)
        
        return closure()