import argparse
import os


def vit_small_pretrain():
    args = argparse.Namespace()
    args.arch = 'vit-small'
    args.resume = None
    args.dataset = 'imagenet1k'
    args.seed = 0

    if args.dataset == 'cifar10' or args.dataset == 'cifar100':
        args.data_root = f'/dataset/{args.dataset}'
        args.input_size = 32
        args.patch_size = 2
        args.num_workers = 8
        args.prefetch_factor = 2
        args.pin_memory = False
        args.save_freq = 100
        args.epochs = 800
        args.batch_size = 512
        args.warmup_epoch = args.epochs // 8
        args.multi_crop_size = 14
        args.temp = 0.2
        args.warmup_temp = 0.2
        args.warmup_temp_epochs = 30
        
    args.use_save_mem = True
    args.drop_path = 0.1

    # lr params
    args.lr = 5e-4
    args.min_lr = 1e-6
    args.weight_decay = 0.04
    args.weight_decay_end = 0.4
    args.use_wd_cos = True                                     
                                 
    args.use_moco = True
    args.moco_m = 0.99
    args.moco_m_cos = True

    args.print_freq = None

    args.out_dim = 256
    args.hidden_dim = 4096
    args.proj_layer = 3
    args.pred_layer = 2

    args.multi_crop = True

    args.mix_p = 1.0
    args.switch_p = 0.25
    args.mix_n = 4
    args.mix_n2 = 4
    args.smoothing = 0.0
    args.min_crop = 0.35
    args.min_mix_crop = 0.05
    args.global_crop = 0.5
    args.exp_dir = f'./log/pretrain/{args.dataset}/ckpts_{args.arch}_p{args.patch_size}' \
                   f'_moco_{args.use_moco}_mm{args.moco_m}_min_crop{args.min_crop}' \
                   f'_lr{args.lr}_wd{args.weight_decay}' \
                   f'_bs{args.batch_size}_epoch{args.epochs}' \
                   f'_mix-n{args.mix_n}_mix_crop{args.min_mix_crop}_global_crop{args.global_crop}' \
                   f'_mc_n{args.multi_crop}_dp{args.drop_path}_all_patches'
    args.rank = 0
    args.distributed = True
    args.use_mix_precision = True
    args.init_method = 'tcp://localhost:17995'
    args.world_size = 1
    os.environ["CUDA_VISIBLE_DEVICES"] = "0,1,2,3,4,5,6,7"

    args.exclude_file_list = ['__pycache__', '.vscode',
                              'log', 'ckpt', '.git', 'out', 'dataset', 'weight','core','.png']

    return args
