#!/bin/bash

# Set correct CUDA path for DeepSpeed compatibility
export CUDA_HOME=/usr/lib/nvidia-cuda-toolkit
export PATH=$CUDA_HOME/bin:$PATH
export LD_LIBRARY_PATH=$CUDA_HOME/lib64:$LD_LIBRARY_PATH


accelerate launch --config_file deepspeed_stage2.yaml bifrost/train/train_fsdp.py \
    --dataset_list '["InternVL-SA1B-Caption-WebDataset", "LLaVA-ReCap-CC12M"]' \
    --dataset_path_list '["data/InternVL-SA1B-Caption-WebDataset/webdataset", "data/LLaVA-ReCap-CC12M/data"]' \
    --vision_language_model 'Qwen2_5_VLForConditionalGeneration' \
    --vision_language_model_name "Qwen/Qwen2.5-VL-3B-Instruct" \
    --vision_gen_enc 'ShallowUViTEncoder' \
    --vision_gen_dec 'ShallowUViTDecoder' \
    --frozen_modules_in_vlm '["vision_language_model.visual", "vision_language_model.lm_head", "vision_language_model.model.embed_tokens", "vision_language_model.model.layers", "vision_language_model.model.norm"]' \
    --remove_vision_und_encoder False \
    --remove_vae False \
    --frozen_vision_gen_vae False \
    --frozen_vision_gen_encdec False \
    --max_seq_length 320 \
    --num_visual_gen_tokens 256 \
    --t2i_resolution 448 \
    --learning_rate 0.0005 \
    --warmup_steps 5000 \
    --max_steps 100000000 \
    --num_train_epochs 1000 \
    --batch_size_t2i 9 \
    --output_dir $OUTPUT_DIR \
    --save_steps 200 \
    --log_task_specific_loss False \
    --bf16 True \
    --vision_denoising_type 'mar' \
    --add_timestep_token False \
    --tf32 False \
    --vision_gen_tokenizer 'magvitv2' \
    --cond_dropout_prob 0.1 \
    --dataloader_num_workers 9 \
    --add_vision_gen_mask_token False \
    --add_vision_soi_token False \
    --add_vision_soi_eoi_tokens False \
    --fully_trainable False \
    --lambda_gpu True \
    --gradient_accumulation_steps 7 \
    --is_fsdp_enabled False \
    --dataloader_drop_last True \
    --vision_head_type 'linear' \
    --vision_loss_type 'mse' \
    --vision_pos_emb_type 'learnable_pos_emb' \
    --lambda_clip 1.0 \
    --full_vision_mask True \
    --precise_prompt_mask True \
    --skip_text_part2 True \
    --add_vision_branch True \
    --add_vision_branch_reuse_layernorm False \
    --use_discrete_visual_tokenizer False \
    --use_clip_visual_encoder True \
    --use_lora False \
    --lora_r 1024 \
    --lora_alpha 2048 \
    --use_rslora False \
    --use_2d_query_tokens False \
    --e2e_training False \
    --ctrlnet_training False \
    --pretrained_diffusion_decoder_name_or_path "black-forest-labs/FLUX.1-dev" \
    --num_single_layers 4 \
    --num_double_layers 1 \
    --diffusion_decoder_text_dropout_prob 0.0
