#! -*- coding: utf-8
import os
import os.path as path
import sys
from argparse import ArgumentParser
from logging import getLogger

import yaml

from log import LogLevel, config_logger
from trainers.trainer import Trainer
import typing


def dict_deep_merge(dest: typing.Dict, src: typing.Dict) -> typing.Dict:
    keys = set(list(dest.keys()) + list(src.keys()))
    ret = dict()
    for k in keys:
        d = dest.get(k, None)
        s = src.get(k, None)
        if not k in dest:
            ret[k] = s
        elif not k in src:
            ret[k] = d
        elif isinstance(d, dict) and isinstance(s, dict):
            ret[k] = dict_deep_merge(d, s)
        else:
            ret[k] = s

    return ret


if __name__ == "__main__":
    parser = ArgumentParser()
    parser.add_argument("outdir", type=str)
    parser.add_argument("config", type=str)
    parser.add_argument("node_config", type=str, help="node settings")
    parser.add_argument("extra_configs", type=str, nargs="*")

    # override parameters
    parser.add_argument("--seed", type=int, default=None)
    parser.add_argument("--model_seed", type=int, default=None)
    parser.add_argument("--data_seed", type=int, default=None)
    parser.add_argument("--data_nduplicate", type=int, default=None)

    parser.add_argument("--dtype", type=str, default="float32",
                        choices=["float32", "float16", "float64"])

    parser.add_argument("--nstep", type=int, default=None,
                        help="override nstep on train config parameter.")
    parser.add_argument("--nouter", type=int, default=None,
                        help="override nouter on train config parameter.")
    parser.add_argument("--ninner", type=int, default=None,
                        help="override ninner on train config parameter.")
    parser.add_argument("--lr", type=float, default=None,
                        help="override lr on train config parameter.")
    parser.add_argument("--lr_min", type=float, default=None,
                        help="override lr_min on train config parameter.")
    parser.add_argument("--t_initial", type=int, default=None)
    parser.add_argument("--warmup_lr_init", type=float, default=None,
                        help="override warmup_lr_init on train config parameter.")
    parser.add_argument("--warmup_t", type=int, default=None,
                        help="override warmup_t on train config parameter.")
    parser.add_argument("--batch_size", type=int, default=None,
                        help="override batch_size on train config parameter.")
    parser.add_argument("--eval_interval", type=int, default=None,
                        help="override eval_interval on train config parameter.")

    parser.add_argument("--force_merge_parameter_interval", type=int, default=None,
                        help="override force_merge_parameter_interval on train config parameter.")

    parser.add_argument("--backend", type=str, default="gloo")
    parser.add_argument("--init_method", type=str, default=None)
    # parser.add_argument("--timeout", type=float, default=1800.0)
    parser.add_argument("--timeout", type=float, default=7200.0)
    parser.add_argument("--process_group", type=str, default="")
    parser.add_argument("--master_addr", type=str, default="127.0.0.1")
    parser.add_argument("--master_port", type=int, default=29501)

    parser.add_argument("--loglevel", type=lambda level: LogLevel.nameof(level),
                        default=LogLevel.INFO)
    parser.add_argument("--logfile", type=str, default=None)
    parser.add_argument("--quiet", default=False, action="store_true")

    args = parser.parse_args()

    os.makedirs(args.outdir, exist_ok=True)

    config_logger(loglevel=args.loglevel, stream=None if args.quiet else sys.stderr,
                  logfile=args.logfile)
    logger = getLogger("main")
    logger.info("python " + " ".join(sys.argv))
    logger.info(args)

    with open(args.config) as f:
        config = yaml.load(f, Loader=yaml.SafeLoader)
    with open(args.node_config) as f:
        node_config = yaml.load(f, Loader=yaml.SafeLoader)

    for conffile in args.extra_configs:
        with open(conffile) as f:
            conf = yaml.load(f, Loader=yaml.SafeLoader)
        config = dict_deep_merge(config, conf)
        logger.debug(f"{conffile}: {config}")

    # override config form commandline args
    if isinstance(args.seed, int):
        config["seed"] = args.seed
    if isinstance(args.model_seed, int):
        config["model_seed"] = args.model_seed
    if isinstance(args.data_seed, int):
        config["dataset"]["seed"] = args.data_seed
    if isinstance(args.data_nduplicate, int):
        config["dataset"]["kwargs"]["nduplicate"] = args.data_nduplicate
    for confname in ["nstep", "nouter", "ninner", "lr", "batch_size", "eval_interval",
                     "force_merge_parameter_interval"]:
        param = getattr(args, confname, None)
        if param is None:
            continue
        config["train"][confname] = param  # update config

    scheduler_config = config.get("train", {}).get("scheduler", {})
    if isinstance(args.lr_min, float):
        dest_confg = scheduler_config.get("kwargs", {})
        dest_confg["lr_min"] = args.lr_min
    if isinstance(args.warmup_t, float):
        dest_confg = scheduler_config.get("kwargs", {})
        dest_confg["warmup_t"] = args.warmup_t
    if isinstance(args.warmup_lr_init, float):
        dest_confg = scheduler_config.get("kwargs", {})
        dest_confg["warmup_lr_init"] = args.lr_min

    t_initial = args.t_initial if isinstance(args.t_initial, int) \
        else config["train"]["nstep"]

    logger.info(config)
    logger.info(node_config)

    with open(path.join(args.outdir, "config.yaml"), "w") as f:
        yaml.dump(config, f, allow_unicode=True)
    with open(path.join(args.outdir, "node_config.yaml"), "w") as f:
        yaml.dump(node_config, f, allow_unicode=True)

    trainer = Trainer(args.outdir, node_config=node_config)
    trainer.run(config,
                t_initial=t_initial, dtype=args.dtype,
                backend=args.backend, init_method=args.init_method, timeout=args.timeout, process_group=args.process_group,
                master_addr=args.master_addr, master_port=args.master_port)
