import argparse
import os

def vit_small_finetune():
    args = argparse.Namespace()

    args.dataset = 'cifar100'
    args.arch = 'vit-small'
    args.resume = None
    args.pretrained_weights = None
    args.evaluate = None
    args.start_epoch = 0
    args.output_dir = './out'
    args.seed = 0

    
    if args.dataset == 'cifar100' or args.dataset == 'cifar10':
        args.epochs = 100
        args.num_workers = 4
        args.prefetch_factor = 2
        args.pin_memory = True
        args.patch_size = 2
        args.input_size = 32
        args.batch_size = 256
        args.data_root = f'/dataset/{args.dataset}'
        args.encoder = 'momentum_encoder'
        args.print_freq = 100
        
    # ---ema----------
    args.model_ema = True
    args.model_ema_decay = 0.99996
    args.model_ema_force_cpu = False

    # Optimizer parameters
    args.opt = 'adamw'
    args.opt_eps = 1e-8
    args.opt_betas = None
    args.clip_grad = None
    args.momentum = 0.9

    # Learning rate schedule parameters
    args.sched = 'cosine'

    if args.dataset=='cifar10':
        args.lr = 1e-4 #1e-4
        args.warmup_lr = 1e-6
        args.warmup_epochs = 5
        args.min_lr = 1e-5
        args.weight_decay = 0.05
        args.init_values = None
        args.layer_decay = 1.0
        args.drop_path = 0.1
    elif args.dataset=='cifar100':
        args.lr = 7.5e-5 #7.5e-5
        args.warmup_lr = 1e-6
        args.warmup_epochs = 5
        args.min_lr = 1e-5
        args.weight_decay = 0.05
        args.init_values = None
        args.layer_decay = 1.0
        args.drop_path = 0.1
    
        
    args.lr_noise = None
    args.lr_noise_pct = 0.67
    args.lr_noise_std = 1.0
    args.decay_epochs = 30
    args.cooldown_epochs = 10
    args.patience_epochs = 10
    args.decay_rate = 0.1

    # Augmentation parameters
    args.color_jitter = 0.4
    args.aa = 'rand-m9-mstd0.5-inc1'
    args.smoothing = 0.1
    args.train_interpolation = 'bicubic'
    args.repeated_aug = True

    # Random Erase params
    args.reprob = 0.25
    args.remode = 'pixel'
    args.recount = 1
    args.resplit = False

    # Mixup params
    args.mixup = 0.8
    args.cutmix = 1.0
    args.cutmix_minmax = None  # float
    args.mixup_prob = 1.0
    args.mixup_switch_prob = 0.5
    args.mixup_mode = 'batch'

    # ----------------#
    args.dist_url = 'tcp://localhost:12614'
    args.dist_backend = 'nccl'
    args.ngpus_per_node = 2
    os.environ["CUDA_VISIBLE_DEVICES"] = "0,1"
    args.world_size = 1
    args.gpu = None

    args.distributed = True

    args.save_freq = 20

    args.rank = 0

    args.exclude_file_list = ['__pycache__', '.vscode', 'log', 'ckpt', '.git','.core' ,'out', 'dataset', 'weight']

    return args