import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
import sys

# shortpath = sys.argv[1]
shortpath = f'unit_q_200_k_3'


path = f'./results/{shortpath}'

alg_names = ['fgzoht', 'svrgzoht', 'szoht', 'sarah-zht', 'saga-zht', 'bvrht12', 'bvrhtn']
results_mean = pd.read_csv(f'{path}/results_mean.csv', header=None)
results_std = pd.read_csv(f'{path}/results_std.csv', header=None)
it_count = pd.read_csv(f'{path}/it_count.csv', header=None)
nizo = pd.read_csv(f'{path}/nizo.csv', header=None)
nht = pd.read_csv(f'{path}/nht.csv', header=None)

true_names = {'fgzoht': 'FGHT',
              'szoht': 'SHT',
              'svrgzoht': 'VR-SHT',
              'sarah-zht': 'SARAH-HT',
              'saga-zht': 'BSAG-HT',
                'bvrht12': 'BVR-HT-1/2',
                'bvrhtn':  'BVR-HT-n'
              }

plt.figure()

for alg_name in alg_names:

    curve_mean = np.array(eval(results_mean[(results_mean[1] == alg_name)].iloc[:, -1].iloc[0]))
    curve_std = np.array(eval((results_std[(results_std[1] == alg_name)].iloc[:, -1].iloc[0])))
    nizo_data = np.array(eval(nizo[(nizo[1] == alg_name)].iloc[:, -1].iloc[0]))
    plt.plot(nizo_data, curve_mean, label=f'{true_names[alg_name]}', linestyle='-', marker='^', markersize=1)
    plt.fill_between(nizo_data, curve_mean-curve_std, curve_mean+curve_std, alpha=0.3)
    plt.legend()
plt.ylabel(r'$\mathcal{F}(\theta)$', fontweight='bold', fontsize=20)
plt.xlabel('# IFO', fontsize=20)
plt.savefig(f'{path}/izo_{shortpath}_main_paper.png')
plt.show()



plt.figure()

for alg_name in alg_names:
    print(alg_name)
    curve_mean = np.array(eval(results_mean[(results_mean[1] == alg_name)].iloc[:, -1].iloc[0]))
    curve_std = np.array(eval((results_std[(results_std[1] == alg_name)].iloc[:, -1].iloc[0])))
    nht_data = np.array(eval(nht[(nht[1] == alg_name)].iloc[:, -1].iloc[0]))
    plt.plot(nht_data, curve_mean, label=f'{true_names[alg_name]}', linestyle='-', marker='^', markersize=1)
    plt.fill_between(nht_data, curve_mean-curve_std, curve_mean+curve_std, alpha=0.3)
    plt.legend()
plt.ylabel(r'$\mathcal{F}(\theta)$', fontweight='bold', fontsize=20)
plt.xlabel('# NHT', fontsize=20)
plt.savefig(f'{path}/nht_{shortpath}_main_paper.png')
print(f'figure saved at {path}/nht_{shortpath}_main_paper.png')
plt.show()