#! -*- coding: utf-8
import os.path as path
import typing
from datetime import datetime
from logging import getLogger

import evaluate
import numpy as np
import torch

from .graphs.dynamic_graph import DynamicGraph
from .optimizers.utils import get_optimizer
from .utils import (merge_model_parameter, recv_model_parameter,
                    send_model_parameter)

__all__ = ["GlueNodeTrainer", "GlueNodeEvaluator"]


class GlueNodeTrainer(object):
    def __init__(self, rank: int, world_size: int, outdir: str, graph: DynamicGraph,
                 *args,
                 seed: int = 11,
                 model_seed: int = 17,
                 model_config: typing.Dict = {}, dataset_config: typing.Dict = {},
                 train_config: typing.Dict = {},
                 eval_node_ranks: typing.List[int] = [],
                 name: str = "",
                 device: str = "cpu", dtype: typing.Union[str, torch.dtype] = "float",
                 **kwargs):
        self.rank = rank
        self.world_size = world_size
        self.name = name
        self.device = device
        self.graph = graph

        self.logger = getLogger(self.name)

        self.outdir = outdir
        self.dtype = getattr(torch, dtype) if not isinstance(dtype, torch.dtype) \
            else dtype
        assert isinstance(self.dtype, torch.dtype)
        self.seed = seed
        self.model_seed = model_seed

        self.model_config = model_config
        self.dataset_config = dataset_config
        self.train_config = train_config

        self.ntrain_node = self.graph.n_nodes
        self.eval_node_ranks = eval_node_ranks

        self.task = self.dataset_config.get("task", "sst2")
        cache_dir = self.dataset_config.get("datadir", None)
        self.evaluator = evaluate.load("glue", self.task,
                                       cache_dir=path.join(cache_dir, "evaluate", "glue", self.task))

    def run(self, t_initial: int = None):
        np.random.seed(self.seed)
        torch.manual_seed(self.seed)
        torch.cuda.manual_seed(self.seed)

        is_train = self.rank < self.graph.n_nodes

        try:
            from .datasets.utils import distribute_dataset, load_dataset
            from .models.utils import load_model

            # load model
            np.random.seed(self.model_seed)
            torch.manual_seed(self.model_seed)
            torch.cuda.manual_seed(self.model_seed)
            name = self.model_config.pop("name", "ResNet18")
            model = load_model(name, *self.model_config.get("args", []),
                               **self.model_config.get("kwargs", {}),
                               lora=self.model_config.get("lora", None),
                               quantization=self.model_config.get("quantization", None))
            self.logger.info(f"model: {model}")

            # load datasets
            np.random.seed(self.model_seed)
            torch.manual_seed(self.model_seed)
            torch.cuda.manual_seed(self.model_seed)
            name = self.dataset_config.pop("name", "cifar-10")
            seed = self.dataset_config.pop("seed", 0)
            dist_method = self.dataset_config.pop("method", "even")
            dist_kwargs = self.dataset_config.pop("kwargs", {})
            trains, evals = load_dataset(name, **self.dataset_config)
            if is_train:
                trains = distribute_dataset(trains, self.graph.n_nodes, dist_method,
                                            seed=seed, **dist_kwargs)
                trains = trains[self.rank]
                if hasattr(trains, "targets"):  # Dataset class
                    labels, counts = np.unique(trains.targets,
                                               return_counts=True)

                    self.logger.info(
                        f"load dataset: train={len(trains)}, eval={len(evals)}"
                        + ", train labels=[" + ", ".join([f"{l}: {c}" for l, c in zip(labels, counts)]) + "]")
                elif hasattr(trains, "dataset") and hasattr(trains, "indices"):
                    # Subset class
                    targets = np.array(trains.dataset.targets)
                    indices = np.array(trains.indices)
                    labels, counts = np.unique(targets[indices],
                                               return_counts=True)

                    self.logger.info(
                        f"load dataset: train={len(trains)}, eval={len(evals)}"
                        + ", train labels=[" + ", ".join([f"{l}: {c}" for l, c in zip(labels, counts)]) + "]")
                else:
                    # no classification
                    self.logger.info(
                        f"load dataset: train={len(trains)}, eval={len(evals)}")
            else:
                self.logger.info(
                    f"load dataset: train=this node is evaluator, eval={len(evals)}")

            np.random.seed(self.seed)
            torch.manual_seed(self.seed)
            torch.cuda.manual_seed(self.seed)

            # load dataset
            self.train(model, trains, evals,
                       **self.train_config, t_initial=t_initial)

        except Exception as e:
            self.logger.critical(e, exc_info=True)
        finally:  # finalize resources
            self.logger.info(f"end")
            pass

    @property
    def metric_keys(self) -> typing.List[str]:
        if self.task in ["sst2", "mnli", "mnli_mismatched", "mnli_matched", "qnli", "rte", "wnli", "hans"]:
            return ["accuracy"]
        elif self.task in ["mrpc", "qqp"]:
            return ["accuracy", "f1"]
        elif self.task in ["stsb"]:
            return ["pearson", "spearmanr"]
        elif self.task in ["cola"]:
            return ["matthews_correlation"]
        else:
            raise ValueError(f"Unknown glue task: {self.task}")

    def train(self,
              model: torch.nn.Module, train_dataset: torch.utils.data.Dataset, eval_dataset: torch.utils.data.Dataset,
              *args,
              nstep: int = 10, nouter: int = -1, ninner: int = -1,
              lr: float = 1e-3,  batch_size: int = 32,
              eval_interval: int = 5,
              optimizer: typing.Dict = {}, scheduler: typing.Dict = {},
              t_initial: int = None,
              **kwargs):
        model.to(self.device)

        self.logger.info(
            f"t_initial: {t_initial if isinstance(t_initial, int) else nstep}")
        optimizer, scheduler = get_optimizer(model, self.rank, self.graph,
                                             t_initial if isinstance(
                                                 t_initial, int) else nstep,
                                             nouter, ninner, lr,
                                             optimizer_config=optimizer,
                                             scheduler_config=scheduler)
        if ninner <= 0:
            train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size,
                                                       shuffle=True)
        else:
            from .datasets.sampler import InnerLoopSampler
            sampler = InnerLoopSampler(train_dataset, batch_size=batch_size, inner_loop=ninner,
                                       shuffle=True, seed=self.seed)
            train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size,
                                                       sampler=sampler)
            self.logger.info(f"use data sampler: {sampler}")
        eval_loader = torch.utils.data.DataLoader(eval_dataset, batch_size=batch_size,
                                                  shuffle=False)
        self.logger.info(f"optimizer: {optimizer}, scheduler: {scheduler}")

        with open(path.join(self.outdir, f"train.{self.name}.csv"), "wt") as trainf, \
                open(path.join(self.outdir, f"eval.{self.name}.csv"), "wt") as evalf:
            # write header
            trainf.write(",".join(["step", "iter", "loss", *self.metric_keys,
                                   "lr", "train_proc", "comm_proc"])+"\n")
            evalf.write(",".join(["step", "iter", "loss",
                                  *self.metric_keys])+"\n")

            train_proc = 0.0
            for istep in range(1, nstep+1):
                model.train()
                t = datetime.now()
                loss, niter, metrics = self.train_iter(nouter, ninner, model, train_loader, optimizer,
                                                       *args, **kwargs)
                train_proc += (datetime.now() - t).total_seconds()
                trainf.write(
                    ",".join(map(str, [istep, niter, loss, *metrics, ]
                                 + scheduler._get_lr(istep-1) if hasattr(scheduler, "_get_lr") else [np.nan]
                                 + [train_proc, getattr(optimizer, "communication_proc", np.nan)]))+"\n")
                self.logger.debug("[%4d/%4d] train loss=%.3f, " + ", ".join([f"{k}=%.3f" for k in self.metric_keys]),
                                  istep, nstep, loss, *metrics)

                if istep % eval_interval == 0:
                    with torch.no_grad():
                        model.eval()
                        send_proc = datetime.now()
                        send_model_parameter(self.eval_node_ranks, model)
                        send_proc = (datetime.now() -
                                     send_proc).total_seconds()
                        self.logger.info(
                            f"send model prameter to eval node: ranks={self.eval_node_ranks}, proc={send_proc}")

                        loss, niter, metrics = self.train_iter(0, 0, model, eval_loader, optimizer,
                                                               *args, **kwargs)
                        evalf.write(
                            ",".join(map(str, [istep, niter, loss, *metrics]))+"\n")
                        self.logger.info("[%4d/%4d] eval loss=%.3f, " + ", ".join([f"{k}=%.3f" for k in self.metric_keys]),
                                         istep, nstep, loss, *metrics)
                    trainf.flush()
                    evalf.flush()
                if scheduler is not None:
                    scheduler.step(istep)

        torch.save(model.state_dict(),
                   path.join(self.outdir, f"{self.name}.pth"))

    def train_iter(self, nouter: int, ninner: int, model: torch.nn.Module, loader: torch.utils.data.DataLoader,
                   optimizer: torch.optim.Optimizer, force_merge_parameter_interval: int = -1,) -> typing.Tuple[float, int, float]:
        nouter = nouter if nouter > 0 else 1

        iter = 0
        loss, metrics, ndata = 0.0, 0.0, 0
        predictions, references = [], []
        for iouter in range(nouter):
            for inputs in loader:
                iter += 1

                inputs = {k: v.to(self.device) if isinstance(v, torch.Tensor) else v
                          for k, v in inputs.items()}
                outputs = model(**inputs)
                l = outputs["loss"]

                if l.requires_grad:  # update model
                    assert not (torch.isnan(l) or torch.isinf(l)), \
                        f"loss is NaN or Inf: {l}"
                    optimizer.zero_grad()
                    l.backward()
                    optimizer.step()

                    com_count = getattr(optimizer, "comm_cnt", -1)
                    if self.graph.n_nodes > 1 and com_count > 0 and force_merge_parameter_interval > 0 and com_count % force_merge_parameter_interval == 0:
                        merge_model_parameter(self.rank, list(
                            range(self.graph.n_nodes)), model)
                        self.logger.info(
                            f"force merge model parameter: {com_count}")

                loss += l.detach().float().cpu().item()
                references.extend(inputs["labels"].detach().cpu().numpy())
                logits = outputs["logits"]
                if self.task != "stsb":
                    logits = logits.argmax(dim=-1)
                predictions.extend(logits.detach().cpu().numpy())
        loss /= iter
        metrics = self.evaluator.compute(predictions=predictions,
                                         references=references)
        return loss, iter, [metrics[k] for k in self.metric_keys]


class GlueNodeEvaluator(GlueNodeTrainer):

    def train(self,
              model: torch.nn.Module, train_dataset: torch.utils.data.Dataset, eval_dataset: torch.utils.data.Dataset,
              *args,
              nstep: int = 10, nouter: int = -1, ninner: int = -1,
              lr: float = 1e-3,  batch_size: int = 32,
              eval_interval: int = 5,
              optimizer: typing.Dict = {}, scheduler: typing.Dict = {},
              t_initial: int = None,
              **kwargs):
        model.to(self.device)

        train_node_ranks = list(range(self.graph.n_nodes))
        eval_loader = torch.utils.data.DataLoader(eval_dataset, batch_size=batch_size,
                                                  shuffle=False)

        with open(path.join(self.outdir, f"eval.{self.name}.csv"), "wt") as evalf:
            # write header
            evalf.write(
                ",".join(["step", "iter", "loss", *self.metric_keys])+"\n")

            for istep in range(1, nstep+1):

                if istep % eval_interval == 0:
                    with torch.no_grad():
                        model.eval()
                        recv_proc = datetime.now()
                        recv_model_parameter(train_node_ranks, model)
                        recv_proc = (datetime.now() -
                                     recv_proc).total_seconds()
                        self.logger.info(
                            f"recv model prameter from train node: ranks={train_node_ranks}, proc={recv_proc}")

                        loss, niter, metrics = self.train_iter(0, 0, model, eval_loader, optimizer,
                                                               *args, **kwargs)
                        evalf.write(
                            ",".join(map(str, [istep, niter, loss, *metrics]))+"\n")
                        self.logger.info("[%4d/%4d] eval loss=%.3f, " + ", ".join([f"{k}=%.3f" for k in self.metric_keys]),
                                         istep, nstep, loss, *metrics)
                    evalf.flush()

        torch.save(model.state_dict(),
                   path.join(self.outdir, f"{self.name}.pth"))
