import argparse
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
import os

from advbench.lib import misc

sns.set(
    style='darkgrid',
    # font_scale=1.5
)

if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='Plot learning curve')
    parser.add_argument('--input_dir', type=str, required=True)
    parser.add_argument('--hparams', nargs='+', type=str)
    parser.add_argument('--algorithm', type=str, required=True)
    parser.add_argument('--include-train', action='store_true')
    args = parser.parse_args()

    selected_test_df = pd.read_pickle(
        os.path.join(args.input_dir, 'selected_test_df.pd')
    )
    sweep_df = pd.read_pickle(
        os.path.join(args.input_dir, 'full_sweep_df.pd')
    )
    train_df = pd.read_pickle(
        os.path.join(args.input_dir, 'sweep_train_df.pd')
    )

    df = selected_test_df[
        (selected_test_df['Metric-Name'] == 'PGD-Accuracy') &
        (selected_test_df['Algorithm'] == args.algorithm) & 
        (selected_test_df['trial_seed'] == 0.0)
    ]

    seed_to_rank_dict = df.set_index('seed')['trial_rank'].to_dict()
    sweep_df['trial_rank'] = sweep_df['seed'].map(seed_to_rank_dict)
    sweep_df = sweep_df[[
        'PGD-Accuracy',
        'Split',
        'Epoch',
        'trial_rank'
    ]]

    if args.include_train is True:
        train_df['trial_rank'] = train_df['seed'].map(seed_to_rank_dict)
        train_df = train_df[[
            'Robust Accuracy',
            'Split',
            'Epoch',
            'trial_rank'
        ]].rename(columns={'Robust Accuracy': 'PGD-Accuracy'})
        sweep_df = pd.concat([sweep_df, train_df], ignore_index=True)

    rank_to_peak_acc_dict = df.set_index('trial_rank')['Metric-Value'].to_dict()

    rank_to_path_dict = df.set_index('trial_rank')['path'].to_dict()

    rank_to_hparams_dict = {}
    for k, v in rank_to_path_dict.items():
        run_hash = v.rsplit('/')[1]
        full_path = os.path.join(args.input_dir, run_hash, 'hparams.json')
        rank_to_hparams_dict[k] = misc.read_dict(full_path)

    g = sns.FacetGrid(
        data=sweep_df, 
        col='trial_rank', 
        hue='Split', 
        sharey=False, 
        col_wrap=5)
    g.map(sns.lineplot, 'Epoch', 'PGD-Accuracy')

    text_style = dict(
        boxstyle='round', 
        alpha=0.5)

    axes = g.axes.flatten()
    for i, ax in enumerate(axes):
        acc = rank_to_peak_acc_dict[float(i) + 1]
        ax.set_title(f'Rank: {i + 1} | Adv. Acc: {acc:.2f}')

        hparams = rank_to_hparams_dict[i + 1]
        if args.hparams is not None:
            if len(args.hparams) == 1:
                hp = args.hparams[0]
                textstr = fr'{hp}={hparams[hp]:.3f}'
            else:
                textstr = '\n'.join((
                    fr'{hp}={hparams[hp]:.5f}' for hp in args.hparams
                ))

            ax.text(0.350, 0.40, textstr, transform=ax.transAxes, fontsize=10,
                    verticalalignment='top', bbox=text_style)
    plt.show()

