#! -*- coding: utf-8
import os.path as path
import typing

import evaluate
import numpy as np
import torch
import torch.distributed as dist

from .anode_trainer import AClientNodeTrainer, AServerNodeTrainer
from .datasets.sampler import InnerLoopSampler

__all__ = ["GlueServerNodeTrainer", "GlueClientNodeTrainer"]


def calc_grad_norm(model: torch.nn.Module, optimizer: torch.optim.Optimizer, loader: torch.utils.data.DataLoader,
                   device: str = "cpu", dtype=torch.float32):
    dataset, batch_size = loader.dataset, loader.batch_size
    nsample = getattr(optimizer, "nsample", None)

    if isinstance(nsample, int) and nsample < len(dataset):
        dataset = torch.utils.data.Subset(dataset,
                                          list(range(nsample)))

    loader = torch.utils.data.DataLoader(
        dataset, batch_size=batch_size)

    with torch.enable_grad():
        params = [p for p in model.parameters()
                  if p.requires_grad]
        grads = {p: 0.0 for p in params}

        for inputs in loader:
            outputs = model(**{k: v.to(device) if isinstance(v, torch.Tensor) else v
                               for k, v in inputs.items()})
            loss = outputs["loss"]
            for p, g in zip(params, torch.autograd.grad(loss, params)):
                grads[p] += g * batch_size

    with torch.no_grad():
        norms = []
        for group in optimizer.param_groups:
            grad_norms = [(grads[p]/len(loader)).norm()
                          for p in group["params"] if p in grads]
            if len(grad_norms) > 0:
                norms.append(torch.tensor(
                    grad_norms).norm().detach().cpu().float())
            else:
                norms.append(torch.tensor(0.0).cpu().float())

    return torch.tensor(norms)


def calc_loss(model: torch.nn.Module, dataset: torch.utils.data.Dataset,
              loss: torch.Tensor, inputs: typing.Dict[typing.Any, typing.Union[torch.Tensor, typing.Any]],
              device: str = "cpu", dtype=torch.float32,
              batch_size: int = 64, nsample: int = None, recalc: bool = True,
              seed: int = 11):
    if not recalc:
        return loss

    with torch.no_grad():
        if nsample is None:
            inputs = {k: v.to(device) if isinstance(v, torch.Tensor) else v
                      for k, v in inputs.items()}
            outputs = model(**inputs)
            loss = outputs["loss"]
        else:
            nsample = nsample if nsample > 0 else len(dataset)
            sampler = InnerLoopSampler(dataset, batch_size=batch_size, inner_loop=nsample,
                                       shuffle=True, seed=seed)
            loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size,
                                                 sampler=sampler)
            losses = []
            for inputs in loader:
                inputs = {k: v.to(device) if isinstance(v, torch.Tensor) else v
                          for k, v in inputs.items()}
                outputs = model(**inputs)
                losses.append(outputs["loss"])
            loss = torch.stack(losses, dim=0).mean(dim=0)

    return loss.detach().cpu()


class GlueServerNodeTrainer(AServerNodeTrainer):
    def __init__(self, rank: int, world_size: int, outdir: str,
                 *args,
                 seed: int = 11, model_seed: int = 17,
                 model_config: typing.Dict = {}, dataset_config: typing.Dict = {}, train_config: typing.Dict = {},
                 name: str = "", device: str = "cpu", dtype: typing.Union[str, torch.dtype] = "float",
                 nretry_http: int = 10, nretry_http_wait: float = 120.0,
                 **kwargs):
        super().__init__(rank, world_size, outdir,
                         *args,
                         seed=seed, model_seed=model_seed,
                         model_config=model_config, dataset_config=dataset_config, train_config=train_config,
                         name=name, device=device, dtype=dtype,
                         nretry_http=nretry_http, nretry_http_wait=nretry_http_wait,
                         **kwargs)

        self.dtype = None
        # GLUEデータセットのタスク（sst2など）
        self.task = self.dataset_config.get("task", "sst2")
        cache_dir = self.dataset_config.get("datadir", None)
        self.evaluator = evaluate.load("glue", self.task,
                                       cache_dir=path.join(cache_dir, "evaluate", "glue", self.task))

    @property
    def metric_keys(self) -> typing.List[str]:
        if self.task in ["sst2", "mnli", "mnli_mismatched", "mnli_matched", "qnli", "rte", "wnli", "hans"]:
            return ["accuracy"]
        elif self.task in ["mrpc", "qqp"]:
            return ["accuracy", "f1"]
        elif self.task in ["stsb"]:
            return ["pearson", "spearmanr"]
        elif self.task in ["cola"]:
            return ["matthews_correlation"]
        else:
            raise ValueError(f"Unknown glue task: {self.task}")

    def train_iter(self, nouter: int, ninner: int, model: torch.nn.Module, loader: torch.utils.data.DataLoader,
                   optimizer: torch.optim.Optimizer, clip_norm: float = None) -> typing.Tuple[float, int, float]:
        is_grad_enabled = torch.is_grad_enabled()
        # inner loopを使用しない場合, loaderを一巡する
        nouter = nouter if nouter > 0 else 1

        iter = 0
        loss, metrics = 0.0, 0.0
        predictions, references = [], []
        # while True:
        for iouter in range(nouter):
            if is_grad_enabled:
                # 何も処理しない
                for _ in range(ninner):
                    iter += 1

                    def server_loss_callback():  # FedProxDoL内でlossを使用する際のcallback
                        with torch.no_grad():
                            loss = 0.0
                            for inputs in loader:
                                outputs = model(**{k: v.to(self.device) if isinstance(v, torch.Tensor) else v
                                                   for k, v in inputs.items()})
                                loss += outputs["loss"]
                            loss = loss / max(len(loader), 1)
                        return loss if isinstance(loss, torch.Tensor) else torch.tensor(loss)
                    # DataLoaderにInnerLoopSamplerが設定されていることが前提
                    # モデル内でパラメータ交換、更新処理を完結させる
                    # closureはパラメータ交換時のみ呼び出される。
                    optimizer.step(closure=server_loss_callback,
                                   calc_grad_norm_fn=lambda: calc_grad_norm(model, optimizer, loader,
                                                                            device=self.device, dtype=self.dtype),
                                   )
            else:
                for inputs in loader:
                    iter += 1
                    inputs = {k: v.to(self.device) if isinstance(v, torch.Tensor) else v
                              for k, v in inputs.items()}

                    outputs = model(**inputs)
                    l = outputs["loss"]
                    # On is_grad_enabled is False, in no_grad scope.
                    # Can't create gradient.
                    loss += l.detach().float().cpu().item()
                    references.extend(inputs["labels"].detach().cpu().numpy())
                    predictions.extend((outputs["logits"].argmax(dim=-1) if self.task != "stsb"
                                        else outputs["logits"]).detach().cpu().numpy())

        loss /= max(iter, 1)
        if not is_grad_enabled:
            metrics = self.evaluator.compute(predictions=predictions,
                                             references=references)
            # cache_dir = self.dataset_config.get("datadir", None)
            # evaluator = evaluate.load("glue", self.task,
            #                         cache_dir=path.join(cache_dir, "evaluate", "glue", self.task))
            # metrics = evaluator.compute(predictions=predictions,
            #                             references=references)
            self.logger.log(5, "references=%s", references)
            self.logger.log(5, "predictions=%s", predictions)
            return loss, iter, [metrics[k] for k in self.metric_keys]
        return loss, iter, [np.nan for k in self.metric_keys]


class GlueClientNodeTrainer(AClientNodeTrainer):
    def __init__(self, rank: int, world_size: int, outdir: str,
                 *args,
                 seed: int = 11, model_seed: int = 17,
                 model_config: typing.Dict = {}, dataset_config: typing.Dict = {}, train_config: typing.Dict = {},
                 name: str = "", device: str = "cpu", dtype: typing.Union[str, torch.dtype] = "float",
                 nretry_http: int = 10, nretry_http_wait: float = 120.0,
                 **kwargs):
        super().__init__(rank, world_size, outdir,
                         *args,
                         seed=seed, model_seed=model_seed,
                         model_config=model_config, dataset_config=dataset_config, train_config=train_config,
                         name=name, device=device, dtype=dtype,
                         nretry_http=nretry_http, nretry_http_wait=nretry_http_wait,
                         **kwargs)

        self.dtype = None
        # GLUEデータセットのタスク（sst2など）
        self.task = self.dataset_config.get("task", "sst2")
        cache_dir = self.dataset_config.get("datadir", None)
        self.evaluator = evaluate.load("glue", self.task,
                                       cache_dir=path.join(cache_dir, "evaluate", "glue", self.task))

    @property
    def metric_keys(self) -> typing.List[str]:
        if self.task in ["sst2", "mnli", "mnli_mismatched", "mnli_matched", "qnli", "rte", "wnli", "hans"]:
            return ["accuracy"]
        elif self.task in ["mrpc", "qqp"]:
            return ["accuracy", "f1"]
        elif self.task in ["stsb"]:
            return ["pearson", "spearmanr"]
        elif self.task in ["cola"]:
            return ["matthews_correlation"]
        else:
            raise ValueError(f"Unknown glue task: {self.task}")

    def train_iter(self, nouter: int, ninner: int, model: torch.nn.Module, loader: torch.utils.data.DataLoader,
                   optimizer: torch.optim.Optimizer, clip_norm: float = None) -> typing.Tuple[float, int, float]:
        # inner loopを使用しない場合, loaderを一巡する
        nouter = nouter if nouter > 0 else 1
        iter = 0
        loss, metrics = 0.0, 0.0
        predictions, references = [], []
        # while True:
        for iouter in range(nouter):
            for inputs in loader:
                iter += 1
                inputs = {k: v.to(self.device) if isinstance(v, torch.Tensor) else v
                          for k, v in inputs.items()}

                outputs = model(**inputs)
                l = outputs["loss"]
                if l.requires_grad:  # update model
                    optimizer.zero_grad()
                    l.backward()
                    if isinstance(clip_norm, float):
                        torch.nn.utils.clip_grad_norm_(model.parameters(),
                                                       clip_norm)
                    optimizer.step(
                        # closure=lambda: l.detach().cpu(),
                        closure=lambda recalc=False, nsample=None: calc_loss(model, loader.dataset, l, inputs,
                                                                             device=self.device, dtype=self.dtype,
                                                                             batch_size=loader.batch_size, nsample=nsample,
                                                                             recalc=recalc, seed=self.seed),
                        calc_grad_norm_fn=lambda: calc_grad_norm(model, optimizer, loader,
                                                                 device=self.device, dtype=self.dtype),
                    )

                loss += l.detach().float().cpu().item()
                references.extend(inputs["labels"].detach().cpu().numpy())
                predictions.extend((outputs["logits"].argmax(dim=-1) if self.task != "stsb"
                                    else outputs["logits"]).detach().cpu().numpy())
                self.logger.log(5, "[%d.%d] loss=%f", iouter, iter, loss/iter)
                assert not np.isnan(loss)

        loss /= iter
        metrics = self.evaluator.compute(predictions=predictions,
                                         references=references)
        # cache_dir = self.dataset_config.get("datadir", None)
        # evaluator = evaluate.load("glue", self.task,
        #                           cache_dir=path.join(cache_dir, "evaluate", "glue", self.task))
        # metrics = evaluator.compute(predictions=predictions,
        #                             references=references)
        return loss, iter, [metrics[k] for k in self.metric_keys]

    @torch.no_grad()
    def calc_grad_diff_norm(self, istep: int, model: torch.nn.Module, loader: torch.utils.data.DataLoader, *args, **kwargs):
        dataset, batch_size = loader.dataset, loader.batch_size
        loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size)
        with torch.enable_grad():
            params = [p for p in model.parameters()
                      if p.requires_grad]
            grads = {p: torch.zeros_like(p) for p in params}

            for inputs in loader:
                tmp, ndata = {}, batch_size
                for k, v in inputs.items():
                    if isinstance(v, torch.Tensor):
                        ndata = len(v)  # get batch dim size
                        v = v.to(self.device)  # trasnport device
                    tmp[k] = v
                outputs = model(**tmp)
                loss = outputs["loss"]
                for p, g in zip(params, torch.autograd.grad(loss, params)):
                    grads[p] += g * ndata

            grads = torch.concat([p.flatten().detach().cpu()
                                 for p in grads.values()])
            grads = grads / len(dataset)  # データ数で平均する

        tasks = [dist.P2POp(dist.isend, grads,
                            self.server_node_rank, tag=12345)]
        for task in dist.batch_isend_irecv(tasks):
            task.wait()
