import argparse
import os
import time

def get_args():
    parser = argparse.ArgumentParser(description='main')
    parser.add_argument('--seed', type=int, default=2022)
    parser.add_argument('--data_dir', type=str, default="./Dataset", help='root data dir')
    parser.add_argument('--dataset', type=str, default='PACS', choices=['PACS', 'Digits', 'OfficeHome'])
    parser.add_argument('--order', type=int, nargs='+', help='training domain order')
    parser.add_argument('--output_dir', type=str, default="./results", help='result output path')
    parser.add_argument('--model_root', type=str, default="saved_models", help='saved models path')

    parser.add_argument('--lr_pre', type=float, default=0.001, help="learning rate")
    parser.add_argument('--schedule_pre', type=str, default='poly')
    parser.add_argument('--power_pre', type=float, default=0.75)
    parser.add_argument('--epoch_pre', type=int, default=100, help="pretraining max epoch")
    parser.add_argument('--wd_pre', type=float, default=5e-4)
    parser.add_argument('--plot_pre', action='store_true')
    parser.add_argument('--fix_pre', action='store_true')
    parser.add_argument('--saved_title', type=str, default='')
    parser.add_argument('--re', action='store_true')
    
    parser.add_argument('--batch_size', type=int, default=128, help='batch_size')
    parser.add_argument('--optimizer', type=str, default='SGD')
    parser.add_argument('--epoch', type=int, default=80, help="max epoch")
    parser.add_argument('--lr', type=float, default=0.001, help="learning rate")
    parser.add_argument('--lr_max', type=float, default=0.001)
    parser.add_argument('--schedule', type=str, default='none')
    parser.add_argument('--power', type=float, default=0.75)
    parser.add_argument('--wd', type=float, default=0.01)
    parser.add_argument('--momentum', type=float, default=0.9)
    parser.add_argument('--nesterov', action='store_true')
    parser.add_argument('--classifier_wn', action='store_true')  
    parser.add_argument('--trans_ensemble', action='store_true')
    parser.add_argument('--num_ens', type=int, default=3)
    parser.add_argument('--proto_notrans', action='store_true')

    parser.add_argument('--mixup_alpha', type=float, default=0.4)
    parser.add_argument('--pl', type=str, default='shot')
    parser.add_argument('--transforms', type=str, default='simpleV2')
    parser.add_argument('--augmentation', type=str, default='none')
    parser.add_argument('--aug_ogcat', action='store_true')
    parser.add_argument('--beta', type=float, default=0.0)
    parser.add_argument('--gamma', type=float, default=0.0)
    parser.add_argument('--delta', type=float, default=0.5)
    parser.add_argument('--lam_thres', type=float, default=0.35)
    parser.add_argument('--ltemp', type=float, default=0.5, help="temperature for lambda")
    parser.add_argument('--stemp', type=float, default=1.0, help="temperature for cosine similarity sofmax")
    parser.add_argument('--ctemp', type=float, default=1.0)
    parser.add_argument('--ctemp_align', action='store_true')
    parser.add_argument('--distance', type=str, default='cosine')
    parser.add_argument('--fea_dim', action='store_true')
    parser.add_argument('--fea_norm', action='store_true')
    parser.add_argument('--fea_norm2', action='store_true')
    parser.add_argument('--no_bias', action='store_true')
    parser.add_argument('--classifier_temp', type=float, default=1.0)

    args = parser.parse_args()
    args = img_param_init(args)

    return args

def img_param_init(args):
    
    args.domains = {
        'PACS': ['art_painting', 'cartoon', 'photo', 'sketch'],
        'Digits': ['mnist', 'mnist_m', 'svhn', 'syn'],
        'OfficeHome': ['art', 'clipart', 'product', 'real_world']
    }
    args.domains_letter = {
        'PACS': ['A', 'C', 'P', 'S'],
        'Digits': ['Mn', 'Mm', 'Sv', 'Sy'],
        'OfficeHome': ['Ar', 'Cl', 'Pr', 'Rw']
    }
    args.order_names = {
        'PACS': ['A2C', 'A2P', 'A2S', 'C2A', 'C2P', 'C2S', 'P2A', 'P2C', 'P2S', 'S2A', 'S2C', 'S2P'],
        'Digits': ['Mn2Mm', 'Mn2Sv', 'Mn2Sy', 'Mm2Mn', 'Mm2Sv', 'Mm2Sy', 'Sv2Mn', 'Sv2Mm', 'Sv2Sy', 'Sy2Mn', 'Sy2Mm', 'Sy2Sv'],
        'OfficeHome': ['Ar2Cl', 'Ar2Pr', 'Ar2Rw', 'Cl2Ar', 'Cl2Pr', 'Cl2Rw', 'Pr2Ar', 'Pr2Cl', 'Pr2Rw', 'Rw2Ar', 'Rw2Cl', 'Rw2Pr']
    }

    if args.dataset == 'Digits':
        args.input_shape = (3, 32, 32)
        args.input_size = 32
        args.num_classes = 10
    else:
        args.input_shape = (3, 224, 224)
        args.input_size = 224
        if args.dataset == 'PACS':
            args.num_classes = 7
        elif args.dataset == 'OfficeHome':
            args.num_classes = 65

    order_dict = {
        (0, 2, 3, 1): 0, (0, 3, 2, 1): 0,
        (0, 1, 3, 2): 1, (0, 3, 1, 2): 1,
        (0, 1, 2, 3): 2, (0, 2, 1, 3): 2,
        (1, 2, 3, 0): 3, (1, 3, 2, 0): 3,
        (1, 0, 3, 2): 4, (1, 3, 0, 2): 4,
        (1, 0, 2, 3): 5, (1, 2, 0, 3): 5,
        (2, 1, 3, 0): 6, (2, 3, 1, 0): 6, 
        (2, 0, 3, 1): 7, (2, 3, 0, 1): 7,
        (2, 0, 1, 3): 8, (2, 1, 0, 3): 8,   
        (3, 1, 2 ,0): 9, (3, 2, 1, 0): 9,
        (3, 0, 2, 1): 10, (3, 2, 0, 1): 10,
        (3, 0, 1, 2): 11, (3, 1, 0, 2): 11
    }
    
    order_tuple = tuple(args.order)
    args.order_num = order_dict[order_tuple]
    args.order_name = args.order_names[args.dataset][order_dict[order_tuple]]

    src_domain = args.domains_letter[args.dataset][args.order[0]]
    
    if args.dataset == 'PACS':
        args.lr_pre = 0.005
        if not args.fix_pre and not args.epoch==1:
            args.epoch = 80
        if args.augmentation == 'convmix':
            if src_domain == 'A':
                args.epoch_pre = 500
            elif src_domain == 'C':
                args.epoch_pre = 100
            elif src_domain == 'P':
                args.epoch_pre = 300
            elif src_domain == 'S':
                args.epoch_pre = 100

    elif args.dataset == 'OfficeHome':
        args.lr_pre = 0.002
        if not args.fix_pre and not args.epoch==1:
            args.epoch = 60
        if args.augmentation == 'convmix':
            if src_domain == 'Ar':
                args.epoch_pre = 100
            elif src_domain == 'Cl':
                args.epoch_pre = 100
            elif src_domain == 'Pr':
                args.epoch_pre = 60
            elif src_domain == 'Rw':
                args.epoch_pre = 60

    elif args.dataset == 'Digits':
        args.lr_pre = 0.05
        if not args.fix_pre and not args.epoch==1:
            args.epoch = 800
        if args.augmentation == 'convmix':
            if src_domain == 'Mn':
                args.epoch_pre = 1000
            elif src_domain == 'Mm':
                args.epoch_pre = 1500
            elif src_domain == 'Sv':
                args.epoch_pre = 1500
            elif src_domain == 'Sy':
                args.epoch_pre = 1000

    if args.ctemp_align:
        args.ctemp = args.stemp

    if args.fix_pre:
        args.src_featurizer_restore = "saved_models/featurizer_" + args.saved_title + '.pt' 
        args.src_classifier_restore = "saved_models/classifier_" + args.saved_title + '.pt'    
    elif not args.augmentation == 'none':
        if args.schedule_pre == 'poly':
            args.src_featurizer_restore = "saved_models/source-featurizer-pretrain_{}_ogcat_{}_wn_{}_ep{}_lr{}_sch_{}_pw{}_wd{}_{}-{}".format(args.augmentation, args.aug_ogcat, args.classifier_wn, args.epoch_pre, args.lr_pre, args.schedule_pre, args.power_pre, args.wd_pre, args.dataset, args.order[0])
            args.src_classifier_restore = "saved_models/source-classifier-pretrain_{}_ogcat_{}_wn_{}_ep{}_lr{}_sch_{}_pw{}_wd{}_{}-{}".format(args.augmentation, args.aug_ogcat, args.classifier_wn, args.epoch_pre, args.lr_pre, args.schedule_pre, args.power_pre, args.wd_pre, args.dataset, args.order[0])
        else:
            args.src_featurizer_restore = "saved_models/source-featurizer-pretrain_{}_ogcat_{}_wn_{}_ep{}_lr{}_sch_{}_wd{}_{}-{}".format(args.augmentation, args.aug_ogcat, args.classifier_wn, args.epoch_pre, args.lr_pre, args.schedule_pre, args.wd_pre, args.dataset, args.order[0])
            args.src_classifier_restore = "saved_models/source-classifier-pretrain_{}_ogcat_{}_wn_{}_ep{}_lr{}_sch_{}_wd{}_{}-{}".format(args.augmentation, args.aug_ogcat, args.classifier_wn, args.epoch_pre, args.lr_pre, args.schedule_pre, args.wd_pre, args.dataset, args.order[0])
    else:    
        args.src_featurizer_restore = "saved_models/source-featurizer-pretrain_ep{}_lr{}_sch_{}_wd{}_{}-{}".format(args.epoch_pre, args.lr_pre, args.schedule_pre, args.wd_pre, args.dataset, args.order[0])
        args.src_classifier_restore = "saved_models/source-classifier-pretrain_ep{}_lr{}_sch_{}_wd{}_{}-{}".format(args.epoch_pre, args.lr_pre, args.schedule_pre, args.wd_pre, args.dataset, args.order[0])
    
    if args.fea_norm:
        args.src_featurizer_restore += '_feanorm_temp{}'.format(args.classifier_temp)
        args.src_classifier_restore += '_feanorm_temp{}'.format(args.classifier_temp)
    elif args.fea_norm2:
        args.src_featurizer_restore += '_feanorm_temp{}'.format(args.classifier_temp)
        args.src_classifier_restore += '_feanorm_temp{}'.format(args.classifier_temp)

    if args.no_bias:
        args.src_featurizer_restore += '_nobias'
        args.src_classifier_restore += '_nobias'

    args.src_featurizer_restore += '.pt'
    args.src_classifier_restore += '.pt'

    if not args.augmentation == 'none':
        args.title = '{}_{}_ep{}_lr{}_wd{}_sch_{}_{}_ogcat_{}_ltemp{}_stemp{}_ctemp{}_thres{}_del{}_malp{}_{}'.format(args.seed, args.order_name, args.epoch, args.lr, args.wd, args.schedule, args.augmentation, args.aug_ogcat, args.ltemp, args.stemp, args.ctemp, args.lam_thres, args.delta, args.mixup_alpha, args.distance)
    else:
        args.title = '{}_{}_ep{}_lr{}_wd{}_sch_{}_ltemp{}_stemp{}_ctemp{}_thres{}_del{}_malp{}_{}'.format(args.seed, args.order_name, args.epoch, args.lr, args.wd, args.schedule, args.ltemp, args.stemp, args.ctemp, args.lam_thres, args.delta, args.mixup_alpha, args.distance)
        
    if args.fea_dim:
        args.title += '_feadim'

    if args.fea_norm:
        args.title += '_feanorm_temp{}'.format(args.classifier_temp)
    elif args.fea_norm2:
        args.title += '_feanorm2_temp{}'.format(args.classifier_temp)

    if args.no_bias:
        args.title += '_nobias'

    if args.classifier_wn:
        args.title += '_wn'

    if args.trans_ensemble:
        args.title += '_transens_num{}'.format(args.num_ens)

    if args.proto_notrans:
        args.title += '_notrans'

    if args.fix_pre:
        args.title += '_continued'

    if args.re:
        args.title += '_re'

    args.log_file = 'log/' + args.title + ".txt"
    
    args.start_time = time.time()

    return args