import os
import argparse
import utils.util as util

def get_args():
    parser = argparse.ArgumentParser()

    # Distributed training parameters
    parser.add_argument("--dist_url", default="env://", type=str, help="""url used to set up
        distributed training; see https://pytorch.org/docs/stable/distributed.html""")
    parser.add_argument('--seed', default=42, type=int, help='Random seed.')
    parser.add_argument("--local_rank", default=0, type=int, help="Please ignore and do not set this argument.")
    parser.add_argument('--batch_size_per_gpu', type=int, default=18)

    # Dataset parameters
    # parser.add_argument('--datalist_2d', default='./data/dataset_2D_local.json', type=str, help='json for 2D datasets')
    parser.add_argument('--datalist_2d', default='./data/dataset_2D_remote.json', type=str, help='json for 2D datasets')
    parser.add_argument('--data_pct_2d', default=1.0, type=float, help='Sample ratio for 2D datasets')
    parser.add_argument('--imsize_2d', default=256, type=int, help='Resize H, W')
    parser.add_argument('--max_words_2d', default=112, type=int, help='Max len of words')
    parser.add_argument('--sent_num_2d', default=3, type=int, help='Sent num')
    parser.add_argument('--cropsize_2d', default=224, type=int, help='Augmentation parameters for 2D images')

    # Model parameters
    parser.add_argument('--hidden_dim', default=2048, type=int, help='The hidden dim in the projection head')
    parser.add_argument('--output_dim', default=128, type=int, help='The output dim in the projection head')
    parser.add_argument('--out_dim', default=65536, type=int, help='Dimensionality of output for [CLS] token')
    parser.add_argument('--patch_out_dim', default=65536, type=int, help='Dimensionality of output for patch tokens')
    parser.add_argument('--drop_path', default=0.0, type=int)
    parser.add_argument('--mask_ratio', default=0.75, type=float)   
    parser.add_argument('--mask_type', default='attn', type=str)
    parser.add_argument('--with_distill', default=True, type=bool)

    # Loss parameters
    parser.add_argument('--warmup_teacher_temp', default=0.04, type=float, help='Initial value for the teacher temperature: 0.04 works well in most cases. Try decreasing it if the training loss does not decrease')
    parser.add_argument('--teacher_temp', default=0.04, type=float, help='Final value (after linear warmup) of the teacher temperature. For most experiments, anything above 0.07 is unstable. We recommend starting with the default value of 0.04 and increase this slightly if needed')
    parser.add_argument('--warmup_teacher_patch_temp', default=0.04, type=float, help='See `--warmup_teacher_temp`')
    parser.add_argument('--teacher_patch_temp', default=0.04, type=float, help='See `--teacher_temp`')
    parser.add_argument('--warmup_teacher_temp_epochs', default=0, type=int, help='Number of warmup epochs for the teacher temperature (Default: 20)')
    parser.add_argument('--lambda1', default=1.0, type=float, help='loss weight for dino loss over [CLS] tokens (Default: 1.0)')
    parser.add_argument('--lambda2', default=0.001, type=float, help='loss weight for beit loss over masked patch tokens (Default: 1.0)') 

    # Training/Optimization parameters
    parser.add_argument('--resume', default=None, type=str, help='Path for continue training')
    parser.add_argument('--epochs', default=50, type=int, help='Number of epochs of training')  
    # parser.add_argument('--warmup_epochs', default=20, type=int, help='Number of epochs of warmup training')  

    parser.add_argument('--use_fp16', type=util.bool_flag, default=False, help='Whether or not to use half precision for training') 
    # parser.add_argument("--min_lr", default=1e-6, type=float, help='Learning rate at the end of linear warmup (lowesr LR used during training)') 
    parser.add_argument("--lr", default=2e-5, type=float, help='Learning rate at the end of linear warmup (highest LR used during training)')  
    parser.add_argument("--momentum", default=0.9, type=float, help='momentum for optimizer')
    parser.add_argument("--weight_decay", default=0.05, type=float, help="Initial value of the weight decay.")
    # parser.add_argument("--weight_decay", default=0.04, type=float, help='Initial value of the weight decay')
    # parser.add_argument("--weight_decay_end", default=0.4, type=float, help='Initial value of the weight decay')
    parser.add_argument('--momentum_teacher', default=0.996, type=float, help='Base EMA parameter for teacher update. The value is increased to 1 during training with cosine schedule. We recommend setting a higher value with small batches: for example use 0.9995 with batch size of 256')
    parser.add_argument('--clip_grad', type=float, default=None, help='Maximal parameter gradient norm if using gradient clipping')
    parser.add_argument('--interval', default=1, type=int)
    parser.add_argument('--num_workers', type=int, default=8)
    parser.add_argument('--saveckp_freq', default=5, type=int, help='Save checkpoint every x epochs.')

    # Output parameters
    parser.add_argument('--exp_name', default="UniMedI_2D_ita_masklow_maskprob08_maskratio075_001iia_65536_64", type=str, help='Path to save logs, tensorboard and checkpoints.')
    parser.add_argument('--log_name', default="UniMedI", type=str, help='log name')
    args = parser.parse_args()

    args.experiment_path = os.path.join('/mnt/output/UniMedI_2D_experiments', args.exp_name)
    args.tfboard_path = os.path.join('/mnt/output/UniMedI_2D_experiments', args.exp_name, 'TFBoard')
    # args.experiment_path = os.path.join('./UniMedI_2D_debug', args.exp_name)
    # args.tfboard_path = os.path.join('./UniMedI_2D_debug', args.exp_name, 'TFBoard')
    create_experiment_dir(args)
    return args

def create_experiment_dir(args):
    if not os.path.exists(args.experiment_path):
        os.makedirs(args.experiment_path)
        print('Create experiment path successfully at %s' % args.experiment_path)
    if not os.path.exists(args.tfboard_path):
        os.makedirs(args.tfboard_path)
        print('Create TFBoard path successfully at %s' % args.tfboard_path)