env_defaults:
  NODES: 2
  GPUS: 8
  MEM: 32
  SLA: Premium # Premium, Standard, Basic
  PRIORITY: high

search:
  job_template:
    name: llava_lora_unkVQA_finetune_${NODES}x${GPUS}GPU_bsz{bsz}_gacc{grad_accu}_lr{lr}
    sku: ${NODES}x${MEM}G${GPUS}-IB
    sla_tier: ${SLA} 
    priority: ${PRIORITY}
    # mpi: False
    command:
     - python llava/train/train_mem.py
        --lora_enable True --lora_r 128 --lora_alpha 256 --mm_projector_lr 2e-5 
        --deepspeed ./scripts/zero3.json 
        --model_name_or_path models/lmsys/vicuna-7b-v1.5 
        --version v1 
        --data_path <DATA_FOLDER>unk_v1+gqa+ours_caption_based.json  
        --image_folder <DATA_FOLDER> 
        --vision_tower models/clip/clip-vit-large-patch14-336 
        --pretrain_mm_mlp_adapter models/LLaVA/llava-v1.5-mlp2x-336px-pretrain-vicuna-7b-v1.5/mm_projector.bin 
        --mm_projector_type mlp2x_gelu 
        --mm_vision_select_layer -2 
        --mm_use_im_start_end False 
        --mm_use_im_patch_token False 
        --image_aspect_ratio pad 
        --group_by_modality_length True 
        --output_dir <OUTPUT_FOLDER>/llava/ft-from-pretrained-lora/llava-v1.5-7b-task-lora_unk_v1+gqa+ours_caption_based
        --num_train_epochs {ep} 
        --per_device_train_batch_size {bsz} 
        --per_device_eval_batch_size {bsz} 
        --gradient_accumulation_steps {grad_accu} 
        --evaluation_strategy "no" 
        --save_strategy "steps" 
        --save_steps 5000 
        --save_total_limit 1 
        --learning_rate {lr} 
        --weight_decay 0. 
        --warmup_ratio 0.03 
        --lr_scheduler_type "cosine" 
        --logging_steps 1 
        --tf32 False 
        --fp16 True 
        --bf16 False 
        --model_max_length 2048 
        --gradient_checkpointing True 
        --dataloader_num_workers 4 
        --lazy_preprocess True 
        --report_to wandb
    process_count_per_node: 8
    submit_args: 
      env:
        NCCL_IB_DISABLE: 0
        NCCL_DEBUG: INFO
        NCCL_IB_TIMEOUT: 60
        NCCL_ASYNC_ERROR_HANDLING: 0
        MKL_THREADING_LAYER: GNU
        max_attempts: 1
  type: grid
  max_trials: 1
  params:
    - name: lr
      spec: discrete
      values: [2e-4]
    - name: grad_accu
      spec: discrete
      values: [4]
    - name: bsz
      spec: discrete
      values: [2]
      # values: [1]
    - name: ep
      spec: discrete
      values: [1]
