import matplotlib.pyplot as plt
import numpy as np
from util import *

from alg_deterministic import *

dim = 100
mini_dim = 10
A = construct_matrix_deterministic(dim, mini_dim)
print(type(A))
b = construct_vector_deterministic(dim)
L = 100        # tuning
rho = 1
rho_zpsgd = 2e-3
f = construct_f_deterministic(A, b)

iters = 5000
iters_zpsgd = iters * 2
iters_rspi = 100
batch_size = dim
x = list(np.zeros(dim))
print(x)

np.random.seed(10)

pagd_complexity, pagd_values = experiment_pagd(f, x, iters, L)
np.savez('pagd', pagd_complexity = pagd_complexity, pagd_values = pagd_values)
zo_gd_ncf_complexity, zo_gd_ncf_vals = experiment_zo_ncf_gd(f, x, iters, L, rho)
np.savez('zo_gd_ncf', zo_gd_ncf_complexity = zo_gd_ncf_complexity, zo_gd_ncf_vals = zo_gd_ncf_vals)
zpsgd_complexity, zpsgd_vals = zpsgd(f, x, iters_zpsgd, batch_size, rho_zpsgd, L)
np.savez('zpsgd', zpsgd_complexity = zpsgd_complexity, zpsgd_vals = zpsgd_vals)
rspi_complexity, rspi_vals = rspi(f, x, iters_rspi, L, sigma_1=0.5, sigma_2=0.4, T_sigma_1=20, ratio=0.98)
np.savez('rspi', rspi_complexity = rspi_complexity, rspi_vals = rspi_vals)


# load data
data_pagd = np.load('pagd.npz')
data_zo_gd_ncf = np.load('zo_gd_ncf.npz')
data_zpsgd = np.load('zpsgd.npz')
data_rspi = np.load('rspi.npz')

pagd_complexity = data_pagd['pagd_complexity']
pagd_values = data_pagd['pagd_values']
zo_gd_ncf_complexity = data_zo_gd_ncf['zo_gd_ncf_complexity']
zo_gd_ncf_vals = data_zo_gd_ncf['zo_gd_ncf_vals']
zpsgd_complexity = data_zpsgd['zpsgd_complexity']
zpsgd_vals = data_zpsgd['zpsgd_vals']
rspi_complexity = data_rspi['rspi_complexity']
rspi_vals = data_rspi['rspi_vals']


plt.plot(pagd_complexity, pagd_values, label='PAGD')
plt.plot(zo_gd_ncf_complexity, zo_gd_ncf_vals, label='ZO-NCF-GD')
plt.plot(zpsgd_complexity, zpsgd_vals, label='ZPSGD')
plt.plot(rspi_complexity, rspi_vals, label='RSPI')


plt.xlabel('# of Function Query')
plt.ylabel('Objective Function')
plt.legend()
plt.savefig('figures/cubic_deterministic_100.pdf', bbox_inches='tight')
plt.show()
