"""
optimizer utils
"""
import torch.optim as optim

def get_optimizer(parameters, opt_config):
    if opt_config.opt == "sgd":
        optimizer = optim.SGD(parameters, nesterov=True, weight_decay=opt_config.weight_decay, lr=opt_config.lr, momentum=opt_config.momentum)
    elif opt_config.opt == "adamw":
        _opt_config = {
            "lr": opt_config.lr,
            "weight_decay": opt_config.weight_decay,
            "betas": opt_config.betas,
            "eps": opt_config.eps,
        }
        optimizer = optim.AdamW(parameters, **_opt_config)
    elif opt_config.opt == "adamp":
        from adamp import AdamP
        _opt_config = {
            "lr": opt_config.lr,
            "weight_decay": opt_config.weight_decay,
            "betas": opt_config.betas,
            "eps": opt_config.eps,
        }
        optimizer = AdamP(parameters, **_opt_config)
    else:
        raise ValueError(f"Invalid optimizer: {opt_config.opt}")
    return optimizer

def get_trainable_params(model, train_config, verbose=False):
    """
    return trainable parameters for CLIP.

    Note: 
    - weight decay is applied to all parameters except for the bias and LayerNorm.weight
    - (experimental) lr for gau is multiplied by train_config['pde']['mul_lr'] 

    """
    no_decay = ["bias", "LayerNorm.weight", "ln_1", "ln_2", "ln_final"]
    parameters = []
    for name, param in model.named_parameters():
        # if "backbone" in name:
        #     continue
        if param.requires_grad:
            wd = 0.0 if any(nd in name for nd in no_decay) else train_config['optimizer']['weight_decay']

            if "gau" in name:
                lr = train_config['optimizer']['lr'] * train_config['pde']['mul_lr']
            else:
                lr = train_config['optimizer']['lr']

            d = {
                'params': param,
                'lr': lr,
                'weight_decay': wd
            }
            parameters.append(d)
            if verbose:
                print("----> trainable {}: lr={}, wd={}".format(name, lr, wd))
        else:
            print("--------> not trainable: ", name)

    return parameters