import torch
import torch.optim as optim


def create_optimizer(cfg, parameters):
    if cfg.name == "sgd":
        return optim.SGD(
            parameters, lr=cfg.lr, weight_decay=cfg.weight_decay, momentum=cfg.momentum
        )
    elif cfg.name == "adamw":
        return optim.AdamW(
            parameters,
            lr=cfg.lr,
            weight_decay=cfg.weight_decay,
        )
    elif cfg.name == "adam":
        return optim.Adam(
            parameters,
            lr=cfg.lr,
            weight_decay=cfg.weight_decay,
        )


def create_scheduler(scheduler, opt, loader, epochs):
    if scheduler == "cos":
        return torch.optim.lr_scheduler.CosineAnnealingLR(
            opt, T_max=epochs, eta_min=0, last_epoch=-1
        )
    elif scheduler == "warmstartcos":
        return torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(
            opt, T_0=epochs // 10
        )
    elif scheduler == "onecycle":
        return torch.optim.lr_scheduler.OneCycleLR(
            opt,
            max_lr=opt.defaults["lr"] * 10,
            epochs=epochs,
            steps_per_epoch=len(loader),
            pct_start=10 / epochs,
        )
    else:
        raise NotImplementedError(f"Scheduler {scheduler} not implemented.")
