#! -*- coding: utf-8
import typing

import torch
from timm.scheduler.scheduler import Scheduler

from ..graphs.dynamic_graph import DynamicGraph

__all__ = ["get_optimizer"]


def _get_scheduler(optimizer: torch.optim.Optimizer, t_initial: int, name: str = "CosineLRScheduler",
                   args: typing.List = [], kwargs: typing.Dict = {}) -> Scheduler:
    if name == "CosineLRScheduler":
        from timm.scheduler import CosineLRScheduler
        return CosineLRScheduler(optimizer, t_initial, *args, **kwargs)
    elif name == "StepLRScheduler":
        from timm.scheduler import StepLRScheduler
        return StepLRScheduler(optimizer, *args, **kwargs)
    elif name == "PlateauLRScheduler":
        from timm.scheduler import PlateauLRScheduler
        return PlateauLRScheduler(optimizer, *args, **kwargs)
    elif name == "" or name is None:
        return None
    else:
        raise ValueError(f"Unsupported LRScheduler: {name}")


def get_optimizer(model: torch.nn.Module, rank: int, graph: DynamicGraph,
                  t_initial: int, nouter: int, ninner: int, lr: float,
                  optimizer_config: typing.Dict = {},
                  scheduler_config: typing.Dict = {},
                  **kwargs) -> typing.Tuple[torch.optim.Optimizer, Scheduler]:
    name = optimizer_config.pop("name", "AdamW")
    if hasattr(torch.optim, name):  # stand alone optimizer
        ctor = getattr(torch.optim, name)
        args = optimizer_config.pop("args", [])
        kwargs = optimizer_config.pop("kwargs", {})
        optim = ctor(model.parameters(), *args, lr=lr, **kwargs)
        scheduler = _get_scheduler(optim, t_initial, **scheduler_config)
        return optim, scheduler

    if name == "Scaffold":
        from .scaffold import Scaffold
        args = optimizer_config.pop("args", [])
        kwargs = optimizer_config.pop("kwargs", {})
        optim = Scaffold(model.parameters(), rank, graph, ninner,
                         *args, lr=lr, **kwargs)
        scheduler = _get_scheduler(optim, t_initial, **scheduler_config)
        return optim, scheduler

    if name == "Gossip":
        from .gossip import GossipOptimizer
        ctor = GossipOptimizer
    elif name == "DSGD":
        from .dsgd import DSGDOptimizer
        ctor = DSGDOptimizer
    elif name == "DAdam":
        ctor = torch.optim.Adam
    elif name == "DAdamW":
        ctor = torch.optim.AdamW
    args = optimizer_config.pop("args", [])
    kwargs = optimizer_config.pop("kwargs", {})
    optim = ctor(model.parameters(), *args, lr=lr, **kwargs)

    # set scheduler.
    scheduler = _get_scheduler(optim, t_initial, **scheduler_config)

    from .param_exchanger import MeanModelParameterExchanger
    optim = MeanModelParameterExchanger(model, optim, node_id=rank, graph=graph, local_step=ninner,
                                        tag_offset=1000)

    return optim, scheduler
