#! -*- coding: utf-8
import typing

import numpy as np
import torch
import torch.distributed as dist

from .anode_trainer import AClientNodeTrainer, AServerNodeTrainer
from .datasets.sampler import InnerLoopSampler

__all__ = ["ClassificationClientNodeTrainer",
           "ClassificationServerNodeTrainer"]


def calc_grad_norm(model: torch.nn.Module, optimizer: torch.optim.Optimizer, loader: torch.utils.data.DataLoader,
                   criterion: torch.nn.Module,
                   device: str = "cpu", dtype=torch.float32):
    dataset, batch_size = loader.dataset, loader.batch_size
    nsample = getattr(optimizer, "nsample", None)

    if isinstance(nsample, int) and nsample < len(dataset):
        dataset = torch.utils.data.Subset(dataset,
                                          list(range(nsample)))

    loader = torch.utils.data.DataLoader(
        dataset, batch_size=batch_size)

    with torch.enable_grad():
        params = [p for p in model.parameters()
                  if p.requires_grad]
        grads = {p: 0.0 for p in params}
        for x, y in loader:
            x = x.to(device=device, dtype=dtype)
            # こちらのdtypeは変えてはいけない
            y = y.to(device=device)
            o: torch.Tensor = model(x)
            l: torch.Tensor = criterion(o, y)
            for p, g in zip(params, torch.autograd.grad(l, params)):
                grads[p] += g * len(x)

    with torch.no_grad():
        norms = []
        for group in optimizer.param_groups:
            grad_norms = [(grads[p]/len(dataset)).norm()
                          for p in group["params"] if p in grads]
            if len(grad_norms) > 0:
                norms.append(torch.tensor(
                    grad_norms).norm().detach().cpu().float())
            else:
                norms.append(torch.tensor(0.0).cpu().float())

    return torch.tensor(norms)


def calc_loss(model: torch.nn.Module, dataset: torch.utils.data.Dataset,
              x: torch.Tensor, y: torch.Tensor,
              criterion: torch.nn.Module,
              loss: torch.Tensor,
              device: str = "cpu", dtype=torch.float32,
              batch_size: int = 64, nsample: int = None, recalc: bool = True,
              seed: int = 11):
    if not recalc:
        return loss.detach().cpu()

    with torch.no_grad():
        if nsample is None:
            x = x.to(device=device, dtype=dtype)
            y = y.to(device=device)  # こちらのdtypeは変えてはいけない
            o: torch.Tensor = model(x)
            l: torch.Tensor = criterion(o, y)
        else:
            nsample = nsample if nsample > 0 else len(dataset)
            sampler = InnerLoopSampler(dataset, batch_size=batch_size, inner_loop=nsample,
                                       shuffle=True, seed=seed)
            loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size,
                                                 sampler=sampler)
            losses = []
            for x, y in loader:
                x = x.to(device=device, dtype=dtype)
                y = y.to(device=device)  # こちらのdtypeは変えてはいけない
                o: torch.Tensor = model(x)
                l: torch.Tensor = criterion(o, y)
                losses.append(l)
            l = torch.stack(losses, dim=0).mean(dim=0)

    return l.detach().cpu()


class AClassificationNodeTrainer(object):
    def metric_keys(self) -> typing.List:
        return ["accuracy"]


class ClassificationServerNodeTrainer(AServerNodeTrainer, AClassificationNodeTrainer):
    # コンストラクタは完全同一のため省略

    def train_iter(self, nouter: int, ninner: int, model: torch.nn.Module, loader: torch.utils.data.DataLoader,
                   optimizer: torch.optim.Optimizer, clip_norm: float = None) -> typing.Tuple[float, int, float]:
        is_grad_enabled = torch.is_grad_enabled()
        # inner loopを使用しない場合, loaderを一巡する
        nouter = nouter if nouter > 0 else 1

        criterion = torch.nn.CrossEntropyLoss()

        iter = 0
        loss, metrics, ndata = 0.0, 0.0, 0
        # while True:
        for iouter in range(nouter):
            if is_grad_enabled:
                # 何も処理しない
                for _ in range(ninner):
                    iter += 1

                    def server_loss_callback():  # FedProxDoL内でlossを使用する際のcallback
                        with torch.no_grad():
                            loss = 0.0
                            for x, y in loader:
                                x = x.to(device=self.device, dtype=self.dtype)
                                # こちらのdtypeは変えてはいけない
                                y = y.to(device=self.device)
                                o: torch.Tensor = model(x)
                                l: torch.Tensor = criterion(o, y)
                                loss += l.detach().cpu()
                            loss = loss / max(len(loader), 1)
                        return loss if isinstance(loss, torch.Tensor) else torch.tensor(loss)
                    # DataLoaderにInnerLoopSamplerが設定されていることが前提
                    # モデル内でパラメータ交換、更新処理を完結させる
                    # closureはパラメータ交換時のみ呼び出される。
                    optimizer.step(closure=server_loss_callback,
                                   calc_grad_norm_fn=lambda: calc_grad_norm(model, optimizer, loader, criterion,
                                                                            device=self.device, dtype=self.dtype),
                                   )
            else:
                for x, y in loader:
                    iter += 1
                    # self.logger.critical(f"{iouter}, {nouter}, {ninner}, {len(loader)}, {iter}")

                    x = x.to(device=self.device, dtype=self.dtype)
                    y = y.to(device=self.device)  # こちらのdtypeは変えてはいけない
                    o: torch.Tensor = model(x)
                    l: torch.Tensor = criterion(o, y)
                    # is_grad_enabledがFalseの場合、requires_gradは必ずFalseになる。
                    # if l.requires_grad:  # update model
                    #     optimizer.zero_grad()
                    #     l.backward()
                    #     if isinstance(clip_norm, float):
                    #         torch.nn.utils.clip_grad_norm_(model.parameters(),
                    #                                     clip_norm)
                    #     optimizer.step()

                    loss += l.detach().float().cpu().item()
                    metrics += (o.argmax(dim=-1).detach().cpu().numpy()
                                == y.detach().cpu().numpy()).sum()
                    ndata += len(x)

        loss /= max(iter, 1)
        metrics /= max(ndata, 1)  # = acc
        return loss, iter, metrics


class ClassificationClientNodeTrainer(AClientNodeTrainer, AClassificationNodeTrainer):

    def train_iter(self, nouter: int, ninner: int, model: torch.nn.Module, loader: torch.utils.data.DataLoader,
                   optimizer: torch.optim.Optimizer, clip_norm: float = None) -> typing.Tuple[float, int, float]:
        # inner loopを使用しない場合, loaderを一巡する
        nouter = nouter if nouter > 0 else 1

        criterion = torch.nn.CrossEntropyLoss()

        iter = 0
        loss, metrics, ndata = 0.0, 0.0, 0
        # while True:
        for iouter in range(nouter):
            for x, y in loader:
                iter += 1
                # self.logger.critical(f"{iouter}, {nouter}, {ninner}, {len(loader)}, {iter}")

                x = x.to(device=self.device, dtype=self.dtype)
                y = y.to(device=self.device)  # こちらのdtypeは変えてはいけない
                o: torch.Tensor = model(x)
                l: torch.Tensor = criterion(o, y)
                if l.requires_grad:  # update model
                    optimizer.zero_grad()
                    l.backward()
                    if isinstance(clip_norm, float):
                        torch.nn.utils.clip_grad_norm_(model.parameters(),
                                                       clip_norm)

                    optimizer.step(
                        # closure=lambda: l.detach().cpu(),
                        closure=lambda recalc=False, nsample=None: calc_loss(model, loader.dataset, x, y, criterion, l,
                                                                             device=self.device, dtype=self.dtype,
                                                                             batch_size=loader.batch_size, nsample=nsample, recalc=recalc,
                                                                             seed=self.seed),
                        calc_grad_norm_fn=lambda: calc_grad_norm(model, optimizer, loader, criterion,
                                                                 device=self.device, dtype=self.dtype),
                    )

                loss += l.detach().float().cpu().item()
                metrics += (o.argmax(dim=-1).detach().cpu().numpy()
                            == y.detach().cpu().numpy()).sum()
                ndata += len(x)
                self.logger.log(5, "[%d.%d] loss=%f, acc=%f, N=%d",
                                iouter, iter, loss/iter, metrics/ndata, ndata)
                assert not np.isnan(loss)

        loss /= iter
        metrics /= ndata  # = acc
        return loss, iter, metrics

    @torch.no_grad()
    def calc_grad_diff_norm(self, istep: int, model: torch.nn.Module, loader: torch.utils.data.DataLoader, *args, **kwargs):
        dataset, batch_size = loader.dataset, loader.batch_size
        loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size)
        criterion = torch.nn.CrossEntropyLoss()
        with torch.enable_grad():
            params = [p for p in model.parameters()
                      if p.requires_grad]
            grads = {p: torch.zeros_like(p) for p in params}
            for x, y in loader:
                x = x.to(device=self.device, dtype=self.dtype)
                # こちらのdtypeは変えてはいけない
                y = y.to(device=self.device)
                o: torch.Tensor = model(x)
                l: torch.Tensor = criterion(o, y)
                for p, g in zip(params, torch.autograd.grad(l, params)):
                    grads[p] += g * len(x)

            grads = torch.concat([p.flatten().detach().cpu()
                                 for p in grads.values()])
            grads = grads / len(dataset)  # データ数で平均する

        tasks = [dist.P2POp(dist.isend, grads,
                            self.server_node_rank, tag=12345)]
        for task in dist.batch_isend_irecv(tasks):
            task.wait()
