#! -*- coding: utf-8
import os
import os.path as path
import sys
from argparse import ArgumentParser
from logging import getLogger

import yaml

from log import LogLevel, config_logger
from trainers.trainer import Trainer
import typing


def dict_deep_merge(dest: typing.Dict, src: typing.Dict) -> typing.Dict:
    keys = set(list(dest.keys()) + list(src.keys()))
    ret = dict()
    for k in keys:
        d = dest.get(k, None)
        s = src.get(k, None)
        if not k in dest:  # key が dest に設定されていない
            ret[k] = s
        elif not k in src:  # key が src に設定されていない
            ret[k] = d
        elif isinstance(d, dict) and isinstance(s, dict):
            ret[k] = dict_deep_merge(d, s)
        else:
            ret[k] = s

    return ret


if __name__ == "__main__":
    parser = ArgumentParser()
    parser.add_argument("outdir", type=str)
    parser.add_argument("config", type=str, help="基本設定ファイル")
    parser.add_argument("node_config", type=str, help="node settings")
    parser.add_argument("extra_configs", type=str, nargs="*",
                        help="追加、上書き設定。あとに設定した設定ファイルほど優先される。")

    # override parameters
    parser.add_argument("--seed", type=int, default=None)
    parser.add_argument("--model_seed", type=int, default=None)
    parser.add_argument("--data_seed", type=int, default=None)
    parser.add_argument("--data_nduplicate", type=int, default=None,
                        help="データ分割時に許容するデータの重複回数。")

    parser.add_argument("--dtype", type=str, default="float32",
                        choices=["float32", "float16", "float64"])

    parser.add_argument("--nstep", type=int, default=None,
                        help="override nstep on train config parameter.")
    parser.add_argument("--nouter", type=int, default=None,
                        help="override nouter on train config parameter.")
    parser.add_argument("--ninner", type=int, default=None,
                        help="override ninner on train config parameter.")
    parser.add_argument("--lr", type=float, default=None,
                        help="override lr on train config parameter.")
    parser.add_argument("--lr_min", type=float, default=None,
                        help="override lr_min on train config parameter.")
    parser.add_argument("--t_initial", type=int, default=None,
                        help="lr Schedulerの適用最終step。指定された値を越えるとlrが変わらなくなる。")
    parser.add_argument("--warmup_lr_init", type=float, default=None,
                        help="override warmup_lr_init on train config parameter.")
    parser.add_argument("--warmup_t", type=int, default=None,
                        help="override warmup_t on train config parameter.")
    parser.add_argument("--batch_size", type=int, default=None,
                        help="override batch_size on train config parameter.")
    parser.add_argument("--eval_interval", type=int, default=None,
                        help="override eval_interval on train config parameter.")
    parser.add_argument("--grad_diff_norm_interval", type=int, default=None,
                        help="override grad_diff_norm_interval on train config parameter.")

    parser.add_argument("--backend", type=str, default="gloo")
    parser.add_argument("--init_method", type=str, default=None)
    # parser.add_argument("--timeout", type=float, default=1800.0)
    parser.add_argument("--timeout", type=float, default=7200.0)
    parser.add_argument("--process_group", type=str, default="")
    parser.add_argument("--master_addr", type=str, default="127.0.0.1")
    parser.add_argument("--master_port", type=int, default=29501)

    parser.add_argument("--loglevel", type=lambda level: LogLevel.nameof(level),
                        default=LogLevel.INFO)
    parser.add_argument("--logfile", type=str, default=None)
    parser.add_argument("--quiet", default=False, action="store_true")

    args = parser.parse_args()

    os.makedirs(args.outdir, exist_ok=True)

    config_logger(loglevel=args.loglevel, stream=None if args.quiet else sys.stderr,
                  logfile=args.logfile)
    logger = getLogger("main")
    cmd = "python " + " ".join(sys.argv)
    logger.info(cmd)
    logger.info(args)
    if args.quiet:
        print(cmd) # 標準出力のみに出力

    with open(args.config) as f:  # 基本設定ファイル
        config = yaml.load(f, Loader=yaml.SafeLoader)
    with open(args.node_config) as f:
        node_config = yaml.load(f, Loader=yaml.SafeLoader)

    # 追加・上書き設定ファイルを読み込む
    for conffile in args.extra_configs:
        with open(conffile) as f:
            conf = yaml.load(f, Loader=yaml.SafeLoader)
        config = dict_deep_merge(config, conf)
        logger.debug(f"{conffile}: {config}")

    # override config form commandline args
    if isinstance(args.seed, int):
        config["seed"] = args.seed
    if isinstance(args.model_seed, int):
        config["model_seed"] = args.model_seed
    if isinstance(args.data_seed, int):
        config["dataset"]["seed"] = args.data_seed
    if isinstance(args.data_nduplicate, int):
        # ノード数が増えるとノード当たりの平均学習データ数が少なくなるので、
        # データの重複を許容する設定
        for target in ["server", "client"]:
            if not "kwargs" in config["dataset"][target]:
                config["dataset"][target]["kwargs"] = {}
            config["dataset"][target]["kwargs"]["nduplicate"] = args.data_nduplicate
    for confname in ["nstep", "nouter", "ninner", "lr", "batch_size", "eval_interval", "grad_diff_norm_interval"]:
        # confing file, commandline nameは一致させている
        # 各パラメータのデータ型はArgmentParserに制限させる
        param = getattr(args, confname, None)
        if param is None:
            continue
        config["train"][confname] = param  # update config

    scheduler_config = config.get("train", {}).get("scheduler", {})
    if isinstance(args.lr_min, float):
        dest_confg = scheduler_config.get("kwargs", {})
        dest_confg["lr_min"] = args.lr_min
    if isinstance(args.warmup_t, float):
        dest_confg = scheduler_config.get("kwargs", {})
        dest_confg["warmup_t"] = args.warmup_t
    if isinstance(args.warmup_lr_init, float):
        dest_confg = scheduler_config.get("kwargs", {})
        dest_confg["warmup_lr_init"] = args.lr_min

    t_initial = args.t_initial if isinstance(args.t_initial, int) \
        else config["train"]["nstep"]

    logger.info(config)
    logger.info(node_config)

    with open(path.join(args.outdir, "config.yaml"), "w") as f:
        yaml.dump(config, f, allow_unicode=True)
    with open(path.join(args.outdir, "node_config.yaml"), "w") as f:
        yaml.dump(node_config, f, allow_unicode=True)

    trainer = Trainer(args.outdir, node_config=node_config)
    trainer.run(config,
                t_initial=t_initial, dtype=args.dtype,
                backend=args.backend, init_method=args.init_method, timeout=args.timeout, process_group=args.process_group,
                master_addr=args.master_addr, master_port=args.master_port)
