import numpy as np
import argparse
import prettytable
import pandas as pd
import sys
import os

from advbench.lib import reporting, misc
from advbench import model_selection

if __name__ == '__main__':
    np.set_printoptions(suppress=True)

    parser = argparse.ArgumentParser(description='Collect results')
    parser.add_argument(
        '--input_dir', 
        type=str, 
        required=True)
    parser.add_argument(
        '--depth', 
        type=int, 
        default=1, 
        help='Results directories search depth')
    parser.add_argument(
        '--backout-metrics',
        action='store_true',
        help='Backout non-selection metrics.')
    parser.add_argument(
        '--selection_methods', 
        type=str, 
        nargs='+', 
        default=['LastStep', 'EarlyStop'])
    parser.add_argument(
        '--sweep_df_name',
        type=str,
        default='selection',
        help='Name of sweep DataFrame'
    )
    
    args = parser.parse_args()

    if args.depth == 0:
        done_flag = os.path.join(args.input_dir, 'done')
        if not os.path.exists(done_flag):
            raise ValueError("This run hasn't finished yet!")

    sys.stdout = misc.Tee(os.path.join(args.input_dir, 'results.txt'), 'w')

    selection_methods = [
        vars(model_selection)[s] for s in args.selection_methods
    ]

    train_args = misc.read_dict(
        os.path.join(args.input_dir, 'args.json')
    )
    selection_df = reporting.load_sweep_dataframes(
        path=args.input_dir,
        depth=args.depth, 
        df_name=args.sweep_df_name
    ).dropna()

    if args.depth == 1:
        train_df = reporting.load_sweep_dataframes(
            path=args.input_dir,
            depth=args.depth,
            df_name='train'
        )
        train_df.to_pickle(os.path.join(
            args.input_dir,
            'sweep_train_df.pd'
        ))
    
    selection_df.to_pickle(
        os.path.join(args.input_dir, 'full_sweep_df.pd')
    )

    selection_metrics = [
        k for k in selection_df.columns.values.tolist()
        if any(e in k for e in train_args['evaluators'])
    ]

    selection_df['Architecture'] = 'ResNet18'

    df = pd.melt(
        frame=selection_df,
        id_vars=['Split', 'Algorithm', 'trial_seed', 'seed', 'path', 'Epoch', 'Architecture']
    ).rename(columns={'variable': 'Metric-Name', 'value': 'Metric-Value'})

    def bold_str(string):
        return f'\033[91m\033[1m{string}\033[0m\033[0m'
    
    def bold_row(ls):
        return [bold_str(s) for s in ls]

    all_selection_dfs = []
    for method in selection_methods:
        for metric_name, metric_df in df.groupby('Metric-Name'):
            t = prettytable.PrettyTable()

            if args.backout_metrics is True:
                t.field_names = [
                    'Algorithm',
                    'Selection Metric',
                    'Metric Name',
                    'Metric Value',
                    'Selection Method'
                ]
            else:
                t.field_names = ['Algorithm', metric_name, 'Selection Method']

            for algorithm, algorithm_df in metric_df.groupby('Algorithm'):
                selection = method(algorithm_df)
                all_selection_dfs.append(selection.test_df)

                if args.backout_metrics is True:
                    metrics = selection.backout_other_metrics(df)
                    for backout_metric, vals in metrics.items():
                        row = [
                            algorithm,
                            metric_name, 
                            backout_metric,
                            f'{np.mean(vals):.4f} +/- {np.std(vals):.4f}',
                            method.NAME
                        ]
                        if backout_metric == metric_name:
                            t.add_row(bold_row(row))
                        else:
                            t.add_row(row)
                else:
                    vals = selection.trial_values
                    mean, sd = np.mean(vals), np.std(vals)
                    t.add_row([
                        algorithm, f'{mean:.4f} +/- {sd:.4f}', method.NAME
                    ])
                
            print(t)
    
    # save ranked test dataframes to disk
    if args.depth == 1:
        df = pd.concat(all_selection_dfs, ignore_index=True)
        df.to_pickle(os.path.join(
            args.input_dir, 
            'selected_test_df.pd'
        ))

    