import torch
import torch.optim as optim
from torch.optim.lr_scheduler import MultiStepLR, CosineAnnealingLR, ExponentialLR


def build_optimizer(model, config):
    name = config.TRAINER.OPTIMIZER
    lr = config.TRAINER.TRUE_LR
    print("Starting learning rate: ", lr)
    # Note: can adjust different learning rate here for different module
    if name == "adam":
        return optim.Adam(model.parameters(), lr=lr, weight_decay=config.TRAINER.ADAM_DECAY)
    elif name == "adamw":
        return optim.AdamW(model.parameters(), lr=lr, weight_decay=config.TRAINER.ADAMW_DECAY)
    else:
        raise ValueError(f"TRAINER.OPTIMIZER = {name} is not a valid optimizer!")
    
def build_bb_optimizer(model_bb, model_head, config):
    name = config.TRAINER.OPTIMIZER
    lr = config.TRAINER.TRUE_LR
    bb_ratio = 2e-3
    print("Starting learning rate: ", lr)
    print("Starting leanring rate for backbone: ", lr * bb_ratio)
    if name == "adam":
        return optim.Adam([
            {'params': model_bb.parameters(), 'lr': lr * bb_ratio, 'weight_decay': config.TRAINER.ADAM_DECAY},
            {'params': model_head.parameters(), 'lr': lr, 'weight_decay': config.TRAINER.ADAM_DECAY}
        ])
    elif name == "adamw":
        return optim.AdamW([
            {'params': model_bb.parameters(), 'lr': lr * bb_ratio, 'weight_decay': config.TRAINER.ADAMW_DECAY},
            {'params': model_head.parameters(), 'lr': lr, 'weight_decay': config.TRAINER.ADAMW_DECAY}
        ])
    else:
        raise ValueError(f"TRAINER.OPTIMIZER = {name} is not a valid optimizer!")

def build_scheduler(config, optimizer):
    """
    Returns:
        scheduler (dict):{
            'scheduler': lr_scheduler,
            'interval': 'step',  # or 'epoch'
            'monitor': 'val_f1', (optional)
            'frequency': x, (optional)
        }
    """
    scheduler = {'interval': config.TRAINER.SCHEDULER_INTERVAL}
    name = config.TRAINER.SCHEDULER
    
    if name == "MultiStepLR":
        scheduler.update(
            {'scheduler': MultiStepLR(optimizer, config.TRAINER.MSLR_MILESTONES, gamma=config.TRAINER.MSLR_GAMMA)}
        )
    elif name == 'CosineAnnealing':
        scheduler.update(
            {'scheduler': CosineAnnealingLR(optimizer, config.TRAINER.COSA_TMAX)}
        )
    elif name == 'ExponentialLR':
        scheduler.update(
            {'scheduler': ExponentialLR(optimizer, config.TRAINER.ELR_GAMMA)}
        )
    else:
        raise NotImplementedError()
    
    return scheduler