if __name__ == '__main__':
    print('run.py -- starting run.py', flush=True)

    import os
    import random
    import numpy as np
    import gc
    import argparse
    import torch
    from tqdm import tqdm

    from config import configs
    from dataset.common import get_dataset
    from model.net import *
    import model.vgg as vgg
    import model.resnet as resnet
    import model.efficientnet as efficientnet
    import model.mobilenet as mobilenet
    import model.wideresnet as wideresnet
    import model.preact_resnet as preact_resnet
    import model.preact_wideresnet as preact_wideresnet

    from model.parallel import ParallelWrapper

    from al.strategy import set_seed
    from al import *

    import wandb

    PROJECT_NAME = 'topk-al-iclr25'

    print('run.py -- import done', flush=True)

    def main(cmd_args: argparse.Namespace):
        set_seed(cmd_args.seed)
        # torch.multiprocessing.set_start_method('spawn', force=True)
        
        dataset_name = cmd_args.dataset
        initial_budget = cmd_args.initial_budget
        budget = cmd_args.budget

        if cmd_args.no_aug:
            configs['CIFAR10']['transform'] =  configs['CIFAR10']['transformTest'] # remove data augmentation
            configs['CIFAR100']['transform'] =  configs['CIFAR100']['transformTest'] # remove data augmentation
            configs['CUB']['transform'] =  configs['CUB']['transformTest'] # remove data augmentation
            configs['TINYIMAGENET']['transform'] =  configs['TINYIMAGENET']['transformTest']
            configs['IMAGENET32']['transform'] =  configs['IMAGENET32']['transformTest']
            configs['IMAGENET64']['transform'] =  configs['IMAGENET64']['transformTest']
            configs['IMAGENET64S']['transform'] =  configs['IMAGENET64S']['transformTest']

        configs['MNIST']['transformTest'] = configs['MNIST']['transform']
        configs['FashionMNIST']['transformTest'] = configs['FashionMNIST']['transform']
        configs['SVHN']['transformTest'] = configs['SVHN']['transform']

        if cmd_args.custom_config:
            args = configs[cmd_args.custom_config]
        else: 
            args = configs[cmd_args.dataset]

        args['cmd_args'] = cmd_args
        args['budget'] = budget
        args['seed'] = cmd_args.seed

        train_dataset, train_raw_dataset, test_dataset = get_dataset(
            dataset_name, 
            train_transform=args['transform'], 
            test_transform=args['transformTest'],
            download=cmd_args.download
        )

        args['modelType'] = cmd_args.model
        args['device'] = 'cuda' if cmd_args.cuda and torch.cuda.is_available() else 'cpu'
        args['wandb'] = cmd_args.wandb
        args['rho'] = cmd_args.rho
        args['calibration'] = cmd_args.calibration
        args['alpha'] = cmd_args.alpha
        args['d'] = cmd_args.d

        if cmd_args.batch_size > 0: 
            args['loader_tr_args']['batch_size'] = cmd_args.batch_size
            args['loader_te_args']['batch_size'] = cmd_args.batch_size
        
        if cmd_args.num_workers > 0:
            args['loader_tr_args']['num_workers'] = cmd_args.num_workers
            args['loader_te_args']['num_workers'] = cmd_args.num_workers

        if cmd_args.lr > 0: 
            args['lr'] = cmd_args.lr
        
        if len(cmd_args.milestones) > 0:
            args['milestones'] = cmd_args.milestones
        
        if cmd_args.gamma > 0:
            args['gamma'] = cmd_args.gamma
        
        if cmd_args.weight_decay > 0:
            args['weight_decay'] = cmd_args.weight_decay

        if hasattr(cmd_args, 'k'):
            args['k'] = cmd_args.k
            if args['k'] >= 1: 
                args['k'] = int(args['k'])
        if cmd_args.n_epochs > 0:
            args['n_epochs'] = cmd_args.n_epochs
        
        args['calibration_set_size'] = cmd_args.calibration_set_size
        args['no_lr_modif'] = cmd_args.no_lr_modif
        args['adaptive_alpha_mode'] = cmd_args.adaptive_alpha_mode
        
        args['optimizer'] = cmd_args.optimizer
        args['scheduler'] = cmd_args.scheduler

        args['sync_bn'] = cmd_args.sync_bn
        args['port'] = cmd_args.port

        args['slurm_job_id'] = os.environ.get('SLURM_JOB_ID')
        args['slurm'] = args['slurm_job_id'] is not None

        args['warmup_epochs'] = cmd_args.warmup_epochs
        args['warmup_lr'] = cmd_args.warmup_lr
        args['push_warmup'] = cmd_args.push_warmup
        
        print(args, flush=True)

        # start experiment
        n_pool = len(train_dataset)
        n_test = len(test_dataset)

        # generate initial labeled pool
        data = {}
        if cmd_args.start_round > 1:
            if not cmd_args.load_idxs_lb:
                raise ValueError('load_idxs_lb should be specified for starting from a specific round')
            if not os.path.exists(cmd_args.load_idxs_lb):
                raise FileNotFoundError(f'{cmd_args.load_idxs_lb} not found for loading idxs_lb')
            
            # ** idxs_lb is stored after query, right before training. So not necessary to query again. **
            loaded_data = np.load(cmd_args.load_idxs_lb)
            idxs_lb = loaded_data['idxs_lb']
            data['cur_k'] = loaded_data['cur_k']
            if 'epsilon' in loaded_data: 
                data['epsilon'] = float(loaded_data['epsilon'].item())
            if 'adaptive_epsilon' in loaded_data:
                data['adaptive_epsilon'] = bool(loaded_data['adaptive_epsilon'].item())
        else: 
            idxs_lb = np.zeros(n_pool, dtype=bool)

        # load specified network
        if cmd_args.model == 'mlp':
            raise NotImplementedError('MLP is not supported yet (dim issue)')
            net = MLPModel(input_dim=np.shape(X_tr)[1:], embedding_size=cmd_args.embedding_dim, n_classes=args['nClasses'])
        elif cmd_args.model == 'resnet18':
            net = resnet.ResNet18(n_classes=args['nClasses'])
        elif cmd_args.model == 'resnet34':
            net = resnet.ResNet34(n_classes=args['nClasses'])
        elif cmd_args.model == 'resnet50':
            net = resnet.ResNet50(n_classes=args['nClasses'])
        elif cmd_args.model == 'resnet101':
            net = resnet.ResNet101(n_classes=args['nClasses'])
        elif cmd_args.model == 'efficientnet': 
            net = efficientnet.EfficientNetV2s(n_classes=args['nClasses'])
        elif cmd_args.model == 'mobilenet': 
            net = mobilenet.MobileNetV3s(n_classes=args['nClasses'])
        elif cmd_args.model == 'wrn-28-5': # WRN-28-5
            net = wideresnet.wideresnet(depth=28, widen_factor=2, n_classes=args['nClasses'])
        elif cmd_args.model == 'wrn-36-2': # WRN-36-2
            net = wideresnet.wideresnet(depth=36, widen_factor=2, n_classes=args['nClasses'])
        elif cmd_args.model == 'wrn-36-5': # WRN-36-5
            net = wideresnet.wideresnet(depth=36, widen_factor=5, n_classes=args['nClasses'])
        elif cmd_args.model == 'vgg':
            net = vgg.VGG('VGG16', n_classes=args['nClasses'])
        elif cmd_args.model == 'lin':
            raise NotImplementedError('Linear model is not supported yet (dim issue)')
            dim = np.prod(list(X_tr.shape[1:]))
            net = LinearModel(dim, args['nClasses'])
        elif cmd_args.model == 'preactresnet18':
            net = preact_resnet.PreActResNet18(n_classes=args['nClasses'])
        elif cmd_args.model == 'preactwideresnet18':
            net = preact_wideresnet.PreActWideResNet18(widen_factor=3, num_classes=args['nClasses'])
        else: 
            print('choose a valid model - mlp, resnet, or vgg', flush=True)
            raise ValueError

        net = net.to(args['device'])

        # set up the specified sampler
        if cmd_args.strategy == 'rand': # random sampling
            strategy = RandomSampling(train_dataset, train_raw_dataset, test_dataset, idxs_lb, net, args)
        elif cmd_args.strategy == 'conf': # confidence-based sampling
            strategy = LeastConfidence(train_dataset, train_raw_dataset, test_dataset, idxs_lb, net, args)
        elif cmd_args.strategy == 'marg': # margin-based sampling
            strategy = MarginSampling(train_dataset, train_raw_dataset, test_dataset, idxs_lb, net, args)
        elif cmd_args.strategy == 'badge': # batch active learning by diverse gradient embeddings
            strategy = BadgeSampling(train_dataset, train_raw_dataset, test_dataset, idxs_lb, net, args)
        elif cmd_args.strategy == 'coreset': # coreset sampling
            strategy = CoreSet(train_dataset, train_raw_dataset, test_dataset, idxs_lb, net, args)
        elif cmd_args.strategy == 'entropy': # entropy-based sampling
            strategy = EntropySampling(train_dataset, train_raw_dataset, test_dataset, idxs_lb, net, args)
        elif cmd_args.strategy == 'baseline': # badge but with k-DPP sampling instead of k-means++
            strategy = BaselineSampling(train_dataset, train_raw_dataset, test_dataset, idxs_lb, net, args)
        elif cmd_args.strategy == 'saal': 
            strategy = SAALSampling(train_dataset, train_raw_dataset, test_dataset, idxs_lb, net, args)
        # elif cmd_args.strategy == 'topk_random': # top-k random sampling
        #     strategy = TopKRandomTrainerNaive(train_dataset, train_raw_dataset, test_dataset, idxs_lb, net, args)
        # elif cmd_args.strategy == 'topk_random_conf':
        #     strategy = TopKRandomTrainerConf(train_dataset, train_raw_dataset, test_dataset, idxs_lb, net, args)
        # elif cmd_args.strategy == 'topk_random_proden':
        #     strategy = TopKRandomTrainerPRODEN(train_dataset, train_raw_dataset, test_dataset, idxs_lb, net, args)
        # elif cmd_args.strategy == 'topk_random_nproden':
        #     strategy = TopKRandomTrainerNegativePRODEN(train_dataset, train_raw_dataset, test_dataset, idxs_lb, net, args)
        # elif cmd_args.strategy == 'topk_entropy': # top-k entropy sampling
        #     strategy = TopKEntropyTrainerNaive(train_dataset, train_raw_dataset, test_dataset, idxs_lb, net, args)
        # elif cmd_args.strategy == 'topk_entropy_conf':
        #     strategy = TopKEntropyTrainerConf(train_dataset, train_raw_dataset, test_dataset, idxs_lb, net, args)
        # elif cmd_args.strategy == 'topk_entropy_proden': 
        #     strategy = TopKEntropyTrainerPRODEN(train_dataset, train_raw_dataset, test_dataset, idxs_lb, net, args)
        # elif cmd_args.strategy == 'topk_entropy_nproden': 
        #     strategy = TopKEntropyTrainerNegativePRODEN(train_dataset, train_raw_dataset, test_dataset, idxs_lb, net, args)
        # elif cmd_args.strategy == 'topk_wentropy': # top-k weighted entropy sampling
        #     strategy = TopKWeightedEntropyTrainerNaive(train_dataset, train_raw_dataset, test_dataset, idxs_lb, net, args)
        # elif cmd_args.strategy == 'topk_wentropy_conf':
        #     strategy = TopKWeightedEntropyTrainerConf(train_dataset, train_raw_dataset, test_dataset, idxs_lb, net, args)
        # elif cmd_args.strategy == 'topk_wentropy_proden':
        #     strategy = TopKWeightedEntropyTrainerPRODEN(train_dataset, train_raw_dataset, test_dataset, idxs_lb, net, args)
        # elif cmd_args.strategy == 'topk_wentropy_nproden':
        #     strategy = TopKWeightedEntropyTrainerNegativePRODEN(train_dataset, train_raw_dataset, test_dataset, idxs_lb, net, args)
        # elif cmd_args.strategy == 'topk_badge': # top-k badge sampling
        #     strategy = TopKBadgeTrainerNaive(train_dataset, train_raw_dataset, test_dataset, idxs_lb, net, args)
        # elif cmd_args.strategy == 'topk_badge_proden':
        #     strategy = TopKBadgeTrainerPRODEN(train_dataset, train_raw_dataset, test_dataset, idxs_lb, net, args)
        # elif cmd_args.strategy == 'topk_badge_nproden':
        #     strategy = TopKBadgeTrainerNegativePRODEN(train_dataset, train_raw_dataset, test_dataset, idxs_lb, net, args)
        # elif cmd_args.strategy == 'topk_wbadge': # top-k weighted badge sampling
        #     strategy = TopKWeightedBadgeTrainerNaive(train_dataset, train_raw_dataset, test_dataset, idxs_lb, net, args)
        # elif cmd_args.strategy == 'topk_wbadge_proden':
        #     strategy = TopKWeightedBadgeTrainerPRODEN(train_dataset, train_raw_dataset, test_dataset, idxs_lb, net, args)
        # elif cmd_args.strategy == 'topk_wbadge_nproden':
        #     strategy = TopKWeightedBadgeTrainerNegativePRODEN(train_dataset, train_raw_dataset, test_dataset, idxs_lb, net, args)
        elif cmd_args.strategy == 'dtopk_random':
            strategy = DynamicTopKStrategyRandom(train_dataset, train_raw_dataset, test_dataset, idxs_lb, net, args)
        elif cmd_args.strategy == 'dtopk_entropy':
            strategy = DynamicTopKStrategyEntropy(train_dataset, train_raw_dataset, test_dataset, idxs_lb, net, args)
        elif cmd_args.strategy == 'dtopk_coreset':
            strategy = DynamicTopKStrategyCoreset(train_dataset, train_raw_dataset, test_dataset, idxs_lb, net, args)
        elif cmd_args.strategy == 'dtopk_badge':
            strategy = DynamicTopKStrategyBadge(train_dataset, train_raw_dataset, test_dataset, idxs_lb, net, args)
        ################################################################################################
        elif cmd_args.strategy == 'dtopk_random_conf':
            strategy = DynamicTopKStrategyRandomConf(train_dataset, train_raw_dataset, test_dataset, idxs_lb, net, args)
        elif cmd_args.strategy == 'dtopk_entropy_conf':
            strategy = DynamicTopKStrategyEntropyConf(train_dataset, train_raw_dataset, test_dataset, idxs_lb, net, args)
        elif cmd_args.strategy == 'dtopk_cost_conf':
            strategy = DynamicTopKStrategyCostConf(train_dataset, train_raw_dataset, test_dataset, idxs_lb, net, args)
        elif cmd_args.strategy == 'dtopk_coreset_conf':
            strategy = DynamicTopKStrategyCoresetConf(train_dataset, train_raw_dataset, test_dataset, idxs_lb, net, args)
        elif cmd_args.strategy == 'dtopk_badge_conf':
            strategy = DynamicTopKStrategyBadgeConf(train_dataset, train_raw_dataset, test_dataset, idxs_lb, net, args)
        elif cmd_args.strategy == 'dtopk_hybrid_entropy_conf':
            strategy = DynamicTopKStrategyHybridEntropyConf(train_dataset, train_raw_dataset, test_dataset, idxs_lb, net, args)
        elif cmd_args.strategy == 'dtopk_hybrid_badge_conf':
            strategy = DynamicTopKStrategyHybridBadgeConf(train_dataset, train_raw_dataset, test_dataset, idxs_lb, net, args)
        ################################################################################################
        elif cmd_args.strategy == 'ubdtopk_random':
            strategy = UBDynamicTopKStrategyRandom(train_dataset, train_raw_dataset, test_dataset, idxs_lb, net, args)
        elif cmd_args.strategy == 'ubdtopk_entropy':
            strategy = UBDynamicTopKStrategyEntropy(train_dataset, train_raw_dataset, test_dataset, idxs_lb, net, args)
        elif cmd_args.strategy == 'ubdtopk_badge':
            strategy = UBDynamicTopKStrategyBadge(train_dataset, train_raw_dataset, test_dataset, idxs_lb, net, args)
        elif cmd_args.strategy == 'dtopk_entropy_conf_adap_log':
            strategy = DynamicTopKStrategyEntropyConfAdapLog(train_dataset, train_raw_dataset, test_dataset, idxs_lb, net, args)
        ################################################################################################
        elif cmd_args.strategy == 'dtopk_entropy_wall':
            strategy = DynamicTopKStrategyEntropyWall(train_dataset, train_raw_dataset, test_dataset, idxs_lb, net, args)
        elif cmd_args.strategy == 'dtopk_badge_wall':
            strategy = DynamicTopKStrategyBadgeWall(train_dataset, train_raw_dataset, test_dataset, idxs_lb, net, args)
        elif cmd_args.strategy == 'dtopk_random_conf_wall':
            strategy = DynamicTopKStrategyRandomConfWall(train_dataset, train_raw_dataset, test_dataset, idxs_lb, net, args)
        elif cmd_args.strategy == 'dtopk_hybrid_entropy_conf_adap_wall':
            strategy = DynamicTopKStrategyHybridEntropyConfAdapWall(train_dataset, train_raw_dataset, test_dataset, idxs_lb, net, args)
        else: 
            raise ValueError('Invalid strategy name')

        if 'epsilon' in data:
            strategy.epsilon = data['epsilon']
        if 'adaptive_epsilon' in data:
            strategy.adaptive_epsilon = data['adaptive_epsilon']
        
        if cmd_args.load_net: 
            if not os.path.exists(cmd_args.load_net):
                raise FileNotFoundError(f'{cmd_args.load_net} not found for loading net')
            net.load_state_dict(torch.load(cmd_args.load_net))
            print(f'net loaded from {cmd_args.load_net}', flush=True)

        if cmd_args.wandb:
            project_name_suffix = f'-{cmd_args.project_suffix}' if len(cmd_args.project_suffix) > 0 else ''
            exp_name_prefix = f'{cmd_args.exp_note}-' if len(cmd_args.exp_note) > 0 else ''
            if cmd_args.start_round > 1:
                exp_name_prefix = 'cont-' + exp_name_prefix
            full_exp_name =f'{exp_name_prefix}{cmd_args.strategy}-{cmd_args.dataset}-{cmd_args.model}-B{cmd_args.budget}'
            logger_id = None if not args['slurm'] else f'{args["slurm_job_id"]}-{full_exp_name}'
            run = wandb.init(
                project=f'{PROJECT_NAME}{project_name_suffix}',
                id=logger_id,
                name=full_exp_name,
                config={
                    'args': args,
                    'cmd_args': vars(cmd_args)
                }
            )
            strategy.wandb_run = run

        # print info
        print(dataset_name, flush=True)
        print(type(strategy).__name__, flush=True)

        acc = np.zeros(cmd_args.n_rounds + 1)
        top5_acc = np.zeros(cmd_args.n_rounds + 1)
        for round in range(cmd_args.start_round, cmd_args.n_rounds + 1):
            log_prefix = f'round{round}/'
            strategy.set_log_prefix(log_prefix)

            if args['device'] == 'cuda': 
                torch.cuda.empty_cache()
            gc.collect()

            if cmd_args.start_round > 1 and round == cmd_args.start_round:
                # it has already been sampled, but yet to be trained
                pass
            else: 
                if round == 1:
                    idxs_tmp = np.arange(n_pool)
                    np.random.shuffle(idxs_tmp)
                    q_idxs = idxs_tmp[:initial_budget]
                else: 
                    # query
                    output = strategy.query(budget)
                    q_idxs = output
            
                # update labeled pool
                strategy.update(q_idxs)
                print(f'round{round}\tlabeled: {sum(strategy.idxs_lb)}\tunlabeled: {sum(~strategy.idxs_lb)}', flush=True)

                # save idxs_lb
                if len(cmd_args.save_idxs_lb) > 0:
                    if not os.path.exists(cmd_args.save_idxs_lb):
                        os.makedirs(cmd_args.save_idxs_lb)
                    idxs_lb_store_path = os.path.join(cmd_args.save_idxs_lb, f'round{round}.npz')
                    if isinstance(strategy, DynamicTopKStrategyConfBase):
                        np.savez(idxs_lb_store_path, idxs_lb=strategy.idxs_lb, cur_k=strategy.cur_k, epsilon=strategy.epsilon, adaptive_epsilon=strategy.adaptive_epsilon)
                    else: 
                        np.savez(idxs_lb_store_path, idxs_lb=strategy.idxs_lb, cur_k=strategy.cur_k)
                    print(f'AL properties saved at {idxs_lb_store_path}', flush=True)

            ''' for debugging '''
            if round == 1 and cmd_args.second_round_debug:
                orig_nepochs = strategy.args['n_epochs']
                strategy.args['n_epochs'] = 2
                val_acc, val_acc_top5 = strategy.train()
                strategy.args['n_epochs'] = orig_nepochs
            else:
                val_acc, val_acc_top5 = strategy.train()

            acc[round] = val_acc
            top5_acc[round] = val_acc_top5
            print(f'round{round}' + '\t' + str(sum(strategy.idxs_lb)) + '\t' + 'testing accuracy {}'.format(acc[round]), flush=True)
            
            if len(cmd_args.save_idxs_lb) > 0:
                if not os.path.exists(cmd_args.save_idxs_lb):
                    os.makedirs(cmd_args.save_idxs_lb)
                model_store_path = os.path.join(cmd_args.save_idxs_lb, f'round{round}_model.pth')
                torch.save(strategy.net.state_dict(), model_store_path)
                print(f'model saved at {model_store_path}', flush=True)

            if cmd_args.wandb: 
                if round == 1: 
                    strategy.wandb_run.define_metric('round')
                    strategy.wandb_run.define_metric('labeled_samples', step_metric='round')
                    strategy.wandb_run.define_metric('unlabeled_samples', step_metric='round')
                    strategy.wandb_run.define_metric('final_accuracy', step_metric='round')
                    strategy.wandb_run.define_metric('final_accuracy_top5', step_metric='round')

                    if 'topk' in cmd_args.strategy: 
                        strategy.wandb_run.define_metric('certain_ratio', step_metric='round')
                        strategy.wandb_run.define_metric('ambiguous_ratio', step_metric='round')
                strategy.wandb_run.log({
                    'round': round,
                    'labeled_samples': sum(strategy.idxs_lb),
                    'unlabeled_samples': sum(~strategy.idxs_lb)
                })
                    
                strategy.wandb_run.define_metric(log_prefix + 'loss')
                strategy.wandb_run.define_metric(log_prefix + 'epoch')
                strategy.wandb_run.define_metric(log_prefix + 'train_acc', step_metric=log_prefix + 'epoch')
                strategy.wandb_run.define_metric(log_prefix + 'train_loss_avg', step_metric=log_prefix + 'epoch')
                strategy.wandb_run.define_metric(log_prefix + 'val_acc', step_metric=log_prefix + 'epoch')
                strategy.wandb_run.define_metric(log_prefix + 'val_acc_top5', step_metric=log_prefix + 'epoch')
                # strategy.wandb_run.define_metric(log_prefix + 'val_loss_avg', step_metric=log_prefix + 'epoch')

                '''
                iamge_indices = q_idxs[:50]
                images_arr = [train_raw_dataset[i][0] for i in iamge_indices]
                if dataset_name == 'SVHN':
                    images_arr = np.transpose(images_arr, (0, 2, 3, 1))
                strategy.wandb_run.log({
                    log_prefix + 'recently_labeled_images': [wandb.Image(img) for img in images_arr]
                })
                '''

                strategy.wandb_run.log({
                    'final_accuracy': acc[round], 
                    'final_accuracy_top5': top5_acc[round]
                })
                
            if strategy.stop_condition(): break
            
        if cmd_args.wandb: 
            strategy.wandb_run.finish()

    ########### START FROM HERE ###########
    strategies = [
        'rand',
        'conf',
        'marg',
        'badge',
        'coreset',
        'entropy',
        'baseline',
        'saal',
        'topk_random',
        'topk_random_conf',
        'topk_random_proden',
        'topk_random_nproden',
        'topk_entropy',
        'topk_entropy_conf',
        'topk_entropy_proden',
        'topk_entropy_nproden',
        'topk_wentropy',
        'topk_wentropy_conf',
        'topk_wentropy_proden',
        'topk_wentropy_nproden',
        'topk_badge',
        'topk_badge_proden',
        'topk_badge_nproden',
        'topk_wbadge',
        'topk_wbadge_proden',
        'topk_wbadge_nproden'
    ] + [
        'dtopk', 
        'dtopk_random',
        'dtopk_entropy',
        'dtopk_coreset',
        'dtopk_badge',
        'dtopk_entropy_predk',
        'dtopk_entropy_predk_simple',
        'dtopk_random_conf',
        'dtopk_entropy_conf',
        'dtopk_coreset_conf',
        'dtopk_badge_conf',
        'dtopk_cost_conf',
        'dtopk_hybrid_entropy_conf',
        'dtopk_hybrid_badge_conf',
    ] + [
        'ubdtopk_random',
        'ubdtopk_entropy',
        'ubdtopk_badge',
    ] + [
        'dtopk_entropy_conf_adap_log'
    ] + [
        'dtopk_entropy_wall', 
        'dtopk_badge_wall', 
        'dtopk_random_conf_wall',
        'dtopk_hybrid_entropy_conf_adap_wall'
    ]
    
    parser = argparse.ArgumentParser()
    parser.add_argument('--strategy', type=str, required=True, choices=strategies, help='selection strategy for AL')
    parser.add_argument('--lr', type=float, default=0, help='learning rate')
    parser.add_argument('--weight_decay', type=float, default=0, help='weight decay')
    parser.add_argument('--model', help='model to train', type=str, default='mlp')
    parser.add_argument('--dataset', type=str.upper, required=True, choices=['MNIST', 'FASHIONMNIST', 'SVHN', 'CIFAR10', 'CIFAR100', 'CUB', 'TINYIMAGENET', 'IMAGENET32', 'IMAGENET64'], help='dataset name')
    parser.add_argument('--download', action='store_true', help='download dataset')
    parser.add_argument('--start_round', type=int, default=1, help='starting round')
    parser.add_argument('--n_rounds', type=int, default=500, help='number of rounds')
    parser.add_argument('--initial_budget', type=int, default=100, help='initial budget')
    parser.add_argument('--budget', help='number of points to query in a batch', type=int, default=0)
    parser.add_argument('--embedding_dim', help='number of embedding dims (mlp)', type=int, default=128)
    parser.add_argument('--no_aug', action='store_true', help='no data augmentation for CIFAR10')
    parser.add_argument('--wandb', action='store_true', help='use wandb')
    parser.add_argument('--k', type=float, default=3, help='k value for top-k sampling')
    parser.add_argument('--n_epochs', type=int, default=0, help='number of epochs')
    parser.add_argument('--rho', type=float, default=0.05, help='norm restriction for SAAL')
    parser.add_argument('--calibration', type=int, default=0, help='number of samples for calibration')
    parser.add_argument('--alpha', type=float, default=0.0, help='alpha for mix-up beta distribution')
    parser.add_argument('--d', type=float, default=0.0, help='hyperparameter for hybrid sampling')
    parser.add_argument('--cuda', action='store_true', help='use cuda')
    parser.add_argument('--batch_size', type=int, default=0, help='batch size')
    parser.add_argument('--num_workers', type=int, default=0, help='number of workers')
    parser.add_argument('--seed', type=int, default=42, help='random seed')
    parser.add_argument('--milestones', type=int, nargs='*', default=[], help='milestones for learning rate scheduler')
    parser.add_argument('--gamma', type=float, default=0, help='gamma for learning rate scheduler')
    parser.add_argument('--second_round_debug', action='store_true', help='use wandb')
    parser.add_argument('--project_suffix', type=str, default='', help='wandb project name suffix')
    parser.add_argument('--exp_note', type=str, default='', help='wandb experiment identifier')
    parser.add_argument('--estimated_cost', action='store_true', help='use estimated cost for cost efficient sampling')
    parser.add_argument('--optimizer', type=str, default='AdamW', choices=['SGD', 'AdamW'], help='specify optimizer')
    parser.add_argument('--scheduler', type=str, default='MultiStepLR', choices=['MultiStepLR', 'CosineAnnealingLR'], help='specify scheduler')
    parser.add_argument('--calibration_set_size', type=int, default=0, help='size of calibration set for conformal prediction. If zero, it uses full val set as calibration set. Otherwise, use this number of samples from most recently labeled ones on each round for calibration')
    parser.add_argument('--no_lr_modif', action='store_true', help='no lr modification for DDP')
    parser.add_argument('--adaptive_alpha_mode', type=str, choices=['base', 'div', 'mul'], default='base', help='adaptive alpha mode among base, div, mul') # Only for Adaptive Alpha
    parser.add_argument('--custom_config', type=str.upper, default='', help='custom config for dataset')
    parser.add_argument('--sync_bn', action='store_true', help='use sync batch norm')
    parser.add_argument('--port', type=int, default=3000, help='port number for ddp')

    parser.add_argument('--load_idxs_lb', type=str, default='', help='path to load idxs_lb')
    parser.add_argument('--load_net', type=str, default='', help='path to load net')
    parser.add_argument('--save_idxs_lb', type=str, default='', help='path to save idxs_lb')

    ## LR warmup
    parser.add_argument('--warmup_epochs', type=int, default=0, help='number of warmup epochs')
    parser.add_argument('--warmup_lr', type=float, default=0, help='warmup learning rate (at the beginning of training)')
    parser.add_argument('--push_warmup', action='store_true', help='push warmup to the beginning of training (n_epochs is increased by warmup_epochs)')

    cmd_args = parser.parse_args()

    print('::::: PID :::::', os.getpid(), flush=True)
    print(cmd_args, flush=True)
    main(cmd_args)
