#! -*- coding: utf-8
import typing
from logging import getLogger

import torch
import numpy as np
from timm.scheduler.scheduler import Scheduler

__all__ = ["get_group_value",
           "clip_param_norm", "calc_param_norm", "calc_grad_norm",
           "get_server_optimizer", "get_client_optimizer"]


@torch.no_grad
def get_group_value(param_groups, key) -> float:
    v = np.nan
    for group in param_groups:
        if not key in group:
            continue
        v = group[key]
        break
    if isinstance(v, torch.Tensor):
        v = v.detach().cpu().item()
    return v


@torch.no_grad
def clip_param_norm(group, max_norm: float):
    if not isinstance(max_norm, (float, torch.Tensor)):
        return

    norm = torch.tensor([p.norm()
                        for p in group["params"] if p.requires_grad]).norm()
    if norm <= max_norm:
        return
    for p in group["params"]:
        p.data.div_(norm/max_norm)


@torch.no_grad
def calc_state_norm(group, state, key: str = None):
    return torch.stack([state[p][key].norm() for p in group["params"]
                        if p.requires_grad and key in state[p]]).norm()


@torch.no_grad
def calc_param_norm(group):
    return torch.stack([p.norm() for p in group["params"]
                        if p.requires_grad]).norm()


@torch.no_grad
def calc_grad_norm(group):
    norms = [p.grad.norm() for p in group["params"]
             if p.requires_grad and p.grad is not None]
    if len(norms) == 0:
        return torch.tensor(0.0)
    return torch.stack(norms).norm()


_LOGGER = getLogger("trainers.optimizers.utils")


def _get_scheduler(optimizer: torch.optim.Optimizer, t_initial: int, name: str = "CosineLRScheduler",
                   args: typing.List = [], kwargs: typing.Dict = {}) -> Scheduler:
    if name == "CosineLRScheduler":
        from timm.scheduler import CosineLRScheduler
        return CosineLRScheduler(optimizer, t_initial, *args, **kwargs)
    elif name == "StepLRScheduler":
        from timm.scheduler import StepLRScheduler
        _ = kwargs.pop("lr_min", None)
        return StepLRScheduler(optimizer, t_initial, *args, **kwargs)
    elif name == "PlateauLRScheduler":
        from timm.scheduler import PlateauLRScheduler
        return PlateauLRScheduler(optimizer, *args, **kwargs)
    elif name == "" or name is None:
        return None
    else:
        raise ValueError(f"Unsupported LRScheduler: {name}")


def get_server_optimizer(model: torch.nn.Module, rank: int,
                         t_initial: int, nouter: int, ninner: int, lr: float,
                         optimizer_config: typing.Dict = {},
                         scheduler_config: typing.Dict = {},
                         client_node_ranks: typing.List[int] = [],
                         **kwargs) -> typing.Tuple[torch.optim.Optimizer, Scheduler]:
    assert isinstance(client_node_ranks, list) and len(client_node_ranks) > 0
    name = optimizer_config.pop("name", "AdamW")
    if hasattr(torch.optim, name):  # stand alone optimizer
        raise ValueError(f"Standalone mode!!")
    if name == "DoG":
        raise ValueError(f"Standalone mode!!")

    if name == "FedAvg":
        from .fedavg import FedAvgServer
        ctor = FedAvgServer
    elif name == "FedAvgMVDR":
        from .fedavg_mvdr import FedAvgMVDRServer
        ctor = FedAvgMVDRServer
    elif name == "FedAvgM":
        from .fedavgm import FedAvgMServer
        ctor = FedAvgMServer
    elif name == "FedProx":
        from .fedprox import FedProxServer
        ctor = FedProxServer
    elif name == "Scaffold":
        from .scaffold import ScaffoldServer
        ctor = ScaffoldServer
    elif name == "FedDyn":
        from .feddyn import FedDynServer
        ctor = FedDynServer
    elif name == "FedAdagrad":
        from .fedadagrad import FedAdagradServer
        ctor = FedAdagradServer
    elif name == "FedYogi":
        from .fedyogi import FedYogiServer
        ctor = FedYogiServer
    elif name == "FedAdam":
        from .fedadam import FedAdamServer
        ctor = FedAdamServer
    elif name == "FedProxDoLEarly":
        from .fedprox_dol_early import FedProxDoLEarlyServer
        ctor = FedProxDoLEarlyServer
    elif name == "FedProxDoL":
        from .fedprox_dol import FedProxDoLServer
        ctor = FedProxDoLServer
    elif name == "FedProxDoWLEarly":
        from .fedprox_dowl_early import FedProxDoWLEarlyServer
        ctor = FedProxDoWLEarlyServer
    elif name == "FedProxDoWL":
        from .fedprox_dowl import FedProxDoWLServer
        ctor = FedProxDoWLServer
    elif name == "FedSPS":
        from .fedsps import FedSPSServer
        ctor = FedSPSServer
    elif name == "FedDecSPS":
        from .feddecsps import FedDecSPSServer
        ctor = FedDecSPSServer
    else:
        raise ValueError(f"Unsupported optimizer: {name}")

    args = optimizer_config.pop("args", [])
    kwargs = optimizer_config.pop("kwargs", {})
    _LOGGER.info(
        f"server optimizer: {ctor} args={[rank, ninner]+args}, kwargs={dict([('lr', lr), ('client_node_ranks', client_node_ranks)]+list(kwargs.items()))}")
    # params, rank, local_step=ninnerは全optimzierで共通の位置引数とする
    # lr, client_node_ranksは全optimizerで共通のキーワード引数とする
    optim = ctor(model.parameters(), rank, ninner,
                 *args,
                 lr=lr, client_node_ranks=client_node_ranks,
                 **kwargs)

    # # set scheduler.
    # scheduler = _get_scheduler(optim, t_initial, **scheduler_config)
    scheduler = None
    return optim, scheduler


def get_client_optimizer(model: torch.nn.Module, rank: int,
                         t_initial: int, nouter: int, ninner: int, lr: float,
                         optimizer_config: typing.Dict = {},
                         scheduler_config: typing.Dict = {},
                         server_node_rank: typing.Optional[int] = None,
                         **kwargs) -> typing.Tuple[torch.optim.Optimizer, Scheduler]:
    name = optimizer_config.pop("name", "AdamW")
    if hasattr(torch.optim, name):  # stand alone optimizer
        ctor = getattr(torch.optim, name)
        args = optimizer_config.pop("args", [])
        kwargs = optimizer_config.pop("kwargs", {})
        _LOGGER.info(
            f"client optimizer: {ctor} args={[rank, ninner]+args}, kwargs={dict([('lr', lr), ('server_node_rank', server_node_rank)]+list(kwargs.items()))}")
        optim = ctor(model.parameters(), *args, lr=lr, **kwargs)
        scheduler = _get_scheduler(optim, t_initial, **scheduler_config)
        return optim, scheduler

    if name == "DoG":
        from .dog import DoG
        args = optimizer_config.pop("args", [])
        kwargs = optimizer_config.pop("kwargs", {})
        _LOGGER.info(
            f"client optimizer: {DoG} args={[rank, ninner]+args}, kwargs={dict([('lr', lr), ('server_node_rank', server_node_rank)]+list(kwargs.items()))}")
        optim = DoG(model.parameters(), *args, lr=lr, **kwargs)
        scheduler = _get_scheduler(optim, t_initial, **scheduler_config)
        return optim, scheduler

    if name == "FedAvg":
        from .fedavg import FedAvgClient
        ctor = FedAvgClient
    elif name == "FedAvgMVDR":
        from .fedavg_mvdr import FedAvgMVDRClient
        ctor = FedAvgMVDRClient
    elif name == "FedAvgM":
        from .fedavgm import FedAvgMClient
        ctor = FedAvgMClient
    elif name == "FedProx":
        from .fedprox import FedProxClient
        ctor = FedProxClient
    elif name == "Scaffold":
        from .scaffold import ScaffoldClient
        ctor = ScaffoldClient
    elif name == "FedDyn":
        from .feddyn import FedDynClient
        ctor = FedDynClient
    elif name == "FedAdagrad":
        from .fedadagrad import FedAdagradClient
        ctor = FedAdagradClient
    elif name == "FedYogi":
        from .fedyogi import FedYogiClient
        ctor = FedYogiClient
    elif name == "FedAdam":
        from .fedadam import FedAdamClient
        ctor = FedAdamClient
    elif name == "FedProxDoLEarly":
        from .fedprox_dol_early import FedProxDoLEarlyClient
        ctor = FedProxDoLEarlyClient
    elif name == "FedProxDoL":
        from .fedprox_dol import FedProxDoLClient
        ctor = FedProxDoLClient
    elif name == "FedProxDoWLEarly":
        from .fedprox_dowl_early import FedProxDoWLEarlyClient
        ctor = FedProxDoWLEarlyClient
    elif name == "FedProxDoWL":
        from .fedprox_dowl import FedProxDoWLClient
        ctor = FedProxDoWLClient
    elif name == "FedSPS":
        from .fedsps import FedSPSClient
        ctor = FedSPSClient
    elif name == "FedDecSPS":
        from .feddecsps import FedDecSPSClient
        ctor = FedDecSPSClient
    else:
        raise ValueError(f"Unsupported optimizer: {name}")

    args = optimizer_config.pop("args", [])
    kwargs = optimizer_config.pop("kwargs", {})
    _LOGGER.info(
        f"client optimizer: {ctor} args={[rank, ninner]+args}, kwargs={dict([('lr', lr), ('server_node_rank', server_node_rank)]+list(kwargs.items()))}")
    # params, rank, local_step=ninnerは全optimzierで共通の位置引数とする
    # lr, client_node_ranksは全optimizerで共通のキーワード引数とする
    optim = ctor(model.parameters(), rank, ninner,
                 *args,
                 lr=lr, server_node_rank=server_node_rank,
                 **kwargs)

    # set scheduler.
    scheduler = _get_scheduler(optim, t_initial, **scheduler_config)
    return optim, scheduler
