import argparse
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
import os
import numpy as np

from tueplots import bundles, figsizes


from advbench.lib import misc

sns.set(
    style='darkgrid',
    # font_scale=1.5
)
plt.rcParams.update(bundles.neurips2022())
plt.rcParams["figure.figsize"] = (6, 6)
fontsizes = {'font.size': 20,
                 'axes.labelsize': 25,
                 'legend.fontsize': 18,
                 'xtick.labelsize': 25,
                 'ytick.labelsize': 25,
                 'axes.titlesize': 25}
plt.rcParams.update(fontsizes)

if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='Plot learning curve')
    parser.add_argument('--input_dir', type=str, required=True)
    parser.add_argument('--curves', nargs='+', type=str)
    args = parser.parse_args()

    sweep_df = pd.read_pickle(
        os.path.join(args.input_dir, 'full_sweep_df.pd')
    )
    sweep_df = sweep_df[[
        'Epoch',
        'Split',
        'PGD-Accuracy',
        'Clean-Accuracy'
    ]].rename(columns={
        'PGD-Accuracy': 'Robust Accuracy', 
        'Clean-Accuracy': 'Clean Accuracy'
    })
    sweep_df = sweep_df[sweep_df['Split'].isin(['Test'])]

    train_df = pd.read_pickle(
        os.path.join(args.input_dir, 'train.pd')
    )
    train_df = train_df[[
        'Epoch',
        'Split',
        'Robust Accuracy',
        'Clean Accuracy'
    ]]

    df = pd.concat([train_df, sweep_df], ignore_index=True)
    df = pd.melt(
        frame=df,
        id_vars=['Epoch', 'Split']
    ).rename(columns={
        'variable': 'Metric Name',
        'value': 'Accuracy'
    })
    df['Curve'] = df['Split'] + df['Metric Name']
    # df = df[df.Epoch <= 125]

    sns.set_palette('colorblind')

    g = sns.lineplot(
        data=df,
        x='Epoch',
        y='Accuracy',
        hue='Curve',
        linewidth=5
    )
    plt.legend(
        title='', 
        loc='upper center',
        bbox_to_anchor=(0.5, 1.3),
        ncol=2,
        labels=['Train Robust', 'Test Robust', 'Train Clean', 'Test Clean']
    )
    # plt.ylim(10, 99)

    plt.tight_layout()
    plt.savefig('beta_cifar.pdf')