from functools import partial

import torch.nn as nn
import torch.optim as optim
import torch.optim.lr_scheduler as lr_sched

from .fastai_optim import OptimWrapper
from .learning_schedules_fastai import CosineWarmupLR, OneCycle


def build_optimizer(model, optim_cfg):
    if optim_cfg.OPTIMIZER == 'adam':
        optimizer = optim.Adam(model.parameters(), lr=optim_cfg.LR, weight_decay=optim_cfg.WEIGHT_DECAY)
    elif optim_cfg.OPTIMIZER == 'adamW':
        optimizer = optim.AdamW(model.parameters(), lr=optim_cfg.LR, weight_decay=optim_cfg.WEIGHT_DECAY)
    elif optim_cfg.OPTIMIZER == 'sgd':
        optimizer = optim.SGD(
            model.parameters(), lr=optim_cfg.LR, weight_decay=optim_cfg.WEIGHT_DECAY,
            momentum=optim_cfg.MOMENTUM
        )
    elif optim_cfg.OPTIMIZER == 'adam_onecycle':
        def children(m: nn.Module):
            return list(m.children())

        def num_children(m: nn.Module) -> int:
            return len(children(m))

        flatten_model = lambda m: sum(map(flatten_model, m.children()), []) if num_children(m) else [m]
        get_layer_groups = lambda m: [nn.Sequential(*flatten_model(m))]

        optimizer_func = partial(optim.Adam, betas=(0.9, 0.99))
        optimizer = OptimWrapper.create(
            optimizer_func, 3e-3, get_layer_groups(model), wd=optim_cfg.WEIGHT_DECAY, true_wd=True, bn_wd=True
        )
    elif optim_cfg.OPTIMIZER == 'adamW_onecycle':
        def children(m: nn.Module):
            return list(m.children())

        def num_children(m: nn.Module) -> int:
            return len(children(m))

        flatten_model = lambda m: sum(map(flatten_model, m.children()), []) if num_children(m) else [m]
        get_layer_groups = lambda m: [nn.Sequential(*flatten_model(m))]

        optimizer_func = partial(optim.AdamW, betas=(0.9, 0.99))
        optimizer = OptimWrapper.create(
            optimizer_func, optim_cfg.LR, get_layer_groups(model), wd=optim_cfg.WEIGHT_DECAY, true_wd=True, bn_wd=True
        )
    else:
        raise NotImplementedError

    return optimizer


def build_scheduler(optimizer, total_iters_each_epoch, total_epochs, last_epoch, optim_cfg):
    lr_warmup_scheduler = None
    total_steps = total_iters_each_epoch * total_epochs
    
    if optim_cfg.SCHEDULER == 'STEP':
        decay_steps = [x * total_iters_each_epoch for x in optim_cfg.DECAY_STEP_LIST]
        if last_epoch != -1:
            last_iter = max((last_epoch-1) * total_iters_each_epoch, 0)
        else:
            last_iter = -1
        lr_scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=decay_steps,
                                                      gamma=optim_cfg.LR_DECAY, last_epoch=last_iter)
    elif optim_cfg.SCHEDULER == 'OneCycle':
        lr_scheduler = OneCycle(
            optimizer, total_steps, optim_cfg.LR, list(optim_cfg.MOMS), optim_cfg.DIV_FACTOR, optim_cfg.PCT_START
        )
    else:
        decay_steps = [x * total_iters_each_epoch for x in optim_cfg.DECAY_STEP_LIST]
        def lr_lbmd(cur_epoch):
            cur_decay = 1
            for decay_step in decay_steps:
                if cur_epoch >= decay_step:
                    cur_decay = cur_decay * optim_cfg.LR_DECAY
            return max(cur_decay, optim_cfg.LR_CLIP / optim_cfg.LR)
        lr_scheduler = lr_sched.LambdaLR(optimizer, lr_lbmd, last_epoch=last_epoch)

    #if optim_cfg.LR_WARMUP:
    #    lr_warmup_scheduler = CosineWarmupLR(
    #        optimizer, T_max=optim_cfg.WARMUP_EPOCH * len(total_iters_each_epoch),
    #        eta_min=optim_cfg.LR / optim_cfg.DIV_FACTOR
    #    )

    return lr_scheduler, lr_warmup_scheduler
