from argparse import ArgumentParser
import torch

global_args = None

def parse_args():
    parser = ArgumentParser()
    parser.add_argument('--root', type=str, default='./data')
    parser.add_argument('--load_dir', type=str, default=None, help='if provided, load model and test')
    parser.add_argument('--load_task_id', type=int, default=None)
    parser.add_argument('--print_filename', type=str, default=None, help="if None, prints on 'result.txt' file")
    parser.add_argument('--dataset', type=str, default='cifar100', choices=['mnist', 'svhn', 'cifar100', 'cifar10', 'timgnet', 'imgnet380'])
    parser.add_argument('--model', type=str, default='derpp', choices=['vitadapter_ewt', 'deitadapter_ewt'])
    parser.add_argument('--noCL', action='store_true')
    parser.add_argument('--task_type', type=str, default='standardCL_randomcls',
                            choices=['cov', 'concept', 'pre-define',
                                     'standardCL_supercls', 'standardCL_randomcls'], help='learning scenarios')
    parser.add_argument('--seed', type=int, default=0)
    parser.add_argument('--init_task', type=int, default=0)
    parser.add_argument('--n_tasks', type=int, default=5)
    parser.add_argument('--validation', type=float, default=None, help='Propertion of dataset used e.g. if set 0.9, 90\% of training data is used for training and rest 10\% is used for validation')
    parser.add_argument('--optim_type', type=str, default='adam', choices=['adam', 'sgd'])
    parser.add_argument('--zero_shot', action='store_true', help='Print zeroshot accuracy')
    parser.add_argument('--n_epochs', type=int, default=1)
    parser.add_argument('--init_epoch', type=int, default=0, help='initial epoch. Epoch starts from init_epoch and finishes at n_epochs-1')
    parser.add_argument('--loss_f', type=str, default='ce', choices=['ce', 'bce', 'nll'])
    parser.add_argument('--revisit', type=int, default=2, help='number of revisits')
    parser.add_argument('--confusion', action='store_true')
    parser.add_argument('--tsne', action='store_true')
    parser.add_argument('--prob', type=float, default=None, help='probability; how many samples of a class are used for revisit')
    parser.add_argument('--coin', type=int, default=None, choices=[0, 1], help='whether a class experiences concept shift or not')
    parser.add_argument('--choose', type=int, default=0, help='whether a class experiences concept shift or not')
    parser.add_argument('--clip_init', action='store_true')
    parser.add_argument('--normalize', action='store_true')
    parser.add_argument('--separate_buffer', action='store_true')
    parser.add_argument('--use_buffer', action='store_true', help='if true, use buffer. Some systems do not use buffer by default. Use it for them.')
    parser.add_argument('--epsilon', type=float, default=None, help='epsilon noise for ODIN')
    parser.add_argument('--T_odin', type=float, default=None, help='temperature scale for ODIN')
    parser.add_argument('--noise_odin', type=float, default=None, help='noise scale for ODIN')
    parser.add_argument('--modify_previous_ood', action='store_true')
    parser.add_argument('--select', action='store_true', help='if true, update only the heads of classes in current batch, and fix other heads')
    parser.add_argument('--choice', default='uniform')
    parser.add_argument('--holdout', type=int, default=None, help='number of holdout samples per class. If None, no calibration')
    parser.add_argument('--modify_alpha', type=float, default=1e-15)
    parser.add_argument('--modify_beta', type=float, default=1e-15)
    parser.add_argument('--save_output', action='store_true')
    parser.add_argument('--save_statistics', action='store_true', help='save parameters of normal distribution when ipca is used')
    parser.add_argument('--save_holdout', action='store_true')
    parser.add_argument('--task_bdry', action='store_true', help='True if task bdry is known during training')
    parser.add_argument('--outlier_exposure', type=str, default='label', choices=['uniform', 'label'])
    parser.add_argument('--output_learning', type=int, default=None, help='number of samples to save to learn outputs')
    parser.add_argument('--n_components', type=int, default=5)
    parser.add_argument('--folder', type=str, default=None, help='directory NAME. e.g. save under ./logs/NAME')
    parser.add_argument('--ff', type=float, default=1.)
    parser.add_argument('--dynamic', type=int, default=None, help='Set the max memory size. If set, use dynamic memory. Only works for PCAs. Use buffer_size for other methods')
    parser.add_argument('--compute_md', action='store_true', help='If true, compute mahalanobis distance of features')
    parser.add_argument('--eval_every', type=int, default=5)
    parser.add_argument('--use_amp', action='store_true')
    parser.add_argument('--resume_id', type=int, default=None, help='resume id. If provided, training begins when task_id == resume_id')
    parser.add_argument('--resume', type=str, default=None, help='resume path')
    parser.add_argument('--train_clf', action='store_true')
    parser.add_argument('--train_ebd', action='store_true')
    parser.add_argument('--obtain_val_outputs', action='store_true')
    parser.add_argument('--obtain_val_outputs_comp', action='store_true')
    parser.add_argument('--test_model_name', type=str, default=None, help='model_task_, model_task_clf_')
    parser.add_argument('--class_order', type=int, default=0, help='class split. Choices=[0, 1, 2]')
    parser.add_argument('--train_clf_save_name', type=str, default='model_task_clf')
    parser.add_argument('--model_copy', action='store_true')
    parser.add_argument('--train_clf_id', type=int, default=None)
    parser.add_argument('--task_inference', type=str, default=None, choices=['entropy'])
    parser.add_argument('--framework_test', action='store_true', help='use for running experiment for "framework" paper')
    parser.add_argument('--finetune_clip', action='store_true')
    parser.add_argument('--prompt', action='store_true')
    parser.add_argument('--load_calibration', action='store_true')
    parser.add_argument('--tr_dynamics', action='store_true')
    parser.add_argument('--rotation', action='store_true')
    parser.add_argument('--recompute_md', action='store_true')
    parser.add_argument('--mean_label_name', type=str, default='mean_label')
    parser.add_argument('--cov_task_name', type=str, default='cov_task')
    parser.add_argument('--mean_task_name', type=str, default='mean_task')
    parser.add_argument('--cov_task_noise_name', type=str, default='cov_task_noise')
    parser.add_argument('--obtain_pseudo_features_id', type=int, default=None)
    parser.add_argument('--save_cil_md', action='store_true')
    parser.add_argument('--n_pre_cls', default=None, type=int)

    # System
    parser.add_argument('--exe', action='store_true', help='If true, use xx samples per class for fast checking code execution')
    parser.add_argument('--exe_n_samples', type=int, default=20, help='Activated if --exe. xx samples for class')

    # Network
    parser.add_argument('--in_dim', type=int, default=512, help='feature size')
    parser.add_argument('--out_dim', type=int, default=1)
    parser.add_argument('--freeze_head', action='store_true', help="If true, don't update classifier")

    # DataLoader
    parser.add_argument('--pin_memory', action='store_false')
    parser.add_argument('--num_workers', type=int, default=15)
    parser.add_argument('--lr', type=float, default=0.01)
    parser.add_argument('--batch_size', type=int, default=16)
    parser.add_argument('--minibatch_size', type=int, default=16)
    parser.add_argument('--test_batch_size', type=int, default=512)

    parser.add_argument('--distillation', action='store_true')
    parser.add_argument('--T', type=float, default=40)
    parser.add_argument('--distill_lambda', type=float, default=0.25)

    # For vitadapter + OOD approaches
    parser.add_argument('--compute_auc', action='store_true')
    parser.add_argument('--calibration', action='store_true')
    parser.add_argument('--use_md', action='store_true', help='use MD value for CIL prediction')
    parser.add_argument('--noise', action='store_true', help='use MD-noise')
    parser.add_argument('--cal_lr', type=float, default=0.01)
    parser.add_argument('--cal_batch_size', type=int, default=8)
    parser.add_argument('--cal_epochs', type=int, default=5)
    parser.add_argument('--cal_size', type=int, default=20, help='number of samples saved per class for calibration')
    parser.add_argument('--momentum', type=float, default=0.9, help='momentum value for sgd')
    parser.add_argument('--adapter_latent', type=int, default=64, help='adapter latent size')
    parser.add_argument('--softmax', action='store_true', help='use softmax for task output before cat for CIL')
    parser.add_argument('--lamb', type=float, default=0.9)

    # For HAT
    parser.add_argument('--smax', type=float, default=500)
    parser.add_argument('--lamb0', type=float, default=0.75)
    parser.add_argument('--lamb1', type=float, default=0.75)
    parser.add_argument('--thres_cosh', type=float, default=50)
    parser.add_argument('--thres_emb', type=float, default=6)

    parser.add_argument('--no_pree', action='store_false')
    parser.add_argument('--train_clf_pree_id', type=int, default=None)
    parser.add_argument('--lr_pree', type=float, default=None)
    parser.add_argument('--n_epochs_clf_pree', type=int, default=None)
    parser.add_argument('--test_pree', action='store_true')
    parser.add_argument('--n_steps_cal_clf_pree', type=int, default=None)
    parser.add_argument('--cal_clf_pree', action='store_true')
    parser.add_argument('--lr_cal_pree', type=float, default=None)

    parser.add_argument('--test_joint', action='store_true')

    parser.add_argument('--train_joint_clf', action='store_true')

    parser.add_argument('--lr_single_head', type=float, default=None)
    parser.add_argument('--generate_ood', action='store_true')
    parser.add_argument('--train_single_head_id', type=int, default=None)
    parser.add_argument('--single_head_model_name', type=str, default='model_single_head')
    parser.add_argument('--test_single_head', action='store_true')
    parser.add_argument('--report_auc_at_each_update', action='store_true')

    parser.add_argument('--test_task_id', type=int, default=None)
    parser.add_argument('--load_path', type=str, default=None)

    parser.add_argument('--use_logit', action='store_true')

    parser.add_argument('--finetune_clf_usingMDStats_id', type=int, default=None)
    parser.add_argument('--finetune_clf_usingMDStats_model_name', type=str, default='model_finetune_clf_usingMDStats')
    parser.add_argument('--n_epochs_finetune_clf_usingMDStats', type=int, default=None)

    parser.add_argument('--use_ood_at_training', action='store_true')

    parser.add_argument('--finetune_clf_using_real_pseudo_features_after_train_id', type=int, default=None)

    parser.add_argument('--train_clf_pree_real_and_ood_id', type=int, default=None)

    args = parser.parse_args()
    if args.dataset == 'mnist':
        args.total_cls = 10
    elif args.dataset == 'svhn':
        args.total_cls = 10
    elif args.dataset == 'cifar10':
        args.total_cls = 10
    elif args.dataset == 'cifar100':
        args.total_cls = 100
    elif args.dataset == 'timgnet':
        args.total_cls = 200
    elif args.dataset == 'imgnet380':
        args.total_cls = 380
    else:
        raise NotImplementedError()

    return args

def run_args():
    device = "cuda" if torch.cuda.is_available() else "cpu"
    global global_args
    if global_args is None:
        global_args = parse_args()
        global_args.device = device

run_args()
