import matplotlib.pyplot as plt
from matplotlib.ticker import FuncFormatter
from common import summary
import pandas as pd

pd.set_option('display.float_format', lambda x: '%.2f' % x)

table = summary('vgg_easy_mode')

# Note that we're renaming columns here to match the other modes
table[f'required_samples'] = pd.to_numeric(table[f'final_samples'], errors='coerce')
table[f'acc_epsilon'] = pd.to_numeric(table[f'final_epsilon'], errors='coerce')
table[f'achieved_sparsity'] = pd.to_numeric(table[f'final_sparsity'], errors='coerce')
table[f'post_prune_acc'] = pd.to_numeric(table[f'post_prune_final_acc'], errors='coerce')
table[f'post_prune_eps'] = pd.to_numeric(table[f'post_prune_final_eps'], errors='coerce')
table[f'post_prune_eps_std'] = pd.to_numeric(table[f'post_prune_final_eps_std'], errors='coerce')
table[f'AccGuarantee'] = pd.to_numeric(table[f'FinalAccGuarantee'], errors='coerce')

table['PruneType'] = table['FinalPruneType']
table['Case'] = "Easy"


for plan in ['vgg_mod_mode', 'vgg_hard_mode']:
    new_table = summary(plan)
    new_table[f'required_samples'] = pd.to_numeric(new_table[f'required_samples'], errors='coerce')
    new_table[f'AccGuarantee'] = pd.to_numeric(new_table[f'AccGuarantee'], errors='coerce')
    new_table[f'post_prune_acc'] = pd.to_numeric(new_table[f'post_prune_acc'], errors='coerce')
    new_table[f'achieved_sparsity'] = pd.to_numeric(new_table[f'achieved_sparsity'], errors='coerce')
    new_table[f'post_prune_eps'] = pd.to_numeric(new_table[f'post_prune_eps'], errors='coerce')
    new_table[f'post_prune_eps_std'] = pd.to_numeric(new_table[f'post_prune_eps_std'], errors='coerce')
    new_table[f'acc_epsilon'] = pd.to_numeric(new_table[f'acc_epsilon'], errors='coerce')

    new_table['Case'] = "Moderate" if plan == 'vgg_mod_mode' else "Difficult"
    table = pd.concat([table, new_table])

table['Case'] = pd.Categorical(table['Case'], categories=["Easy", "Moderate", "Difficult"], ordered=True)
table = table.sort_values(['PruneType', 'Case', 'AccGuarantee'])

def format_num(x):
    if x > 1000000:
        return '%1.1fM' % (x * 1e-6)
    elif x > 1000:
        return '%1.0fK' % (x * 1e-3)
    else:
        return '%1.0f' % x

table['required_samples'] = table['required_samples'].map(format_num)
table['achieved_sparsity'] = table['achieved_sparsity'] * 100
table['layer_size'] = table['PruneType'].map({
    "DataDepCoreset": "10K",
    "DataIndCoreset": "1K",
    "DataDepDet": "10K",
}.get)

table['PruneType'] = table['PruneType'].map({
    "DataDepCoreset": "Weight Rand",
    "DataIndCoreset": "Neuron",
    "DataDepDet": "Weight Det",
}.get)

table['eps_combined'] = table['post_prune_eps'].map(lambda x: '%.2f' % x) + '$\pm$' + table['post_prune_eps_std'].map(lambda x: '%.2f' % x)

rel = table[(table['AccGuarantee'] == 0.5)][['PruneType', 'Case', 'acc_epsilon', 'eps_combined', 'post_prune_acc', 'layer_size', 'required_samples', 'achieved_sparsity']]

print(rel)
print(rel.to_latex(index=False))
