import argparse
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
import os

from advbench.lib import misc

sns.set(
    style='darkgrid',
    font='Palatino'
)

if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='Plot learning curve')
    parser.add_argument('--input_dir', type=str, required=True)
    args = parser.parse_args()

    all_beta_dfs = []
    for n_steps in [1, 5, 10, 20, 50, 100, 200]:
        fname = os.path.join(
            args.input_dir, f'evaluated_ckpts_n_{n_steps}.pd'
        )
        all_beta_dfs.append(pd.read_pickle(fname))

    all_pgd_dfs = []
    for n_steps in [1, 5, 10, 20, 50, 100, 200]:
        fname = os.path.join(
            args.input_dir, f'pgd_evaluated_ckpts_n_{n_steps}.pd'
        )
        all_pgd_dfs.append(pd.read_pickle(fname))

    all_dfs = all_pgd_dfs + all_beta_dfs

    df = pd.concat(all_dfs, ignore_index=True)
    df = pd.melt(
        frame=df,
        id_vars=['epoch', 'n_steps']
    ).rename(columns={'variable': 'Metric-Name', 'value': 'Metric-Value'})
    df = df[df['Metric-Name'].isin(['BETA-Loss', 'BETA-Accuracy', 'PGD-Loss', 'PGD-Accuracy'])]

    sns.relplot(
        data=df, 
        x="epoch", 
        y="Metric-Value", 
        hue="n_steps",
        col='Metric-Name',
        facet_kws={'sharey': False, 'sharex': True},
        palette='colorblind'    
    )
    plt.show()