import torch.nn as nn
from torch.optim import Adam, AdamW
from transformers.trainer_pt_utils import get_parameter_names
from transformers import (
    get_linear_schedule_with_warmup, 
    get_constant_schedule_with_warmup, 
    get_cosine_schedule_with_warmup,
)    


def set_optim_and_schedule(config, model, extra_args):
    
    if config.lr_scheduler_type not in ["linear", "cosine", "constant_with_warmup"]:
        raise ValueError(f"We currently only support for linear, cosine or fixed learning rate scheduler. Please set --lr_scheduler_type: linear or fixed")

    if extra_args.get('decay_forbidden_layer_types'):
        # decay_parameters = get_parameter_names(model, [nn.LayerNorm])
        decay_parameters = get_parameter_names(model, [getattr(nn, t) for t in extra_args['decay_forbidden_layer_types']])
        decay_parameters = [name for name in decay_parameters if "bias" not in name]

        optimizer_grouped_parameters = [
            {
                "params": [p for n, p in model.named_parameters() if n in decay_parameters],
                "weight_decay": config.weight_decay,
            },
            {
                "params": [p for n, p in model.named_parameters() if n not in decay_parameters],
                "weight_decay": 0.0,
            },
        ]
    else:
        decay_parameters = [n for n, p in model.named_parameters()]
        decay_parameters = [name for name in decay_parameters if "bias" not in name]

        optimizer_grouped_parameters = [
            {
                "params": [p for n, p in model.named_parameters() if n in decay_parameters],
                "weight_decay": config.weight_decay,
            },
            {
                "params": [p for n, p in model.named_parameters() if n not in decay_parameters],
                "weight_decay": 0.0,
            },
        ]

    if config.optim == 'adamw_torch':
        from torch.optim import AdamW
        
    elif config.optim == 'adamw_hf':
        from transformers.optimization import AdamW
        
    else:
         raise ValueError(f"We currently only support for AdamW optimizer. Please set --optim: adam_hf  or adam_torch")

    optim = AdamW(
        optimizer_grouped_parameters,
        lr=config.learning_rate,
        betas=(config.adam_beta1, config.adam_beta2),
        eps=config.adam_epsilon,
        weight_decay=config.weight_decay,
    )

    if config.lr_scheduler_type == "linear":
        schedule = get_linear_schedule_with_warmup(
            optimizer=optim,
            num_warmup_steps=config.warmup_steps,
            num_training_steps=config.max_steps,
        )
    elif config.lr_scheduler_type == "constant_with_warmup":
        schedule = get_constant_schedule_with_warmup(
            optimizer=optim,
            num_warmup_steps=config.warmup_steps,
        )
        
    elif config.lr_scheduler_type == "cosine":
        schedule = get_cosine_schedule_with_warmup(
            optimizer=optim,
            num_warmup_steps=config.warmup_steps,
            num_training_steps=config.max_steps,
            num_cycles=0.5 # default value
        )        

    return optim, schedule