import torch

def load_optimizer(parameters, optim_alg, config):

    if optim_alg == 'sgd':
        return torch.optim.SGD(parameters, 
                               lr=config.agent.start_lr, 
                               momentum=config.agent.momentum,
                               weight_decay=config.agent.weight_decay)
    
    elif optim_alg == 'adam':
        return torch.optim.Adam(parameters, 
                                lr=config.agent.start_lr,
                                weight_decay=config.agent.weight_decay)
    
    elif optim_alg == 'adamw':
        return torch.optim.AdamW(parameters,
                                 lr=config.agent.start_lr,
                                 weight_decay=config.agent.weight_decay)
    

def load_scheduler(optimizer, scheduler_type, config):

    if scheduler_type == 'cosine':
        return torch.optim.lr_scheduler.CosineAnnealingLR(optimizer,
                                                          T_max=config.agent.total_steps,
                                                          eta_min=config.agent.final_lr)
    elif scheduler_type == 'linear':
        return torch.optim.lr_scheduler.LinearLR(optimizer, 0.99, 1)
    else:
        return None