#! -*- coding: utf-8
import logging.handlers
import os
import typing
from datetime import timedelta
from logging import getLogger

import numpy as np
import torch
import torch.distributed

from log import LogLevel, config_logger

from .graphs.dynamic_graph import DynamicGraph
from .graphs.utils import get_graph


def _init_process(rank, world_size, outdir: str, graph: DynamicGraph,
                  trainer_class_name: str, *args,
                  t_initial: int = None,
                  backend: str = "gloo", init_method: str = None, timeout: float = 1800.0, process_group: str = "",
                  master_addr: str = "127.0.0.1", master_port: int = 29501,
                  log_queue: torch.multiprocessing.Queue = None,
                  loglevel: LogLevel = LogLevel.INFO, logfile: str = None,
                  is_trainer: bool = True,
                  **kwargs):
    handler = logging.handlers.QueueHandler(log_queue)
    logging.basicConfig(level=loglevel, handlers=[handler])
    if is_trainer and trainer_class_name == "ClassificationNodeTrainer":
        from .classification_node_trainer import ClassificationNodeTrainer
        trainer = ClassificationNodeTrainer(rank, world_size, outdir, graph,
                                            log_queue=log_queue, loglevel=loglevel,
                                            *args, **kwargs)
    elif not is_trainer and trainer_class_name == "ClassificationNodeTrainer":
        from .classification_node_trainer import ClassificationNodeEvaluator
        trainer = ClassificationNodeEvaluator(rank, world_size, outdir, graph,
                                              log_queue=log_queue, loglevel=loglevel,
                                              *args, **kwargs)
    elif is_trainer and trainer_class_name == "GlueNodeTrainer":
        from .glue_node_trainer import GlueNodeTrainer
        trainer = GlueNodeTrainer(rank, world_size, outdir, graph,
                                  log_queue=log_queue, loglevel=loglevel,
                                  *args, **kwargs)
    elif not is_trainer and trainer_class_name == "GlueNodeTrainer":
        from .glue_node_trainer import GlueNodeEvaluator
        trainer = GlueNodeEvaluator(rank, world_size, outdir, graph,
                                    log_queue=log_queue, loglevel=loglevel,
                                    *args, **kwargs)

    # set init process group
    if init_method is None:
        os.environ["MASTER_ADDR"] = master_addr
        os.environ["MASTER_PORT"] = str(master_port)
        trainer.logger.info(
            f"backend={backend}, rank={rank}/{world_size}, mastar addr={master_addr}:{master_port}")
    else:
        trainer.logger.info(
            f"backend={backend}, rank={rank}/{world_size}, init_method={init_method}")
    torch.distributed.init_process_group(backend, init_method=init_method, timeout=timedelta(seconds=timeout),
                                         world_size=world_size, rank=rank, group_name=process_group)

    trainer.run(t_initial=t_initial)


class Trainer(object):
    def __init__(self, outdir: str, node_config: typing.Dict, *args, **kwargs):
        self.outdir = outdir
        # ie: {name: "node-0",  device: "cuda:0"}
        self.node_config = node_config
        self.train_node_configs: typing.List[typing.Dict[str, str]] = node_config.get(
            "nodes", {}).get("train", [])
        self.eval_node_configs: typing.List[typing.Dict[str, str]] = node_config.get(
            "nodes", {}).get("eval", [])
        self.ntrain_node = len(self.train_node_configs)
        self.neval_node = len(self.eval_node_configs)
        self.logger = getLogger(__name__)
        self.logger.info(f"train node config: {self.train_node_configs}")
        self.logger.info(f"eval node config: {self.eval_node_configs}")

    def get_graph(self, name, *args, **kwargs) -> DynamicGraph:
        return get_graph(name, self.ntrain_node, *args, **kwargs)

    def run(self, config: typing.Dict,
            t_initial: int = None,
            backend: str = "gloo", init_method: str = None, timeout: float = 1800.0,
            process_group: str = "", master_addr: str = "127.0.0.1", master_port: int = 29501,
            dtype: typing.Union[str, torch.dtype] = "float",
            *args, **kwargs):
        torch.multiprocessing.set_start_method("spawn")
        # multi-processing log
        log_queue = torch.multiprocessing.Queue()
        listener = logging.handlers.QueueListener(log_queue, getLogger().handlers[-1],
                                                  respect_handler_level=False)
        listener.start()

        processes = []
        try:
            loglevel = int(np.max([getLogger().level, self.logger.level]))

            self.logger.info(f"config: {config}")

            graph_config = config.get("graph", {})
            graph = self.get_graph(graph_config.get("name", "BaseGraph"),
                                   *graph_config.get("args", []), **graph_config.get("kwargs", {}))
            self.logger.info(f"graph: {graph}")

            seed = config.pop("seed", 11)
            model_seed = config.pop("model_seed", seed)
            if not isinstance(model_seed, int):
                model_seed = seed
            model_config = config.pop("model", {})
            dataset_config = config.pop("dataset", {})
            train_config = config.get("train", {})
            process_group_config = config.get("process_group", {})
            process_group_config.update(dict(backend=backend, init_method=init_method, timeout=timeout, process_group=process_group,
                                             master_addr=master_addr, master_port=master_port))
            trainer_class_name = train_config.pop("name",
                                                  "ClassificationNodeTrainer")

            eval_node_ranks = [id for id in range(self.ntrain_node,
                                                  self.ntrain_node+self.neval_node)]
            world_size = self.ntrain_node + self.neval_node
            for rank, node_config in enumerate(self.train_node_configs):
                kwargs = dict(seed=seed, model_seed=model_seed, dtype=dtype, model_config=model_config, dataset_config=dataset_config,
                              train_config=train_config, t_initial=t_initial,
                              eval_node_ranks=eval_node_ranks,
                              log_queue=log_queue, loglevel=loglevel, is_trainer=True)
                kwargs.update(node_config)
                kwargs.update(**process_group_config)

                p = torch.multiprocessing.Process(
                    target=_init_process,
                    args=[rank, world_size, self.outdir, graph,
                          trainer_class_name],
                    kwargs=kwargs)
                p.start()
                processes.append(p)
            for rank, node_config in zip(eval_node_ranks, self.eval_node_configs):
                kwargs = dict(seed=seed, dtype=dtype, model_config=model_config, dataset_config=dataset_config,
                              train_config=train_config, eval_node_ranks=eval_node_ranks,
                              log_queue=log_queue, loglevel=loglevel, is_trainer=False)
                kwargs.update(node_config)
                kwargs.update(**process_group_config)

                p = torch.multiprocessing.Process(
                    target=_init_process,
                    args=[rank, world_size, self.outdir, graph,
                          trainer_class_name],
                    kwargs=kwargs)
                p.start()
                processes.append(p)
            for p in processes:
                p.join()
        except Exception as e:
            self.logger.critical(e, exc_info=True)
        finally:
            listener.stop()
