#! -*- coding: utf-8
import itertools
import os.path as path
import time
import typing
from abc import abstractmethod
from collections.abc import Iterable
from datetime import datetime
from logging import getLogger

import numpy as np
import torch
import torch.distributed as dist

from .datasets.utils import load_dataset
from .models.utils import load_model
from .optimizers.utils import get_client_optimizer, get_server_optimizer

__all__ = ["ANodeTrainer", "AServerNodeTrainer", "AClientNodeTrainer"]


@torch.no_grad
def mean_param(model: torch.nn.Module, params: typing.List[typing.List[torch.Tensor]]):
    for dest, recved in zip([p for p in model.parameters() if p.requires_grad], params):
        dest.data = torch.mean(torch.stack(recved, dim=0),
                               dim=0).to(dest.device)


@torch.no_grad()
def send_model_parameter(dest_ranks: typing.List[int], model: torch.nn.Module):
    if len(dest_ranks) < 1:
        return
    tasklist: typing.List[torch.distributed.P2POp] = []

    for rank in dest_ranks:
        tasklist += send_param(rank, model)

    for task in torch.distributed.batch_isend_irecv(tasklist):
        task.wait()


@torch.no_grad()
def send_param(rank: int, model: torch.nn.Module, tagoffset: int = 0) -> typing.List[torch.distributed.P2POp]:
    backend = torch.distributed.get_backend()

    return [torch.distributed.P2POp(torch.distributed.isend,
                                    p.cpu() if backend == "gloo" else p,
                                    rank, tag=tag+tagoffset)
            for tag, p in enumerate([p for p in model.parameters() if p.requires_grad])]


@torch.no_grad()
def recv_model_parameter(src_ranks: typing.List[int], model: torch.nn.Module):
    if len(src_ranks) < 1:
        return
    tasklist, params = [], [[] for p in model.parameters() if p.requires_grad]

    for rank in src_ranks:

        tasks, buffs = recv_param(rank, model)
        tasklist += tasks
        for param, buf in zip(params, buffs):
            param.append(buf)
        # params.append(buffs)

    for task in torch.distributed.batch_isend_irecv(tasklist):
        task.wait()

    mean_param(model, params)


@torch.no_grad()
def recv_param(rank: int, model: torch.nn.Module, tagoffset: int = 0) -> typing.Tuple[typing.List[torch.distributed.P2POp], torch.Tensor]:
    backend = torch.distributed.get_backend()
    tasklist, params = [], []

    for tag, p in enumerate([p for p in model.parameters() if p.requires_grad]):
        # 受信バッファ
        buff = torch.zeros_like(
            p, device="cpu" if backend == "gloo" else p.device)
        tasklist.append(torch.distributed.P2POp(torch.distributed.irecv,
                                                buff,
                                                rank, tag=tag+tagoffset))
        params.append(buff)

    return tasklist, params


class ANodeTrainer(object):
    def __init__(self, rank: int, world_size: int, outdir: str,
                 *args,
                 seed: int = 11, model_seed: int = 17,
                 model_config: typing.Dict = {}, dataset_config: typing.Dict = {}, train_config: typing.Dict = {},
                 name: str = "", device: str = "cpu", dtype: typing.Union[str, torch.dtype] = "float",
                 nretry_http: int = 10, nretry_http_wait: float = 120.0,
                 **kwargs):
        self.rank = rank
        self.world_size = world_size
        self.name = name
        self.device = device
        self.logger = getLogger(self.name)

        self.outdir = outdir
        self.dtype = getattr(torch, dtype) if not isinstance(dtype, torch.dtype) \
            else dtype
        assert isinstance(self.dtype, torch.dtype)
        self.seed = seed
        self.model_seed = model_seed

        self.model_config = model_config
        self.dataset_config = dataset_config
        self.train_config = train_config

        # huggingfaceとの通信でエラーが発生することがある。
        # エラー発生時に再通信を試みる
        self.nretry_http = nretry_http
        self.nretry_http_wait = nretry_http_wait

    @abstractmethod
    def init_model(self, model: torch.nn.Module):
        # override on server/client node trainer.
        raise NotImplementedError("abstract method")

    @abstractmethod
    def distribute_dataset(self, dataset: torch.utils.data.Dataset, dist_method, seed: int = 17,
                           **dist_kwargs):
        # override on server/client node trainer.
        raise NotImplementedError("abstract method")

    @abstractmethod
    def get_optimizer(self, model: torch.nn.Module, t_initial: int, nouter: int, ninner: int, lr: float,
                      optimizer_config: typing.Dict = {}, scheduler_config: typing.Dict = {}) -> typing.Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LRScheduler]:
        raise NotImplementedError(f"Override client/server node trainer.")

    @abstractmethod
    def load_dataset(self, ) -> typing.Tuple[torch.utils.data.Dataset, torch.utils.data.Dataset, ]:
        raise NotImplementedError(f"Override client/server node trainer.")

    def release_resource(self):
        pass

    def run(self, t_initial: int = None):
        np.random.seed(self.seed)
        torch.manual_seed(self.seed)
        torch.cuda.manual_seed(self.seed)

        try:
            # load datasets
            # huggingfaceとの通信でエラーになることがあるので複数回リトライできるようにする。
            for i in range(1, self.nretry_http+1):
                try:
                    trains, evals = self.load_dataset()
                    break
                except Exception as e:
                    if i < self.nretry_http:
                        self.logger.warning("[%d/%d] failed load data error: %s, wait %f sec to retry.",
                                            *[i, self.nretry_http, e,
                                                self.nretry_http_wait],
                                            exc_info=True)
                        time.sleep(self.nretry_http_wait)
                        continue
                    raise  # 指定したretry回数を越えた場合、処理中断

            # このノードの学習データのラベル
            if hasattr(trains, "targets"):  # Dataset class
                labels, counts = np.unique(trains.targets,
                                           return_counts=True)

                self.logger.info(
                    f"load dataset: train={len(trains)}, eval={len(evals)}"
                    + ", train labels=[" + ", ".join([f"{l}: {c}" for l, c in zip(labels, counts)]) + "]")
            elif hasattr(trains, "dataset") and hasattr(trains, "indices"):
                # Subset class
                targets = np.array(trains.dataset.targets)
                indices = np.array(trains.indices)
                labels, counts = np.unique(targets[indices],
                                           return_counts=True)

                self.logger.info(
                    f"load dataset: train={len(trains)}, eval={len(evals)}"
                    + ", train labels=[" + ", ".join([f"{l}: {c}" for l, c in zip(labels, counts)]) + "]")
            else:
                # no classification
                self.logger.info(
                    f"load dataset: train={len(trains)}, eval={len(evals)}")

            # huggingfaceとの通信でエラーになることがあるので複数回リトライできるようにする。
            for i in range(1, self.nretry_http+1):
                try:
                    # load model
                    np.random.seed(self.model_seed)
                    torch.manual_seed(self.model_seed)
                    torch.cuda.manual_seed(self.model_seed)
                    name = self.model_config.pop("name", "ResNet18")
                    model = load_model(name, **self.model_config)
                    self.init_model(model)
                    self.logger.info(f"model: {model}")
                    break
                except Exception as e:
                    if i < self.nretry_http:
                        self.logger.warning("[%d/%d] failed load data error: %s, wait %f sec to retry.",
                                            *[i, self.nretry_http, e,
                                                self.nretry_http_wait],
                                            exc_info=True)
                        time.sleep(self.nretry_http_wait)
                        continue
                    raise  # 指定したretry回数を越えた場合、処理中断

            np.random.seed(self.seed)
            torch.manual_seed(self.seed)
            torch.cuda.manual_seed(self.seed)

            # load dataset
            self.train(model, trains, evals, **self.train_config,
                       t_initial=t_initial)
        except Exception as e:
            self.logger.critical(e, exc_info=True)
        finally:  # finalize resources
            self.release_resource()

    @property
    def metric_keys(self) -> typing.List[str]: return ["metric"]

    def train(self,
              model: torch.nn.Module, train_dataset: torch.utils.data.Dataset, eval_dataset: torch.utils.data.Dataset,
              *args,
              nstep: int = 10, nouter: int = -1, ninner: int = -1,
              lr: float = 1e-3,  batch_size: int = 32,
              eval_interval: int = 5, grad_diff_norm_interval: int = -1,
              optimizer: typing.Dict = {}, scheduler: typing.Dict = {},
              t_initial: int = None,
              **kwargs):
        self.logger.info(f"use dtype: {self.dtype}")
        model.to(self.device)
        if isinstance(self.dtype, torch.dtype):
            model.type(self.dtype)

        self.logger.info(
            f"t_initial: {t_initial if isinstance(t_initial, int) else nstep}")
        optimizer, scheduler = self.get_optimizer(
            model, t_initial if isinstance(t_initial, int) else nstep,
            nouter, ninner, lr,
            optimizer_config=optimizer,
            scheduler_config=scheduler)
        if ninner <= 0:
            train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size,
                                                       shuffle=True)
        else:
            from .datasets.sampler import InnerLoopSampler
            sampler = InnerLoopSampler(train_dataset, batch_size=batch_size, inner_loop=ninner,
                                       shuffle=True, seed=self.seed)
            train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size,
                                                       sampler=sampler)
            self.logger.info(f"use data sampler: {sampler}")
        eval_loader = torch.utils.data.DataLoader(eval_dataset, batch_size=batch_size,
                                                  shuffle=False)
        self.logger.info(f"optimizer: {optimizer}, scheduler: {scheduler}")

        # nouter: epochs, ninnter: nbatchに相当。
        # データ数を不均等にする場合、iteration数を合わせるために設定
        with open(path.join(self.outdir, f"train.{self.name}.csv"), "wt") as trainf, \
                open(path.join(self.outdir, f"eval.{self.name}.csv"), "wt") as evalf:
            # write header
            ext_log_items = ["eta", "mu", "r", "vu", "G", "rbar",
                             "rdist", "client_loss", "server_loss", "delta",
                             "server_loss_out", "server_loss_best",
                             "w", "y", "w1", "w2",
                             "local_c_norm", "global_c_norm",
                             "loss_correct",
                             "param_norm", "grad_norm", "grad_sq_norm", "gamma", "gamma_t"]
            ext_log_elements = []
            for ext_log_item in ext_log_items:
                if hasattr(optimizer, ext_log_item):
                    ext_log_elements.append(ext_log_item)

            trainf.write(",".join(["step", "loss", *self.metric_keys, *ext_log_elements,
                                   "lr", "train_proc", "comm_proc"])+"\n")
            evalf.write(",".join(["step", "loss", *self.metric_keys, *ext_log_elements])
                        + "\n")

            # logging initial parameters.
            ext_log_elements, metrics = [], [0.0 for _ in self.metric_keys]
            for ext_log_item in ext_log_items:  # for Dog
                if hasattr(optimizer, ext_log_item):
                    ext_log_elements.append(
                        getattr(optimizer, ext_log_item, np.nan))
            trainf.write(
                ",".join(map(str, [0, np.nan, *metrics, *ext_log_elements]
                             + (scheduler._get_lr(0)
                                if scheduler is not None else [np.nan])
                             + [0.0, getattr(optimizer, "communication_proc", np.nan)]))+"\n")
            evalf.write(
                ",".join(map(str, [0, np.nan, *metrics, *ext_log_elements]))+"\n")

            train_proc = 0.0
            for istep in range(1, nstep+1):
                model.train()
                t = datetime.now()
                loss, niter, metrics = self.train_iter(nouter, ninner, model, train_loader, optimizer,
                                                       *args, **kwargs)
                train_proc += (datetime.now() - t).total_seconds()

                if not isinstance(metrics, Iterable):
                    metrics = [metrics]

                ext_log_elements = []
                for ext_log_item in ext_log_items:
                    if hasattr(optimizer, ext_log_item):
                        # for Dog
                        ext_log_elements.append(getattr(optimizer, ext_log_item,
                                                        np.nan))

                trainf.write(
                    ",".join(map(str, [istep, loss, *metrics, *ext_log_elements]
                                 + (scheduler._get_lr(istep-1)
                                    if scheduler is not None else [np.nan])
                                 + [train_proc, getattr(optimizer, "communication_proc", np.nan)]))+"\n")
                self.logger.debug(f"[{istep:4d}/{nstep:4d}] train loss={loss:.3f}, "
                                  + ", ".join([f"{k}={metric:.3f}" for k, metric in zip(self.metric_keys, metrics)]))

                if istep % eval_interval == 0:
                    with torch.no_grad():
                        model.eval()

                        # ローカルモデルで評価
                        loss, niter, metrics = self.train_iter(0, 0, model, eval_loader, optimizer,
                                                               *args, **kwargs)
                        if not isinstance(metrics, Iterable):
                            metrics = [metrics]

                        evalf.write(
                            ",".join(map(str, [istep, loss, *metrics, *ext_log_elements]))+"\n")
                        self.logger.info(f"[{istep:4d}/{nstep:4d}] eval loss={loss:.3f}, "
                                         + ", ".join([f"{k}={metric:.3f}" for k, metric in zip(self.metric_keys, metrics)]))
                    trainf.flush()
                    evalf.flush()
                if scheduler is not None:
                    scheduler.step(istep)
                if isinstance(grad_diff_norm_interval, int) and grad_diff_norm_interval > 0 and istep % grad_diff_norm_interval == 0:
                    self.calc_grad_diff_norm(istep, model, train_loader,
                                             *args, **kwargs)

        torch.save(model.state_dict(),
                   path.join(self.outdir, f"{self.name}.pth"))

    @abstractmethod
    def train_iter(self, nouter: int, ninner: int, model: torch.nn.Module, loader: torch.utils.data.DataLoader,
                   optimizer: torch.optim.Optimizer, clip_norm: float = None) -> typing.Tuple[float, int, float]:
        # return loss, niter, metric
        # is_grad_enabled = torch.is_grad_enabled()
        raise NotImplementedError("abstract method")

    @abstractmethod
    def calc_grad_diff_norm(self, model: torch.nn.Module, loader: torch.utils.data.DataLoader, *args, **kwargs):
        raise NotImplementedError("abstract method")


class AServerNodeTrainer(ANodeTrainer):
    def __init__(self, rank: int, world_size: int, outdir: str,
                 *args,
                 seed: int = 11, model_seed: int = 17,
                 model_config: typing.Dict = {}, dataset_config: typing.Dict = {},
                 train_config: typing.Dict = {},
                 name: str = "", device: str = "cpu", dtype: typing.Union[str, torch.dtype] = "float",
                 nretry_http: int = 10, nretry_http_wait: float = 120.0,
                 client_node_ranks: typing.List[int] = [],
                 **kwargs):
        super().__init__(rank, world_size, outdir, *args, seed=seed, model_seed=model_seed,
                         model_config=model_config, dataset_config=dataset_config, train_config=train_config,
                         name=name, device=device, dtype=dtype,
                         nretry_http=nretry_http, nretry_http_wait=nretry_http_wait,
                         **kwargs)

        self.client_node_ranks = client_node_ranks
        self.client_grad_diff_norm_log = open(path.join(outdir, "client_grad_diff_norms.csv"),
                                              "wt")
        self.client_grad_diff_norm_log.write(
            ",".join(["step", "node_a", "node_b", "grad_diff_norm"])+"\n")

    def init_model(self, model: torch.nn.Module):
        # single nodeの場合、Server nodeは起動しない。
        send_model_parameter(self.client_node_ranks, model)
        self.logger.info(
            f"send initialized model to client nodes: {self.client_node_ranks}")

    def distribute_dataset(self, dataset: torch.utils.data.Dataset, dist_method, seed: int = 17, **kwargs):
        server_nodes = [i for i in range(self.world_size)
                        if not i in self.client_node_ranks]
        self_idx = server_nodes.index(self.rank)

        from .datasets.utils import distribute_dataset
        datasets = distribute_dataset(dataset, len(server_nodes), dist_method, seed=seed,
                                      **kwargs)
        dataset = datasets[self_idx]
        return dataset

    def load_dataset(self):
        # load datasets
        name = self.dataset_config.pop("name", "cifar-10")
        seed = self.dataset_config.pop("seed", 0)

        client_dist_config = self.dataset_config.pop("client", {})
        server_dist_config = self.dataset_config.pop("server", {})

        dist_method = server_dist_config.pop("method", "even")
        dist_kwargs = server_dist_config.pop("kwargs", {})
        # trains, evals = load_dataset(name, **self.dataset_config)

        if seed != 0 and not "seed" in dist_kwargs:
            dist_kwargs["seed"] = seed
        trains, evals = load_dataset(name, **self.dataset_config)

        trains = self.distribute_dataset(trains, dist_method,
                                         **dist_kwargs)  # 分割方法の指定
        return trains, evals

    def get_optimizer(self, model: torch.nn.Module, t_initial: int, nouter: int, ninner: int, lr: float,
                      optimizer_config: typing.Dict = {},
                      scheduler_config: typing.Dict = {}) -> typing.Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LRScheduler]:
        optimizer, scheduler = get_server_optimizer(model, self.rank, t_initial, nouter, ninner, lr,
                                                    optimizer_config=optimizer_config,
                                                    scheduler_config=scheduler_config,
                                                    client_node_ranks=self.client_node_ranks, )
        return optimizer, scheduler

    def release_resource(self):
        try:
            self.client_grad_diff_norm_log.flush()
            self.client_grad_diff_norm_log.close()
            self.client_grad_diff_norm_log = None
        except:
            pass

    @torch.no_grad()
    def calc_grad_diff_norm(self, istep: int, model: torch.nn.Module, loader: torch.utils.data.DataLoader, *args, **kwargs):
        nparams = np.sum([p.numel() for p in model.parameters()
                          if p.requires_grad])
        tasks, grads = [], {}

        for node_id in self.client_node_ranks:
            buf = torch.zeros(nparams)  # MUST cpu and float.
            tasks.append(dist.P2POp(dist.irecv, buf, node_id, tag=12345))
            grads[node_id] = buf

        for task in dist.batch_isend_irecv(tasks):
            task.wait()

        for node_a, node_b in itertools.combinations(self.client_node_ranks, 2):
            grad_diff_norm = (grads[node_a] - grads[node_b]
                              ).norm().detach().cpu().item()
            self.client_grad_diff_norm_log.write(
                ",".join(list(map(str, [istep, node_a, node_b, grad_diff_norm])))+"\n")
        self.client_grad_diff_norm_log.flush()


class AClientNodeTrainer(ANodeTrainer):
    def __init__(self, rank: int, world_size: int, outdir: str,
                 *args,
                 seed: int = 11, model_seed: int = 17,
                 model_config: typing.Dict = {}, dataset_config: typing.Dict = {},
                 train_config: typing.Dict = {},
                 name: str = "", device: str = "cpu", dtype: typing.Union[str, torch.dtype] = "float",
                 nretry_http: int = 10, nretry_http_wait: float = 120.0,
                 server_node_rank: int = -1,
                 **kwargs):
        super().__init__(rank, world_size, outdir, *args, seed=seed, model_seed=model_seed,
                         model_config=model_config, dataset_config=dataset_config, train_config=train_config,
                         name=name, device=device, dtype=dtype,
                         nretry_http=nretry_http, nretry_http_wait=nretry_http_wait,
                         **kwargs)

        self.server_node_rank = server_node_rank

    def init_model(self, model: torch.nn.Module):
        if not isinstance(self.server_node_rank, int) or self.server_node_rank < 0:
            # single node.
            return model

        recv_model_parameter([self.server_node_rank], model)
        self.logger.info(f"initialized model from server node.")

    def distribute_dataset(self, dataset: torch.utils.data.Dataset, dist_method, seed: int = 17, **kwargs):
        client_nodes = [i for i in range(self.world_size)
                        if i != self.server_node_rank]
        self_idx = client_nodes.index(self.rank)

        from .datasets.utils import distribute_dataset
        datasets = distribute_dataset(dataset, len(client_nodes), dist_method, seed=seed,
                                      **kwargs)
        dataset = datasets[self_idx]
        return dataset

    def load_dataset(self):
        # load datasets
        name = self.dataset_config.pop("name", "cifar-10")
        seed = self.dataset_config.pop("seed", 0)

        client_dist_config = self.dataset_config.pop("client", {})
        server_dist_config = self.dataset_config.pop("server", {})

        dist_method = client_dist_config.pop("method", "even")
        dist_kwargs = client_dist_config.pop("kwargs", {})
        if seed != 0 and not "seed" in dist_kwargs:
            dist_kwargs["seed"] = seed
        trains, evals = load_dataset(name, **self.dataset_config)

        trains = self.distribute_dataset(trains, dist_method,
                                         **dist_kwargs)  # 分割方法の指定
        return trains, evals

    def get_optimizer(self, model: torch.nn.Module, t_initial: int, nouter: int, ninner: int, lr: float,
                      optimizer_config: typing.Dict = {},
                      scheduler_config: typing.Dict = {}) -> typing.Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LRScheduler]:
        optimizer, scheduler = get_client_optimizer(model, self.rank, t_initial, nouter, ninner, lr,
                                                    optimizer_config=optimizer_config,
                                                    scheduler_config=scheduler_config,
                                                    server_node_rank=self.server_node_rank)
        return optimizer, scheduler

    @abstractmethod
    def calc_grad_diff_norm(self, istep: int, model: torch.nn.Module, loader: torch.utils.data.DataLoader, *args, **kwargs):
        raise NotImplementedError()
