from torch.optim import Adam, SGD
from torch.optim.lr_scheduler import CosineAnnealingLR, LinearLR

from .lr_scheduler import CyclicLinearLR


__all__ = [
    "build_optimizer",
    "build_lr_scheduler",
    "reset_optimizer"
]


def build_optimizer(model, args):
    if args.optimizer == 'sgd':
        return SGD(
            model.parameters(), 
            lr=args.lr, 
            momentum=args.momentum, 
            weight_decay=args.weight_decay
        )
    elif args.optimizer == 'adam':
        return Adam(
            model.parameters(), 
            lr=args.lr, 
            betas=(args.adam_beta1, args.adam_beta2),
            weight_decay=args.weight_decay
        )
    else:
        raise ValueError("Unsupported optimizer")

def build_lr_scheduler(optimizer, args):
    if args.lr_scheduler == None:
        return None
    elif args.lr_scheduler == 'linear':
        return LinearLR(
            optimizer,
            start_factor=1.0, 
            end_factor=0.0,
            total_iters=args.num_train_steps
        )
    elif args.lr_scheduler == 'cosine':
        return CosineAnnealingLR(
            optimizer, 
            T_max=args.num_train_steps,
            eta_min=0
        )
    elif args.lr_scheduler == 'cyclic_linear':
        cycle_steps = args.cycle_steps or args.num_train_steps
        return CyclicLinearLR(
            optimizer,
            cycle_steps,
            start_factor=1.0, 
            end_factor=0.0,
        )
    else:
        raise ValueError("Unsupported lr_scheduler")

def reset_optimizer(optimizer):
    for param_group in optimizer.param_groups:
        for param in param_group['params']:
            state = optimizer.state[param]
            if "momentum_buffer" in state:
                state["momentum_buffer"].zero_()
            if 'exp_avg' in state:
                state['exp_avg'].zero_()
            if 'exp_avg_sq' in state:
                state['exp_avg_sq'].zero_()
            if 'max_exp_avg_sq' in state:
                state['max_exp_avg_sq'].zero_()

                
