import matplotlib.pyplot as plt
import numpy as np
from util import *

from algs import *

dim = 201
f = construct_f()
# f_star = - 0.25 * (dim - 1)
f_star = 0
batch_size = dim
x = list(np.zeros(dim))
print(x)

repeat = 10

# np.random.seed(11)

"""=================== Perturbed Approximate Gradient Descent (PAGD)  ==================="""

# pagd_vals = []
# pagd_complexity = []
# for i in range(repeat):
#     complexity, vals = pagd(f, x, iters=1800, L=110)
#     pagd_vals.append(vals)
#     pagd_complexity.append(complexity)
# np.savez('data_mean_dim_200/pagd', pagd_complexity=pagd_complexity, pagd_vals=pagd_vals)


"""=================== Random Search Power Iteration  ==================="""

# rspi_vals = []
# rspi_complexity = []
# for i in range(repeat):
#     complexity, vals = rspi(f, x, iters=40, L=200, sigma_1=1.5, sigma_2=0.65, T_sigma_1=15, ratio=0.96)
#     rspi_vals.append(vals)
#     rspi_complexity.append(complexity)
# np.savez('data_mean_dim_200/rspi', rspi_complexity=rspi_complexity, rspi_vals=rspi_vals)


"""=================== Random Search Power Iteration (with SPSA acceleration) ==================="""

# rspi_spsa_vals = []
# rspi_spsa_complexity = []
# for i in range(repeat):
#     complexity, vals = rspi_spsa(f, x, iters=3000, L=200, sigma_1=1.75, sigma_2=0.65, T_sigma_1=15, ratio=0.98)
#     rspi_spsa_vals.append(vals)
#     rspi_spsa_complexity.append(complexity)
# np.savez('data_mean_dim_200/rspi_spsa', rspi_spsa_complexity=rspi_spsa_complexity, rspi_spsa_vals=rspi_spsa_vals)

"""=================== ZO-GD-NCF ==================="""

zo_gd_ncf_vals = []
zo_gd_ncf_complexity = []
for i in range(repeat):
    complexity, vals = zo_gd_ncf(f, x, iters=1000, L=150, rho = 10)
    zo_gd_ncf_vals.append(vals)
    zo_gd_ncf_complexity.append(complexity)
np.savez('data_mean_dim_200/zo_gd_ncf', zo_gd_ncf_complexity=zo_gd_ncf_complexity, zo_gd_ncf_vals=zo_gd_ncf_vals)

"""=================== ZO Perturbed AGD ==================="""

# zo_p_agd_vals = []
# zo_p_agd_complexity = []
# for i in range(repeat):
#     complexity, vals = zo_p_agd(f, x, iters = 1000, L=200, rho = 10)
#     zo_p_agd_vals.append(vals)
#     zo_p_agd_complexity.append(complexity)
# np.savez('data_mean_dim_200/zo_p_agd', zo_p_agd_complexity=zo_p_agd_complexity, zo_p_agd_vals=zo_p_agd_vals)


"""=================== ZO Perturbed AGD with ANCF==================="""

# zo_p_agd_ancf_vals = []
# zo_p_agd_ancf_complexity = []
# for i in range(repeat):
#     complexity, vals = zo_p_agd_ancf(f, x, iters = 1000, L=200, rho = 10)
#     zo_p_agd_ancf_vals.append(vals)
#     zo_p_agd_ancf_complexity.append(complexity)
# np.savez('data_mean_dim_200/zo_p_agd_ancf', zo_p_agd_ancf_complexity=zo_p_agd_ancf_complexity, zo_p_agd_ancf_vals=zo_p_agd_ancf_vals)


# load data
data_pagd = np.load('data_mean_dim_200/pagd.npz')
pagd_complexity = data_pagd['pagd_complexity']
pagd_vals = data_pagd['pagd_vals']
pagd_complexity_mean = np.mean(pagd_complexity, axis=0)
pagd_vals_mean = np.mean(pagd_vals, axis=0)
pagd_vals_max = np.max(pagd_vals, axis=0)
pagd_vals_min = np.min(pagd_vals, axis=0)

data_rspi = np.load('data_mean_dim_200/rspi.npz')
rspi_complexity = data_rspi['rspi_complexity']
rspi_vals = data_rspi['rspi_vals']
rspi_complexity_mean = np.mean(rspi_complexity, axis=0)
rspi_vals_mean = np.mean(rspi_vals, axis=0)
rspi_vals_max = np.max(rspi_vals, axis=0)
rspi_vals_min = np.min(rspi_vals, axis=0)

data_spsa_rspi = np.load('data_mean_dim_200/rspi_spsa.npz')
rspi_spsa_complexity = data_spsa_rspi['rspi_spsa_complexity']
rspi_spsa_vals = data_spsa_rspi['rspi_spsa_vals']
rspi_spsa_complexity_mean = np.mean(rspi_spsa_complexity, axis=0)
rspi_spsa_vals_mean = np.mean(rspi_spsa_vals, axis=0)
rspi_spsa_vals_max = np.max(rspi_spsa_vals, axis=0)
rspi_spsa_vals_min = np.min(rspi_spsa_vals, axis=0)

data_zo_gd_ncf = np.load('data_mean_dim_200/zo_gd_ncf.npz', allow_pickle=True)
zo_gd_ncf_complexity = data_zo_gd_ncf['zo_gd_ncf_complexity']
zo_gd_ncf_vals = data_zo_gd_ncf['zo_gd_ncf_vals']
zo_gd_ncf_complexity_mean = np.mean(zo_gd_ncf_complexity, axis=0)
zo_gd_ncf_vals_mean = np.mean(zo_gd_ncf_vals, axis=0)
zo_gd_ncf_vals_max = np.max(zo_gd_ncf_vals, axis=0)
zo_gd_ncf_vals_min = np.min(zo_gd_ncf_vals, axis=0)

data_zo_p_agd = np.load('data_mean_dim_200/zo_p_agd.npz')
zo_p_agd_complexity = data_zo_p_agd['zo_p_agd_complexity']
zo_p_agd_vals = data_zo_p_agd['zo_p_agd_vals']
zo_p_agd_complexity_mean = np.mean(zo_p_agd_complexity, axis=0)
zo_p_agd_vals_mean = np.mean(zo_p_agd_vals, axis=0)
zo_p_agd_vals_max = np.max(zo_p_agd_vals, axis=0)
zo_p_agd_vals_min = np.min(zo_p_agd_vals, axis=0)

data_zo_p_agd_ancf = np.load('data_mean_dim_200/zo_p_agd_ancf.npz')
zo_p_agd_ancf_complexity = data_zo_p_agd_ancf['zo_p_agd_ancf_complexity']
zo_p_agd_ancf_vals = data_zo_p_agd_ancf['zo_p_agd_ancf_vals']
zo_p_agd_ancf_complexity_mean = np.mean(zo_p_agd_ancf_complexity, axis=0)
zo_p_agd_ancf_vals_mean = np.mean(zo_p_agd_ancf_vals, axis=0)
zo_p_agd_ancf_vals_max = np.max(zo_p_agd_ancf_vals, axis=0)
zo_p_agd_ancf_vals_min = np.min(zo_p_agd_ancf_vals, axis=0)

# plot-FQC
plt.plot(pagd_complexity_mean, pagd_vals_mean - f_star, label='PAGD')
plt.fill_between(pagd_complexity_mean, pagd_vals_min - f_star, pagd_vals_max - f_star, alpha=0.1, color='b')

plt.plot(rspi_complexity_mean, rspi_vals_mean - f_star, label='RSPI')
plt.fill_between(rspi_complexity_mean, rspi_vals_min - f_star, rspi_vals_max - f_star, alpha=0.1, color='orange')

plt.plot(rspi_spsa_complexity_mean, rspi_spsa_vals_mean - f_star, label='RSPI (SPSA)')
plt.fill_between(rspi_spsa_complexity_mean, rspi_spsa_vals_min - f_star, rspi_spsa_vals_max - f_star, alpha=0.1, color='g')

plt.plot(zo_gd_ncf_complexity_mean, zo_gd_ncf_vals_mean - f_star, label='ZO_GD_NCF', color='coral')
plt.fill_between(zo_gd_ncf_complexity_mean, zo_gd_ncf_vals_min - f_star, zo_gd_ncf_vals_max - f_star, alpha=0.1, color='coral')

plt.plot(zo_p_agd_complexity_mean, zo_p_agd_vals_mean - f_star, label='ZO_Pertrubed_AGD')
plt.fill_between(zo_p_agd_complexity_mean, zo_p_agd_vals_min - f_star, zo_p_agd_vals_max - f_star, alpha=0.1, color='r')

plt.plot(zo_p_agd_ancf_complexity_mean, zo_p_agd_ancf_vals_mean - f_star, label='ZO_Pertrubed_AGD_ANCF')
plt.fill_between(zo_p_agd_ancf_complexity_mean, zo_p_agd_ancf_vals_min - f_star, zo_p_agd_ancf_vals_max - f_star, alpha=0.1, color='purple')


plt.xlabel('# of Function Query')
plt.ylabel('Objective Function')
# plt.yscale('log')
plt.ticklabel_format(style='sci', scilimits=(0, 0), axis='x')
plt.legend()
plt.savefig('figures/quartic_mean_200.pdf', bbox_inches='tight')
plt.show()

# plot-iteration