import torch

from .optimizer import SGD, Adam, AdamW
from .cmd_runner import CMDRunner
from torch.optim.lr_scheduler import MultiStepLR

OPTIMIZER_DICT = {
    'SGD':SGD,
    'Adam':Adam,
    'AdamW':AdamW
}

RUNNER_DICT = {
    'CMD':CMDRunner
}

SCHEDULER_DICT = {
    'MultiStepLR':MultiStepLR
}


def build_optimizers(models, cfg):
    """Build an optimizer accroding to the cfg.
    Args:
        model (nn.Module): The model to be optimized.
        cfg (dict): A dictionary contains args for building an optimizer.
    Returns:
        optim (torch.optim.Optimizer)
    """
    assert len(cfg.type) == len(models), f'The number of optimizer should equal the models, now {len(cfg.type)} optims, {len(models)} models.'
    # assert cfg['type'] in OPTIMIZER_DICT, 'Unsupported optimizer {}.'.format(cfg['type'])
    # optim_args = {key:cfg[key] for key in cfg if key not in ['type']}
    optims, schedulers = {}, {}

    for key in cfg.type:
        optim_type = cfg.type[key]
        optim_args = cfg.kwargs[key]
        model_params = models[key].parameters()
        optims[key] = OPTIMIZER_DICT[optim_type](model_params, **optim_args)
        if key in cfg.scheduler_type:
            scheduler_type = cfg.scheduler_type[key]
            scheduler_args = cfg.scheduler_kwargs[key]
            schedulers[key] = SCHEDULER_DICT[scheduler_type](optims[key], **scheduler_args)
    return optims, schedulers


def build_runner(cfg):
    """Build an optimizer accroding to the cfg.
    Args:
        cfg (dict): A dictionary contains args for building an optimizer.
    Returns:
        runner (runner.Runner)
    """
    assert cfg['type'] in  RUNNER_DICT, 'Unsupported optimizer {}.'.format(cfg['type'])
    runners = {}
    runner_args = {key:cfg[key] for key in cfg if key not in ['type']}
        
    return RUNNER_DICT[cfg['type']](**runner_args)