import argparse
import os
from datetime import datetime

import pandas as pd
import yaml

pd.set_option('display.precision', 4)


def report_trial_result(trial_dir, reported_param=None):
    special_params, special_name = [], ''

    if reported_param is not None:
        trial_param_path = os.path.join(trial_dir, 'seed_0', 'hparams.yaml')
        with open(trial_param_path, 'r') as f:
            hparams = yaml.unsafe_load(f.read())
        for hparam_name in reported_param:
            special_params.append(hparams[hparam_name])

    results = []
    for version_name in os.listdir(trial_dir):
        if 'seed' in version_name:
            trial_result_path = os.path.join(trial_dir, version_name, 'metrics.csv')
            metrics = pd.read_csv(trial_result_path)
            # print(trial_result_path)
            results.append([
                metrics['test/mae'].tolist()[-1],
                metrics['test/mse'].tolist()[-1],
            ])
    result_df = pd.DataFrame(results, columns=['mae', 'mse'], index=None)

    result_df.loc['mean', 'mae'] = result_df.mae.mean()
    result_df.loc['mean', 'mse'] = result_df.mse.mean()
    result_df.loc['mean', 'seed'] = 'mean'

    result_df.loc['std', 'mae'] = result_df.mae.std()
    result_df.loc['std', 'mse'] = result_df.mse.std()
    result_df.loc['std', 'seed'] = 'std'

    return result_df, special_params


def report_exp_abstract(exp_dir, reported_param, metrics=None, with_std=False):
    if metrics is None:
        metrics = ["mae", "mse"]

    abstract = []
    abstract_header = ['config_hash'] + reported_param + metrics

    for config_hash in os.listdir(exp_dir):
        config_cp_dir = os.path.join(exp_dir, config_hash)
        if not os.path.isdir(config_cp_dir):
            continue
        result_df, special_params = report_trial_result(config_cp_dir, reported_param)
        metric_mean = []
        for metric in metrics:
            if with_std:
                metric_mean.append('{:.2f}+{:.3f}'.format(result_df.loc['mean'][metric], result_df.loc['std'][metric]))
            else:
                metric_mean.append(result_df.loc['mean'][metric])
        abstract.append([config_hash] + special_params + metric_mean)

    abstract_df = pd.DataFrame(abstract, columns=abstract_header)
    time_stamp = datetime.now().strftime('%Y_%m_%d_%H_%M_%S')
    save_path = os.path.join(exp_dir, 'abstract_{}.csv'.format(time_stamp))
    abstract_df.to_csv(save_path)


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument("--exp_dir", default=None, type=str, help="config save dir")
    args = parser.parse_args()

    if args.exp_dir is not None:
        report_exp_abstract(
            exp_dir=args.exp_dir,
            reported_param=["hist_len", "pred_len", "lr"]
        )
