import torch
from .lr_scheduler import LR_Scheduler, MultiStepLRScheduler

def get_optimizer(name, model, lr, momentum, weight_decay, partial_freeze=''):
    if partial_freeze == 'lower':
        model_parameters = [w for n, w in model.named_parameters() if not 'layers.0.' in n]
    elif partial_freeze == 'upper':
        model_parameters = [w for n, w in model.named_parameters() if not 'layers.3.' in n]
    else:
        model_parameters = model.parameters()
    
    parameters = [{
        'name': 'base',
        'params': model_parameters,
        'lr': lr
    }]
    
    if name == 'sgd':
        optimizer = torch.optim.SGD(parameters, lr=lr, momentum=momentum, weight_decay=weight_decay)
    elif name == 'adam':
        optimizer = torch.optim.Adam(parameters, lr=lr)
    elif name == 'adamw':
        optimizer = torch.optim.AdamW(parameters, lr=lr, weight_decay=weight_decay)
    else:
        raise NotImplementedError
    return optimizer

def get_laps_inner_optimizer(name, parameters, lr, momentum, weight_decay):
    if name == 'sgd':
        optimizer = torch.optim.SGD(parameters, lr=lr, momentum=momentum, weight_decay=weight_decay)
    elif name == 'adam':
        optimizer = torch.optim.Adam(parameters, lr=lr)
    elif name == 'adamw':
        optimizer = torch.optim.AdamW(parameters, lr=lr, weight_decay=weight_decay)
    else:
        raise NotImplementedError
    return optimizer

def get_apd_optimizer(name, model, task_id, lr, momentum, weight_decay, base_parameters=None):
    weights = model.key_weights
    parameters = [{
        'name': 'base',
        'params': model.parameters() if base_parameters is None else base_parameters,
        'lr': lr
    },{
        'name': 'tsh',
        'params': [weights['tsh'][_key] for _key in weights['tsh'].keys()],
        'lr': lr
    },{
        'name': 'msk',
        'params': [weights['msk'][_key] for _key in weights['msk'].keys() if _key.endswith('_t%s'%task_id)],
        'lr': lr
    },{
        'name': 'tad',
        'params': [weights['tad'][_key] for _key in weights['tad'].keys() if task_id >= int(_key.split('_t')[-1])],
        'lr': lr
    }]
    print('# params for task%d, base, tsh, msk, tad: %s'%(task_id, [len(parameters[j]['params']) for j in range(len(parameters))]))
    if name == 'sgd':
        optimizer = torch.optim.SGD(parameters, lr=lr, momentum=momentum, weight_decay=weight_decay)
    elif name == 'adam':
        optimizer = torch.optim.Adam(parameters, lr=lr)
    elif name == 'adamw':
        optimizer = torch.optim.AdamW(parameters, lr=lr, weight_decay=weight_decay)
    else:
        raise NotImplementedError
    return optimizer
