import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
from matplotlib.lines import Line2D
from matplotlib.ticker import AutoMinorLocator
from sklearn.metrics import confusion_matrix

plt.rcParams.update({
    'text.usetex': False,
    'mathtext.fontset': 'cm',
    'font.family': 'Times New Roman',
})


def read_csv_with_ref(path, match):
    df = pd.read_csv(path)
    ref_train_acc, ref_std_acc, ref_cert_acc = [], [], []
    for _, row in df.iterrows():
        reference = df[df['aux'].isna() & match(df, row) & (df['epochs'] == row['epochs'])]
        if reference.shape[0] != 1:
            print('Couldn\'t find reference for \n{}'.format(row))
            continue
        ref_train_acc.append(reference['best_train_acc'].iloc[0])
        ref_std_acc.append(reference['best_std_acc'].iloc[0])
        ref_cert_acc.append(reference['best_cert_acc'].iloc[0])
    df['ref_train_acc'] = ref_train_acc
    df['ref_std_acc'] = ref_std_acc
    df['ref_cert_acc'] = ref_cert_acc
    return df


def read_csv_merged():
    target_columns = ['best_epoch', 'best_train_acc', 'best_std_acc', 'best_cert_acc', 'last_train_acc', 'last_std_acc',
                      'last_cert_acc', 'ref_train_acc', 'ref_std_acc', 'ref_cert_acc', 'aux', 'epochs']
    # Load linf-dist Net and SortNet
    sortnet = read_csv_with_ref('sortnet-results.csv', lambda df, row: (
        df['dropout'].isna() if np.isnan(row['dropout']) else df['dropout'] == row['dropout']))
    sortnet_merge = sortnet.loc[:, target_columns]
    sortnet_merge['model'] = np.where(sortnet['dropout'] == 1.0, '$\ell_\infty$-dist Net', 'SortNet')
    epochs = sortnet['epochs']
    epoch_list, epoch_map = [epochs == 800, epochs == 1600, epochs == 3000, epochs == 6000], [1, 2, 1, 2]
    sortnet_merge['epoch_size'] = np.select(epoch_list, epoch_map)  # relative epoch count
    sortnet_merge['model_size'] = 1
    # Load LOT
    lot = read_csv_with_ref('lot-results.csv',
                            lambda df, row: (df['opt'] == row['opt']) & (df['blocks'] == row['blocks']))
    lot = lot[lot['opt'] == 'oc']  # only keep one-cycle scheduler
    lot_merge = lot.loc[:, target_columns]
    lot_merge['model'] = 'LOT'
    epochs = lot['epochs']
    epoch_list, epoch_map = [epochs == 200, epochs == 400, epochs == 600], [1, 2, 3]
    lot_merge['epoch_size'] = np.select(epoch_list, epoch_map)  # relative epoch count
    blocks = lot['blocks']
    model_list, model_map = [blocks == 2, blocks == 4, blocks == 8], [3, 2, 1]
    lot_merge['model_size'] = np.select(model_list, model_map)  # relative model size
    # Load GloroNet
    gloro = read_csv_with_ref('gloro-results.csv',
                              lambda df, row: (df['depth'] == row['depth']) & (df['width'] == row['width']))
    gloro_merge = gloro.loc[:, target_columns]
    gloro_merge['model'] = 'GloroNet'
    epochs = gloro['epochs']
    epoch_list, epoch_map = [epochs == 800, epochs == 1600, epochs == 2400], [1, 2, 3]
    gloro_merge['epoch_size'] = np.select(epoch_list, epoch_map)  # relative epoch count
    depth, width = gloro['depth'], gloro['width']
    model_list = [depth == 6, depth == 12, (depth == 18) & (width == 128), (depth == 18) & (width == 256)]
    model_map = [4, 3, 2, 1]
    gloro_merge['model_size'] = np.select(model_list, model_map)  # relative model size
    return pd.concat([sortnet_merge, lot_merge, gloro_merge], ignore_index=True)


def load_certificates(data, scale):
    predictions = data['predictions'] if 'predictions' in data.keys() else data['outputs'].argmax(axis=-1)
    labels = data['labels']
    if 'certificates' in data.keys():
        certificates = data['certificates']
    else:
        outputs = data['outputs'] / scale
        outputs.sort()
        certificates = (outputs[:, -1] - outputs[:, -2]) / 2.0
    return certificates, predictions, labels


def plot_bounds(stat='count'):
    colors = ['#04316A', '#FDB735', '#C50F3C', '#18B4F1', '#7BB725', '#8C9FB1'][3:5]
    fig, axes = plt.subplots(nrows=1, ncols=2, figsize=(6.5, 2.0))
    lot_title = 'CIFAR-10 ($\ell_2, \epsilon = 36/255$)\nLOT-L'
    lot_config = (axes[1], lot_title, 'lot', 36/255, 1.0, 1.0)
    sortnet_title = 'CIFAR-10 ($\ell_\infty, \epsilon = 8/255$)\nSortNet w/o dropout'
    f = 0.1561473369835739 / (8/255)
    sortnet_config = (axes[0], sortnet_title, 'sortnet', 8/255, f * 5.384060382843018, f * 3.1185569763183594)
    for axis, title, model, epsilon, std_scale, aux_scale in [lot_config, sortnet_config]:
        axis.set_title(title, fontsize=10)
        # Plot standard distribution (without auxiliary)
        data = np.load('bound/{}-std.npz'.format(model))
        certificates, predictions, labels = load_certificates(data, std_scale)
        correct_certs = certificates[predictions == labels]
        incorrect_certs = certificates[predictions != labels]
        print('Standard: {}'.format(np.sum(correct_certs >= epsilon)))
        sns.ecdfplot(correct_certs, complementary=True, stat=stat, lw=1, c=colors[0], ax=axis, label='w/o auxiliary (Correct)')
        sns.ecdfplot(incorrect_certs, complementary=True, stat=stat, lw=1, ls=':', c=colors[0], ax=axis, label='w/o auxiliary (Incorrect)')
        # Plot auxiliary distribution (without auxiliary)
        data = np.load('bound/{}-aux.npz'.format(model))
        certificates, predictions, labels = load_certificates(data, aux_scale)
        correct_certs = certificates[predictions == labels]
        incorrect_certs = certificates[predictions != labels]
        print('Auxiliary: {}'.format(np.sum(correct_certs >= epsilon)))
        sns.ecdfplot(correct_certs, complementary=True, stat=stat, lw=1, c=colors[1], ax=axis, label='w/ auxiliary (Correct)')
        sns.ecdfplot(incorrect_certs, complementary=True, stat=stat, lw=1, ls=':', c=colors[1], ax=axis, label='w/ auxiliary (Incorrect)')
        # Annotations
        axis.set_xlabel('Certification Radius')
        axis.set_ylabel('Number of Images' if stat == 'count' else 'Proportion')
        axis.axvline(epsilon, c='#cccc', lw=1)
        axis.text(epsilon, 1, '$\epsilon$', c='black', ha='left', va='top', transform=axis.get_xaxis_transform())
        if axis == axes[1]:
            handles = [
                Line2D([0], [0], color=colors[0], linewidth=1, linestyle='-'),
                Line2D([0], [0], color=colors[1], linewidth=1, linestyle='-'),
                Line2D([0], [0], color='black', linewidth=1, linestyle='-'),
                Line2D([0], [0], color='black', linewidth=1, linestyle=':'),
            ]
            labels = ['w/o auxiliary', 'w/ auxiliary', 'correct', 'incorrect']
            kwargs = {'prop': {'size': 8}, 'handletextpad': 0.5, 'handlelength': 1.0, 'frameon': True}
            axis.legend(handles, labels, loc='center left', bbox_to_anchor=(1.05, 0.5), **kwargs)
        axis.xaxis.grid(zorder=-1, c='#eeee')
        axis.xaxis.grid(which='minor', zorder=-1, c='#eeee', ls=':')
        axis.xaxis.set_minor_locator(AutoMinorLocator())
        axis.yaxis.grid(zorder=-1, c='#eeee')
        axis.yaxis.grid(which='minor', zorder=-1, c='#eeee', ls=':')
        axis.yaxis.set_minor_locator(AutoMinorLocator())
        axis.set_axisbelow(True)
    plt.tight_layout()
    plt.savefig('bounds.png', dpi=300, bbox_inches='tight')
    plt.savefig('bounds.pdf', bbox_inches='tight')


def plot_certification_heatmaps():
    colors = ['#04316A', '#FDB735', '#C50F3C', '#18B4F1', '#7BB725', '#8C9FB1'][1:5]
    fig, axes = plt.subplots(nrows=1, ncols=2, figsize=(5.5, 2.5))
    lot_title = 'CIFAR-10 ($\ell_2, \epsilon = 36/255$)\nLOT-L'
    lot_config = (axes[1], lot_title, 'lot', 36/255, 1.0, 1.0)
    sortnet_title = 'CIFAR-10 ($\ell_\infty, \epsilon = 8/255$)\nSortNet w/o dropout'
    sortnet_config = (axes[0], sortnet_title, 'sortnet', 0.1561473369835739, 5.384060382843018, 3.1185569763183594)
    for axis, title, model, epsilon, std_scale, aux_scale in [lot_config, sortnet_config]:
        axis.set_title(title, fontsize=10)
        class_list = ['certified_correct', 'certified_incorrect', 'non_certified_correct', 'non_certified_incorrect']
        # Get correct and/or certified images for base model
        data = np.load('bound/{}-std.npz'.format(model))
        certificates, predictions, labels = load_certificates(data, std_scale)
        cond_list = [(certificates >= epsilon) & (predictions == labels),
                     (certificates >= epsilon) & (predictions != labels),
                     (certificates < epsilon) & (predictions == labels),
                     (certificates < epsilon) & (predictions != labels)]
        std_classes = np.select(cond_list, class_list)
        # Get correct and/or certified images for model with auxiliary data
        data = np.load('bound/{}-aux.npz'.format(model))
        certificates, predictions, labels = load_certificates(data, aux_scale)
        cond_list = [(certificates >= epsilon) & (predictions == labels),
                     (certificates >= epsilon) & (predictions != labels),
                     (certificates < epsilon) & (predictions == labels),
                     (certificates < epsilon) & (predictions != labels)]
        aux_classes = np.select(cond_list, class_list)
        cm = confusion_matrix(std_classes, aux_classes)
        class_list = ['Cert. & Corr.', 'Cert. & $\\neg$Corr.', '$\\neg$Cert. & Corr', '$\\neg$Cert. & $\\neg$Corr.']
        sns.heatmap(cm, annot=True, fmt='d', ax=axis, cmap='mako')
        axis.set_xticklabels(class_list, fontdict={'fontsize': 4})
        axis.set_xlabel('w/ auxiliary')
        axis.set_yticklabels(class_list, fontdict={'fontsize': 4})
        axis.set_ylabel('w/o auxiliary')
    plt.tight_layout()
    plt.savefig('cert-heatmap.png', dpi=300, bbox_inches='tight')
    plt.savefig('cert-heatmap.pdf', bbox_inches='tight')


def plot_generalization():
    colors = ['#04316A', '#FDB735', '#C50F3C', '#18B4F1', '#7BB725', '#8C9FB1'][1:5]
    merged = read_csv_merged()
    with_aux = merged.loc[~merged['aux'].isna(), :]
    merged['gen_gap'] = merged['best_train_acc'] - merged['best_std_acc']
    merged['aux'] = merged['aux'].fillna('None')
    with_aux['gen_gap'] = with_aux['ref_train_acc'] - with_aux['ref_std_acc']
    with_aux['cert_improve'] = with_aux['best_cert_acc'] - with_aux['ref_cert_acc']
    fig, axes = plt.subplots(nrows=1, ncols=2, figsize=(5.5, 2.0))
    ax = axes[0]
    for i, model in enumerate(with_aux['model'].unique()):
        subset = with_aux[with_aux['model'] == model]
        sns.regplot(subset, x='gen_gap', y='cert_improve', color=colors[i], ax=ax,
                    scatter_kws={'s': 8, 'fc': colors[i], 'ec': None, 'alpha': 0.8}, line_kws={'lw': 1, 'ls': 'dotted'})
    ax.set_xlabel('Base Model Generalization Gap')
    ax.set_ylabel('$\Delta$ Certified Accuracy')
    ax.legend([], [], frameon=False)
    ax.xaxis.grid(zorder=-1, c='#eeee')
    ax.yaxis.grid(zorder=-1, c='#eeee')
    ax.set_axisbelow(True)
    # plt.tight_layout()
    # plt.savefig('generalization-a.png', dpi=300, bbox_inches='tight')
    # plt.savefig('generalization-a.pdf', bbox_inches='tight')
    # fig, ax = plt.subplots(nrows=1, ncols=1, figsize=(2.75, 2.0))
    ax = axes[1]
    sns.swarmplot(merged, x='aux', y='gen_gap', hue='model', palette=colors, size=np.sqrt(8), alpha=0.8, ax=ax)
    ax.set_xlabel('Auxiliary Data')
    ax.set_ylabel('Generalization Gap')
    handles, labels = ax.get_legend_handles_labels()
    for h in handles:
        h.set_sizes([8])
    kwargs = {'prop': {'size': 8}, 'handletextpad': 0.5, 'handlelength': 1.0, 'frameon': True}
    ax.legend(handles, labels, **kwargs)
    ax.yaxis.grid(zorder=-1, c='#eeee')
    ax.set_axisbelow(True)
    plt.tight_layout()
    plt.savefig('generalization.png', dpi=300, bbox_inches='tight')
    plt.savefig('generalization.pdf', bbox_inches='tight')


def plot_aaai_figure():
    colors = ['#04316A', '#FDB735', '#C50F3C', '#18B4F1', '#7BB725', '#8C9FB1'][1:5]
    merged = read_csv_merged()
    with_aux = merged.loc[~merged['aux'].isna(), :]
    merged['gen_gap'] = merged['best_train_acc'] - merged['best_std_acc']
    merged['aux'] = merged['aux'].fillna('None')
    with_aux['cert_improve'] = with_aux['best_cert_acc'] - with_aux['ref_cert_acc']
    with_aux['ref_gen_gap'] = with_aux['ref_train_acc'] - with_aux['ref_std_acc']
    with_aux['best_gen_gap'] = with_aux['best_train_acc'] - with_aux['best_std_acc']
    with_aux['gen_gap_improve'] = with_aux['ref_gen_gap'] - with_aux['best_gen_gap']
    hue_order = ['$\ell_\infty$-dist Net', 'SortNet', 'LOT', 'GloroNet']
    fig, axes = plt.subplots(nrows=1, ncols=3, figsize=(7, 2.5))
    # Generalization gap versus certified accuracy improvement
    ax = axes[0]
    ax.set_title('a', fontsize=10, fontweight='bold')
    for i, model in enumerate(hue_order):
        subset = with_aux[with_aux['model'] == model]
        sns.regplot(subset, x='gen_gap_improve', y='cert_improve', color=colors[i], ax=ax,
                    scatter_kws={'s': 8, 'fc': colors[i], 'ec': None, 'alpha': 0.8}, line_kws={'lw': 1, 'ls': 'dotted'})
    ax.set_xlabel('$\Delta$ Generalization Gap')
    ax.set_ylabel('$\Delta$ Certified Accuracy')
    ax.legend([], [], frameon=False)
    ax.xaxis.grid(zorder=-1, c='#eeee')
    ax.yaxis.grid(zorder=-1, c='#eeee')
    ax.set_axisbelow(True)
    # Generalization gap
    ax = axes[1]
    ax.set_title('b', fontsize=10, fontweight='bold')
    idx = merged.groupby('model')['epochs'].transform(max) == merged['epochs']
    max_epochs = merged[idx]
    sns.swarmplot(max_epochs, x='aux', y='gen_gap', hue='model', hue_order=hue_order, palette=colors, size=np.sqrt(8), alpha=0.8, ax=ax)
    ax.set_xticklabels(ax.get_xticklabels(), rotation=45)
    ax.set_xlabel('$|\mathcal{D}_{gen}|$')
    ax.set_ylabel('Generalization Gap')
    handles, labels = ax.get_legend_handles_labels()
    for h in handles:
        h.set_sizes([8])
    kwargs = {'prop': {'size': 8}, 'handletextpad': 0.5, 'handlelength': 1.0, 'frameon': True}
    ax.legend(handles, labels, **kwargs)
    ax.yaxis.grid(zorder=-1, c='#eeee')
    ax.set_axisbelow(True)
    # Generated-to-original ratio
    ax = axes[2]
    ax.set_title('c', fontsize=10, fontweight='bold')
    for ablation, color in [('sortnet-ablation.csv', colors[0]), ('lot-ablation.csv', colors[2])]:
        df = pd.read_csv(ablation)
        frac = df['frac'].to_numpy()
        best_std_acc, best_cert_acc = df['best_std_acc'].to_numpy(), df['best_cert_acc'].to_numpy()
        ax.plot(frac, best_std_acc, '^:', lw=1, ms=2, c=color)
        ax.plot(frac, best_cert_acc, 'o:', lw=1, ms=2, c=color)
    ax.xaxis.set_minor_locator(AutoMinorLocator())
    ax.yaxis.grid(zorder=-1, c='#eeee')
    ax.yaxis.set_minor_locator(AutoMinorLocator())
    ax.set_xlabel('Generated-to-Original Ratio')
    ax.set_ylabel('Accuracy (%)')
    handles = [
        Line2D([0], [0], color='black', linewidth=1, linestyle='-', markersize=2, marker='^'),
        Line2D([0], [0], color='black', linewidth=1, linestyle='-', markersize=2, marker='o'),
    ]
    labels = ['clean', 'certified']
    kwargs = {'prop': {'size': 8}, 'handletextpad': 0.5, 'handlelength': 1.0, 'frameon': True}
    ax.legend(handles, labels, loc='center right', bbox_to_anchor=(1, 0.4), **kwargs)
    plt.tight_layout()
    plt.savefig('aaai.png', dpi=300, bbox_inches='tight')
    plt.savefig('aaai.pdf', bbox_inches='tight')


def plot_epoch():
    colors = ['#04316A', '#FDB735', '#C50F3C', '#18B4F1', '#7BB725', '#8C9FB1'][1:5]
    merged = read_csv_merged()
    merged['aux'] = merged['aux'].fillna('None')
    merged['epoch_diff'] = merged['epochs'] - merged['best_epoch']
    merged['cert_improve'] = merged['last_cert_acc'] - merged['best_cert_acc']
    fig, axis = plt.subplots(nrows=1, ncols=1, figsize=(5.5, 2.5))
    sns.scatterplot(merged, x='epoch_diff', y='cert_improve', hue='aux', palette=colors, s=8, alpha=0.8, ax=axis)
    axis.set_xlabel('Epoch Difference')
    axis.set_ylabel('$\Delta$ Certified Accuracy')
    handles, labels = axis.get_legend_handles_labels()
    for handle in handles:
        handle.set_sizes([8])
    kwargs = {'prop': {'size': 8}, 'handletextpad': 0.5, 'handlelength': 1.0, 'frameon': True}
    axis.legend(handles, labels, **kwargs)
    plt.tight_layout()
    plt.savefig('epoch.png', dpi=300, bbox_inches='tight')
    plt.savefig('epoch.pdf', bbox_inches='tight')


def plot_model_and_epoch_size():
    # Plot influence of amount of generated data, epochs, model size
    colors = ['#04316A', '#FDB735', '#C50F3C', '#18B4F1', '#7BB725', '#8C9FB1']
    merged = read_csv_merged()
    merged = merged[~merged['aux'].isna()]
    merged['cert_improve'] = merged['best_cert_acc'] - merged['ref_cert_acc']
    merged['std_improve'] = merged['best_std_acc'] - merged['ref_std_acc']
    fig, axes = plt.subplots(nrows=1, ncols=2, figsize=(5.5, 2.5))
    axis = axes[0]
    axis.set_title('LOT', fontsize=10)
    subset = merged[merged['model'] == 'LOT'].groupby(['epoch_size', 'model_size'])['cert_improve'].mean().reset_index()
    sns.heatmap(subset.pivot(index='model_size', columns='epoch_size', values='cert_improve'), annot=True, ax=axis, cmap='crest')
    for t in axis.texts: t.set_text('+' + t.get_text() + ' %')
    axis.set_xlabel('Number of Epochs')
    axis.set_xticklabels(['$\\times 1$', '$\\times 2$', '$\\times 3$'])
    axis.set_ylabel('Model Size')
    axis.set_yticklabels(['L', 'M', 'S'], fontdict={'fontsize': 8})
    axis = axes[1]
    axis.set_title('GloroNet', fontsize=10)
    subset = merged[merged['model'] == 'GloroNet'].groupby(['epoch_size', 'model_size'])['cert_improve'].mean().reset_index()
    sns.heatmap(subset.pivot(index='model_size', columns='epoch_size', values='cert_improve'), annot=True, ax=axis,
                cmap='crest', cbar_kws={'label': '$\Delta$ Certified Accuracy'})
    for t in axis.texts: t.set_text('+' + t.get_text() + ' %')
    axis.set_xlabel('Number of Epochs')
    axis.set_xticklabels(['$\\times 1$', '$\\times 2$', '$\\times 3$'])
    axis.set_ylabel('Model Size')
    axis.set_yticklabels(['L', 'M', 'S', 'XS'], fontdict={'fontsize': 8})
    plt.tight_layout()
    plt.savefig('influences.png', dpi=300, bbox_inches='tight')
    plt.savefig('influences.pdf', bbox_inches='tight')


def plot_leaderboard():
    colors = ['#04316A', '#FDB735', '#C50F3C', '#18B4F1', '#7BB725', '#8C9FB1']
    arrow_props = dict(arrowstyle='-|>', ls='-', color='#cccc', connectionstyle='arc3', mutation_scale=8)
    leaderboard = pd.read_csv('leaderboard.csv')
    fig, axes = plt.subplots(nrows=1, ncols=2, figsize=(6.5, 2.0))
    axis = axes[0]  # Linf norm
    axis.set_title('CIFAR-10 ($\ell_\infty, \epsilon = 8/255$)', fontsize=10)
    axis.set_xlabel('Certified Accuracy (%)')
    axis.set_ylabel('Clean Accuracy (%)')
    board = leaderboard[leaderboard['norm'] == 'linf']
    for i, year in enumerate([2018, 2019, 2020, 2021, 2022, 2023]):
        models = board[board['year'] == year]
        cert_acc, std_acc = models['cert_acc'].values, models['std_acc'].values
        axis.scatter(cert_acc, std_acc, c=colors[i], s=8, zorder=3, label=year)
    ours = board[board['year'] == 0]
    cert_acc, std_acc = ours['cert_acc'].values, ours['std_acc'].values
    axis.scatter(cert_acc, std_acc, c='black', s=8, marker='*', zorder=3, label='Ours')
    for src_i, dst_i in [(2, 0), (0, 1)]:
        axis.annotate('', xy=(cert_acc[dst_i], std_acc[dst_i]), zorder=2,
                      xytext=(board['cert_acc'].iloc[src_i], board['std_acc'].iloc[src_i]), arrowprops=arrow_props)
    handles, labels = axis.get_legend_handles_labels()
    axis = axes[1]  # Linf norm
    axis.set_title('CIFAR-10 ($\ell_2, \epsilon = 36/255$)', fontsize=10)
    axis.set_xlabel('Certified Accuracy (%)')
    axis.set_ylabel('Clean Accuracy (%)')
    board = leaderboard[leaderboard['norm'] == 'l2']
    for i, year in enumerate([2018, 2020, 2021, 2022, 2023]):
        models = board[board['year'] == year]
        cert_acc, std_acc = models['cert_acc'].values, models['std_acc'].values
        axis.scatter(cert_acc, std_acc, c=colors[i if i == 0 else i + 1], s=8, zorder=3)
    ours = board[board['year'] == 0]
    cert_acc, std_acc = ours['cert_acc'].values, ours['std_acc'].values
    axis.scatter(cert_acc, std_acc, c='black', s=8, marker='*', zorder=3)
    for src_i, dst_i in [(0, 1), (1, 0)]:
        axis.annotate('', xy=(cert_acc[dst_i], std_acc[dst_i]), zorder=2,
                      xytext=(board['cert_acc'].iloc[src_i], board['std_acc'].iloc[src_i]), arrowprops=arrow_props)
    kwargs = {'prop': {'size': 8}, 'handletextpad': 0.5, 'handlelength': 1.0, 'frameon': True}
    axis.legend(handles, labels, loc='center left', bbox_to_anchor=(1.05, 0.5), **kwargs)
    for ax in axes:
        ax.xaxis.grid(zorder=-1, c='#eeee')
        ax.xaxis.grid(which='minor', zorder=-1, c='#eeee', ls=':')
        ax.xaxis.set_minor_locator(AutoMinorLocator())
        ax.yaxis.grid(zorder=-1, c='#eeee')
        ax.yaxis.grid(which='minor', zorder=-1, c='#eeee', ls=':')
        ax.yaxis.set_minor_locator(AutoMinorLocator())
    plt.tight_layout()
    plt.savefig('leaderboard.png', dpi=300, bbox_inches='tight')
    plt.savefig('leaderboard.pdf', bbox_inches='tight')


def plot_ablation():
    fig, axes = plt.subplots(nrows=1, ncols=2, figsize=(5.5, 2.0))
    colors = ['#04316A', '#FDB735', '#C50F3C', '#18B4F1', '#7BB725', '#8C9FB1'][3:5]
    sortnet = pd.read_csv('sortnet-ablation.csv')
    lot = pd.read_csv('lot-ablation.csv')
    sortnet_config = (axes[0], '$\ell_\infty$-dist Net', sortnet)
    lot_config = (axes[1], 'LOT-S', lot)
    for axis, title, df in [sortnet_config, lot_config]:
        axis.set_title(title, fontsize=10)
        axis.set_xlabel('Generated-to-Original Ratio')
        axis.set_ylabel('Certified Accuracy (%)')
        frac = df['frac'].to_numpy()
        best_std_acc, best_cert_acc = df['best_std_acc'].to_numpy(), df['best_cert_acc'].to_numpy()
        axis.plot(frac, best_std_acc, 'o:', lw=1, ms=2, c=colors[0])
        axis.plot(frac, best_cert_acc, '^:', lw=1, ms=2, c=colors[1])
        axis.xaxis.grid(zorder=-1, c='#eeee')
        axis.xaxis.set_minor_locator(AutoMinorLocator())
        axis.yaxis.grid(zorder=-1, c='#eeee')
        axis.yaxis.set_minor_locator(AutoMinorLocator())
        if axis == axes[0]:
            handles = [
                Line2D([0], [0], color=colors[0], linewidth=1, linestyle='-', markersize=2, marker='o'),
                Line2D([0], [0], color=colors[1], linewidth=1, linestyle='-', markersize=2, marker='^'),
            ]
            labels = ['clean', 'certified']
            kwargs = {'prop': {'size': 8}, 'handletextpad': 0.5, 'handlelength': 1.0, 'frameon': True}
            axis.legend(handles, labels, loc='center right', **kwargs)
    plt.tight_layout()
    plt.savefig('ablation.png', dpi=300, bbox_inches='tight')
    plt.savefig('ablation.pdf', bbox_inches='tight')


def plot_scaling():
    fig, axes = plt.subplots(nrows=1, ncols=3, figsize=(7, 2.5))
    axes = [axes[0], axes[0], axes[1], axes[2]]
    colors = ['#04316A', '#FDB735', '#C50F3C', '#18B4F1', '#7BB725', '#8C9FB1'][1:5]
    models = ['$\ell_\infty$-dist Net', 'SortNet', 'LOT', 'GloroNet']
    markers = {np.nan: 'o', 'XS': 'd', 'S': 'v', 'M': 'o', 'L': '^'}
    df = pd.read_csv('scaling-results.csv')
    df['aux'] = df['aux'].fillna(0).replace({'k': '*1e3', 'm': '*1e6'}, regex=True).map(pd.eval).astype(int)
    df['aux'] += 50000  # calculate total number of images (plus original)
    for axis, model, color in zip(axes, models, colors):
        axis.set_title(model)
        axis.set_xscale('log')
        axis.set_xlabel('$|\mathcal{D}_{orig}| + |\mathcal{D}_{gen}|$')
        model_df = df[df['model'] == model]
        sizes = model_df['size'].unique()
        for size in sizes:
            size_df = model_df[model_df['size'].isna()] if pd.isna(size) else model_df[model_df['size'] == size]
            print(model, size, size_df.shape)
            axis.plot(size_df['aux'], size_df['best_cert_acc'], ls=':', lw=1, marker=markers[size], ms=2, c=color)
        axis.xaxis.grid(zorder=-1, c='#eeee')
        axis.yaxis.grid(zorder=-1, c='#eeee')
        axis.set_axisbelow(True)
    axes[0].set_ylabel('Certified Accuracy (%)')
    axes[0].set_title('$\ell_\infty$-dist Net & SortNet')
    handles = [
        Line2D([0], [0], color='black', linestyle='', marker='v', markersize=2),
        Line2D([0], [0], color='black', linestyle='', marker='o', markersize=2),
    ]
    labels = ['$\\rho=.85$', '$\\rho=.00$']
    kwargs = {'prop': {'size': 8}, 'handletextpad': 0.5, 'handlelength': 1.0, 'frameon': True}
    axes[0].legend(handles, labels, **kwargs)
    handles = [
        Line2D([0], [0], color=c, linestyle=':', linewidth=1, marker='o', markersize=2) for c in colors
    ]
    kwargs = {'prop': {'size': 8}, 'handletextpad': 0.5, 'handlelength': 1.0, 'frameon': True}
    axes[2].legend(handles, models, **kwargs)
    handles = [
        Line2D([0], [0], color='black', linestyle='', marker='d', markersize=2),
        Line2D([0], [0], color='black', linestyle='', marker='v', markersize=2),
        Line2D([0], [0], color='black', linestyle='', marker='o', markersize=2),
        Line2D([0], [0], color='black', linestyle='', marker='^', markersize=2),
    ]
    labels = ['XS', 'S', 'M', 'L']
    kwargs = {'prop': {'size': 8}, 'handletextpad': 0.5, 'handlelength': 1.0, 'ncol': 2, 'frameon': True}
    axes[3].legend(handles, labels, **kwargs)
    plt.tight_layout()
    plt.savefig('scaling.png', dpi=300, bbox_inches='tight')
    plt.savefig('scaling.pdf', bbox_inches='tight')


def main():
    plot_bounds()
    # plot_certification_heatmaps()
    # plot_generalization()
    # plot_epoch()
    # plot_model_and_epoch_size()
    plot_leaderboard()
    # plot_improvement()
    # plot_ablation()
    plot_scaling()
    plot_aaai_figure()


if __name__ == '__main__':
    main()
