import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
import sys

shortpath = sys.argv[1]


path = f'./results/{shortpath}'


results_mean = pd.read_csv(f'{path}/results_mean.csv', header=None)
results_std = pd.read_csv(f'{path}/results_std.csv', header=None)
it_count = pd.read_csv(f'{path}/it_count.csv', header=None)
nizo = pd.read_csv(f'{path}/nizo.csv', header=None)
nht = pd.read_csv(f'{path}/nht.csv', header=None)


# alg_names = ['fgzoht', 'svrgzoht', 'szoht', 'sarah-zht', 'saga-zht', 'q-saga-zht-03', 'q-saga-zht-06']
# true_names = {'fgzoht': 'FGSZOHT',
#               'szoht': 'SZOHT',
#               'svrgzoht': 'VR-SZHT',
#               'sarah-zht': 'SARAH-ZHT',
#               'saga-zht': 'SAGA-ZHT',
#               'q-saga-zht-03': '3-SAGA-ZHT',
#               'q-saga-zht-06': '6-SAGA-ZHT'
#               }

alg_names = ["szoht", "svrgzoht",  "sarah-zht", "b-svrgzoht"]
true_names = {
              'szoht': 'SZOHT',
              'svrgzoht': 'VR-SZHT',
              'sarah-zht': 'SARAH-ZHT',
              'b-svrgzoht' : 'BVR-SZHT(ours)'
              }

# alg_names = ["svrgzoht", "b-svrgzoht"]
# true_names = {
#               'szoht': 'SZOHT',
#               'svrgzoht': 'VR-SZHT',
#               'sarah-zht': 'SARAH-ZHT',
#               'b-svrgzoht' : 'BVR-SZHT'
#               }

plt.figure()

for alg_name in alg_names:
    curve_mean = np.array(eval(results_mean[(results_mean[1] == alg_name)].iloc[:, -1].iloc[0]))
    curve_std = np.array(eval((results_std[(results_std[1] == alg_name)].iloc[:, -1].iloc[0])))
    nizo_data = np.array(eval(nizo[(nizo[1] == alg_name)].iloc[:, -1].iloc[0]))
    plt.plot(nizo_data, curve_mean, label=f'{true_names[alg_name]}', linestyle='-', marker='^', markersize=1)
    plt.fill_between(nizo_data, curve_mean-curve_std, curve_mean+curve_std, alpha=0.3)
    plt.legend()
plt.ylabel(r'$f(\omega)$', fontweight='bold', fontsize=20)
plt.xlabel('IZO', fontsize=20)
# plt.savefig(f'{path}/izo_{shortpath}.png')
# plt.show()
plt.savefig('IZO_reg.png')


plt.figure()

for alg_name in alg_names:
    print(alg_name)
    curve_mean = np.array(eval(results_mean[(results_mean[1] == alg_name)].iloc[:, -1].iloc[0]))
    curve_std = np.array(eval((results_std[(results_std[1] == alg_name)].iloc[:, -1].iloc[0])))
    nht_data = np.array(eval(nht[(nht[1] == alg_name)].iloc[:, -1].iloc[0]))
    plt.plot(nht_data, curve_mean, label=f'{true_names[alg_name]}', linestyle='-', marker='^', markersize=1)
    plt.fill_between(nht_data, curve_mean-curve_std, curve_mean+curve_std, alpha=0.3)
    plt.legend()
plt.ylabel(r'$f(\omega)$', fontweight='bold', fontsize=20)
plt.xlabel('NHT', fontsize=20)
# plt.savefig(f'{path}/nht_{shortpath}.png')
# plt.show()
plt.savefig('NHT_reg.png')