#! -*- coding: utf-8
import logging.handlers
import os
import typing
from datetime import timedelta
from logging import getLogger

import numpy as np
import torch
import torch.distributed

from log import LOGFORMAT, LogLevel


def _init_process(rank, world_size, outdir: str,
                  trainer_class_name: str, *args,
                  name: str = "", t_initial: int = None,
                  server_node_rank: typing.Optional[int] = None,
                  client_node_ranks: typing.List[int] = [],
                  backend: str = "gloo", init_method: str = None, timeout: float = 1800.0, process_group: str = "",
                  master_addr: str = "127.0.0.1", master_port: int = 29501,
                  log_queue: torch.multiprocessing.Queue = None,
                  loglevel: LogLevel = LogLevel.INFO, logfile: str = None,
                  **kwargs):
    handler = logging.handlers.QueueHandler(log_queue)
    "[%(asctime)s] %(process)d %(name)s <%(levelname)s> [%(module)s.%(funcName)s#%(lineno)d] %(message)s"
    logging.basicConfig(level=loglevel, handlers=[handler],
                        format=f"{name} [%(module)s.%(funcName)s#%(lineno)d] %(message)s" if loglevel <= LogLevel.DEBUG else f"{name} %(message)s")
    is_server = rank == server_node_rank
    if is_server and trainer_class_name == "ClassificationNodeTrainer":
        from .classification_node_trainer import \
            ClassificationServerNodeTrainer
        trainer = ClassificationServerNodeTrainer(rank, world_size, outdir,
                                                  *args,
                                                  name=name,
                                                  client_node_ranks=client_node_ranks,
                                                  log_queue=log_queue, loglevel=loglevel,
                                                  **kwargs)
    elif not is_server and trainer_class_name == "ClassificationNodeTrainer":
        from .classification_node_trainer import \
            ClassificationClientNodeTrainer
        trainer = ClassificationClientNodeTrainer(rank, world_size, outdir,
                                                  *args,
                                                  name=name,
                                                  server_node_rank=server_node_rank,
                                                  log_queue=log_queue, loglevel=loglevel,
                                                  **kwargs)
    elif is_server and trainer_class_name == "GlueNodeTrainer":
        from .glue_node_trainer import GlueServerNodeTrainer
        trainer = GlueServerNodeTrainer(rank, world_size, outdir,
                                        *args,
                                        name=name,
                                        client_node_ranks=client_node_ranks,
                                        log_queue=log_queue, loglevel=loglevel,
                                        **kwargs)
    elif not is_server and trainer_class_name == "GlueNodeTrainer":
        from .glue_node_trainer import GlueClientNodeTrainer
        trainer = GlueClientNodeTrainer(rank, world_size, outdir,
                                        *args,
                                        name=name,
                                        server_node_rank=server_node_rank,
                                        log_queue=log_queue, loglevel=loglevel,
                                        **kwargs)

    # set init process group
    if init_method is None:
        os.environ["MASTER_ADDR"] = master_addr
        os.environ["MASTER_PORT"] = str(master_port)
        trainer.logger.info(
            f"backend={backend}, rank={rank}/{world_size}, mastar addr={master_addr}:{master_port}")
    else:
        trainer.logger.info(
            f"backend={backend}, rank={rank}/{world_size}, init_method={init_method}")
    torch.distributed.init_process_group(backend, init_method=init_method, timeout=timedelta(seconds=timeout),
                                         world_size=world_size, rank=rank, group_name=process_group)

    trainer.run(t_initial=t_initial)


class Trainer(object):
    def __init__(self, outdir: str, node_config: typing.Dict, *args, **kwargs):
        self.outdir = outdir
        # ie: {name: "node-0",  device: "cuda:0"}
        self.node_config = node_config
        self.server_node_config: typing.Optional[typing.Dict[str, str]] = node_config.get(
            "nodes", {}).get("server", None)
        self.client_node_configs: typing.List[typing.Dict[str, str]] = node_config.get(
            "nodes", {}).get("clients", [])

        self.nclient_node = len(self.client_node_configs)
        self.logger = getLogger(__name__)
        self.logger.info(f"server node config: {self.server_node_config}")
        self.logger.info(f"client node configs: {self.client_node_configs}")

    def run(self, config: typing.Dict,
            t_initial: int = None,
            backend: str = "gloo", init_method: str = None, timeout: float = 1800.0,
            process_group: str = "", master_addr: str = "127.0.0.1", master_port: int = 29501,
            dtype: typing.Union[str, torch.dtype] = "float",
            *args, **kwargs):
        # multi-processのメモリ管理をspawnに変更
        torch.multiprocessing.set_start_method("spawn")
        # multi-processing log
        log_queue = torch.multiprocessing.Queue()
        # subprocessのログはファイルログを優先する。
        # 標準出力、ファイル出力両方を指定している場合、ファイル出力のみに出力される
        listener = logging.handlers.QueueListener(log_queue, getLogger().handlers[-1],
                                                  respect_handler_level=False)
        listener.start()

        processes = []
        try:
            loglevel = int(np.max([getLogger().level, self.logger.level]))

            self.logger.info(f"config: {config}")

            seed = config.pop("seed", 11)
            model_seed = config.pop("model_seed", seed)
            if not isinstance(model_seed, int):
                model_seed = seed
            model_config = config.pop("model", {})
            dataset_config = config.pop("dataset", {})
            train_config = config.get("train", {})
            process_group_config = config.get("process_group", {})
            # process groupの設定
            process_group_config.update(dict(backend=backend, init_method=init_method, timeout=timeout, process_group=process_group,
                                             master_addr=master_addr, master_port=master_port))
            # trainerクラスのクラス名を取得
            trainer_class_name = train_config.pop("name",
                                                  "ClassificationNodeTrainer")

            # 各ノードにrank割り当て
            configs, server_node_rank, client_node_ranks = [], None, []
            rank = 0
            if self.server_node_config is not None:  # server nodeを先頭に設定する
                configs.append(self.server_node_config)
                server_node_rank = rank
                rank += 1
            for client_config in self.client_node_configs:
                configs.append(client_config)
                client_node_ranks.append(rank)
                rank += 1

            world_size = len(configs)
            for rank, node_config in enumerate(configs):
                kwargs = dict(seed=seed, model_seed=model_seed, dtype=dtype, model_config=model_config, dataset_config=dataset_config,
                              train_config=train_config, t_initial=t_initial,
                              server_node_rank=server_node_rank,
                              client_node_ranks=client_node_ranks,
                              log_queue=log_queue, loglevel=loglevel, is_trainer=True)
                kwargs.update(node_config)
                kwargs.update(**process_group_config)

                p = torch.multiprocessing.Process(
                    target=_init_process,
                    args=[rank, world_size, self.outdir, trainer_class_name],
                    kwargs=kwargs)
                p.start()
                processes.append(p)
            for p in processes:
                p.join()
        except Exception as e:
            self.logger.critical(e, exc_info=True)
        finally:
            listener.stop()
