from optparse import Option
from random import gammavariate
from pytest import param
from torch.optim import SGD, RMSprop, Adam
from torch.optim.lr_scheduler import StepLR, MultiStepLR, CosineAnnealingLR


def make(name, params, lr, weight_decay=0,
        schedule='step', milestones=None, gamma=0.1, prompt=None):
    """
    Prepares an optimizer and its learning-rate scheduler.

    Args:
        name (str): name of the optimizer. Options: 'sgd', 'rmsprop', 'adam'
        params (iterable): parameters to optimize.
        lr (float): initial learning rate.
        weight_decay (float, optional): weight decay. Default: 0.
        schedule (str, optional): type of learning-rate schedule. Default: 'step'
        Options: 'step', 'cosine'
        (This argument is ignored if milestones=None.)
        milestones (int list, optional): a list of epoches when learning rate 
        is altered. Default: None
        gamma (float, optional): multiplicative factor of learning rate decay.
        Default: 0.1
    """
    #* set the optimizer
    if name == 'sgd':
        pass
    elif name == 'rmsprop':
        pass
    elif name == 'adam':
        optimizer = Adam(params=params, lr=lr, weight_decay=weight_decay)
    elif name == 'prompt':
        optimizer = Adam(
            [{'params':params[0], 'lr': prompt['lr'], 'weight_decay ':prompt['weight_decay']},
            {'params':params[1]},],
            lr=lr, weight_decay=weight_decay)
    else:
        raise ValueError('invalid optimizer')
    
    #* set the learning rate scheduler
    if milestones is not None:
        if schedule == "StepLR":
            lr_scheduler = StepLR(optimizer=optimizer, step_size=milestones, gamma=gamma)
        elif schedule == 'step':
            lr_scheduler = MultiStepLR(optimizer=optimizer, milestones=milestones, gamma=gamma)
        elif schedule == 'cosine':
            lr_scheduler = CosineAnnealingLR(optimizer=optimizer, milestones=milestones[-1])
        else:
            raise ValueError('invalid lr_scheduler')
    else:
        lr_scheduler = None
        
    return optimizer, lr_scheduler