from plot_traj_eq import *
import numpy as np
from ZO_methods_eq import *
if __name__ == "__main__":
    # Example nonconvex contraints
    np.random.seed(14893)
    n=100
    Q = np.random.randn(n, n)
    Q = Q.T @ Q  # Make Q symmetric positive semi-definite
    #Q = np.eye(n)
    c = np.random.randn(n)
    A = np.random.randn(n)
    b = 20
    def f(x):
        return 0.5 * x @ Q @ x + c @ x 

    def h(x):
        return A @ x - b + 0.5*x@x 
    def df(x):
        return Q@x + c
    def dh(x):
        return A + x
    

    #initial_guess = np.random.randn(n)
    initial_guess = np.zeros(n,)
    best_point,best_value, history = ZO_baseline(f, h, initial_guess, \
                                            learning_rate=0.0001, r=0.01, epsilon=1e-5, TB=8, k=100000, max_iter=60000)
    best_point2,best_value2, history2 = ZOFL(f, h, initial_guess, \
                                            learning_rate=0.0001,r=0.01, epsilon=1e-5, TB=4, k=1000, max_iter=60000)
    best_point3,best_value3, history3 = ZOFL_midpoint(f, h, initial_guess, \
                                            learning_rate=0.0001,r=0.01, epsilon=1e-5, TB=4, k=1000, max_iter=60000)
    best_point4,best_value4, history4 = ZOGDA(f, h, initial_guess, \
                                            learning_rate_x=0.0001,learning_rate_lambda=0.1,r=0.01, TB=8, max_iter=60000)
    optimal_x,optimal_value,_ = first_order_opt(f, h, df, dh,initial_guess, \
                                            learning_rate=0.001,r=0.01, epsilon=1e-5, TB=4, k=1, max_iter=60000,tol=1e-15)

    best_point5,best_value5, history5 = ConEx_meta(f, h, initial_guess, \
                                            learning_rate_x=0.0003,learning_rate_lambda=0.003,r=0.01, TB=8, max_iter_inner=100, max_iter_outer=601,theta=0.2,mu_h=10)
            
    plot_optimization_history_moving_average(history4, history5, history, history2,history3, optimal_value=optimal_value,\
                          legends = ['ZOGDA', 'SZO-ConEx', 'ZO-Baseline','ZOFL','ZOFL-midpoint'],linewidth=3,savename='comparison_eq.png',colors = ['#1f77b4', '#2ca02c', '#ff7f0e', '#d62728', '#9467bd']
)

    np.savez('test_qp.npz', history=history, history2=history2, history3=history3, history4=history4, history5=history5, \
        best_point=best_point, best_point2=best_point2, best_point3=best_point3, best_point4=best_point4, best_point5=best_point5, \
            best_value=best_value, best_value2=best_value2, best_value3=best_value3, best_value4=best_value4, best_value5=best_value5, optimal_value=optimal_value)