# coding=utf-8
import torch


def get_params(alg, args, alg_name, inner=False, alias=True):
    if args.schuse:
        if args.schusech == 'cos':
            initlr = args.lr
        else:
            initlr = 1.0
    else:
        if inner:
            initlr = args.inner_lr
        else:
            initlr = args.lr
    if inner:
        params = [
            {'params': alg[0].parameters(), 'lr': args.lr_decay1 *
             initlr},
            {'params': alg[1].parameters(), 'lr': args.lr_decay2 *
             initlr}
        ]
    elif alias:
        params = [
            {'params': alg.featurizer.parameters(), 'lr': args.lr_decay1 * initlr},
            {'params': alg.classifier.parameters(), 'lr': args.lr_decay2 * initlr}
        ]
    else:
        params = [
            {'params': alg[0].parameters(), 'lr': args.lr_decay1 * initlr},
            {'params': alg[1].parameters(), 'lr': args.lr_decay2 * initlr}
        ]
    if ('DANN' in alg_name) or ('CDANN' in alg_name):
        params.append({'params': alg.discriminator.parameters(),
                       'lr': args.lr_decay2 * initlr})
    if ('CDANN' in alg_name):
        params.append({'params': alg.class_embeddings.parameters(),
                       'lr': args.lr_decay2 * initlr})
    return params


def get_optimizer(alg, args, inner=False, alias=True):
    params = get_params(alg, args, args.DGalgorithm, inner, alias)
    optimizer = torch.optim.SGD(
        params, lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay, nesterov=True)
    return optimizer

def get_BF_optimizer(alg, args, inner=False, alias=True):
    params = get_params(alg, args, args.BFalgorithm, inner, alias)
    optimizer = torch.optim.SGD(
        params, lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay, nesterov=True)
    return optimizer


def get_scheduler(optimizer, args):
    if not args.schuse:
        return None
    if args.schusech == 'cos':
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
            optimizer, args.max_epoch * args.steps_per_epoch)
    else:
        scheduler = torch.optim.lr_scheduler.LambdaLR(
            optimizer, lambda x:  args.lr * (1. + args.lr_gamma * float(x)) ** (-args.lr_decay))
    return scheduler

# def ERM_lr_decrease(args, opt, epoch, max_epoch, task_id):
#     '''
#     manually descrease lr when source algorithm is ERM.
#     When source Algorithm is ERM, manually descrease lr in source training stage.
#     When source and target Algoritm are both ERM, manually descrease lr in source and target training stage.
#     '''
#     if args.sourceAlg == 'ERM':
#         if task_id == 0 or (task_id > 0 and args.targetAlg == 'ERM'):

#             if (epoch in [int(max_epoch*0.7), int(max_epoch*0.9)]):
#                 # print('manually descrease lr')
#                 for params in opt.param_groups:
#                     params['lr'] = params['lr']*0.1
#     else:
#         pass
#     return opt

# def lr_scheduler(optimizer, args, iter_num, max_iter, gamma=10, power=0.75):
#     decay = (1 + gamma * iter_num / max_iter) ** (-power)
#     for param_group in optimizer.param_groups:
#         param_group['lr'] = args.lr * decay
#     return optimizer