import os
import ast
import moxing as mox
import argparse
import logging
import time

os.environ["NCCL_NET_GDR_LEVEL"] = '0'

parser = argparse.ArgumentParser()
parser.add_argument('--model', type=str, default='SMR_B_compress')
parser.add_argument('--clip_path', type=str, default='./')
parser.add_argument('--clip_model', type=str, default='ViT-B/16')
parser.add_argument('--epochs', type=int, default=300, help='world size')
parser.add_argument('--warmup_epochs', type=int, default=20, help='world size')
parser.add_argument('--dpr', type=float, default=0.1, help='world size')
parser.add_argument('--update_freq', type=int, default=1, help='world size')
parser.add_argument('--batch_size', type=int, default=128, help='world size')
parser.add_argument('--input_size', type=int, default=224, help='world size')
parser.add_argument('--second_input_size', type=int, default=224, help='world size')
parser.add_argument('--num_gpus', type=int, default=8, help='the number of gpus')
parser.add_argument('--rank', type=int, default=0, help='node rank')
parser.add_argument('--world_size', type=int, default=1, help='world size')
parser.add_argument('--init_method', type=str, default='tcp://127.0.0.1:6666')

args, unparsed = parser.parse_known_args()

MODEL = args.model
input_size = args.input_size
second_input_size = args.second_input_size
num_mask_patches = int(args.input_size * args.input_size / 14 / 14 * 0.4)  ### 224*224/14/14 * 0.4

batch_size = args.batch_size  # 32(bsz_per_gpu)*8(#gpus_per_node)*8(#nodes)*1(update_freq)=2048(total_bsz)
update_freq = args.update_freq

lr = 1.5e-3
b2 = 0.98
eps = 1e-8

dpr = args.dpr
ls = 0.0

epochs = args.epochs
wmep = args.warmup_epochs
save_ckpt_freq = 20

mixup = 0.0
cj = 0.0

zero_stage = 1

teacher_type = 'clip'
clip_model = args.clip_model

OUTPUT_DIR = f'/cache/output/{MODEL}'

master_addr = ...
master_port = ...

cmd_str = f"python -m torch.distributed.launch --nnodes={args.world_size} \
        --nproc_per_node=8 --node_rank={args.rank} \
        --master_addr={master_addr} --master_port={master_port} \
        --use_env run_pretraining_compress.py \
        --data_path /path/to/imagenet \
        --val_data_path /path/to/imagenet \
        --output_dir {OUTPUT_DIR} \
        --log_dir {OUTPUT_DIR}/tb_log \
        --model {MODEL} \
        --teacher_type {teacher_type} \
        --clip_model {clip_model} \
        --input_size {input_size} --second_input_size {second_input_size} \
        --num_mask_patches {num_mask_patches} \
        --layer_scale_init_value {ls} \
        --batch_size {batch_size} \
        --lr {lr} \
        --opt_betas 0.9 {b2} \
        --opt_eps {eps} \
        --drop_path {dpr} \
        --epochs {epochs} \
        --mixup {mixup} \
        --color_jitter {cj} \
        --warmup_epochs {wmep} \
        --update_freq {update_freq} \
        --weight_decay 0.05 \
        --zero_stage {zero_stage} \
        --save_ckpt_freq {save_ckpt_freq} \
        --stop_grad_conv1 \
        --grad_ckpt \
        --enable_deepspeed "

os.system(cmd_str)

