import matplotlib.pyplot as plt
import numpy as np

dim = [20, 100, 200, 1000]
str_dim = str(dim)
fig, ax = plt.subplots(2, 4, figsize=(20, 8))

for i in range(len(dim)):
    # load data
    data_pagd = np.load('data_mean_' + str(dim[i]) + '/pagd.npz')
    pagd_complexity = data_pagd['pagd_complexity']
    pagd_vals = data_pagd['pagd_vals']
    pagd_complexity_mean = np.mean(pagd_complexity, axis=0)
    pagd_vals_mean = np.mean(pagd_vals, axis=0)
    pagd_vals_max = np.max(pagd_vals, axis=0)
    pagd_vals_min = np.min(pagd_vals, axis=0)

    data_zo_p_agd = np.load('data_mean_' + str(dim[i]) + '/zo_p_agd.npz')
    zo_p_agd_complexity = data_zo_p_agd['zo_p_agd_complexity']
    zo_p_agd_vals = data_zo_p_agd['zo_p_agd_vals']
    zo_p_agd_complexity_mean = np.mean(zo_p_agd_complexity, axis=0)
    zo_p_agd_vals_mean = np.mean(zo_p_agd_vals, axis=0)
    zo_p_agd_vals_max = np.max(zo_p_agd_vals, axis=0)
    zo_p_agd_vals_min = np.min(zo_p_agd_vals, axis=0)

    data_zo_p_agd_ancf = np.load('data_mean_' + str(dim[i]) + '/zo_p_agd_ancf.npz')
    zo_p_agd_ancf_complexity = data_zo_p_agd_ancf['zo_p_agd_ancf_complexity']
    zo_p_agd_ancf_vals = data_zo_p_agd_ancf['zo_p_agd_ancf_vals']
    zo_p_agd_ancf_complexity_mean = np.mean(zo_p_agd_ancf_complexity, axis=0)
    zo_p_agd_ancf_vals_mean = np.mean(zo_p_agd_ancf_vals, axis=0)
    zo_p_agd_ancf_vals_max = np.max(zo_p_agd_ancf_vals, axis=0)
    zo_p_agd_ancf_vals_min = np.min(zo_p_agd_ancf_vals, axis=0)

    # plot-FQC

    ax[0, i].plot(pagd_complexity_mean, pagd_vals_mean, label='PAGD', color='b')
    ax[0, i].fill_between(pagd_complexity_mean, pagd_vals_min, pagd_vals_max, alpha=0.1, color='b')

    ax[0, i].plot(zo_p_agd_complexity_mean, zo_p_agd_vals_mean, label='ZO_Pertrubed_AGD', color='r')
    ax[0, i].fill_between(zo_p_agd_complexity_mean, zo_p_agd_vals_min, zo_p_agd_vals_max, alpha=0.1, color='r')

    ax[0, i].plot(zo_p_agd_ancf_complexity_mean, zo_p_agd_ancf_vals_mean, label='ZO_Pertrubed_AGD_ANCF', color='g')
    ax[0, i].fill_between(zo_p_agd_ancf_complexity_mean, zo_p_agd_ancf_vals_min, zo_p_agd_ancf_vals_max, alpha=0.1,
                       color='g')

    ax[0, i].ticklabel_format(style='sci', scilimits=(0, 0), axis='x')
    ax[0, i].set_title('d='+str(dim[i]))
    ax[0, i].set_xlabel('Function Query',  fontsize=13)
    if i == 0:
        ax[0, i].set_ylabel('Objective Function',  fontsize=13)
        ax[0, i].legend(fontsize='large')

    iters = np.linspace(0, len(pagd_vals_mean), len(pagd_vals_mean))
    ax[1, i].plot(iters, pagd_vals_mean, label='PAGD', color='b')
    ax[1, i].fill_between(iters, pagd_vals_min, pagd_vals_max, alpha=0.1, color='b')

    iters = np.linspace(0, len(zo_p_agd_vals_mean), len(zo_p_agd_vals_mean))
    ax[1, i].plot(iters, zo_p_agd_vals_mean, label='ZO_Pertrubed_AGD', color='r')
    ax[1, i].fill_between(iters, zo_p_agd_vals_min, zo_p_agd_vals_max, alpha=0.1, color='r')

    iters = np.linspace(0, len(zo_p_agd_ancf_vals_mean), len(zo_p_agd_ancf_vals_mean))
    ax[1, i].plot(iters, zo_p_agd_ancf_vals_mean, label='ZO_Pertrubed_AGD_ANCF', color='g')
    ax[1, i].fill_between(iters, zo_p_agd_ancf_vals_min, zo_p_agd_ancf_vals_max, alpha=0.1, color='g')

    # plt.yscale('log')
    ax[1, i].ticklabel_format(style='sci', scilimits=(0, 0), axis='x')
    ax[1, i].set_xlabel('Iterations',  fontsize=13)
    if i == 0:
        ax[1, i].set_ylabel('Objective Function',  fontsize=13)
        ax[1, i].legend()

plt.savefig('figures/cubic_mean.pdf', bbox_inches='tight')
fig.tight_layout()
plt.show()