import pandas as pd
from common import summary

pd.set_option("display.precision", 2)
table = summary('first_layer_adaptation')

table[f'post_prune_eps'] = pd.to_numeric(table[f'post_prune_eps'], errors='coerce')
table[f'post_prune_eps_std'] = pd.to_numeric(table[f'post_prune_eps_std'], errors='coerce')
table[f'post_prune_hidden_eps'] = pd.to_numeric(table[f'post_prune_hidden_eps'], errors='coerce')
table[f'post_prune_hidden_eps_std'] = pd.to_numeric(table[f'post_prune_hidden_eps_std'], errors='coerce')
table[f'post_kd_eps'] = pd.to_numeric(table[f'post_kd_eps'], errors='coerce')
table[f'post_kd_eps_std'] = pd.to_numeric(table[f'post_kd_eps_std'], errors='coerce')
table[f'post_kd_hidden_eps'] = pd.to_numeric(table[f'post_kd_hidden_eps'], errors='coerce')
table[f'post_kd_hidden_eps_std'] = pd.to_numeric(table[f'post_kd_hidden_eps_std'], errors='coerce')

table['post_prune_eps_combined'] = table['post_prune_eps'].map(lambda x: '%.2f' % x) + '$\pm$' + table['post_prune_eps_std'].map(lambda x: '%.2f' % x)
table['post_prune_hidden_eps_combined'] = table['post_prune_hidden_eps'].map(lambda x: '%.2f' % x) + '$\pm$' + table['post_prune_hidden_eps_std'].map(lambda x: '%.2f' % x)
table['post_kd_eps_combined'] = table['post_kd_eps'].map(lambda x: '%.2f' % x) + '$\pm$' + table['post_kd_eps_std'].map(lambda x: '%.2f' % x)
table['post_kd_hidden_eps_combined'] = table['post_kd_hidden_eps'].map(lambda x: '%.2f' % x) + '$\pm$' + table['post_kd_hidden_eps_std'].map(lambda x: '%.2f' % x)

table['PruneType'] = table['PruneType'].map({
    "DataDepCoreset": "Weight Rand",
    "DataIndCoreset": "Neuron",
    "DataDepDet": "Weight Det",
}.get)
print(table[['PruneType', 'Sparsity', 'post_prune_acc', 'post_prune_eps_combined',  'post_prune_hidden_eps_combined', 'post_kd_acc', 'post_kd_eps_combined',  'post_kd_hidden_eps_combined']].to_latex(index=False))
