import argparse
import pandas as pd
import os
import matplotlib.pyplot as plt
import seaborn as sns
sns.set('talk')

from my_code.utils.Utils import format_mean_std

if __name__ == "__main__":
    
    ####### possible arguments #######
    parser = argparse.ArgumentParser(description='Train a chosen model with different algorithms')
    parser.add_argument('--test-result-dir', help='The directory containing the test results', 
                        type=str, default='./outputs/test_results/')
    parser.add_argument('--graph-dir', help='The directory to save the graphs', 
                        type=str, default='./outputs/graphs/')

    args = parser.parse_args()



    # mappings for the names and datasets and metrics to plot

    dataset_model_dict = {
        'cifar10': 'Conv3Net',
        'cifar100': 'Conv3Net_100',
        'fmnist': 'MLP'
    }

    dataset_metric_dict = {
        'cifar10': 'Accuracy',
        'cifar100': 'Top 5 Accuracy',
        'fmnist': 'Accuracy'
    }

    short_long_name_dict = {
        'no_c': 'No Corruption',
        'c_cs': 'Chunk Shuffle',
        'c_rl': 'Random Label',
        'c_lbs': 'Batch Label Shuffle',
        'c_lbf': 'Batch Label Flip',
        'c_ns': 'Added Noise',
        'c_no': 'Replace With Noise',
        }


    name_mapping = {
        'Conv3Net-no_c-drstd-ns_10-ncs_4-ssize_128-ds_0.0-ne_25': 'Traditional Model',
        'Conv3Net-no_c-drstd-ns_10-ncs_4-ssize_128-ds_1.0-stns_0.8-lap_n_25-ne_25': 'LAP Model',

        'Conv3Net-c_cs-drstd-ns_10-ncs_4-ssize_128-ds_0.0-ne_25': 'Traditional Model',
        'Conv3Net-c_cs-drstd-ns_10-ncs_4-ssize_128-ds_1.0-stns_0.8-lap_n_25-ne_25': 'LAP Model',
        'Conv3Net-c_cs_srb-drstd-ns_10-ncs_4-ssize_128-ds_0.0-ne_25': 'Corruption Oracle',

        'Conv3Net-c_rl-drstd-ns_10-ncs_4-ssize_128-ds_0.0-ne_25': 'Traditional Model',
        'Conv3Net-c_rl-drstd-ns_10-ncs_4-ssize_128-ds_1.0-stns_0.8-lap_n_25-ne_25': 'LAP Model',
        'Conv3Net-c_rl_srb-drstd-ns_10-ncs_4-ssize_128-ds_0.0-ne_25': 'Corruption Oracle',

        'Conv3Net-c_lbf-drstd-ns_10-ncs_4-ssize_128-ds_0.0-ne_25': 'Traditional Model',
        'Conv3Net-c_lbf-drstd-ns_10-ncs_4-ssize_128-ds_1.0-stns_0.8-lap_n_25-ne_25': 'LAP Model',
        'Conv3Net-c_lbf_srb-drstd-ns_10-ncs_4-ssize_128-ds_0.0-ne_25': 'Corruption Oracle',
        
        'Conv3Net-c_ns-drstd-ns_10-ncs_4-ssize_128-ds_0.0-ne_25': 'Traditional Model',
        'Conv3Net-c_ns-drstd-ns_10-ncs_4-ssize_128-ds_1.0-stns_0.8-lap_n_25-ne_25': 'LAP Model',
        'Conv3Net-c_ns_srb-drstd-ns_10-ncs_4-ssize_128-ds_0.0-ne_25': 'Corruption Oracle',

        'Conv3Net-c_lbs-drstd-ns_10-ncs_4-ssize_128-ds_0.0-ne_25': 'Traditional Model',
        'Conv3Net-c_lbs-drstd-ns_10-ncs_4-ssize_128-ds_1.0-stns_0.8-lap_n_25-ne_25': 'LAP Model',
        'Conv3Net-c_lbs_srb-drstd-ns_10-ncs_4-ssize_128-ds_0.0-ne_25': 'Corruption Oracle',
        
        'Conv3Net-c_no-drstd-ns_10-ncs_4-ssize_128-ds_0.0-ne_25': 'Traditional Model',
        'Conv3Net-c_no-drstd-ns_10-ncs_4-ssize_128-ds_1.0-stns_0.8-lap_n_25-ne_25': 'LAP Model',
        'Conv3Net-c_no_srb-drstd-ns_10-ncs_4-ssize_128-ds_0.0-ne_25': 'Corruption Oracle',



        'Conv3Net_100-no_c-drstd-ns_10-ncs_2-ssize_128-ds_0.0-ne_40': 'Traditional Model',
        'Conv3Net_100-no_c-drstd-ns_10-ncs_2-ssize_128-ds_1.0-stns_0.8-lap_n_25-ho_250-ne_40': 'LAP Model',

        'Conv3Net_100-c_cs-drstd-ns_10-ncs_2-ssize_128-ds_0.0-ne_40': 'Traditional Model',
        'Conv3Net_100-c_cs-drstd-ns_10-ncs_2-ssize_128-ds_1.0-stns_0.8-lap_n_25-ho_250-ne_40': 'LAP Model',
        'Conv3Net_100-c_cs_srb-drstd-ns_10-ncs_2-ssize_128-ds_0.0-ne_40': 'Corruption Oracle',

        'Conv3Net_100-c_rl-drstd-ns_10-ncs_2-ssize_128-ds_0.0-ne_40': 'Traditional Model',
        'Conv3Net_100-c_rl-drstd-ns_10-ncs_2-ssize_128-ds_1.0-stns_0.8-lap_n_25-ho_250-ne_40': 'LAP Model',
        'Conv3Net_100-c_rl_srb-drstd-ns_10-ncs_2-ssize_128-ds_0.0-ne_40': 'Corruption Oracle',
        
        'Conv3Net_100-c_lbf-drstd-ns_10-ncs_2-ssize_128-ds_0.0-ne_40': 'Traditional Model',
        'Conv3Net_100-c_lbf-drstd-ns_10-ncs_2-ssize_128-ds_1.0-stns_0.8-lap_n_25-ho_250-ne_40': 'LAP Model',
        'Conv3Net_100-c_lbf_srb-drstd-ns_10-ncs_2-ssize_128-ds_0.0-ne_40': 'Corruption Oracle',
        
        'Conv3Net_100-c_ns-drstd-ns_10-ncs_2-ssize_128-ds_0.0-ne_40': 'Traditional Model',
        'Conv3Net_100-c_ns-drstd-ns_10-ncs_2-ssize_128-ds_1.0-stns_0.8-lap_n_25-ho_250-ne_40': 'LAP Model',
        'Conv3Net_100-c_ns_srb-drstd-ns_10-ncs_2-ssize_128-ds_0.0-ne_40': 'Corruption Oracle',

        'Conv3Net_100-c_lbs-drstd-ns_10-ncs_2-ssize_128-ds_0.0-ne_40': 'Traditional Model',
        'Conv3Net_100-c_lbs-drstd-ns_10-ncs_2-ssize_128-ds_1.0-stns_0.8-lap_n_25-ho_250-ne_40': 'LAP Model',
        'Conv3Net_100-c_lbs_srb-drstd-ns_10-ncs_2-ssize_128-ds_0.0-ne_40': 'Corruption Oracle',
        
        'Conv3Net_100-c_no-drstd-ns_10-ncs_2-ssize_128-ds_0.0-ne_40': 'Traditional Model',
        'Conv3Net_100-c_no-drstd-ns_10-ncs_2-ssize_128-ds_1.0-stns_0.8-lap_n_25-ho_250-ne_40': 'LAP Model',
        'Conv3Net_100-c_no_srb-drstd-ns_10-ncs_2-ssize_128-ds_0.0-ne_40': 'Corruption Oracle',



        'MLP-no_c-drstd-ns_10-ncs_6-ssize_200-ds_0.0-ne_40': 'Traditional Model', 
        'MLP-no_c-drstd-ns_10-ncs_6-ssize_200-ds_1.0-stns_0.8-lap_n_50-ne_40':'LAP Model',

        'MLP-c_cs-drstd-ns_10-ncs_6-ssize_200-ds_0.0-ne_40': 'Traditional Model', 
        'MLP-c_cs-drstd-ns_10-ncs_6-ssize_200-ds_1.0-stns_0.8-lap_n_50-ne_40':'LAP Model',
        'MLP-c_cs_srb-drstd-ns_10-ncs_6-ssize_200-ds_0.0-ne_40': 'Corruption Oracle', 

        'MLP-c_rl-drstd-ns_10-ncs_6-ssize_200-ds_0.0-ne_40': 'Traditional Model', 
        'MLP-c_rl-drstd-ns_10-ncs_6-ssize_200-ds_1.0-stns_0.8-lap_n_50-ne_40':'LAP Model',
        'MLP-c_rl_srb-drstd-ns_10-ncs_6-ssize_200-ds_0.0-ne_40': 'Corruption Oracle', 

        'MLP-c_lbf-drstd-ns_10-ncs_6-ssize_200-ds_0.0-ne_40': 'Traditional Model', 
        'MLP-c_lbf-drstd-ns_10-ncs_6-ssize_200-ds_1.0-stns_0.8-lap_n_50-ne_40':'LAP Model',
        'MLP-c_lbf_srb-drstd-ns_10-ncs_6-ssize_200-ds_0.0-ne_40': 'Corruption Oracle', 

        'MLP-c_ns-drstd-ns_10-ncs_6-ssize_200-ds_0.0-ne_40': 'Traditional Model', 
        'MLP-c_ns-drstd-ns_10-ncs_6-ssize_200-ds_1.0-stns_0.8-lap_n_50-ne_40':'LAP Model',
        'MLP-c_ns_srb-drstd-ns_10-ncs_6-ssize_200-ds_0.0-ne_40': 'Corruption Oracle', 

        'MLP-c_lbs-drstd-ns_10-ncs_6-ssize_200-ds_0.0-ne_40': 'Traditional Model', 
        'MLP-c_lbs-drstd-ns_10-ncs_6-ssize_200-ds_1.0-stns_0.8-lap_n_50-ne_40':'LAP Model',
        'MLP-c_lbs_srb-drstd-ns_10-ncs_6-ssize_200-ds_0.0-ne_40': 'Corruption Oracle', 

        'MLP-c_no-drstd-ns_10-ncs_6-ssize_200-ds_0.0-ne_40': 'Traditional Model', 
        'MLP-c_no-drstd-ns_10-ncs_6-ssize_200-ds_1.0-stns_0.8-lap_n_50-ne_40':'LAP Model',
        'MLP-c_no_srb-drstd-ns_10-ncs_6-ssize_200-ds_0.0-ne_40': 'Corruption Oracle', 

                    }





    # loading the data

    result_files = os.listdir(args.test_result_dir)

    results = pd.DataFrame()
    for dataset in dataset_model_dict.keys():
        dataset_result_files = [file for file in result_files 
                                if dataset +'-' in file]
        for corruption_type in short_long_name_dict.keys():
            dataset_corruption_result_files = [file for file in dataset_result_files 
                                    if corruption_type in file 
                                    or corruption_type +'_srb' in file ]
            for file in dataset_corruption_result_files:
                result_temp = pd.read_csv(os.path.join(args.test_result_dir, file))
                result_temp['Dataset'] = dataset
                result_temp['Corruption Type'] = corruption_type
                results = pd.concat([results, result_temp])
    results['Model Name'] = results['Run'].map(name_mapping).fillna(results['Run'])
    results['Corruption Type'] = results['Corruption Type'].map(short_long_name_dict)





    # plotting the data and saving results table

    for dataset in dataset_model_dict.keys():
        data_plot = results[(results.Dataset == dataset)
                            &(results.Metric == dataset_metric_dict[dataset]) ].copy()
        
        if len(data_plot) == 0:
            continue

        # plotting figures
        fig, ax = plt.subplots(1,1,figsize=(18,5))
        ax = sns.boxplot(data=data_plot, 
                    x='Corruption Type',
                    y='Value',
                    hue='Model Name',
                    ax=ax,
                    )
        ax.set_title('{} on {}'.format(dataset_metric_dict[dataset], dataset))
        fig.savefig(os.path.join(args.graph_dir, 
                        '{} {} Boxplot.pdf'.format(dataset_metric_dict[dataset], dataset)),
                    bbox_inches='tight')

            
        # writing table of results
        std_data = data_plot[['Corruption Type', 
                            'Model Name', 
                            'Value']].groupby(by=['Corruption Type', 
                                                'Model Name']).apply('std')
        mean_data = data_plot[['Corruption Type', 
                                'Model Name', 
                                'Value']].groupby(by=['Corruption Type', 
                                                    'Model Name']).apply('mean')
        mean_data.columns = ['Mean']
        std_data.columns = ['Std']

        accuracy_results = data_plot[['Corruption Type', 
                                    'Model Name', 
                                    'Value']].groupby(by=['Corruption Type', 
                                                'Model Name'])['Value'].apply(format_mean_std).reset_index()

        accuracy_results.to_csv(os.path.join(args.test_result_dir, 
                        '{} {} Results.csv'.format(dataset_metric_dict[dataset], dataset)),
                        index=False)