import argparse
from easy_tpp.config_factory import Config
from easy_tpp.runner import Runner
from easy_tpp.utils import set_seed, create_folder, save_yaml_config
import datetime
import pickle
import git
import numpy as np
from tqdm import tqdm

all_results = {}

def main(args, config_path):
    config = Config.build_from_yaml_file(config_path, experiment_id=args.experiment_id)


    # Grab some git information.  # TODO: sync git info to server
    try:
        git_commit = git.Repo(search_parent_directories=True).head.object.hexsha
        git_branch = git.Repo(search_parent_directories=True).active_branch
        git_is_dirty = git.Repo(search_parent_directories=True).is_dirty()
    except:
        print('Failed to grab git info...')
        git_commit = 'NoneFound'
        git_branch = 'NoneFound'
        git_is_dirty = 'NoneFound'


    i = 0
    results = {
        'date': datetime.date.today(),
        'gitcommit': git_commit,
        'git_branch': git_branch,
        'git_is_dirty': git_is_dirty,
        'test_ll': [],
        'test_time_ll': [],
        'test_mark_ll': [],
    }

    save_path = config.base_config.specs['saved_model_dir']
    create_folder(save_path)
    best_valid_ll = float('-inf')
    best_valid_res = {}
    best_model_id = -1



    for seed in list(range(2019, 2024)):
    # for seed in [2023]:  # nhp + ehrshot
    # for seed in [18]:  # for thp + retweet
    # for seed in [2019, 2020, 2021, 2023, 2024]:  # attnhp + taobao ran into nan with seed=2022
        config.trainer_config.seed = seed
        print(f"Start training model {i}...")
        save_yaml_config(f'{save_path}/model_{i}.yaml', config)
        set_seed(config.trainer_config.seed)
        model_runner = Runner.build_from_config(config)  #, unique_model_dir=True)

        res = model_runner.run(save_model_id=f'_{i}')

        # save configs for model i
        res['model_id'] = i
        # res['config'] = config
        res['params'] = {'random_seed': config.trainer_config.seed}
        results[i] = res
        results['test_ll'].append(res['best_metrics']['test']['loglike'])
        results['test_time_ll'].append(res['best_metrics']['test']['time_ll'])
        results['test_mark_ll'].append(res['best_metrics']['test']['mark_ll'])
        if res['best_valid_ll'] > best_valid_ll:
            best_valid_ll = res['best_valid_ll']
            best_valid_res = res
            best_model_id = i
        i += 1

        with open(save_path + '_results.pkl', 'wb') as f:
            pickle.dump(results, f)
        # for k, val in results.items():
        #     print(f'Model {k}:')
        #     print(val)
        print(results)
        # print(f'Best valid ll so far: {best_valid_ll}')
        # print(best_valid_res)


    with open(save_path + '_results.pkl', 'wb') as f:
        pickle.dump(results, f)
    print(results)
    for k, val in results.items():
        print(f'Model {k}:')
        print(val)
    print('Experiment finished')

    # with open(save_path + '_best_results.pkl', 'wb') as f:
    #     pickle.dump(best_valid_res, f)
    print()
    print(f'Global best validation ll: {best_valid_ll} at model {best_model_id}')
    print(best_valid_res)
    print()

    return results, best_valid_res, best_model_id


if __name__ == '__main__':

    dataset_config_path = {
        'taxi': './configs/exp_config_taxi.yaml',
        'taobao': './configs/exp_config_tb.yaml',
        'stackoverflow': './configs/exp_config_so.yaml',
        'amazon': './configs/exp_config_amazon.yaml',
        'retweet': './configs/exp_config_rt_jitter.yaml',
        'lastfm': './configs/exp_config_lastfm.yaml',
        'mimic': './configs/exp_config_mimic.yaml',
        'ehr': './configs/exp_config_ehr.yaml'
    }


    parser = argparse.ArgumentParser()
    parser.add_argument('--dataset', type=str, required=False, default='taxi',
                        help='Dir of configuration yaml to train and evaluate the model.')
    parser.add_argument('--experiment_id', type=str, required=False, default='RMTPP_train',
                        help='Experiment id in the config file.')
    args = parser.parse_args()
    print(f'Current dataset: {args.dataset}')
    results, best_valid_res, best_model_id = main(args, dataset_config_path[args.dataset])
    print(results)

    # assert(len(results['test_ll']) == len(results['test_time_ll']) == len(results['test_mark_ll']) == 5)
    test_ll, test_time_ll, test_mark_ll = np.array(results['test_ll']), np.array(results['test_time_ll']), np.array(results['test_mark_ll'])

    print()
    print("     AGGREGATED RESULTS:")
    print(f'Dataset: {args.dataset}')
    print(f'Model: {args.experiment_id}')
    print(f'Total logL: mean {np.round(np.mean(test_ll), 3)}; std {np.round(np.std(test_ll, ddof=1), 3)}')
    print(f'Time logL: mean {np.round(np.mean(test_time_ll), 3)}; std {np.round(np.std(test_time_ll, ddof=1), 3)}')
    print(f'Mark logL : mean {np.round(np.mean(test_mark_ll), 3)}; std {np.round(np.std(test_mark_ll, ddof=1), 3)}')
    print('Experiments done.')

