import pandas as pd
import torch

from mind_the_pad.paths import experiments_folder, metrics_folder
import re
from model import MnistPadding
import pytorch_lightning as pl
from mind_the_pad.data.mnist import letters_mnist_test_dl

experiments_uneven_padding = experiments_folder.dirs("EMNIST_padding_type='same'*")

regex_padding_mode = re.compile(r"padding\_mode='(\w*)'")
def main():
    uneven_padding_same_results_path = metrics_folder / 'uneven_padding_same_results.csv'
    if not uneven_padding_same_results_path.exists():
        test_dl = letters_mnist_test_dl(32)
        dataset = []
        for exprmnt_class_folder in experiments_uneven_padding:
            padding_mode = regex_padding_mode.search(exprmnt_class_folder.basename()).group(1)
            input_size = int(regex_input_size.search(exprmnt_class_folder.basename()).group(1))
            is_uneven = input_size == 28
            for exprm_folder in exprmnt_class_folder.dirs():
                print(str(exprm_folder))
                ckpt_path = exprm_folder / 'last.ckpt'
                if not ckpt_path.exists(): continue
                for random_pad in [0, 1]:
                    model = MnistPadding(padding_type='same', padding_mode=padding_mode, random_pad_input=random_pad)
                    trainer = pl.Trainer(logger=False, accelerator='cpu')
                    model.load_state_dict(torch.load(ckpt_path, 'cpu')['state_dict'])
                    trainer.test(model, test_dl)
                    row = dict(input_size=28 + random_pad, accuracy=model.accuracy.compute().item(),
                               padding_mode=padding_mode, training_uneven=is_uneven)
                    dataset.append(row)
                    print('accuracy =', row['accuracy'])
        dataset = pd.DataFrame(dataset)
        dataset.to_csv(uneven_padding_same_results_path, index=False)
    else:
        dataset = pd.read_csv(uneven_padding_same_results_path)
    print(dataset)

    dataset_same_input_size = dataset.query('(input_size == 28 and training_uneven) or (input_size == 29 and not training_uneven)')
    dataset_same_input_size = dataset_same_input_size.drop(columns='input_size')
    results_aggr_same = dataset_same_input_size.pivot_table('accuracy', ['training_uneven'], dataset_same_input_size.columns.drop(['accuracy', 'training_uneven']).tolist(), aggfunc=['mean', 'std'], margins=True)
    float_format = "%.3f"
    mean_results_aggr_same = results_aggr_same.loc[:, 'mean'].applymap(lambda x: float_format % x)
    std_results_aggr_same = results_aggr_same.loc[:, 'std'].applymap(lambda x: float_format % x)
    mean_std_table_same = mean_results_aggr_same.combine(std_results_aggr_same, lambda m, s: m.str.cat(' +- ' + s))
    print(mean_std_table_same)
    mean_std_table_same.to_latex(metrics_folder / 'mean_std_accuracy_uneven_same_results.tex')
    mean_std_table_same.to_markdown(metrics_folder / 'mean_std_accuracy_uneven_same_results.md')

    results_aggr = dataset.pivot_table('accuracy', ['training_uneven'], dataset.columns.drop(['accuracy', 'training_uneven']).tolist(), aggfunc=['mean', 'std'], margins=True)
    mean_results_aggr = results_aggr.loc[:, 'mean'].applymap(lambda x: float_format % x)
    std_results_aggr = results_aggr.loc[:, 'std'].applymap(lambda x: float_format % x)
    mean_std_table = mean_results_aggr.combine(std_results_aggr, lambda m, s: m.str.cat(' +- ' + s))
    mean_std_table_28 = mean_std_table.loc[:, 28]
    mean_std_table_29 = mean_std_table.loc[:, 29]
    mean_std_table.to_latex(metrics_folder / 'mean_std_accuracy_uneven_results.tex')
    mean_std_table.to_markdown(metrics_folder / 'mean_std_accuracy_uneven_results.md')
    mean_std_table_28.to_latex(metrics_folder / 'mean_std_accuracy_uneven_results_28.tex')
    mean_std_table_29.to_latex(metrics_folder / 'mean_std_accuracy_uneven_results_29.tex')


    print(results_aggr_same)
    results_aggr_same.to_markdown(metrics_folder / 'acc_uneven_padding_same_input.md')

    results_aggr.to_markdown(metrics_folder / 'acc_uneven_padding.md')

if __name__ == '__main__':
    main()


regex_input_size = re.compile(r"input\_size=\((\d+)\,\s{0,1}(\d+)\)")