import numpy as np
import pandas as pd


def get_entry(row):
    is_df = isinstance(row, pd.DataFrame)  # dataframes require extra checks
    if is_df & (row.shape[0] != 1):
        return '- & -'
    elif is_df:
        row = row.iloc[0]
    best_clean_acc, best_cert_acc = row['best_std_acc'], row['best_cert_acc']
    return '{:.2f} & {:.2f}'.format(best_clean_acc, best_cert_acc)


def get_lot_table(df):
    for model, blocks in [('LipConvnet-10', 2), ('LipConvnet-20', 4), ('LipConvnet-40', 8)]:
        for epochs in [200, 400, 600]:
            rows = df[(df['blocks'] == blocks) & (df['epochs'] == epochs)]
            row_none = rows[rows['aux'].isna()]
            row_1m, row_5m, row_10m = rows[rows['aux'] == '1m'], rows[rows['aux'] == '5m'], rows[rows['aux'] == '10m']
            diff = max(row_1m['best_cert_acc'].item(), row_5m['best_cert_acc'].item(), row_10m['best_cert_acc'].item())\
                   - row_none['best_cert_acc'].item()
            model = model if epochs == 200 else ''
            print('        {} & {} & {} & {} & {} & {} & {:.2f} \\\\'.format(model, epochs, get_entry(row_none), get_entry(row_1m),
                                                            get_entry(row_5m), get_entry(row_10m), diff))
        print('        \midrule')


def get_gloro_table(df):
    for depth, width in [(6, 128), (12, 128), (18, 128), (18, 256)]:
        for epochs in [800, 1600, 2400]:
            rows = df[(df['depth'] == depth) & (df['width'] == width) & (df['epochs'] == epochs)]
            row_none = rows[rows['aux'].isna()]
            row_1m, row_5m, row_10m = rows[rows['aux'] == '1m'], rows[rows['aux'] == '5m'], rows[rows['aux'] == '10m']
            diff = max(row_1m['best_cert_acc'].item(), row_5m['best_cert_acc'].item(), row_10m['best_cert_acc'].item())\
                   - row_none['best_cert_acc'].item()
            model = 'LiResNet L{}W{}'.format(depth, width) if epochs == 800 else ''
            print('        {} & {} & {} & {} & {} & {} & {:.2f} \\\\'.format(model, epochs, get_entry(row_none), get_entry(row_1m),
                                                            get_entry(row_5m), get_entry(row_10m), diff))
        print('        \midrule')


def get_sortnet_table(df):
    for model, dropout in [('$\ell_\infty$-dist Net', 1.0), ('SortNet w/ dropout', 0.85), ('SortNet w/o dropout', None)]:
        for epochs in [800, 1600] if dropout == 1.0 else [3000, 6000]:
            rows = df[(df['dropout'] == dropout if dropout else df['dropout'].isna()) & (df['epochs'] == epochs)]
            row_none = rows[rows['aux'].isna()]
            row_1m, row_5m, row_10m = rows[rows['aux'] == '1m'], rows[rows['aux'] == '5m'], rows[rows['aux'] == '10m']
            diff = max(row_1m['best_cert_acc'].item(), row_5m['best_cert_acc'].item(), row_10m['best_cert_acc'].item())\
                   - row_none['best_cert_acc'].item()
            model = model if epochs == 800 or epochs == 3000 else ''
            print('        {} & {} & {} & {} & {} & {} & {:.2f} \\\\'.format(model, epochs, get_entry(row_none), get_entry(row_1m),
                                                            get_entry(row_5m), get_entry(row_10m), diff))
        print('        \midrule')


def get_cifar100_table():
    print('CIFAR-100 table:')
    df = pd.read_csv('cifar100-results.csv')
    for model, key in [('$\ell_\infty$-dist Net', 'linfnet'), ('SortNet w/o dropout', 'sortnet'), ('LOT', 'lot'), ('GloroNet', 'gloro')]:
        rows = df[df['model'] == key]
        row_none = rows[rows['aux'].isna()]
        row_best = rows.loc[rows['best_cert_acc'].idxmax()]
        none, best = get_entry(row_none), get_entry(row_best)
        diff = row_best['best_cert_acc'].item() - row_none['best_cert_acc'].item()
        print('        & {} & {} & {} \\\\'.format(none, best, row_best['aux']))


def get_aaai_tables():
    # GloroNet
    print('GloroNet table:')
    df = pd.read_csv('gloro-results.csv')
    for depth, width in [(6, 128), (12, 128), (18, 128), (18, 256)]:
        for epochs in [800, 1600, 2400]:
            rows = df[(df['depth'] == depth) & (df['width'] == width) & (df['epochs'] == epochs)]
            row_none = rows[rows['aux'].isna()]
            row_best = rows.loc[rows['best_cert_acc'].idxmax()]
            none, best = get_entry(row_none), get_entry(row_best)
            diff = row_best['best_cert_acc'].item() - row_none['best_cert_acc'].item()
            print('        & {} & {} & {} & {} \\\\'.format(epochs, none, best, row_best['aux']))
        print('        \\midrule')
    # Layer-wise Orthogonal Training (LOT)
    print('LOT table:')
    df = pd.read_csv('lot-results.csv')
    df = df[df['opt'] == 'oc']  # we limit the main evaluation to one-cycle
    for blocks in [2, 4, 8]:
        for epochs in [200, 400, 600]:
            rows = df[(df['blocks'] == blocks) & (df['epochs'] == epochs)]
            row_none = rows[rows['aux'].isna()]
            row_best = rows.loc[rows['best_cert_acc'].idxmax()]
            none, best = get_entry(row_none), get_entry(row_best)
            diff = row_best['best_cert_acc'].item() - row_none['best_cert_acc'].item()
            print('        & {} & {} & {} & {} \\\\'.format(epochs, none, best, row_best['aux']))
        print('        \\midrule')
    # l_inf-dist Net and SortNet
    df = pd.read_csv('sortnet-results.csv')
    for dropout in [1.0, 0.85, None]:
        for epochs in [800, 1600] if dropout == 1.0 else [3000, 6000]:
            rows = df[(df['dropout'] == dropout if dropout else df['dropout'].isna()) & (df['epochs'] == epochs)]
            row_none = rows[rows['aux'].isna()]
            row_best = rows.loc[rows['best_cert_acc'].idxmax()]
            none, best = get_entry(row_none), get_entry(row_best)
            diff = row_best['best_cert_acc'].item() - row_none['best_cert_acc'].item()
            print('        & {} & {} & {} & {} \\\\'.format(epochs, none, best, row_best['aux']))
        print('        \\midrule')


def get_scaling_table():
    df = pd.read_csv('scaling-results.csv')
    df['name'] = df['model'] + '-' + df['size'].fillna('')
    df['aux'] = df['aux'].fillna('None')
    pivot = df.pivot(columns='aux', index='name', values='best_cert_acc')
    cols = ['None', '50k', '100k', '200k', '500k', '1m', '5m', '10m']
    pivot = pivot.reindex(cols, axis=1)
    print(pivot.to_markdown(floatfmt='.2f'))


def main():
    get_aaai_tables()
    get_cifar100_table()
    return
    # GloroNet
    df = pd.read_csv('gloro-results.csv')
    get_gloro_table(df)
    # Layer-wise Orthogonal Training (LOT)
    df = pd.read_csv('lot-results.csv')
    # Print table for multistep scheduler
    multi_step = df[df['opt'] == 'ms']
    get_lot_table(multi_step)
    # Print table for one cycle scheduler
    one_cycle = df[df['opt'] == 'oc']
    get_lot_table(one_cycle)
    # SortNet
    df = pd.read_csv('sortnet-results.csv')
    get_sortnet_table(df)
    # CIFAR-100
    get_cifar100_table(df)


if __name__ == '__main__':
    main()
