# --------------------------------------------------------
# Copyright (C) 2020 NVIDIA Corporation. All rights reserved.
# Nvidia Source Code License-NC
# --------------------------------------------------------
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
import torch

from .lr_scheduler import WarmupMultiStepLR

def make_optimizer(cfg, model):
    params = []
    for key, value in model.named_parameters():
        if not value.requires_grad:
            continue
        lr = cfg.SOLVER.BASE_LR
        weight_decay = cfg.SOLVER.WEIGHT_DECAY
        if "bias" in key:
            lr = cfg.SOLVER.BASE_LR * cfg.SOLVER.BIAS_LR_FACTOR
            weight_decay = cfg.SOLVER.WEIGHT_DECAY_BIAS
        params += [{"params": [value], "lr": lr, "weight_decay": weight_decay}]

# Whether to changing the optimizer from SGD into AdamW? All lr and momentum should be aligned with WSOD0 vanilla
    # optimizer = torch.optim.SGD(params, lr, momentum=cfg.SOLVER.MOMENTUM)
    optimizer = torch.optim.Adam(params, lr, weight_decay=weight_decay)

    return optimizer

def make_cdb_optimizer(cfg, model):
    params = []
    for key, value in model.named_parameters():
        if not value.requires_grad:
            continue
        lr = cfg.SOLVER_CDB.BASE_LR
        weight_decay = cfg.SOLVER_CDB.WEIGHT_DECAY
        if "bias" in key:
            lr = cfg.SOLVER_CDB.BASE_LR * cfg.SOLVER_CDB.BIAS_LR_FACTOR
            weight_decay = cfg.SOLVER_CDB.WEIGHT_DECAY_BIAS
        params += [{"params": [value], "lr": lr, "weight_decay": weight_decay}]

    # optimizer = torch.optim.SGD(params, lr, momentum=cfg.SOLVER_CDB.MOMENTUM)
    optimizer = torch.optim.Adam(params, lr, weight_decay=weight_decay) # Currently using: Adam
    if isinstance(optimizer, torch.optim.Adam):
        print('using ADAM optimizer.')

    return optimizer

def make_lr_scheduler(cfg, optimizer):
    return WarmupMultiStepLR(
        optimizer,
        cfg.SOLVER.STEPS,
        cfg.SOLVER.GAMMA,
        warmup_factor=cfg.SOLVER.WARMUP_FACTOR,
        warmup_iters=cfg.SOLVER.WARMUP_ITERS,
        warmup_method=cfg.SOLVER.WARMUP_METHOD,
    )

def make_lr_cdb_scheduler(cfg, optimizer):
    return WarmupMultiStepLR(
        optimizer,
        cfg.SOLVER_CDB.STEPS,
        cfg.SOLVER_CDB.GAMMA,
        warmup_factor=cfg.SOLVER_CDB.WARMUP_FACTOR,
        warmup_iters=cfg.SOLVER_CDB.WARMUP_ITERS,
        warmup_method=cfg.SOLVER_CDB.WARMUP_METHOD,
    )
