import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
import sys

shortpath = sys.argv[1]

alg_names = ["svrgzoht", "b-svrgzoht"]
true_names = {
              'szoht': 'SZOHT',
              'svrgzoht': 'VR-SZHT',
              'sarah-zht': 'SARAH-ZHT',
              'b-svrgzoht' : 'BVR-SZHT'
              }

betas=[0.2, 0.4, 0.6, 0.8]#, 1.0]


plt.figure()



for alg_name in alg_names:
    if alg_name == 'b-svrgzoht':
        for b in betas:
            path = f'./results/{shortpath}_beta_{b}'
            results_mean = pd.read_csv(f'{path}/results_mean.csv', header=None)
            results_std = pd.read_csv(f'{path}/results_std.csv', header=None)
            it_count = pd.read_csv(f'{path}/it_count.csv', header=None)
            nizo = pd.read_csv(f'{path}/nizo.csv', header=None)
            nht = pd.read_csv(f'{path}/nht.csv', header=None)

            curve_mean = np.array(eval(results_mean[(results_mean[1] == alg_name)].iloc[:, -1].iloc[0]))
            curve_std = np.array(eval((results_std[(results_std[1] == alg_name)].iloc[:, -1].iloc[0])))
            nizo_data = np.array(eval(nizo[(nizo[1] == alg_name)].iloc[:, -1].iloc[0]))
            plt.plot(nizo_data, curve_mean, label=f'{true_names[alg_name]}_{str(b)}', linestyle='-', marker='^', markersize=1)
            plt.fill_between(nizo_data, curve_mean-curve_std, curve_mean+curve_std, alpha=0.3)
            plt.legend()
    else:
        path = f'./results/{shortpath}'
        results_mean = pd.read_csv(f'{path}/results_mean.csv', header=None)
        results_std = pd.read_csv(f'{path}/results_std.csv', header=None)
        it_count = pd.read_csv(f'{path}/it_count.csv', header=None)
        nizo = pd.read_csv(f'{path}/nizo.csv', header=None)
        nht = pd.read_csv(f'{path}/nht.csv', header=None)

        curve_mean = np.array(eval(results_mean[(results_mean[1] == alg_name)].iloc[:, -1].iloc[0]))
        curve_std = np.array(eval((results_std[(results_std[1] == alg_name)].iloc[:, -1].iloc[0])))
        nizo_data = np.array(eval(nizo[(nizo[1] == alg_name)].iloc[:, -1].iloc[0]))
        plt.plot(nizo_data, curve_mean, label=f'{true_names[alg_name]}', linestyle='-', marker='^', markersize=1)
        plt.fill_between(nizo_data, curve_mean-curve_std, curve_mean+curve_std, alpha=0.3)
        plt.legend()


plt.ylabel(r'$f(\theta)$', fontweight='bold', fontsize=20)
plt.xlabel('IZO', fontsize=20)
# plt.savefig(f'{path}/izo_{shortpath}.png')
# plt.show()
plt.savefig('unit1.png')


plt.figure()

for alg_name in alg_names:
    if alg_name == 'b-svrgzoht':
        for b in betas:
            path = f'./results/{shortpath}_beta_{b}'
            results_mean = pd.read_csv(f'{path}/results_mean.csv', header=None)
            results_std = pd.read_csv(f'{path}/results_std.csv', header=None)
            it_count = pd.read_csv(f'{path}/it_count.csv', header=None)
            nizo = pd.read_csv(f'{path}/nizo.csv', header=None)
            nht = pd.read_csv(f'{path}/nht.csv', header=None)

            print(alg_name)
            curve_mean = np.array(eval(results_mean[(results_mean[1] == alg_name)].iloc[:, -1].iloc[0]))
            curve_std = np.array(eval((results_std[(results_std[1] == alg_name)].iloc[:, -1].iloc[0])))
            nht_data = np.array(eval(nht[(nht[1] == alg_name)].iloc[:, -1].iloc[0]))
            plt.plot(nht_data, curve_mean, label=f'{true_names[alg_name]}_{str(b)}', linestyle='-', marker='^', markersize=1)
            plt.fill_between(nht_data, curve_mean-curve_std, curve_mean+curve_std, alpha=0.3)
            plt.legend()
    else:
        path = f'./results/{shortpath}'
        results_mean = pd.read_csv(f'{path}/results_mean.csv', header=None)
        results_std = pd.read_csv(f'{path}/results_std.csv', header=None)
        it_count = pd.read_csv(f'{path}/it_count.csv', header=None)
        nizo = pd.read_csv(f'{path}/nizo.csv', header=None)
        nht = pd.read_csv(f'{path}/nht.csv', header=None)

        print(alg_name)
        curve_mean = np.array(eval(results_mean[(results_mean[1] == alg_name)].iloc[:, -1].iloc[0]))
        curve_std = np.array(eval((results_std[(results_std[1] == alg_name)].iloc[:, -1].iloc[0])))
        nht_data = np.array(eval(nht[(nht[1] == alg_name)].iloc[:, -1].iloc[0]))
        plt.plot(nht_data, curve_mean, label=f'{true_names[alg_name]}', linestyle='-', marker='^', markersize=1)
        plt.fill_between(nht_data, curve_mean-curve_std, curve_mean+curve_std, alpha=0.3)
        plt.legend()

plt.ylabel(r'$f(\theta)$', fontweight='bold', fontsize=20)
plt.xlabel('NHT', fontsize=20)
# plt.savefig(f'{path}/nht_{shortpath}.png')
# plt.show()
plt.savefig('unit2.png')