from torch import optim as optim


def build_optimizer(config, trainable_list):
    """
    Build optimizer, set weight decay of normalization to 0 by default.
    """
    skip = {}
    skip_keywords = {}
    if hasattr(trainable_list[0], 'no_weight_decay'):
        skip = trainable_list[0].no_weight_decay()
    if hasattr(trainable_list[0], 'no_weight_decay_keywords'):
        skip_keywords = trainable_list[0].no_weight_decay_keywords()
    parameters = set_weight_decay(trainable_list[0], skip, skip_keywords)

    opt_lower = config.TRAIN.OPTIMIZER.NAME.lower()
    optimizer = None
    if opt_lower == 'sgd':
        optimizer = optim.SGD(parameters,momentum=config.TRAIN.OPTIMIZER.MOMENTUM, nesterov=True,
                              lr=config.TRAIN.BASE_LR, weight_decay=config.TRAIN.WEIGHT_DECAY)
    elif opt_lower == 'adamw':
        optimizer = optim.AdamW(parameters,eps=config.TRAIN.OPTIMIZER.EPS, betas=config.TRAIN.OPTIMIZER.BETAS,
                                lr=config.TRAIN.BASE_LR, weight_decay=config.TRAIN.WEIGHT_DECAY)

    return optimizer

def build_optimizer_shadow(config, trainable_list):
    """
    Build optimizer, set weight decay of normalization to 0 by default.
    """
    # skip = {}
    # skip_keywords = {}
    # if hasattr(trainable_list[0], 'no_weight_decay'):
    #     skip = trainable_list[0].no_weight_decay()
    # if hasattr(trainable_list[0], 'no_weight_decay_keywords'):
    #     skip_keywords = trainable_list[0].no_weight_decay_keywords()
    # parameters = set_weight_decay(trainable_list[0], skip, skip_keywords)
    trainable_list_cnn=trainable_list[1]
    trainable_list_inn = trainable_list[2]

    opt_lower = config.TRAIN.OPTIMIZER.NAME.lower()
    optimizer_shadow = None
    if opt_lower == 'sgd':
        optimizer_shadow_inn = optim.SGD(trainable_list_inn.parameters(),momentum=config.TRAIN.OPTIMIZER.MOMENTUM, nesterov=True,
                              lr=config.TRAIN.BASE_LR, weight_decay=config.TRAIN.WEIGHT_DECAY)
        optimizer_shadow_cnn = optim.SGD(trainable_list_cnn.parameters(),momentum=config.TRAIN.OPTIMIZER.MOMENTUM, nesterov=True,
                              lr=config.TRAIN.BASE_LR, weight_decay=config.TRAIN.WEIGHT_DECAY)
    elif opt_lower == 'adamw':
        optimizer_shadow_inn = optim.AdamW(trainable_list_inn.parameters(),eps=config.TRAIN.OPTIMIZER.EPS, betas=config.TRAIN.OPTIMIZER.BETAS,
                                lr=config.TRAIN.BASE_LR, weight_decay=config.TRAIN.WEIGHT_DECAY)
        optimizer_shadow_cnn = optim.AdamW(trainable_list_cnn.parameters(), eps=config.TRAIN.OPTIMIZER.EPS,betas=config.TRAIN.OPTIMIZER.BETAS,
                                       lr=config.TRAIN.BASE_LR, weight_decay=config.TRAIN.WEIGHT_DECAY)

    return optimizer_shadow_cnn , optimizer_shadow_inn

def build_optimizer_cinn(config, model):
    """
    Build optimizer, set weight decay of normalization to 0 by default.
    """
    skip = {}
    skip_keywords = {}
    if hasattr(model, 'no_weight_decay'):
        skip = model.no_weight_decay()
    if hasattr(model, 'no_weight_decay_keywords'):
        skip_keywords = model.no_weight_decay_keywords()
    parameters = set_weight_decay(model, skip, skip_keywords)

    opt_lower = config.TRAIN.OPTIMIZER.NAME.lower()
    optimizer = None
    if opt_lower == 'sgd':
        optimizer = optim.SGD(parameters, momentum=config.TRAIN.OPTIMIZER.MOMENTUM, nesterov=True,
                              lr=config.TRAIN.BASE_LR, weight_decay=config.TRAIN.WEIGHT_DECAY)
    elif opt_lower == 'adamw':
        optimizer = optim.AdamW(parameters,eps=config.TRAIN.OPTIMIZER.EPS, betas=config.TRAIN.OPTIMIZER.BETAS,
                                lr=config.TRAIN.BASE_LR, weight_decay=config.TRAIN.WEIGHT_DECAY)

    return optimizer

def set_weight_decay(model, skip_list=(), skip_keywords=()):
    has_decay = []
    no_decay = []

    for name, param in model.named_parameters():
        if not param.requires_grad:
            continue  # frozen weights
        if len(param.shape) == 1 or name.endswith(".bias") or (name in skip_list) or \
                check_keywords_in_name(name, skip_keywords):
            no_decay.append(param)
            # print(f"{name} has no weight decay")
        else:
            has_decay.append(param)
    return [{'params': has_decay},
            {'params': no_decay, 'weight_decay': 0.}]


def check_keywords_in_name(name, keywords=()):
    isin = False
    for keyword in keywords:
        if keyword in name:
            isin = True
    return isin
