import argparse
import os
import shutil
import sys
import time
import numpy as np
import torch
from utils.util import Tee

def get_args():
    parser = argparse.ArgumentParser(description='DG')
    # Data
    parser.add_argument('--dataset', type=str, default='PACS', choices=['PACS', 'office_home', 'idomain_net', 'subdomain_net', 'dg5', 'cifar10_c_common', 'cifar10_c_all', 'office', 'office_caltech'])
    parser.add_argument('--test_envs', type=int, nargs='+',
                        default=[], help='no fixed target domains')
    parser.add_argument('--data_dir', type=str, default='./Dataset', help='root data dir')
    parser.add_argument('--order', type=int, nargs='+', help='training domain order')
    parser.add_argument('--N_WORKERS', type=int, default=4)
    parser.add_argument('--split_style', type=str, default='strat',help="the style to split the train and eval datasets")
    parser.add_argument('--target_data', type=str, default='split', choices=['split', 'all'], help='all means use all target data for training and testing')
    parser.add_argument('--commonloader', action='store_true', help='use dataloader instead of infinite dataloader')

    #training algorithm
    parser.add_argument('--sourceAlg', type=str, default="PCL2", choices=['ERM', 'PCL','PCL2', 'supcon', 'SupPCL', 'SupPCL2', 'ERM_bot', "FP", 'MFP'], help='labeled source domain training algorithm')
    parser.add_argument('--targetAlg', type=str, default="PCL2", choices=['ERM', 'PCL','PCL2', 'supcon', 'SupPCL', 'SupPCL2', 'ERM_bot', 'LDAuCID', 'ERMPCL', "FP", 'MFP'], help='unlabeled target domain adaptation algorithm')
    parser.add_argument('--loss_alpha1', type=float, default=1.0, help='loss weight')
    parser.add_argument('--PCL_scale', default=12, type=float, help='scale of cross entropy in PCL')
    parser.add_argument('--pLabelAlg', type=str, default="topkSHOTknn", choices=['softmax', 'ground', 'proxy', 'softmax_proxy', 'proxy_softmax', 'SHOT', 'SHOT_PCL', 'MSHOT', 'BMD', 'knn', 'mknn', 'MSHOTknn', 'knnMSHOT', 'topkSHOTknn'], help='pesudo label assigning algorithm in target domain. ground is ground true label')
    parser.add_argument('--pseudo_tau', default=0.0, type=float, help='pseudo label threshold')
    parser.add_argument('--alpha_tau', default=2, type=int, help='progress pseudo label')
    parser.add_argument('--pseudo_fre', default=1, type=int, help='assign new pseudo label each pseduo_fre epoch')
    parser.add_argument('--gmm_tau', default=0.1, type=float, help='LDAuCID gmm pseudo label tau')
    parser.add_argument('--pseudo_scale', default=12, type=int, help='scale of logits in pseudo labels')
    parser.add_argument('--replay', type=str, default='icarl', choices=['icarl', 'Finetune', 'LDAuCID_buff'], help='data replay algorithm')
    parser.add_argument('--replay_mode', type=str, default='class', choices=['class', 'domain'])
    parser.add_argument('--memory_size', type=int, help="replay exemplar size")
    parser.add_argument('--freeze', action='store_true', help='freeze classifier and proxy in adaptation step.')
    parser.add_argument('--SHOT_IM', action='store_true', help='Information maximization of CE for SHOT pseudo label')
    parser.add_argument('--SHOT_step', type=int, default=2, help='SHOT cluster steps')
    parser.add_argument('--MSHOT_tau', type=float, default=0.7, help='MSHOT reatio of current domain center and replay center')
    parser.add_argument('--aug_tau', type=float, default=0.8, help='do augmentation whose pseudo label confidence larger than this value ')
    parser.add_argument('--distance', type=str, default='cosine', choices=['cosine', 'euclidean'])
    parser.add_argument('--distill', action='store_true', help='distill loss')
    parser.add_argument('--distillProxy', action='store_true', help='distill proxy predict loss')
    parser.add_argument('--distill_alpha', type=float, default=0.5)
    parser.add_argument('--mix_classifier', action='store_true', help='mix classifier ')
    parser.add_argument('--mix_proxy', action='store_true', help='mix current proxy and old proxy')
    parser.add_argument('--classifier_mix_tau', type=float, default=0.5)
    parser.add_argument('--topk_alpha', default=20, type=int, help='k nears in knn pseudo labeling.')
    parser.add_argument('--topk_alpha2', default=2, type=int, help='topk fitting samples in knn pseudo labeling.')
    parser.add_argument('--knn_softmax', action='store_true', help='use softmax to select top k')
    parser.add_argument('--weight_pcl', action='store_true', help='use softmax as weight in pcl loss in adaptation')
    parser.add_argument('--MPCL', type=int, default=0, help='MPCL version')
    parser.add_argument('--MPCL_alpha', type=float, default=0.5, help='MPCL weight')
    parser.add_argument('--pseudo_max_epoch', type=int, default=30, help='assign pseudo label epoch')
    parser.add_argument('--classifier_proxy', action='store_true')
    

    # data augmentation based method
    parser.add_argument('--bacAug', type=str, choices=['styleTransfer', 'EFDMix0', 'EFDMix1', 'MixWave', 'MixWaveSoftmax'], help='backward transfer augmentation method')
    parser.add_argument('--forAug', type=str, choices=['v1'], help='forward transfer augmentation method')
    parser.add_argument('--bacAug_tau', type=float, default=0.3, help='the probability of applying backward augmentation when training')
    parser.add_argument('--mix_layers', type=int, nargs='+', default=[1,2], help='apply mixstyle after which res block')

    
    # Utils
    parser.add_argument('--seed', type=int, default=2022)
    # parser.add_argument('--save_model_every_checkpoint', action='store_true')
    parser.add_argument('--output', type=str,
                        default="result_develop", help='result output path')
    parser.add_argument('--log_file', type=str, help="logging file name under output dir")
    parser.add_argument('--tsne', action='store_true', help='visualize embedding space using tSNE')

    # Model
    parser.add_argument('--net', type=str, default='resnet50',
                        help="featurizer: vgg16, resnet50, resnet101,DTNBase")
    parser.add_argument('--classifier', type=str,
                        default="linear", choices=["linear", "wn"])
    parser.add_argument('--no_pretrained', action='store_true', help='default using pretrained model')

    # Training
    parser.add_argument('--lr', type=float, default=2e-3, help="learning rate")
    parser.add_argument('--targetLR', type=float, help='target domain learning rate')
    parser.add_argument('--no_lr_sch', action='store_true')
    parser.add_argument('--lr_decay1', type=float, default=1.0, help='feature extractor lr scheduler')
    parser.add_argument("--lr_sc", default=0.0005, type=float, help="forward augmentation model learning rate")
    parser.add_argument('--max_epoch', type=int,
                        default=30, help="max epoch")
    parser.add_argument('--steps_per_epoch', type=int, help='training steps in each epoch. totaly trained sampels in each epoch is steps_per_epoch*batch_size')
    parser.add_argument('--batch_size', type=int,
                        default=64, help='batch_size')
    parser.add_argument('--gpu', type=int, default=0, help="device id to run")
    parser.add_argument('--weight_decay', type=float, default=5e-4)
    parser.add_argument('--momentum', type=float,
                        default=0.9, help='for optimizer')


    # Don't need to change
    parser.add_argument('--data_file', type=str, default='',
                        help='root_dir')
    parser.add_argument('--task', type=str, default="img_dg",
                        choices=["img_dg"], help='now only support image tasks')
    

    args = parser.parse_args()

    # I/O
    args.data_dir = os.path.join(args.data_dir, args.dataset, '')
    args.result_dir = os.path.join(args.output, args.dataset, '{}_{}_{}_{}_{}_{}_{}'.format(
                    args.sourceAlg, args.targetAlg, args.pLabelAlg, args.replay, args.forAug if args.forAug is not None else '_', 'freeze' if args.freeze else 'nofreeze', 'distill' if args.distill else 'nodistill'))
    args.tSNE_dir = os.path.join(args.result_dir, 'tSNE')
    os.makedirs(args.output, exist_ok=True)
    os.makedirs(args.result_dir, exist_ok=True)
    os.makedirs(args.tSNE_dir, exist_ok=True)
    # os.makedirs(os.path.join(args.output, 'saved_model'), exist_ok=True)

    os.environ["CUDA_VISIBLE_DEVICES"] = str(args.gpu)
    args = img_param_init(args)
    args = set_default_args(args)
    args.num_task = len(args.domains) - len(args.test_envs)

    args.saved_model_name = os.path.join(args.result_dir, 'source{}.pt'.format(args.order[0]))  

    return args

def set_default_args(args):
    args.MixAlg = ['EFDMix0', 'EFDMix1', 'MixWave', 'MixWaveSoftmax']
    args.PCL_net = ['PCL','PCL2', 'SupPCL', 'SupPCL2']
    args.ERM_net = ['ERM', 'supcon', 'ERM_bot']

    args.order = [i for i in range(len(args.domains)-len(args.test_envs))] if args.order is None else args.order
    args.log_file = os.path.join(args.result_dir, 'order{}.log'.format(''.join(str(i) for i in args.order))) if args.log_file is None else os.path.join(args.result_dir, args.log_file)
    if args.replay == 'icarl':
        args.replay = 'iCaRL'
    
    memory_size = {'PACS':200, 'office_home':600, 'idomain_net':1000, 'subdomain_net':200,
                   'dg5':200, 'office':100, 'office_caltech':40, 'cifar10_c_common':200, 'cifar10_c_all':200}
    steps_per_epoch = {'PACS':50, 'office_home':60, 'idomain_net':750, 'subdomain_net':70,
                   'dg5':800, 'office':35, 'office_caltech':14, 'cifar10_c_common':5400, 'cifar10_c_all':900}
    args.memory_size = memory_size[args.dataset] if args.memory_size is None else args.memory_size
    args.steps_per_epoch = steps_per_epoch[args.dataset] if args.steps_per_epoch is None else args.steps_per_epoch
    
    return args

def img_param_init(args):
    dataset = args.dataset
    if dataset == 'PACS':
        domains = ['art_painting', 'cartoon', 'photo', 'sketch']
    elif dataset == 'office_home':
        domains = ['Art', 'Clipart', 'Product', 'Real_World']
    elif dataset == 'domain_net':
        domains = ['clipart', 'infograph', 'painting', 'quickdraw', 'real', 'sketch']
    elif dataset == 'idomain_net':   # only contain 100 classes with most amount images in domain_net
        domains = ['clipart', 'infograph', 'painting', 'quickdraw', 'real', 'sketch'] 
    elif dataset == 'subdomain_net':   # only contain 10 classes with most amount images in domain_net
        domains = ['clipart', 'infograph', 'painting', 'quickdraw', 'real', 'sketch'] 
    elif dataset == 'cifar10_c_common':
        domains = ['noise', 'blur', 'weather', 'digital']
    elif dataset == 'cifar10_c_all':
        domains = ['gaussian_noise', 'shot_noise', 'impulse_noise', 'defocus_blur', 'glass_blur', 'motion_blur', 'zoom_blur', 'snow', 'frost', 'fog', 'brightness', 'contrast', 'elastic_transform', 'pixelate', 'jpeg_compression', 'speckle_noise', 'gaussian_blur', 'spatter', 'saturate']

    elif dataset == 'office':
        domains = ['amazon', 'dslr', 'webcam']
    elif dataset == 'office_caltech':
        domains = ['amazon', 'dslr', 'webcam', 'caltech']
    elif dataset == 'dg5':
        domains = ['mnist', 'mnist_m', 'svhn', 'syn', 'usps']
    elif dataset == 'VLCS':
        domains = ['Caltech101', 'LabelMe', 'SUN09', 'VOC2007']
    else:
        print('No such dataset exists!')
    args.domains = domains
    args.img_dataset = {
        'PACS': ['art_painting', 'cartoon', 'photo', 'sketch'],
        'office_home': ['Art', 'Clipart', 'Product', 'Real_World'],
        'domain_net': ['clipart', 'infograph', 'painting', 'quickdraw', 'real', 'sketch'],
        'idomain_net': ['clipart', 'infograph', 'painting', 'quickdraw', 'real', 'sketch'],
        'subdomain_net': ['clipart', 'infograph', 'painting', 'quickdraw', 'real', 'sketch'],
        'cifar10_c_common': ['noise', 'blur', 'weather', 'digital'],
        'cifar10_c_all': ['gaussian_noise', 'shot_noise', 'impulse_noise', 'defocus_blur', 'glass_blur', 'motion_blur', 'zoom_blur', 'snow', 'frost', 'fog', 'brightness', 'contrast', 'elastic_transform', 'pixelate', 'jpeg_compression', 'speckle_noise', 'gaussian_blur', 'spatter', 'saturate'],

        'office': ['amazon', 'dslr', 'webcam'],
        'office_caltech': ['amazon', 'dslr', 'webcam', 'caltech'],
        'dg5': ['mnist', 'mnist_m', 'svhn', 'syn', 'usps'],
        'VLCS': ['Caltech101', 'LabelMe', 'SUN09', 'VOC2007']
    }
    if dataset == 'dg5' or dataset == 'cifar10_c_common' or dataset == 'cifar10_c_all':
        args.input_shape = (3, 32, 32)
        args.num_classes = 10
    else:
        args.input_shape = (3, 224, 224)
        if args.dataset == 'office_home':
            args.num_classes = 65
        elif args.dataset == 'PACS':
            args.num_classes = 7
        elif args.dataset == 'domain_net':
            args.num_classes = 345
        elif args.dataset == 'idomain_net':
            args.num_classes = 100
        elif args.dataset == 'subdomain_net':
            args.num_classes = 10

        elif args.dataset == 'office':
            args.num_classes = 31
        elif args.dataset == 'office_caltech':
            args.num_classes = 10
        elif args.dataset == 'VLCS':
            args.num_classes = 5

    args.proj_dim = {'dg5':128, 'cifar10_c_common':256, 'cifar10_c_all':256, 'PACS':256, 'office':256, 'office_caltech':256, 'subdomain_net':256, 'office_home':512, 'idomain_net':512, 'domain_net':1024}   # project dim for contrastive loss.

    return args



