#!/usr/bin/env python
"""
Usage:
   train.py [options] MODEL_NAME TASK_NAME

MODEL_NAME has to be one of the supported models, which currently are
 GGNN, GNN-Edge-MLP, GNN-FiLM, RGAT, RGCN, RGDCN

Options:
    -h --help                       Show this screen.
    --data-path PATH                Path to load data from, has task-specific defaults under data/.
    --result-dir DIR                Directory to store logfiles and trained models. [default: trained_models]
    --run-test                      Indicate if the task's test should be run.
    --model-param-overrides PARAMS  Parameter settings overriding model defaults (in JSON format).
    --task-param-overrides PARAMS   Parameter settings overriding task defaults (in JSON format).
    --quiet                         Show less output.
    --tensorboard DIR               Dump tensorboard event files to DIR.
    --azure-info=<path>             Azure authentication information file (JSON). [default: azure_auth.json]
    --debug                         Turn on debugger.
"""
import json
import os
import sys
import time

from docopt import docopt
from dpu_utils.utils import run_and_debug, RichPath, git_tag_run

from utils.model_utils import name_to_model_class, name_to_task_class
from test import test


def run(args):
    azure_info_path = args.get('--azure-info', None)
    model_cls, additional_model_params = name_to_model_class(args['MODEL_NAME'])
    task_cls, additional_task_params = name_to_task_class(args['TASK_NAME'])

    # Collect parameters from first the class defaults, potential task defaults, and then CLI:
    task_params = task_cls.default_params()
    task_params.update(additional_task_params)
    model_params = model_cls.default_params()
    model_params.update(additional_model_params)

    # Load potential task-specific defaults:
    task_model_default_hypers_file = \
        os.path.join(os.path.dirname(__file__),
                     "tasks",
                     "default_hypers",
                     "%s_%s.json" % (task_cls.name(), model_cls.name(model_params)))
    if os.path.exists(task_model_default_hypers_file):
        print("Loading task/model-specific default parameters from %s." % task_model_default_hypers_file)
        with open(task_model_default_hypers_file, "rt") as f:
            default_task_model_hypers = json.load(f)
        task_params.update(default_task_model_hypers['task_params'])
        model_params.update(default_task_model_hypers['model_params'])

    # Load overrides from command line:
    task_params.update(json.loads(args.get('--task-param-overrides') or '{}'))
    model_params.update(json.loads(args.get('--model-param-overrides') or '{}'))

    # Finally, upgrade every parameters that's a path to a RichPath:
    task_params_orig = dict(task_params)
    for (param_name, param_value) in task_params.items():
        if param_name.endswith("_path"):
            task_params[param_name] = RichPath.create(param_value, azure_info_path)

    # Now prepare to actually run by setting up directories, creating object instances and running:
    result_dir = args.get('--result-dir', 'trained_models')
    os.makedirs(result_dir, exist_ok=True)
    task = task_cls(task_params)
    data_path = args.get('--data-path') or task.default_data_path()
    data_path = RichPath.create(data_path, azure_info_path)
    task.load_data(data_path)

    random_seeds = model_params['random_seed']
    if not isinstance(random_seeds, list):
        random_seeds = [random_seeds]

    for random_seed in random_seeds:
        model_params['random_seed'] = random_seed
        run_id = "_".join([task_cls.name(), model_cls.name(model_params), time.strftime("%Y-%m-%d-%H-%M-%S"), str(os.getpid())])

        model = model_cls(model_params, task, run_id, result_dir)
        model.log_line("Run %s starting." % run_id)
        model.log_line(" Using the following task params: %s" % json.dumps(task_params_orig))
        model.log_line(" Using the following model params: %s" % json.dumps(model_params))

        if sys.stdin.isatty():
            try:
                git_sha = git_tag_run(run_id)
                model.log_line(" git tagged as %s" % git_sha)
            except:
                print(" Tried tagging run in git, but failed.")
                pass

        model.initialize_model( )
        model.train(quiet=args.get('--quiet'), tf_summary_path=args.get('--tensorboard'))

        if args.get('--run-test'):
            test(model.best_model_file, data_path, result_dir, quiet=args.get('--quiet'), run_id=run_id)


if __name__ == "__main__":
    args = docopt(__doc__)
    run_and_debug(lambda: run(args), enable_debugging=args['--debug'])
