import argparse
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
import os

from advbench.lib import misc

ROOT = './experiments_archive/Scaling-Laws-for-AT'

sns.set(
    style='darkgrid',
)

if __name__ == '__main__':

    archs = ['resnet18']
    amount_data = [100, 1000] #, 10_000, 25_000]

    all_dfs = []
    for arch in archs:
        for amt in amount_data:
            path = os.path.join(
                ROOT, 
                f'mnist{amt}_{arch}', 
                'selected_test_df.pd'
            )
            df = pd.read_pickle(path)
            df['Architecture'] = arch
            df['Num-Data'] = amt
            all_dfs.append(df)
    
    df = pd.concat(all_dfs, ignore_index=True)
    df = df[
        (df['Metric-Name'] == 'PGD-Accuracy') &
        (df['trial_seed'] == 0.0) &
        (df['trial_rank'] == 1.0) &
        (df['Algorithm'].isin(['FGSM', 'PGD', 'AdversarialMMD']))
    ]
    
    sns.lineplot(
        data=df,
        hue='Algorithm',
        x='Num-Data',
        y='Metric-Value',
        palette='colorblind',
        marker='o'
    )
    plt.show()
