import matplotlib.pyplot as plt
import numpy as np
from util import *

from alg_stochastic import *

num = 2000
dim = 20
mini_dim = 4
batch_size = 128
mini_batch_size = 10

#############################################################################
iters_sgd = 5000  # sgd
iters_scsg_rand = 10000
iters_scsg_coord = 6000
epoch_rand = int(iters_scsg_rand / (batch_size / mini_batch_size))  # scsg-rand
epoch_coord = int(iters_scsg_coord / (batch_size / mini_batch_size))   # scsg-coord

iters_spider = 1000
epoch_spider = int(iters_spider / (batch_size / mini_batch_size))  # spider
epoch_spider_size = 20  # spider

batch_g_scrn = batch_size       # scrn
batch_h_scrn = mini_batch_size  # scrn


iters_zpsgd = 1000
#############################################################################
# epoch_rand = 1000
# epoch_coord = 200

p = 0.01

A = construct_random_matrix(num, dim, mini_dim)
b = construct_random_vector(num, dim)

L = 100      # tuning
rho = 1       # tuning

f = construct_f_stochastic(A, b)


x_0 = list(np.zeros(dim)+0)
print(x_0)

np.random.seed(10)


fqc_sgd, zo_sgd_ncf_vals = zo_sgd_ncf(f, num, batch_size, x_0, iters_sgd, p, L, rho)
np.savez('zo_sgd_ncf.npz', fqc_sgd=fqc_sgd, zo_sgd_ncf_vals=zo_sgd_ncf_vals)

fqc_scsg_rand, zo_scsg_ncf_rand_vals = zo_scsg_ncf_rand(f, num, batch_size, mini_batch_size, x_0, epoch_rand, p, L, rho)
np.savez('zo_scsg_ncf_rand.npz', fqc_scsg_rand=fqc_scsg_rand, zo_scsg_ncf_rand_vals=zo_scsg_ncf_rand_vals)

fqc_scsg_coord, zo_scsg_ncf_coord_vals = zo_scsg_ncf_coord(f, num, batch_size, mini_batch_size, x_0, epoch_coord, p, L, rho)
np.savez('zo_scsg_ncf_coord.npz', fqc_scsg_coord=fqc_scsg_coord, zo_scsg_ncf_coord_vals=zo_scsg_ncf_coord_vals)

fqc_spider, zo_spider_ncf_vals = zo_spider_ncf(f, num, batch_size, mini_batch_size, x_0, epoch_spider, epoch_spider_size, p, L, rho)
np.savez('zo_spider_ncf.npz', fqc_spider=fqc_spider, zo_spider_ncf_vals=zo_spider_ncf_vals)


# load data
data_zo_sgd = np.load('zo_sgd_ncf.npz')
data_zo_scsg_rand = np.load('zo_scsg_ncf_rand.npz')
data_zo_scsg_coord = np.load('zo_scsg_ncf_coord.npz')
data_zo_spider = np.load('zo_spider_ncf.npz')

fqc_sgd = data_zo_sgd['fqc_sgd']
zo_sgd_ncf_vals = data_zo_sgd['zo_sgd_ncf_vals']
fqc_scsg_rand = data_zo_scsg_rand['fqc_scsg_rand']
zo_scsg_ncf_rand_vals = data_zo_scsg_rand['zo_scsg_ncf_rand_vals']
fqc_scsg_coord=data_zo_scsg_coord['fqc_scsg_coord']
zo_scsg_ncf_coord_vals=data_zo_scsg_coord['zo_scsg_ncf_coord_vals']
fqc_spider = data_zo_spider['fqc_spider']
zo_spider_ncf_vals = data_zo_spider['zo_spider_ncf_vals']


# plot figures
plt.plot(fqc_sgd, zo_sgd_ncf_vals, label='ZO-SGD-NCF')
plt.plot(fqc_scsg_coord, zo_scsg_ncf_coord_vals, label='ZO-SCSG-NCF (Option I)')
plt.plot(fqc_scsg_rand, zo_scsg_ncf_rand_vals, label='ZO-SCSG-NCF (Option II)')
plt.plot(fqc_spider, zo_spider_ncf_vals, label='ZO-SPIDER-NCF')



plt.xlabel('# of Function Query')
plt.ylabel('Objective Function')
plt.legend()
plt.savefig('figures/cubic_stochastic_20.pdf', bbox_inches='tight')
plt.show()
