#coding=utf-8
from os import fwalk
import torch

def get_params(alg,args,inner=False,domainadv=False,commonnet=False,nettype='',alias=True):
    if args.schuse:
        if args.schusech=='cos':
            init_lr=args.lr
        else:
            init_lr=1.0
    else:
        init_lr=args.lr
    if nettype=='FEA':
        if args.task.startswith('cross'):
            params = [
                {'params': alg.featurizer.parameters(), 'lr': args.lr_decay1 * init_lr},
                {'params': alg.bottleneck.parameters(), 'lr': args.lr_decay2 * init_lr}
            ]
        else:
            params = [
                {'params': alg.featurizer.parameters(), 'lr': args.lr_decay1 * init_lr}
            ]
        return params
    if nettype=='CLS':
        params = [
            {'params': alg.classifier.parameters(), 'lr': args.lr_decay2 * init_lr}
        ]
        if ('DANN' in args.algorithm) or ('CDANN' in args.algorithm):
            params.append({'params': alg.discriminator.parameters(), 'lr': args.lr_decay2 * init_lr})
        if ('CDANN' in args.algorithm):
            params.append({'params': alg.class_embeddings.parameters(), 'lr': args.lr_decay2 * init_lr})
        return params        
    if nettype=='TDBADV':
        params = [
            {'params': alg.dbottleneck.parameters(), 'lr': args.lr_decay2 * init_lr},
            {'params': alg.dclassifier.parameters(), 'lr': args.lr_decay2 * init_lr},
            {'params': alg.ddiscriminator.parameters(), 'lr': args.lr_decay2 * init_lr}
        ]  
        return params
    elif nettype=='TDBCLS':
        params = [
            {'params': alg.bottleneck.parameters(), 'lr': args.lr_decay2 * init_lr},
            {'params': alg.classifier.parameters(), 'lr': args.lr_decay2 * init_lr},
            {'params': alg.discriminator.parameters(), 'lr': args.lr_decay2 * init_lr}
        ]  
        return params
    elif nettype=='TDBALL':
        params = [
            {'params': alg.featurizer.parameters(), 'lr': args.lr_decay1 * init_lr},
            {'params': alg.abottleneck.parameters(), 'lr': args.lr_decay2 * init_lr},
            {'params': alg.aclassifier.parameters(), 'lr': args.lr_decay2 * init_lr}
        ]  
        return params        
    if commonnet:
        return [{'params':alg.parameters(),'lr':args.lr_decay2*init_lr}]
    if domainadv:
    
        params = [
            {'params': alg.dfeaturizer.parameters(), 'lr': args.lr_decay1 * init_lr},
            {'params': alg.dbottleneck.parameters(), 'lr': args.lr_decay2 * init_lr},
            {'params': alg.dclassifier.parameters(), 'lr': args.lr_decay2 * init_lr},
            {'params': alg.ddiscriminator.parameters(), 'lr': args.lr_decay2 * init_lr}
        ]  
    if inner:
 
        if args.task.startswith('cross'):
            params = [
                {'params': alg[0].parameters(), 'lr': args.lr_decay1 * init_lr},
                {'params': alg[1].parameters(), 'lr': args.lr_decay2 * init_lr},
                {'params': alg[2].parameters(), 'lr': args.lr_decay2 * init_lr}
            ]
        else:
            params = [
                {'params': alg[0].parameters(), 'lr': args.lr_decay1 * init_lr},
                {'params': alg[1].parameters(), 'lr': args.lr_decay2 * init_lr}
            ]            
        return params
    elif alias:
        params = [
            {'params': alg.featurizer.parameters(), 'lr': args.lr_decay1 * init_lr},
            {'params': alg.classifier.parameters(), 'lr': args.lr_decay2 * init_lr}
        ]
        if args.task.startswith('cross') or args.algorithm in ['FFTdisexp']:
            params.append({'params': alg.bottleneck.parameters(), 'lr': args.lr_decay2 * init_lr})
    else:
        params = [
            {'params': alg[0].parameters(), 'lr': args.lr_decay1 * init_lr},
            # {'params': alg[1].parameters(), 'lr': args.lr_decay2 * init_lr},
            {'params': alg[1].parameters(), 'lr': args.lr_decay2 * init_lr}
        ]
    if ('DANN' in args.algorithm) or ('CDANN' in args.algorithm):
        params.append({'params': alg.discriminator.parameters(), 'lr': args.lr_decay2 * init_lr})
    if ('CDANN' in args.algorithm):
        params.append({'params': alg.class_embeddings.parameters(), 'lr': args.lr_decay2 * init_lr})
    if 'DLDA' in args.algorithm:
        params.append({'params': alg.dldanet.parameters(), 'lr': args.lr_decay2 * init_lr})
    if 'PAW' in args.algorithm and (nettype!='CLS'):
        params.append({'params': alg.bottleneck2.parameters(), 'lr': args.lr_decay2 * init_lr})
    return params

def get_optimizer(alg, args,inner=False,domainadv=False,commonnet=False,nettype='',alias=True):
    params = get_params(alg,args,inner,domainadv,commonnet,nettype=nettype,alias=alias)
    if args.task.startswith('cross'):
        optimizer=torch.optim.Adam(params,lr=args.lr,weight_decay=args.weight_decay,betas=(args.beta1, 0.9))
    else:
        # optimizer = torch.optim.SGD(params, lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay, nesterov=False)
        optimizer = torch.optim.SGD(params, lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay, nesterov=True)
        # optimizer = torch.optim.Adam(
        #     alg.network.parameters(),
        #     lr=args.lr,
        #     weight_decay=args.weight_decay
        # )
    return optimizer

def get_scheduler(optimizer, args):
    if args.task.startswith('cross'):
        return None
    else:
        if args.schuse:
            if args.schusech=='cos':
                # print(args.max_epoch * args.steps_per_epoch)
                scheduler=torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, args.max_epoch * args.steps_per_epoch)
            else:
                scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lambda x:  args.lr * (1. + args.lr_gamma * float(x)) ** (-args.lr_decay))
            return scheduler
        else:
            return None