#! -*- coding: utf-8
import os.path as path
import typing
from datetime import datetime
from logging import getLogger

import numpy as np
import torch

from .graphs.dynamic_graph import DynamicGraph
from .optimizers.utils import get_optimizer
from .utils import (merge_model_parameter, recv_model_parameter,
                    send_model_parameter)

__all__ = ["ClassificationNodeTrainer", "ClassificationNodeEvaluator"]


class ClassificationNodeTrainer(object):
    def __init__(self, rank: int, world_size: int, outdir: str, graph: DynamicGraph,
                 *args,
                 seed: int = 11,
                 model_seed: int = 17,
                 model_config: typing.Dict = {}, dataset_config: typing.Dict = {},
                 train_config: typing.Dict = {},
                 eval_node_ranks: typing.List[int] = [],
                 name: str = "", device: str = "cpu", dtype: typing.Union[str, torch.dtype] = "float",
                 **kwargs):
        self.rank = rank
        self.world_size = world_size
        self.name = name
        self.device = device
        self.graph = graph

        self.logger = getLogger(self.name)

        self.outdir = outdir
        self.dtype = getattr(torch, dtype) if not isinstance(dtype, torch.dtype) \
            else dtype
        assert isinstance(self.dtype, torch.dtype)
        self.seed = seed
        self.model_seed = model_seed

        self.model_config = model_config
        self.dataset_config = dataset_config
        self.train_config = train_config

        self.ntrain_node = self.graph.n_nodes
        self.eval_node_ranks = eval_node_ranks

    def run(self, t_initial: int = None):
        np.random.seed(self.seed)
        torch.manual_seed(self.seed)
        torch.cuda.manual_seed(self.seed)

        is_train = self.rank < self.graph.n_nodes

        try:
            from .datasets.utils import distribute_dataset, load_dataset
            from .models.utils import load_model

            # load model
            np.random.seed(self.model_seed)
            torch.manual_seed(self.model_seed)
            torch.cuda.manual_seed(self.model_seed)
            name = self.model_config.pop("name", "ResNet18")
            model = load_model(name, *self.model_config.get("args", []),
                               **self.model_config.get("kwargs", {}))
            self.logger.info(f"model: {model}")

            # load datasets
            np.random.seed(self.model_seed)
            torch.manual_seed(self.model_seed)
            torch.cuda.manual_seed(self.model_seed)
            name = self.dataset_config.pop("name", "cifar-10")
            seed = self.dataset_config.pop("seed", 0)
            dist_method = self.dataset_config.pop("method", "even")
            dist_kwargs = self.dataset_config.pop("kwargs", {})
            trains, evals = load_dataset(name, **self.dataset_config)
            if is_train:
                trains = distribute_dataset(trains, self.graph.n_nodes, dist_method,
                                            seed=seed, **dist_kwargs)
                trains = trains[self.rank]
                if hasattr(trains, "targets"):  # Dataset class
                    labels, counts = np.unique(trains.targets,
                                               return_counts=True)

                    self.logger.info(
                        f"load dataset: train={len(trains)}, eval={len(evals)}"
                        + ", train labels=[" + ", ".join([f"{l}: {c}" for l, c in zip(labels, counts)]) + "]")
                elif hasattr(trains, "dataset") and hasattr(trains, "indices"):
                    # Subset class
                    targets = np.array(trains.dataset.targets)
                    indices = np.array(trains.indices)
                    labels, counts = np.unique(targets[indices],
                                               return_counts=True)

                    self.logger.info(
                        f"load dataset: train={len(trains)}, eval={len(evals)}"
                        + ", train labels=[" + ", ".join([f"{l}: {c}" for l, c in zip(labels, counts)]) + "]")
                else:
                    # no classification
                    self.logger.info(
                        f"load dataset: train={len(trains)}, eval={len(evals)}")
            else:
                self.logger.info(
                    f"load dataset: train=this node is evaluator, eval={len(evals)}")

            np.random.seed(self.seed)
            torch.manual_seed(self.seed)
            torch.cuda.manual_seed(self.seed)

            # load dataset
            self.train(model, trains, evals,
                       **self.train_config, t_initial=t_initial)

        except Exception as e:
            self.logger.critical(e, exc_info=True)
        finally:  # finalize resources
            self.logger.info(f"end")
            pass

    def train(self,
              model: torch.nn.Module, train_dataset: torch.utils.data.Dataset, eval_dataset: torch.utils.data.Dataset,
              *args,
              nstep: int = 10, nouter: int = -1, ninner: int = -1,
              lr: float = 1e-3,  batch_size: int = 32,
              eval_interval: int = 5,
              optimizer: typing.Dict = {}, scheduler: typing.Dict = {},
              t_initial: int = None,
              **kwargs):
        self.logger.info(f"use dtype: {self.dtype}")
        model.to(self.device).type(self.dtype)

        self.logger.info(
            f"t_initial: {t_initial if isinstance(t_initial, int) else nstep}")
        optimizer, scheduler = get_optimizer(model, self.rank, self.graph,
                                             t_initial if isinstance(
                                                 t_initial, int) else nstep,
                                             nouter, ninner, lr,
                                             optimizer_config=optimizer,
                                             scheduler_config=scheduler)
        if ninner <= 0:
            train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size,
                                                       shuffle=True)
        else:
            from .datasets.sampler import InnerLoopSampler
            sampler = InnerLoopSampler(train_dataset, batch_size=batch_size, inner_loop=ninner,
                                       shuffle=True, seed=self.seed)
            train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size,
                                                       sampler=sampler)
            self.logger.info(f"use data sampler: {sampler}")
        eval_loader = torch.utils.data.DataLoader(eval_dataset, batch_size=batch_size,
                                                  shuffle=False)
        self.logger.info(f"optimizer: {optimizer}, scheduler: {scheduler}")

        with open(path.join(self.outdir, f"train.{self.name}.csv"), "wt") as trainf, \
                open(path.join(self.outdir, f"eval.{self.name}.csv"), "wt") as evalf:
            # write header
            trainf.write(",".join(["step", "iter", "loss", "metric", "lr",
                                   "train_proc", "comm_proc"])+"\n")
            evalf.write(",".join(["step", "iter", "loss", "metric"])+"\n")

            train_proc = 0.0
            for istep in range(1, nstep+1):
                model.train()
                t = datetime.now()
                loss, niter, metric = self.train_iter(nouter, ninner, model, train_loader, optimizer,
                                                      *args, **kwargs)
                train_proc += (datetime.now() - t).total_seconds()
                trainf.write(
                    ",".join(map(str, [istep, niter, loss, metric, ]
                                 + scheduler._get_lr(istep-1) if hasattr(scheduler, "_get_lr") else [np.nan]
                                 + [train_proc, getattr(optimizer, "communication_proc", np.nan)]))+"\n")
                self.logger.debug(
                    f"[{istep:4d}/{nstep:4d}] train loss={loss:.3f}, metric={metric:.3f}")

                if istep % eval_interval == 0:
                    with torch.no_grad():
                        model.eval()
                        send_proc = datetime.now()
                        send_model_parameter(self.eval_node_ranks, model)
                        send_proc = (datetime.now() -
                                     send_proc).total_seconds()
                        self.logger.info(
                            f"send model prameter to eval node: ranks={self.eval_node_ranks}, proc={send_proc}")

                        loss, niter, metric = self.train_iter(0, 0, model, eval_loader, optimizer,
                                                              *args, **kwargs)
                        evalf.write(
                            ",".join(map(str, [istep, niter, loss, metric]))+"\n")
                        self.logger.info(
                            f"[{istep:4d}/{nstep:4d}] eval loss={loss:.3f}, metric={metric:.3f}")
                    trainf.flush()
                    evalf.flush()
                if scheduler is not None:
                    scheduler.step(istep)

        torch.save(model.state_dict(),
                   path.join(self.outdir, f"{self.name}.pth"))

    def train_iter(self, nouter: int, ninner: int, model: torch.nn.Module, loader: torch.utils.data.DataLoader,
                   optimizer: torch.optim.Optimizer, force_merge_parameter_interval: int = -1,
                   criterion: str = "CrossEntropyLoss") -> typing.Tuple[float, int, float]:
        nouter = nouter if nouter > 0 else 1

        niter = nouter * ninner
        if isinstance(criterion, str):
            criterion = getattr(torch.nn, criterion)()
        else:
            criterion = torch.nn.CrossEntropyLoss()

        iter = 0
        loss, metrics, ndata = 0.0, 0.0, 0
        for iouter in range(nouter):
            for x, y in loader:
                iter += 1

                x = x.to(device=self.device, dtype=self.dtype)
                y = y.to(device=self.device)
                o: torch.Tensor = model(x)
                l: torch.Tensor = criterion(o, y)

                if l.requires_grad:  # update model
                    assert not (torch.isnan(l) or torch.isinf(l)), \
                        f"loss is NaN or Inf: {l}"
                    optimizer.zero_grad()
                    l.backward()
                    optimizer.step()

                    com_count = getattr(optimizer, "comm_cnt", -1)
                    if self.graph.n_nodes > 1 and com_count > 0 and force_merge_parameter_interval > 0 and com_count % force_merge_parameter_interval == 0:
                        merge_model_parameter(self.rank, list(
                            range(self.graph.n_nodes)), model)
                        self.logger.info(
                            f"force merge model parameter: {com_count}")

                loss += l.detach().float().cpu().item()
                metrics += (o.argmax(dim=-1).detach().cpu().numpy()
                            == y.detach().cpu().numpy()).sum()
                ndata += len(x)

        loss /= iter
        metrics /= ndata  # = acc
        return loss, iter, metrics


class ClassificationNodeEvaluator(ClassificationNodeTrainer):

    def train(self,
              model: torch.nn.Module, train_dataset: torch.utils.data.Dataset, eval_dataset: torch.utils.data.Dataset,
              *args,
              nstep: int = 10, nouter: int = -1, ninner: int = -1,
              lr: float = 1e-3,  batch_size: int = 32,
              eval_interval: int = 5,
              optimizer: typing.Dict = {}, scheduler: typing.Dict = {},
              t_initial: int = None,
              **kwargs):
        model.to(self.device).type(self.dtype)

        train_node_ranks = list(range(self.graph.n_nodes))
        eval_loader = torch.utils.data.DataLoader(eval_dataset, batch_size=batch_size,
                                                  shuffle=False)

        with open(path.join(self.outdir, f"eval.{self.name}.csv"), "wt") as evalf:
            # write header
            evalf.write(",".join(["step", "iter", "loss", "metric"])+"\n")

            for istep in range(1, nstep+1):

                if istep % eval_interval == 0:
                    with torch.no_grad():
                        model.eval()
                        recv_proc = datetime.now()
                        recv_model_parameter(train_node_ranks, model)
                        recv_proc = (datetime.now() -
                                     recv_proc).total_seconds()
                        self.logger.info(
                            f"recv model prameter from train node: ranks={train_node_ranks}, proc={recv_proc}")

                        loss, niter, metric = self.train_iter(0, 0, model, eval_loader, optimizer,
                                                              *args, **kwargs)
                        evalf.write(
                            ",".join(map(str, [istep, niter, loss, metric]))+"\n")
                        self.logger.info(
                            f"[{istep:4d}/{nstep:4d}] eval loss={loss:.3f}, metric={metric:.3f}")
                    evalf.flush()

        torch.save(model.state_dict(),
                   path.join(self.outdir, f"{self.name}.pth"))
