import matplotlib.pyplot as plt
import numpy as np

from algs import *
# from octopus.zo_ncf_gd_1 import zo_gd_ncf

iters = 2000
iters_rspi = 300
x = [0 + 0.1 for i in range(10)]
print(type(x))
batch_size = len(x)

np.random.seed(10)

pagd_complexity, pagd_vals = experiment_pagd(x, iters)
np.savez('data/pagd_10', pagd_complexity = pagd_complexity, pagd_vals = pagd_vals)
zo_gd_ncf_complexity, zo_gd_ncf_vals = zo_gd_ncf(x, iters)
np.savez('data/zo_gd_ncf_10', zo_gd_ncf_complexity = zo_gd_ncf_complexity, zo_gd_ncf_vals = zo_gd_ncf_vals)
zpsgd_complexity, zpsgd_vals = zpsgd(x, iters, batch_size)
np.savez('data/zpsgd_10', zpsgd_complexity = zpsgd_complexity, zpsgd_vals = zpsgd_vals)
rspi_complexity, rspi_vals = rspi(x, iters_rspi, sigma_1=1, sigma_2=1.25, T_sigma_1 = 20, ratio=0.95)
np.savez('data/rspi_10', rspi_complexity = rspi_complexity, rspi_vals = rspi_vals)





data_pagd = np.load('data/pagd_10.npz')
data_zo_gd_ncf = np.load('data/zo_gd_ncf_10.npz')
data_zpsgd = np.load('data/zpsgd_10.npz')
data_rspi = np.load('data/rspi_10.npz')

pagd_complexity = data_pagd['pagd_complexity']
pagd_vals = data_pagd['pagd_vals']
zo_gd_ncf_complexity = data_zo_gd_ncf['zo_gd_ncf_complexity']
zo_gd_ncf_vals = data_zo_gd_ncf['zo_gd_ncf_vals']
zpsgd_complexity = data_zpsgd['zpsgd_complexity']
zpsgd_vals = data_zpsgd['zpsgd_vals']
rspi_complexity = data_rspi['rspi_complexity']
rspi_vals = data_rspi['rspi_vals']


plt.rcParams.update({'font.size': 14})
plt.figure(figsize=(8, 6))

plt.plot(pagd_complexity, pagd_vals, label='PAGD')
plt.plot(zo_gd_ncf_complexity, zo_gd_ncf_vals, label='ZO-GD-NCF')
plt.plot(zpsgd_complexity, zpsgd_vals, label = 'ZPSGD')
plt.plot(rspi_complexity, rspi_vals, label = 'RSPI')


plt.xlabel('Function Query')
plt.ylabel('Objective Function')
plt.legend()
plt.title('d = 10')
plt.savefig('figures/octopus_dim_10.pdf', bbox_inches='tight')
plt.show()