import truststore
truststore.inject_into_ssl()

import argparse
from easy_tpp.config_factory import Config
from easy_tpp.runner import Runner
from easy_tpp.utils import set_seed, create_folder, save_yaml_config
import datetime
import pickle
import git


def main(args):
    config = Config.build_from_yaml_file(args.config_dir, experiment_id=args.experiment_id)

    config.trainer_config.gpu = args.gpu

    # Grab some git information.  # TODO: sync git info to server
    try:
        git_commit = git.Repo(search_parent_directories=True).head.object.hexsha
        git_branch = git.Repo(search_parent_directories=True).active_branch
        git_is_dirty = git.Repo(search_parent_directories=True).is_dirty()
    except:
        print('Failed to grab git info...')
        git_commit = 'NoneFound'
        git_branch = 'NoneFound'
        git_is_dirty = 'NoneFound'


    i = 0
    results = {
        'date': datetime.date.today(),
        'gitcommit': git_commit,
        'git_branch': git_branch,
        'git_is_dirty': git_is_dirty
    }

    save_path = config.base_config.specs['saved_model_dir']
    create_folder(save_path)
    best_valid_ll = float('-inf')
    best_valid_res = {}
    best_model_id = -1

    for num_layers in [1, 2, 3]:
        config.model_config.num_layers = num_layers

        for hidden_size in [16, 32, 64]:
            config.model_config.hidden_size = hidden_size

            print(config.model_config)

            print(f"Start training model with num_layers={num_layers}, hidden_size={hidden_size}...")
            save_yaml_config(f'{save_path}/model_{i}.yaml', config)
            set_seed(config.trainer_config.seed)
            model_runner = Runner.build_from_config(config)  #, unique_model_dir=True)

            res = model_runner.run(save_model_id=f'_{i}')

            # save configs for model i
            res['model_id'] = i
            res['config'] = config
            res['params'] = {'num_layers': num_layers,
                             'hidden_size': hidden_size,
                             'random_seed': config.trainer_config.seed}
            results[i] = res
            if res['best_valid_ll'] > best_valid_ll:
                best_valid_ll = res['best_valid_ll']
                best_valid_res = res
                best_model_id = i
            i += 1

            with open(save_path + '_results.pkl', 'wb') as f:
                pickle.dump(results, f)
            for k, val in results.items():
                print(f'Model {k}:')
                print(val)
            # print(results)
            print(f'Best valid ll so far: {best_valid_ll}')
            print(best_valid_res)


    with open(save_path + '_results.pkl', 'wb') as f:
        pickle.dump(results, f)
    # print(results)
    for k, val in results.items():
        print(f'Model {k}:')
        print(val)
    print('Experiment finished')

    with open(save_path + '_best_results.pkl', 'wb') as f:
        pickle.dump(best_valid_res, f)
    print(f'Global best validation ll: {best_valid_ll} at model {best_model_id}')
    print(best_valid_res)


if __name__ == '__main__':

    # parser.add_argument('--config_dir', type=str, required=False, default='configs/exp_config_rt_gpu.yaml',
    #                     help='Dir of configuration yaml to train and evaluate the model.')

    # parser.add_argument('--config_dir', type=str, required=False, default='configs/exp_config_amazon_gpu.yaml',
    #                     help='Dir of configuration yaml to train and evaluate the model.')
    #
    # parser.add_argument('--config_dir', type=str, required=False, default='configs/exp_config_so_gpu.yaml',
    #                     help='Dir of configuration yaml to train and evaluate the model.')
    #
    #
    # parser.add_argument('--config_dir', type=str, required=False, default='configs/exp_config_tb_gpu.yaml',
    #                     help='Dir of configuration yaml to train and evaluate the model.')

    # parser.add_argument('--config_dir', type=str, required=False, default='configs/exp_config_taxi_gpu.yaml',
    #                     help='Dir of configuration yaml to train and evaluate the model.')

    dataset_config_path = {
        'taxi': 'configs/exp_config_taxi_gpu.yaml',
        'amazon': 'configs/exp_config_amazon_gpu.yaml',
        'retweet': 'configs/exp_config_rt_gpu.yaml',
        'taobao': 'configs/exp_config_tb_gpu.yaml',
        'stackoverflow': 'configs/exp_config_so_gpu.yaml',
    }

    for data, config_path in dataset_config_path.items():
        print(f'Current dataset: {data}')
        parser = argparse.ArgumentParser()
        parser.add_argument('--config_dir', type=str, required=False, default=config_path,
                            help='Dir of configuration yaml to train and evaluate the model.')
        parser.add_argument('--experiment_id', type=str, required=False, default='THP_train',
                            help='Experiment id in the config file.')
        parser.add_argument('--gpu', type=int, required=False, default=3, help='GPU ID.')
        args = parser.parse_args()
        main(args)
