import os.path as osp
import sys

proj_path = osp.abspath(osp.dirname(__file__)).split('src')[0]
sys.path.append(proj_path + 'src')

import multiprocessing
from utils.util_funcs import *
from utils import *
import pandas as pd
import time
from copy import deepcopy
import os
import ast
from itertools import product
from pprint import pformat
import traceback
import argparse
import importlib


class Tuner():
    # Major functions
    # ✅ Maintains dataset specific tune dict
    # ✅ Tune dict to tune dataframe (para combinations)
    # ✅ Beautiful printer
    # ✅ Build-in grid search function
    # ✅ Result summarization
    # ✅ Try-catch function to deal with bugs
    # ✅ Tune report to txt.
    def __init__(self, exp_args, search_dict, default_dict=None):
        self.birth_time = get_cur_time(t_format='%m_%d-%H_%M_%S')
        self.__dict__.update(exp_args.__dict__)

        self.model = exp_args.model_config(exp_args).model
        self._d = deepcopy(default_dict) if default_dict is not None else {}
        if 'data_spec_configs' in search_dict:
            self.update_data_specific_cand_dict(search_dict['data_spec_configs'])
        self._d.update(search_dict)
        self._d.pop('data_spec_configs', None)

        # Configs
        self._searched_conf_list = list(self._d.keys())
        module = importlib.import_module(f'tune.{self.model}')
        self.important_paras = module.important_paras

    def update_data_specific_cand_dict(self, cand_dict):
        for k, v in cand_dict.items():
            self._d.update({k: v[self.dataset]})

    # * ============================= Properties =============================

    def __str__(self):
        return f'\nExperimental config: {pformat(self.cf)}\n' \
               f'\nGrid searched parameters:{pformat(self._d)}' \
               f'\nTune_df:\n{self.tune_df}\n'

    @property
    def cf(self):
        # All configs = tune specific configs + trial configs
        return {k: v for k, v in self.__dict__.items() if k[0] != '_'}

    @property
    def trial_cf(self):
        # Trial configs: configs for each trial.
        tune_global_cf = ['run_times', 'start_ind', 'reverse_iter',
                          'model', 'model_config', 'train_func', 'birth_time']
        return {k: self.cf[k] for k in self.cf if k not in tune_global_cf}

    @property
    def tune_df(self):
        # Tune dataframe: each row stands for a trial (hyper-parameter combination).
        # convert the values of parameters to list
        for para in self._d:
            if not isinstance(self._d[para], list):
                self._d[para] = [self._d[para]]
        return pd.DataFrame.from_records(dict_product(self._d))

    @time_logger
    def grid_search(self, debug_mode=False):
        # ! Step 1: Remove finished trials.
        if self.ignore_prev:
            print('Ignoring previous results and rerun experiments')
            tune_df = self.tune_df
        else:
            (finished_ids, _), __ = self.check_running_status()
            print(f'Found {len(finished_ids)}/{len(self.tune_df)} previous finished trials.')
            tune_df = self.tune_df.drop(finished_ids)

        # ! Step 2: Subset trials left to run using start and end points.
        print(self)

        # Parse start and end point
        end_ind = int(self.end_point) if self.end_point > 0 else len(tune_df)
        if self.start_point < 1:
            start_ind = int(len(tune_df) * self.start_point)
        else:
            start_ind = int(self.start_point)
        end_ind = min(len(tune_df), end_ind)
        assert start_ind >= 0 and start_ind <= len(tune_df)

        tune_dict = tune_df.to_dict('records')[start_ind:end_ind]

        total_trials = len(tune_dict)
        finished_trials = 0
        outer_start_time = time.time()

        # ! Step 3: Grid search
        failed_trials, skipped_trials = 0, 0

        for i in range(len(tune_dict)):
            ind = len(tune_dict) - i - 1 if self.reverse_iter else i
            para_dict = deepcopy(self.trial_cf)
            para_dict.update(tune_dict[ind])
            inner_start_time = time.time()
            print(f'\n{i}/{len(tune_dict)} <{self.exp_name}> Start tuning: {para_dict}, {get_cur_time()}')
            res_file = self.model_config(Dict2Config(para_dict)).res_file

            # ! Step 2.1: Check whether previous results exists
            if skip_results(res_file, self.ignore_prev):
                print(f'Found previous results, skipped running current trial.')
                total_trials -= 1
                skipped_trials += 1
            elif debug_mode:
                for seed in range(self.run_times):
                    para_dict['seed'] = seed
                    if not self.log_on: block_log()
                    cf = self.train_func(Dict2Config(para_dict))
                    if not self.log_on: enable_logs()
                    iter_time_estimate(f'\tSeed {seed}/{self.run_times}', '',
                                       inner_start_time, seed + 1, self.run_times)
                finished_trials += 1
                iter_time_estimate(f'Trial finished, ', '',
                                   outer_start_time, finished_trials, total_trials)
                calc_mean_std(cf.res_file)
            else:
                try:
                    for seed in range(self.run_times):
                        para_dict['seed'] = seed
                        if not self.log_on: block_log()
                        cf = self.train_func(Dict2Config(para_dict))
                        if not self.log_on: enable_logs()
                        iter_time_estimate(f'\tSeed {seed}/{self.run_times}', '',
                                           inner_start_time, seed + 1, self.run_times)
                    finished_trials += 1
                    iter_time_estimate(f'Trial finished, ', '',
                                       outer_start_time, finished_trials, total_trials)
                except Exception as e:
                    log_file = f'log/{self.model}-{self.dataset}-{self.exp_name}-{self.birth_time}.log'
                    mkdir_list([log_file])
                    error_msg = ''.join(traceback.format_exception(None, e, e.__traceback__))
                    with open(log_file, 'a+') as f:
                        f.write(
                            f'\nTrain failed at {get_cur_time()} while running {para_dict} at seed {seed},'
                            f' error message:{error_msg}\n'
                            f'Tunning command: tu -s{self.start_point} -e{self.end_point} -d{self.dataset} -x{self.exp_name}')
                        f.write(f'{"-" * 100}')
                    if not self.log_on: enable_logs()
                    print(f'Train failed, error message: {error_msg}')
                    failed_trials += 1
                    continue
                calc_mean_std(cf.res_file)
        print(f'\n\n{"=" * 24 + " Grid Search Finished " + "=" * 24}\n'
              f'Successfully run {finished_trials} trials, skipped {skipped_trials} previous trials,'
              f'failed {failed_trials} trials.')
        if failed_trials > 0: print(f'Check {log_file} for bug reports.\n{"=" * 70}\n')

    # * ============================= Results Processing =============================

    def summarize(self):
        exp_name, model, dataset = self.exp_name, self.model, self.dataset
        _, res_f_list = self.check_running_status()
        # print(f'\n\nSummarizing expriment {self.exp_name}...')
        out_prefix = f'{SUM_PATH}{model}/{dataset}/{model}_{dataset}<{exp_name}>'

        # res_file = res_to_excel(res_f_list, out_prefix, set(self._searched_conf_list + self.important_paras))
        # print(f'Summary of {self.exp_name} finished. Results saved to {res_file}')

        try:
            res_file = res_to_excel(res_f_list, out_prefix, set(self._searched_conf_list + self.important_paras))
            print(f'Summary of {self.exp_name} finished. Results saved to {res_file}')
        except:
            pass
            # print(f'!!!!!!Cannot summarize {self.exp_name} \tres_f_list:{res_f_list}\n '
            #       f'was not summarized and skipped!!!!')

    def check_running_status(self):
        finished_ids, res_f_list = [], []
        tune_df = self.tune_df
        tune_dict = tune_df.to_dict('records')

        for i in range(len(tune_df)):
            para_dict = deepcopy(self.trial_cf)
            para_dict.update(tune_dict[i])
            res_file = self.model_config(Dict2Config(para_dict)).res_file
            if os.path.exists(res_file):
                with open(res_file, 'r') as f:
                    content = ''.join(f.readlines())
                    if 'avg_' in content and '±nan' not in content:
                        finished_ids.append(i)
                        res_f_list.append(res_file)
        return (finished_ids, len(tune_df)), res_f_list


def iter_time_estimate(prefix, postfix, start_time, iters_finished, total_iters):
    """
    Generates progress bar AFTER the ith epoch.
    Args:
        prefix: the prefix of printed string
        postfix: the postfix of printed string
        start_time: start time of the iteration
        iters_finished: finished iterations
        max_i: max iteration index
        total_iters: total iteration to run, not necessarily
            equals to max_i since some trials are skiped.

    Returns: prints the generated progress bar
    """
    cur_run_time = time.time() - start_time
    total_estimated_time = cur_run_time * total_iters / iters_finished
    print(
        f'{prefix} [{time2str(cur_run_time)}/{time2str(total_estimated_time)}, {time2str(total_estimated_time - cur_run_time)} left] {postfix} [{get_cur_time()}]')


def dict_product(d):
    keys = d.keys()
    return [dict(zip(keys, element)) for element in product(*d.values())]


def add_tune_df_common_paras(tune_df, para_dict):
    for para in para_dict:
        tune_df[para] = [para_dict[para] for _ in range(len(tune_df))]
    return tune_df


@time_logger
def run_multiple_process(func, func_arg_list):
    '''
    Args:
        func: Function to run
        func_arg_list: An iterable object that contains several dict. Each dict has the input (**kwargs) of the tune_func

    Returns:

    '''
    process_list = []
    for func_arg in func_arg_list:
        _ = multiprocessing.Process(target=func, kwargs=func_arg)
        process_list.append(_)
        _.start()
    for _ in process_list:
        _.join()
    return


def calc_mean_std(f_name):
    """
    Load results from f_name and calculate mean and std value
    """
    if os.path.exists(f_name):
        out_df, metric_set = load_dict_results(f_name)
    else:
        'Result file missing, skipped!!'
        return
    mean_res = out_df[metric_set].mean()
    std_res = out_df[metric_set].std()
    for k in metric_set:
        for m in ['acc', 'Acc', 'AUC', 'ROC', 'f1', 'F1']:
            if m in k:  # percentage of metric value
                mean_res[k] = mean_res[k] * 100
                std_res[k] = std_res[k] * 100
    mean_dict = dict(zip([f'avg_{m}' for m in metric_set], [f'{mean_res[m]:.2f}' for m in metric_set]))
    std_dict = dict(zip([f'std_{m}' for m in metric_set], [f'{std_res[m]:.2f}' for m in metric_set]))
    with open(f_name, 'a+') as f:
        f.write('\n\n' + '#' * 10 + 'AVG RESULTS' + '#' * 10 + '\n')
        for m in metric_set:
            f.write(f'{m}: {mean_res[m]:.4f} ({std_res[m]:.4f})\n')
        f.write('#' * 10 + '###########' + '#' * 10)
    write_nested_dict({'avg': mean_dict, 'std': std_dict}, f_name)


def res_to_excel(res_f_list, out_prefix, interested_conf_list):
    """
    res_path: folder name of the result files.
    """

    def post_process_results(sum_df):
        # ! Format mean and std
        metric_names = [cname[4:] for cname in sum_df.columns if 'avg' in cname]
        for m in metric_names:
            sum_df['avg_' + m] = sum_df['avg_' + m].apply(lambda x: f'{x:.2f}')
            sum_df['std_' + m] = sum_df['std_' + m].apply(lambda x: f'{x:.2f}')
            sum_df[m] = sum_df['avg_' + m] + '±' + sum_df['std_' + m]
            sum_df = sum_df.drop(columns=['avg_' + m, 'std_' + m])
        # ! Deal with failed experiments and NA columns
        for col in sum_df.columns[sum_df.isnull().any()]:
            for index in sum_df[col][sum_df.isnull()[col]].index:
                try:
                    sum_df.loc[index, col] = sum_df.loc[index, 'config2str'][col]
                except KeyError:
                    # Found previous results that doesn't match
                    # print(f'{index} dropped.')
                    sum_df.drop(index)
        bad_trails = [i for i in sum_df.index if '±0.00' in sum_df.loc[i, f'test_{METRIC}']]
        sum_df = sum_df.drop(bad_trails)
        sum_df.reset_index(inplace=True, drop=True)
        # Reorder column order list : move config2str to the end
        col_names = list(sum_df.columns) + ['config2str']
        col_names.remove('config2str')
        sum_df = sum_df[col_names]
        return sum_df

    eval_metric, test_metric, final_metric = (f'avg_{_}_{METRIC}' for _ in ['val', 'test', 'final'])

    sum_res_list = []
    for res_file in res_f_list:
        # print(f'ResFile:{res_file}')
        if os.path.isfile(res_file) and res_file[-3:] == 'txt':
            # Load records
            sum_dict, conf_dict = {}, {}
            with open(res_file, 'r') as f:
                res_lines = f.readlines()
                if "\'nan\'" in ''.join(res_lines):
                    os.remove(res_file)
                    continue
                if 'avg_' not in ''.join(res_lines): continue
                # print(f'Summarizing {f}, Lines:{res_lines}')
                for line in res_lines:
                    if line[0] == '{':
                        d = ast.literal_eval(line.strip('\n'))
                        if 'model' in d.keys():  # parameters
                            conf_dict = d.copy()
                        elif 'avg_' in list(d.keys())[0]:  # mean results
                            avg_res_dict = dict(zip(d.keys(), [float(v) for v in d.values()]))
                        elif 'std_' in list(d.keys())[0]:
                            std_res_dict = dict(zip(d.keys(), [float(v) for v in d.values()]))
                try:
                    sum_dict.update(subset_dict(conf_dict, interested_conf_list))
                    sum_dict.update(avg_res_dict)
                    sum_dict.update(std_res_dict)
                except:
                    # print(f'!!!!File {f.name} is not summarized, skipped!!!!')
                    continue

                sum_dict['config2str'] = conf_dict
                sum_res_list.append(sum_dict)
    sum_df = pd.DataFrame.from_dict(sum_res_list).sort_values(eval_metric, ascending=False)
    sum_df = post_process_results(sum_df)
    # Save to excel
    # test_best_out_prefix = out_prefix.replace('results', 'test_best_res')
    mkdir_list([out_prefix])
    # val_best_result = sum_df.loc[0, f'test_{METRIC}']
    # res_file = f'{out_prefix}{val_best_result}.xlsx'
    # sum_df.to_excel(res_file)
    # Save test best results
    test_best_res = sum_df.max()[f'final_{METRIC}']
    sum_df = sum_df.sort_values(f'final_{METRIC}', ascending=False)
    # sum_df.to_excel(f'{test_best_out_prefix}{test_best_result}.xlsx')
    test_best_res_file = f'{out_prefix}{test_best_res}.xlsx'
    sum_df.to_excel(test_best_res_file)
    return test_best_res_file


def load_dict_results(f_name):
    # Init records
    parameters = {}
    metric_set = None
    eid = 0
    with open(f_name, 'r') as f:
        res_lines = f.readlines()
        for line in res_lines:
            if line[0] == '{':
                d = ast.literal_eval(line.strip('\n'))
                if 'model' in d.keys():  # parameters
                    eid += 1
                    parameters[eid] = line.strip('\n')
                elif 'avg_' in list(d.keys())[0] or 'std_' in list(d.keys())[0]:
                    pass
                else:  # results
                    if metric_set == None:
                        metric_set = list(d.keys())
                        for m in metric_set:  # init metric dict
                            exec(f'{m.replace("-", "")}=dict()')
                    for m in metric_set:
                        exec(f'{m.replace("-", "")}[eid]=float(d[\'{m}\'])')
    metric_set_str = str(metric_set).replace('\'', '').strip('[').strip(']').replace("-", "")
    exec(f'out_list_ = [parameters,{metric_set_str}]', globals(), locals())
    out_list = locals()["out_list_"]
    out_df = pd.DataFrame.from_records(out_list).T
    out_df.columns = ['parameters', *metric_set]
    return out_df, metric_set


def skip_results(res_file, ignore_prev):
    """
    Case 0: Ignore previous: Clear and rerun => Clear and return False
    Case 1: Previous results exists and summarized => skip => Return True
    Case 2: Previous results exists but unfinished => clear and rerun => Clear and return False
    Case 3: Previous results doesn't exist => run => Return False

    """
    if os.path.isfile(res_file):
        if ignore_prev:
            os.remove(res_file)
            return False
        with open(res_file, 'r') as f:
            for line in f.readlines():
                if line[0] == '{':
                    d = ast.literal_eval(line.strip('\n'))
                    if 'avg_' in list(d.keys())[0]:
                        # ! Case 1: Previous results exists and summarized => skip => Return True
                        return True
            # ! Case 2: Previous results exists but unfinished => clear and rerun => Clear and return False
            os.remove(res_file)
            print(f'Resuming from {res_file}')
            return False
    else:
        # ! Case 3: Previous results doesn't exist => run => Return False
        return False


@time_logger
def tune_model(model='GraphHD', prt_dataset='zinc_standard_agent',  dataset='sider', run_times=10,
               split='scaffold_80', exp_name='GraphHD_Debug'):
    parser = argparse.ArgumentParser()
    parser.add_argument('-m', '--model', type=str, default=model)
    parser.add_argument('-d', '--dataset', type=str, default=dataset)
    parser.add_argument('-pd', '--prt_dataset', type=str, default=prt_dataset)
    parser.add_argument('-sp', '--split', type=str, default=split)
    parser.add_argument('-r', '--run_times', type=int, default=run_times)
    parser.add_argument('-x', '--exp_name', type=str, default=exp_name)
    parser.add_argument('-s', '--start_point', type=float, default=0)
    parser.add_argument('-e', '--end_point', type=float, default=-1)
    parser.add_argument('-g', '--gpus', type=int, default=0)
    parser.add_argument('-v', '--reverse_iter', action='store_true', help='reverse iter or not')
    parser.add_argument('-b', '--log_on', action='store_true', help='show log or not')
    parser.add_argument('-i', '--ignore_prev', action='store_true', help='ignore previous results or not')
    parser.add_argument('-D', '--debug_mode', action='store_true', help='stop if run into error')
    args = parser.parse_args()
    if is_runing_on_local(): args.gpu = -1

    exp_init(gpu_id=args.gpus)
    import importlib

    module = importlib.import_module(f'tune.{args.model}')
    search_dict = module.EXP_DICT[args.exp_name]
    search_dict.pop('gpu_conf', None)
    conf_class, train_func = [module.model_settings[_] for _ in ['model_config', 'train_func']]

    args.__dict__.update({'model_config': conf_class, 'train_func': train_func})
    args.log_on = True if args.debug_mode else args.log_on
    tuner = Tuner(args, search_dict=search_dict)
    tuner.grid_search(args.debug_mode)


if __name__ == '__main__':
    tune_model()
