# # #! /bin/bash

# # export SCRIPT_DIR="$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )"
# # export PROJECT_DIR="$( cd -- "$( dirname -- "$SCRIPT_DIR" )" &> /dev/null && pwd )"
# # cd $PROJECT_DIR
# # export PYTHONPATH="$PYTHONPATH:$PROJECT_DIR"
# # export LIBTPU_INIT_ARGS="--xla_tpu_megacore_fusion_allow_ags=false --xla_enable_async_collective_permute=true --xla_tpu_enable_ag_backward_pipelining=true --xla_tpu_enable_data_parallel_all_reduce_opt=true --xla_tpu_data_parallel_opt_different_sized_ops=true --xla_tpu_enable_async_collective_fusion=true --xla_tpu_enable_async_collective_fusion_multiple_steps=true --xla_tpu_overlap_compute_collective_tc=true --xla_enable_async_all_gather=true"

# # export llama_tokenizer_path="llava/checkpoints/lwm_checkpoints/tokenizer.model"
# # export dataset_path="data/cross_attn_nsvq_code32_real_2M_train.jsonl"
# # export eval_dataset_path="data/cross_attn_nsvq_code32_real_2M_val.jsonl"

# # export output_dir="llava/checkpoints/debug"

# # export project_id='lwm'
# # export experiment_note='world-model'
# # export experiment_id='cross_attn_nsvq_code32_real_2M_val_no_mix_batch128'

# # # mesh_dim: dp, fsdp, tp, sp
# # CUDA_VISIBLE_DEVICES=0,1,2,3 python3 -u -m lwm.train \
# #     --modality='vision,text,delta' \
# #     --mesh_dim='!-1,4,1,1' \
# #     --dtype='bf16' \
# #     --total_steps=13605 \
# #     --log_freq=1 \
# #     --delta_tokens=1 \
# #     --eval_steps=1 \
# #     --save_model_freq=13600 \
# #     --eval_log_freq=100 \
# #     --save_milestone_freq=13600 \
# #     --load_llama_config='7b' \
# #     --load_checkpoint="params::/home/World-Model/llava/checkpoints/lwm_checkpoints/params"\
# #     --update_llama_config="dict(delta_vocab_size=32,theta=50000000,max_sequence_length=2048,use_flash_attention=True,scan_attention=True,scan_query_chunk_size=512,scan_key_chunk_size=1024,remat_attention='nothing_saveable',scan_mlp=True,scan_mlp_chunk_size=8192,remat_mlp='nothing_saveable',remat_block='nothing_saveable',scan_layers=True)" \
# #     --tokenizer.vocab_file="$llama_tokenizer_path" \
# #     --optimizer.type='adamw' \
# #     --llama.delta_vocab_size=32 \
# #     --optimizer.accumulate_gradient_steps=1 \
# #     --optimizer.adamw_optimizer.weight_decay=0 \
# #     --optimizer.adamw_optimizer.lr=2e-5 \
# #     --optimizer.adamw_optimizer.end_lr=0 \
# #     --optimizer.adamw_optimizer.lr_warmup_steps=1000 \
# #     --optimizer.adamw_optimizer.lr_decay_steps=13605 \
# #     --use_data_sharded_loader=True \
# #     --train_dataset.type='json_vision_delta' \
# #     --train_dataset.delta_vision_text_processor.fields_from_example='fields' \
# #     --train_dataset.delta_vision_text_processor.n_tokens_per_delta=1 \
# #     --train_dataset.delta_vision_text_processor.max_n_frames=1 \
# #     --train_dataset.json_delta_dataset.mode="pad" \
# #     --train_dataset.json_delta_dataset.path="$dataset_path" \
# #     --train_dataset.json_delta_dataset.seq_length=384 \
# #     --train_dataset.json_delta_dataset.batch_size=128 \
# #     --train_dataset.json_delta_dataset.tokenizer_processes=1 \
# #     --train_dataset.json_delta_dataset.tokenizer_parallel_chunk_size=128 \
# #     --train_dataset.json_delta_dataset.tokenizer_parallel_batch_size=128 \
# #     --train_dataset.json_delta_dataset.use_data_sharded_loader=True \
# #     --eval_dataset.type='json_vision_delta' \
# #     --eval_dataset.delta_vision_text_processor.fields_from_example='fields' \
# #     --eval_dataset.delta_vision_text_processor.n_tokens_per_delta=1 \
# #     --eval_dataset.delta_vision_text_processor.max_n_frames=1 \
# #     --eval_dataset.json_delta_dataset.mode="pad" \
# #     --eval_dataset.json_delta_dataset.path="$eval_dataset_path" \
# #     --eval_dataset.json_delta_dataset.seq_length=384 \
# #     --eval_dataset.json_delta_dataset.batch_size=128 \
# #     --eval_dataset.json_delta_dataset.tokenizer_processes=1 \
# #     --eval_dataset.json_delta_dataset.tokenizer_parallel_chunk_size=128 \
# #     --eval_dataset.json_delta_dataset.tokenizer_parallel_batch_size=128 \
# #     --eval_dataset.json_delta_dataset.use_data_sharded_loader=True \
# #     --checkpointer.save_optimizer_state=False \
# #     --autoresume=False \
# #     --logger.append_uuid=False \
# #     --logger.online=True \
# #     --logger.project_id="$project_id" \
# #     --logger.experiment_id="$experiment_id" \
# #     --logger.experiment_note="$experiment_note" \
# #     --logger.output_dir="$output_dir" \
# #     --logger.wandb_dir="$HOME/experiment_output/$project_id"




# # export SCRIPT_DIR="$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )"
# # export PROJECT_DIR="$( cd -- "$( dirname -- "$SCRIPT_DIR" )" &> /dev/null && pwd )"
# # cd $PROJECT_DIR
# # export PYTHONPATH="$PYTHONPATH:$PROJECT_DIR"
# # export LIBTPU_INIT_ARGS="--xla_tpu_megacore_fusion_allow_ags=false --xla_enable_async_collective_permute=true --xla_tpu_enable_ag_backward_pipelining=true --xla_tpu_enable_data_parallel_all_reduce_opt=true --xla_tpu_data_parallel_opt_different_sized_ops=true --xla_tpu_enable_async_collective_fusion=true --xla_tpu_enable_async_collective_fusion_multiple_steps=true --xla_tpu_overlap_compute_collective_tc=true --xla_enable_async_all_gather=true"

# # export llama_tokenizer_path="llava/checkpoints/lwm_checkpoints/tokenizer.model"
# # export dataset_path="data/lang_table_whole_ver1_processed_1000traj.jsonl"
# # export eval_dataset_path="data/lang_table_whole_ver1_processed_val.jsonl"

# # export output_dir="llava/checkpoints/debug"

# # export project_id='lwm'
# # export experiment_note='world-model'
# # export experiment_id='lwm_1000traj_bf16_from_whole_delta32_batch128_real_2M_val_no_mix'

# # # mesh_dim: dp, fsdp, tp, sp
# # CUDA_VISIBLE_DEVICES=0,1,2,3 python3 -u -m lwm.train \
# #     --modality='vision,text,delta' \
# #     --mesh_dim='!-1,4,1,1' \
# #     --dtype='bf16' \
# #     --total_steps=505 \
# #     --log_freq=1 \
# #     --delta_tokens=1 \
# #     --eval_steps=1 \
# #     --save_model_freq=500 \
# #     --eval_log_freq=100 \
# #     --save_milestone_freq=0 \
# #     --load_llama_config='7b' \
# #     --load_checkpoint="params::/home/World-Model/llava/checkpoints/debug/cross_attn_nsvq_code32_real_2M_val_no_mix_batch128/streaming_params_13600"\
# #     --update_llama_config="dict(delta_vocab_size=32,theta=50000000,max_sequence_length=2048,use_flash_attention=True,scan_attention=True,scan_query_chunk_size=512,scan_key_chunk_size=1024,remat_attention='nothing_saveable',scan_mlp=True,scan_mlp_chunk_size=8192,remat_mlp='nothing_saveable',remat_block='nothing_saveable',scan_layers=True)" \
# #     --tokenizer.vocab_file="$llama_tokenizer_path" \
# #     --optimizer.type='adamw' \
# #     --llama.delta_vocab_size=32 \
# #     --optimizer.accumulate_gradient_steps=1 \
# #     --optimizer.adamw_optimizer.weight_decay=0.1 \
# #     --optimizer.adamw_optimizer.lr=2e-5 \
# #     --optimizer.adamw_optimizer.end_lr=2e-5 \
# #     --optimizer.adamw_optimizer.lr_warmup_steps=40 \
# #     --optimizer.adamw_optimizer.lr_decay_steps=505 \
# #     --use_data_sharded_loader=True \
# #     --train_dataset.type='json_vision_delta' \
# #     --train_dataset.delta_vision_text_processor.fields_from_example='fields' \
# #     --train_dataset.delta_vision_text_processor.n_tokens_per_delta=1 \
# #     --train_dataset.delta_vision_text_processor.max_n_frames=1 \
# #     --train_dataset.json_delta_dataset.mode="pad" \
# #     --train_dataset.json_delta_dataset.path="$dataset_path" \
# #     --train_dataset.json_delta_dataset.seq_length=384 \
# #     --train_dataset.json_delta_dataset.batch_size=128 \
# #     --train_dataset.json_delta_dataset.tokenizer_processes=1 \
# #     --train_dataset.json_delta_dataset.tokenizer_parallel_chunk_size=128 \
# #     --train_dataset.json_delta_dataset.tokenizer_parallel_batch_size=128 \
# #     --train_dataset.json_delta_dataset.use_data_sharded_loader=True \
# #     --eval_dataset.type='json_vision_delta' \
# #     --eval_dataset.delta_vision_text_processor.fields_from_example='fields' \
# #     --eval_dataset.delta_vision_text_processor.n_tokens_per_delta=1 \
# #     --eval_dataset.delta_vision_text_processor.max_n_frames=1 \
# #     --eval_dataset.json_delta_dataset.mode="pad" \
# #     --eval_dataset.json_delta_dataset.path="$eval_dataset_path" \
# #     --eval_dataset.json_delta_dataset.seq_length=384 \
# #     --eval_dataset.json_delta_dataset.batch_size=128 \
# #     --eval_dataset.json_delta_dataset.tokenizer_processes=1 \
# #     --eval_dataset.json_delta_dataset.tokenizer_parallel_chunk_size=128 \
# #     --eval_dataset.json_delta_dataset.tokenizer_parallel_batch_size=128 \
# #     --eval_dataset.json_delta_dataset.use_data_sharded_loader=True \
# #     --checkpointer.save_optimizer_state=False \
# #     --autoresume=False \
# #     --logger.append_uuid=False \
# #     --logger.online=True \
# #     --logger.project_id="$project_id" \
# #     --logger.experiment_id="$experiment_id" \
# #     --logger.experiment_note="$experiment_note" \
# #     --logger.output_dir="$output_dir" \
# #     --logger.wandb_dir="$HOME/experiment_output/$project_id"



# # export SCRIPT_DIR="$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )"
# # export PROJECT_DIR="$( cd -- "$( dirname -- "$SCRIPT_DIR" )" &> /dev/null && pwd )"
# # cd $PROJECT_DIR
# # export PYTHONPATH="$PYTHONPATH:$PROJECT_DIR"
# # export LIBTPU_INIT_ARGS="--xla_tpu_megacore_fusion_allow_ags=false --xla_enable_async_collective_permute=true --xla_tpu_enable_ag_backward_pipelining=true --xla_tpu_enable_data_parallel_all_reduce_opt=true --xla_tpu_data_parallel_opt_different_sized_ops=true --xla_tpu_enable_async_collective_fusion=true --xla_tpu_enable_async_collective_fusion_multiple_steps=true --xla_tpu_overlap_compute_collective_tc=true --xla_enable_async_all_gather=true"

# # export llama_tokenizer_path="llava/checkpoints/lwm_checkpoints/tokenizer.model"
# # export dataset_path="data/lang_table_separate_whole_shuffled.jsonl"
# # export eval_dataset_path="data/lang_table_whole_ver1_processed_val.jsonl"

# # export output_dir="llava/checkpoints/debug"

# # export project_id='lwm'
# # export experiment_note='world-model'
# # export experiment_id='lwm_separate_bf16_from_whole_delta32_batch128_real_2M_val_no_mix'

# # # mesh_dim: dp, fsdp, tp, sp
# # CUDA_VISIBLE_DEVICES=0,1,2,3 python3 -u -m lwm.train \
# #     --modality='vision,text,delta' \
# #     --mesh_dim='!-1,4,1,1' \
# #     --dtype='bf16' \
# #     --total_steps=1301 \
# #     --log_freq=1 \
# #     --delta_tokens=1 \
# #     --eval_steps=1 \
# #     --save_model_freq=1300 \
# #     --eval_log_freq=100 \
# #     --save_milestone_freq=0 \
# #     --load_llama_config='7b' \
# #     --load_checkpoint="params::/home/World-Model/llava/checkpoints/debug/cross_attn_nsvq_code32_real_2M_val_no_mix_batch128/streaming_params_13600"\
# #     --update_llama_config="dict(delta_vocab_size=32,theta=50000000,max_sequence_length=2048,use_flash_attention=True,scan_attention=True,scan_query_chunk_size=512,scan_key_chunk_size=1024,remat_attention='nothing_saveable',scan_mlp=True,scan_mlp_chunk_size=8192,remat_mlp='nothing_saveable',remat_block='nothing_saveable',scan_layers=True)" \
# #     --tokenizer.vocab_file="$llama_tokenizer_path" \
# #     --optimizer.type='adamw' \
# #     --llama.delta_vocab_size=32 \
# #     --optimizer.accumulate_gradient_steps=1 \
# #     --optimizer.adamw_optimizer.weight_decay=0.1 \
# #     --optimizer.adamw_optimizer.lr=2e-5 \
# #     --optimizer.adamw_optimizer.end_lr=2e-5 \
# #     --optimizer.adamw_optimizer.lr_warmup_steps=40 \
# #     --optimizer.adamw_optimizer.lr_decay_steps=1300 \
# #     --use_data_sharded_loader=True \
# #     --train_dataset.type='json_vision_delta' \
# #     --train_dataset.delta_vision_text_processor.fields_from_example='fields' \
# #     --train_dataset.delta_vision_text_processor.n_tokens_per_delta=1 \
# #     --train_dataset.delta_vision_text_processor.max_n_frames=1 \
# #     --train_dataset.json_delta_dataset.mode="pad" \
# #     --train_dataset.json_delta_dataset.path="$dataset_path" \
# #     --train_dataset.json_delta_dataset.seq_length=384 \
# #     --train_dataset.json_delta_dataset.batch_size=128 \
# #     --train_dataset.json_delta_dataset.tokenizer_processes=1 \
# #     --train_dataset.json_delta_dataset.tokenizer_parallel_chunk_size=128 \
# #     --train_dataset.json_delta_dataset.tokenizer_parallel_batch_size=128 \
# #     --train_dataset.json_delta_dataset.use_data_sharded_loader=True \
# #     --eval_dataset.type='json_vision_delta' \
# #     --eval_dataset.delta_vision_text_processor.fields_from_example='fields' \
# #     --eval_dataset.delta_vision_text_processor.n_tokens_per_delta=1 \
# #     --eval_dataset.delta_vision_text_processor.max_n_frames=1 \
# #     --eval_dataset.json_delta_dataset.mode="pad" \
# #     --eval_dataset.json_delta_dataset.path="$eval_dataset_path" \
# #     --eval_dataset.json_delta_dataset.seq_length=384 \
# #     --eval_dataset.json_delta_dataset.batch_size=128 \
# #     --eval_dataset.json_delta_dataset.tokenizer_processes=1 \
# #     --eval_dataset.json_delta_dataset.tokenizer_parallel_chunk_size=128 \
# #     --eval_dataset.json_delta_dataset.tokenizer_parallel_batch_size=128 \
# #     --eval_dataset.json_delta_dataset.use_data_sharded_loader=True \
# #     --checkpointer.save_optimizer_state=False \
# #     --autoresume=False \
# #     --logger.append_uuid=False \
# #     --logger.online=True \
# #     --logger.project_id="$project_id" \
# #     --logger.experiment_id="$experiment_id" \
# #     --logger.experiment_note="$experiment_note" \
# #     --logger.output_dir="$output_dir" \
# #     --logger.wandb_dir="$HOME/experiment_output/$project_id"


# # export SCRIPT_DIR="$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )"
# # export PROJECT_DIR="$( cd -- "$( dirname -- "$SCRIPT_DIR" )" &> /dev/null && pwd )"
# # cd $PROJECT_DIR
# # export PYTHONPATH="$PYTHONPATH:$PROJECT_DIR"
# # export LIBTPU_INIT_ARGS="--xla_tpu_megacore_fusion_allow_ags=false --xla_enable_async_collective_permute=true --xla_tpu_enable_ag_backward_pipelining=true --xla_tpu_enable_data_parallel_all_reduce_opt=true --xla_tpu_data_parallel_opt_different_sized_ops=true --xla_tpu_enable_async_collective_fusion=true --xla_tpu_enable_async_collective_fusion_multiple_steps=true --xla_tpu_overlap_compute_collective_tc=true --xla_enable_async_all_gather=true"

# # export llama_tokenizer_path="llava/checkpoints/lwm_checkpoints/tokenizer.model"
# # export dataset_path="data/lang_table_separate_whole_shuffled.jsonl"
# # export eval_dataset_path="data/lang_table_whole_ver1_processed_val.jsonl"

# # export output_dir="llava/checkpoints/debug"

# # export project_id='lwm'
# # export experiment_note='world-model'
# # export experiment_id='lwm_separate_bf16_from_whole_delta32_batch128_real_2M_sim_25K_mixed'

# # # mesh_dim: dp, fsdp, tp, sp
# # CUDA_VISIBLE_DEVICES=0,1,2,3 python3 -u -m lwm.train \
# #     --modality='vision,text,delta' \
# #     --mesh_dim='!-1,4,1,1' \
# #     --dtype='bf16' \
# #     --total_steps=1301 \
# #     --log_freq=1 \
# #     --delta_tokens=1 \
# #     --eval_steps=1 \
# #     --save_model_freq=1300 \
# #     --eval_log_freq=100 \
# #     --save_milestone_freq=0 \
# #     --load_llama_config='7b' \
# #     --load_checkpoint="params::/home/World-Model/llava/checkpoints/debug/cross_attn_nsvq_code8_real_2M_sim_25K_mixed_batch128/streaming_params_13870"\
# #     --update_llama_config="dict(delta_vocab_size=8,theta=50000000,max_sequence_length=2048,use_flash_attention=True,scan_attention=True,scan_query_chunk_size=512,scan_key_chunk_size=1024,remat_attention='nothing_saveable',scan_mlp=True,scan_mlp_chunk_size=8192,remat_mlp='nothing_saveable',remat_block='nothing_saveable',scan_layers=True)" \
# #     --tokenizer.vocab_file="$llama_tokenizer_path" \
# #     --optimizer.type='adamw' \
# #     --llama.delta_vocab_size=8 \
# #     --optimizer.accumulate_gradient_steps=1 \
# #     --optimizer.adamw_optimizer.weight_decay=0.1 \
# #     --optimizer.adamw_optimizer.lr=2e-5 \
# #     --optimizer.adamw_optimizer.end_lr=2e-5 \
# #     --optimizer.adamw_optimizer.lr_warmup_steps=40 \
# #     --optimizer.adamw_optimizer.lr_decay_steps=1300 \
# #     --use_data_sharded_loader=True \
# #     --train_dataset.type='json_vision_delta' \
# #     --train_dataset.delta_vision_text_processor.fields_from_example='fields' \
# #     --train_dataset.delta_vision_text_processor.n_tokens_per_delta=1 \
# #     --train_dataset.delta_vision_text_processor.max_n_frames=1 \
# #     --train_dataset.json_delta_dataset.mode="pad" \
# #     --train_dataset.json_delta_dataset.path="$dataset_path" \
# #     --train_dataset.json_delta_dataset.seq_length=384 \
# #     --train_dataset.json_delta_dataset.batch_size=128 \
# #     --train_dataset.json_delta_dataset.tokenizer_processes=1 \
# #     --train_dataset.json_delta_dataset.tokenizer_parallel_chunk_size=128 \
# #     --train_dataset.json_delta_dataset.tokenizer_parallel_batch_size=128 \
# #     --train_dataset.json_delta_dataset.use_data_sharded_loader=True \
# #     --eval_dataset.type='json_vision_delta' \
# #     --eval_dataset.delta_vision_text_processor.fields_from_example='fields' \
# #     --eval_dataset.delta_vision_text_processor.n_tokens_per_delta=1 \
# #     --eval_dataset.delta_vision_text_processor.max_n_frames=1 \
# #     --eval_dataset.json_delta_dataset.mode="pad" \
# #     --eval_dataset.json_delta_dataset.path="$eval_dataset_path" \
# #     --eval_dataset.json_delta_dataset.seq_length=384 \
# #     --eval_dataset.json_delta_dataset.batch_size=128 \
# #     --eval_dataset.json_delta_dataset.tokenizer_processes=1 \
# #     --eval_dataset.json_delta_dataset.tokenizer_parallel_chunk_size=128 \
# #     --eval_dataset.json_delta_dataset.tokenizer_parallel_batch_size=128 \
# #     --eval_dataset.json_delta_dataset.use_data_sharded_loader=True \
# #     --checkpointer.save_optimizer_state=False \
# #     --autoresume=False \
# #     --logger.append_uuid=False \
# #     --logger.online=True \
# #     --logger.project_id="$project_id" \
# #     --logger.experiment_id="$experiment_id" \
# #     --logger.experiment_note="$experiment_note" \
# #     --logger.output_dir="$output_dir" \
# #     --logger.wandb_dir="$HOME/experiment_output/$project_id"


# cd lang_table
# # Assuming CUDA_VISIBLE_DEVICES is used to specify the GPU
# CUDA_VISIBLE_DEVICES=4 ./eval_langtable3.sh &
# PID1=$!

# CUDA_VISIBLE_DEVICES=5 ./eval_langtable4.sh &
# PID2=$!


# wait $PID1 $PID2 

# echo "All post-training scripts completed."

# cd ..

# #! /bin/bash

# export SCRIPT_DIR="$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )"
# export PROJECT_DIR="$( cd -- "$( dirname -- "$SCRIPT_DIR" )" &> /dev/null && pwd )"
# cd $PROJECT_DIR
# export PYTHONPATH="$PYTHONPATH:$PROJECT_DIR"
# export LIBTPU_INIT_ARGS="--xla_tpu_megacore_fusion_allow_ags=false --xla_enable_async_collective_permute=true --xla_tpu_enable_ag_backward_pipelining=true --xla_tpu_enable_data_parallel_all_reduce_opt=true --xla_tpu_data_parallel_opt_different_sized_ops=true --xla_tpu_enable_async_collective_fusion=true --xla_tpu_enable_async_collective_fusion_multiple_steps=true --xla_tpu_overlap_compute_collective_tc=true --xla_enable_async_all_gather=true"

# export llama_tokenizer_path="llava/checkpoints/lwm_checkpoints/tokenizer.model"
# export dataset_path="/home/World-Model/data/cross_attn_nsvq_code32_real_whole_train.jsonl"
# export eval_dataset_path="/home/World-Model/data/cross_attn_nsvq_code32_real_whole_val.jsonl"

# export output_dir="llava/checkpoints/debug"

# export project_id='lwm'
# export experiment_note='world-model'
# export experiment_id='cross_attn_nsvq_code32_real_whole_no_mix_batch128'

# # mesh_dim: dp, fsdp, tp, sp
# CUDA_VISIBLE_DEVICES=4,5,6,7 python3 -u -m lwm.train \
#     --modality='vision,text,delta' \
#     --mesh_dim='!-1,4,1,1' \
#     --dtype='bf16' \
#     --total_steps=55005 \
#     --log_freq=1 \
#     --delta_tokens=1 \
#     --eval_steps=1 \
#     --save_model_freq=55000 \
#     --eval_log_freq=100 \
#     --save_milestone_freq=36448 \
#     --load_llama_config='7b' \
#     --load_checkpoint="params::/home/World-Model/llava/checkpoints/lwm_checkpoints/params"\
#     --update_llama_config="dict(delta_vocab_size=32,theta=50000000,max_sequence_length=2048,use_flash_attention=True,scan_attention=True,scan_query_chunk_size=512,scan_key_chunk_size=1024,remat_attention='nothing_saveable',scan_mlp=True,scan_mlp_chunk_size=8192,remat_mlp='nothing_saveable',remat_block='nothing_saveable',scan_layers=True)" \
#     --tokenizer.vocab_file="$llama_tokenizer_path" \
#     --optimizer.type='adamw' \
#     --llama.delta_vocab_size=32 \
#     --optimizer.accumulate_gradient_steps=1 \
#     --optimizer.adamw_optimizer.weight_decay=0 \
#     --optimizer.adamw_optimizer.lr=2e-5 \
#     --optimizer.adamw_optimizer.end_lr=0 \
#     --optimizer.adamw_optimizer.lr_warmup_steps=1000 \
#     --optimizer.adamw_optimizer.lr_decay_steps=55005 \
#     --use_data_sharded_loader=True \
#     --train_dataset.type='json_vision_delta' \
#     --train_dataset.delta_vision_text_processor.fields_from_example='fields' \
#     --train_dataset.delta_vision_text_processor.n_tokens_per_delta=1 \
#     --train_dataset.delta_vision_text_processor.max_n_frames=1 \
#     --train_dataset.json_delta_dataset.mode="pad" \
#     --train_dataset.json_delta_dataset.path="$dataset_path" \
#     --train_dataset.json_delta_dataset.seq_length=384 \
#     --train_dataset.json_delta_dataset.batch_size=128 \
#     --train_dataset.json_delta_dataset.tokenizer_processes=1 \
#     --train_dataset.json_delta_dataset.tokenizer_parallel_chunk_size=128 \
#     --train_dataset.json_delta_dataset.tokenizer_parallel_batch_size=128 \
#     --train_dataset.json_delta_dataset.use_data_sharded_loader=True \
#     --eval_dataset.type='json_vision_delta' \
#     --eval_dataset.delta_vision_text_processor.fields_from_example='fields' \
#     --eval_dataset.delta_vision_text_processor.n_tokens_per_delta=1 \
#     --eval_dataset.delta_vision_text_processor.max_n_frames=1 \
#     --eval_dataset.json_delta_dataset.mode="pad" \
#     --eval_dataset.json_delta_dataset.path="$eval_dataset_path" \
#     --eval_dataset.json_delta_dataset.seq_length=384 \
#     --eval_dataset.json_delta_dataset.batch_size=128 \
#     --eval_dataset.json_delta_dataset.tokenizer_processes=1 \
#     --eval_dataset.json_delta_dataset.tokenizer_parallel_chunk_size=128 \
#     --eval_dataset.json_delta_dataset.tokenizer_parallel_batch_size=128 \
#     --eval_dataset.json_delta_dataset.use_data_sharded_loader=True \
#     --checkpointer.save_optimizer_state=False \
#     --autoresume=False \
#     --logger.append_uuid=False \
#     --logger.online=True \
#     --logger.project_id="$project_id" \
#     --logger.experiment_id="$experiment_id" \
#     --logger.experiment_note="$experiment_note" \
#     --logger.output_dir="$output_dir" \
#     --logger.wandb_dir="$HOME/experiment_output/$project_id"



# export SCRIPT_DIR="$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )"
# export PROJECT_DIR="$( cd -- "$( dirname -- "$SCRIPT_DIR" )" &> /dev/null && pwd )"
# cd $PROJECT_DIR
# export PYTHONPATH="$PYTHONPATH:$PROJECT_DIR"
# export LIBTPU_INIT_ARGS="--xla_tpu_megacore_fusion_allow_ags=false --xla_enable_async_collective_permute=true --xla_tpu_enable_ag_backward_pipelining=true --xla_tpu_enable_data_parallel_all_reduce_opt=true --xla_tpu_data_parallel_opt_different_sized_ops=true --xla_tpu_enable_async_collective_fusion=true --xla_tpu_enable_async_collective_fusion_multiple_steps=true --xla_tpu_overlap_compute_collective_tc=true --xla_enable_async_all_gather=true"

# export llama_tokenizer_path="llava/checkpoints/lwm_checkpoints/tokenizer.model"
# export dataset_path="data/lang_table_whole_ver1_processed_1000traj.jsonl"
# export eval_dataset_path="data/lang_table_whole_ver1_processed_val.jsonl"

# export output_dir="llava/checkpoints/debug"

# export project_id='lwm'
# export experiment_note='world-model'
# export experiment_id='lwm_1000traj_bf16_from_whole_delta32_batch128_real_whole'

# # mesh_dim: dp, fsdp, tp, sp
# CUDA_VISIBLE_DEVICES=4,5,6,7 python3 -u -m lwm.train \
#     --modality='vision,text,delta' \
#     --mesh_dim='!-1,4,1,1' \
#     --dtype='bf16' \
#     --total_steps=505 \
#     --log_freq=1 \
#     --delta_tokens=1 \
#     --eval_steps=1 \
#     --save_model_freq=500 \
#     --eval_log_freq=100 \
#     --save_milestone_freq=0 \
#     --load_llama_config='7b' \
#     --load_checkpoint="params::/home/World-Model/llava/checkpoints/debug/cross_attn_nsvq_code32_real_whole_no_mix_batch128/streaming_params"\
#     --update_llama_config="dict(delta_vocab_size=32,theta=50000000,max_sequence_length=2048,use_flash_attention=True,scan_attention=True,scan_query_chunk_size=512,scan_key_chunk_size=1024,remat_attention='nothing_saveable',scan_mlp=True,scan_mlp_chunk_size=8192,remat_mlp='nothing_saveable',remat_block='nothing_saveable',scan_layers=True)" \
#     --tokenizer.vocab_file="$llama_tokenizer_path" \
#     --optimizer.type='adamw' \
#     --llama.delta_vocab_size=32 \
#     --optimizer.accumulate_gradient_steps=1 \
#     --optimizer.adamw_optimizer.weight_decay=0.1 \
#     --optimizer.adamw_optimizer.lr=2e-5 \
#     --optimizer.adamw_optimizer.end_lr=2e-5 \
#     --optimizer.adamw_optimizer.lr_warmup_steps=40 \
#     --optimizer.adamw_optimizer.lr_decay_steps=505 \
#     --use_data_sharded_loader=True \
#     --train_dataset.type='json_vision_delta' \
#     --train_dataset.delta_vision_text_processor.fields_from_example='fields' \
#     --train_dataset.delta_vision_text_processor.n_tokens_per_delta=1 \
#     --train_dataset.delta_vision_text_processor.max_n_frames=1 \
#     --train_dataset.json_delta_dataset.mode="pad" \
#     --train_dataset.json_delta_dataset.path="$dataset_path" \
#     --train_dataset.json_delta_dataset.seq_length=384 \
#     --train_dataset.json_delta_dataset.batch_size=128 \
#     --train_dataset.json_delta_dataset.tokenizer_processes=1 \
#     --train_dataset.json_delta_dataset.tokenizer_parallel_chunk_size=128 \
#     --train_dataset.json_delta_dataset.tokenizer_parallel_batch_size=128 \
#     --train_dataset.json_delta_dataset.use_data_sharded_loader=True \
#     --eval_dataset.type='json_vision_delta' \
#     --eval_dataset.delta_vision_text_processor.fields_from_example='fields' \
#     --eval_dataset.delta_vision_text_processor.n_tokens_per_delta=1 \
#     --eval_dataset.delta_vision_text_processor.max_n_frames=1 \
#     --eval_dataset.json_delta_dataset.mode="pad" \
#     --eval_dataset.json_delta_dataset.path="$eval_dataset_path" \
#     --eval_dataset.json_delta_dataset.seq_length=384 \
#     --eval_dataset.json_delta_dataset.batch_size=128 \
#     --eval_dataset.json_delta_dataset.tokenizer_processes=1 \
#     --eval_dataset.json_delta_dataset.tokenizer_parallel_chunk_size=128 \
#     --eval_dataset.json_delta_dataset.tokenizer_parallel_batch_size=128 \
#     --eval_dataset.json_delta_dataset.use_data_sharded_loader=True \
#     --checkpointer.save_optimizer_state=False \
#     --autoresume=False \
#     --logger.append_uuid=False \
#     --logger.online=True \
#     --logger.project_id="$project_id" \
#     --logger.experiment_id="$experiment_id" \
#     --logger.experiment_note="$experiment_note" \
#     --logger.output_dir="$output_dir" \
#     --logger.wandb_dir="$HOME/experiment_output/$project_id"


# export SCRIPT_DIR="$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )"
# export PROJECT_DIR="$( cd -- "$( dirname -- "$SCRIPT_DIR" )" &> /dev/null && pwd )"
# cd $PROJECT_DIR
# export PYTHONPATH="$PYTHONPATH:$PROJECT_DIR"
# export LIBTPU_INIT_ARGS="--xla_tpu_megacore_fusion_allow_ags=false --xla_enable_async_collective_permute=true --xla_tpu_enable_ag_backward_pipelining=true --xla_tpu_enable_data_parallel_all_reduce_opt=true --xla_tpu_data_parallel_opt_different_sized_ops=true --xla_tpu_enable_async_collective_fusion=true --xla_tpu_enable_async_collective_fusion_multiple_steps=true --xla_tpu_overlap_compute_collective_tc=true --xla_enable_async_all_gather=true"

# export llama_tokenizer_path="llava/checkpoints/lwm_checkpoints/tokenizer.model"
# export dataset_path="data/lang_table_whole_ver1_processed_1000traj.jsonl"
# export eval_dataset_path="data/lang_table_whole_ver1_processed_val.jsonl"

# export output_dir="llava/checkpoints/debug"

# export project_id='lwm'
# export experiment_note='world-model'
# export experiment_id='lwm_1000traj_bf16_from_whole_delta32_batch128_real_36K_step'

# # mesh_dim: dp, fsdp, tp, sp
# CUDA_VISIBLE_DEVICES=4,5,6,7 python3 -u -m lwm.train \
#     --modality='vision,text,delta' \
#     --mesh_dim='!-1,4,1,1' \
#     --dtype='bf16' \
#     --total_steps=505 \
#     --log_freq=1 \
#     --delta_tokens=1 \
#     --eval_steps=1 \
#     --save_model_freq=500 \
#     --eval_log_freq=100 \
#     --save_milestone_freq=0 \
#     --load_llama_config='7b' \
#     --load_checkpoint="params::/home/World-Model/llava/checkpoints/debug/cross_attn_nsvq_code32_real_whole_no_mix_batch128/streaming_params_36448"\
#     --update_llama_config="dict(delta_vocab_size=32,theta=50000000,max_sequence_length=2048,use_flash_attention=True,scan_attention=True,scan_query_chunk_size=512,scan_key_chunk_size=1024,remat_attention='nothing_saveable',scan_mlp=True,scan_mlp_chunk_size=8192,remat_mlp='nothing_saveable',remat_block='nothing_saveable',scan_layers=True)" \
#     --tokenizer.vocab_file="$llama_tokenizer_path" \
#     --optimizer.type='adamw' \
#     --llama.delta_vocab_size=32 \
#     --optimizer.accumulate_gradient_steps=1 \
#     --optimizer.adamw_optimizer.weight_decay=0.1 \
#     --optimizer.adamw_optimizer.lr=2e-5 \
#     --optimizer.adamw_optimizer.end_lr=2e-5 \
#     --optimizer.adamw_optimizer.lr_warmup_steps=40 \
#     --optimizer.adamw_optimizer.lr_decay_steps=505 \
#     --use_data_sharded_loader=True \
#     --train_dataset.type='json_vision_delta' \
#     --train_dataset.delta_vision_text_processor.fields_from_example='fields' \
#     --train_dataset.delta_vision_text_processor.n_tokens_per_delta=1 \
#     --train_dataset.delta_vision_text_processor.max_n_frames=1 \
#     --train_dataset.json_delta_dataset.mode="pad" \
#     --train_dataset.json_delta_dataset.path="$dataset_path" \
#     --train_dataset.json_delta_dataset.seq_length=384 \
#     --train_dataset.json_delta_dataset.batch_size=128 \
#     --train_dataset.json_delta_dataset.tokenizer_processes=1 \
#     --train_dataset.json_delta_dataset.tokenizer_parallel_chunk_size=128 \
#     --train_dataset.json_delta_dataset.tokenizer_parallel_batch_size=128 \
#     --train_dataset.json_delta_dataset.use_data_sharded_loader=True \
#     --eval_dataset.type='json_vision_delta' \
#     --eval_dataset.delta_vision_text_processor.fields_from_example='fields' \
#     --eval_dataset.delta_vision_text_processor.n_tokens_per_delta=1 \
#     --eval_dataset.delta_vision_text_processor.max_n_frames=1 \
#     --eval_dataset.json_delta_dataset.mode="pad" \
#     --eval_dataset.json_delta_dataset.path="$eval_dataset_path" \
#     --eval_dataset.json_delta_dataset.seq_length=384 \
#     --eval_dataset.json_delta_dataset.batch_size=128 \
#     --eval_dataset.json_delta_dataset.tokenizer_processes=1 \
#     --eval_dataset.json_delta_dataset.tokenizer_parallel_chunk_size=128 \
#     --eval_dataset.json_delta_dataset.tokenizer_parallel_batch_size=128 \
#     --eval_dataset.json_delta_dataset.use_data_sharded_loader=True \
#     --checkpointer.save_optimizer_state=False \
#     --autoresume=False \
#     --logger.append_uuid=False \
#     --logger.online=True \
#     --logger.project_id="$project_id" \
#     --logger.experiment_id="$experiment_id" \
#     --logger.experiment_note="$experiment_note" \
#     --logger.output_dir="$output_dir" \
#     --logger.wandb_dir="$HOME/experiment_output/$project_id"



# export SCRIPT_DIR="$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )"
# export PROJECT_DIR="$( cd -- "$( dirname -- "$SCRIPT_DIR" )" &> /dev/null && pwd )"
# cd $PROJECT_DIR
# export PYTHONPATH="$PYTHONPATH:$PROJECT_DIR"
# export LIBTPU_INIT_ARGS="--xla_tpu_megacore_fusion_allow_ags=false --xla_enable_async_collective_permute=true --xla_tpu_enable_ag_backward_pipelining=true --xla_tpu_enable_data_parallel_all_reduce_opt=true --xla_tpu_data_parallel_opt_different_sized_ops=true --xla_tpu_enable_async_collective_fusion=true --xla_tpu_enable_async_collective_fusion_multiple_steps=true --xla_tpu_overlap_compute_collective_tc=true --xla_enable_async_all_gather=true"

# export llama_tokenizer_path="llava/checkpoints/lwm_checkpoints/tokenizer.model"
# export dataset_path="data/lang_table_separate_whole_shuffled.jsonl"
# export eval_dataset_path="data/lang_table_whole_ver1_processed_val.jsonl"

# export output_dir="llava/checkpoints/debug"

# export project_id='lwm'
# export experiment_note='world-model'
# export experiment_id='lwm_separate_bf16_from_whole_delta32_batch128_real_whole'

# # mesh_dim: dp, fsdp, tp, sp
# CUDA_VISIBLE_DEVICES=4,5,6,7 python3 -u -m lwm.train \
#     --modality='vision,text,delta' \
#     --mesh_dim='!-1,4,1,1' \
#     --dtype='bf16' \
#     --total_steps=1301 \
#     --log_freq=1 \
#     --delta_tokens=1 \
#     --eval_steps=1 \
#     --save_model_freq=1300 \
#     --eval_log_freq=100 \
#     --save_milestone_freq=0 \
#     --load_llama_config='7b' \
#     --load_checkpoint="params::/home/World-Model/llava/checkpoints/debug/cross_attn_nsvq_code32_real_whole_no_mix_batch128/streaming_params"\
#     --update_llama_config="dict(delta_vocab_size=32,theta=50000000,max_sequence_length=2048,use_flash_attention=True,scan_attention=True,scan_query_chunk_size=512,scan_key_chunk_size=1024,remat_attention='nothing_saveable',scan_mlp=True,scan_mlp_chunk_size=8192,remat_mlp='nothing_saveable',remat_block='nothing_saveable',scan_layers=True)" \
#     --tokenizer.vocab_file="$llama_tokenizer_path" \
#     --optimizer.type='adamw' \
#     --llama.delta_vocab_size=32 \
#     --optimizer.accumulate_gradient_steps=1 \
#     --optimizer.adamw_optimizer.weight_decay=0.1 \
#     --optimizer.adamw_optimizer.lr=2e-5 \
#     --optimizer.adamw_optimizer.end_lr=2e-5 \
#     --optimizer.adamw_optimizer.lr_warmup_steps=40 \
#     --optimizer.adamw_optimizer.lr_decay_steps=1300 \
#     --use_data_sharded_loader=True \
#     --train_dataset.type='json_vision_delta' \
#     --train_dataset.delta_vision_text_processor.fields_from_example='fields' \
#     --train_dataset.delta_vision_text_processor.n_tokens_per_delta=1 \
#     --train_dataset.delta_vision_text_processor.max_n_frames=1 \
#     --train_dataset.json_delta_dataset.mode="pad" \
#     --train_dataset.json_delta_dataset.path="$dataset_path" \
#     --train_dataset.json_delta_dataset.seq_length=384 \
#     --train_dataset.json_delta_dataset.batch_size=128 \
#     --train_dataset.json_delta_dataset.tokenizer_processes=1 \
#     --train_dataset.json_delta_dataset.tokenizer_parallel_chunk_size=128 \
#     --train_dataset.json_delta_dataset.tokenizer_parallel_batch_size=128 \
#     --train_dataset.json_delta_dataset.use_data_sharded_loader=True \
#     --eval_dataset.type='json_vision_delta' \
#     --eval_dataset.delta_vision_text_processor.fields_from_example='fields' \
#     --eval_dataset.delta_vision_text_processor.n_tokens_per_delta=1 \
#     --eval_dataset.delta_vision_text_processor.max_n_frames=1 \
#     --eval_dataset.json_delta_dataset.mode="pad" \
#     --eval_dataset.json_delta_dataset.path="$eval_dataset_path" \
#     --eval_dataset.json_delta_dataset.seq_length=384 \
#     --eval_dataset.json_delta_dataset.batch_size=128 \
#     --eval_dataset.json_delta_dataset.tokenizer_processes=1 \
#     --eval_dataset.json_delta_dataset.tokenizer_parallel_chunk_size=128 \
#     --eval_dataset.json_delta_dataset.tokenizer_parallel_batch_size=128 \
#     --eval_dataset.json_delta_dataset.use_data_sharded_loader=True \
#     --checkpointer.save_optimizer_state=False \
#     --autoresume=False \
#     --logger.append_uuid=False \
#     --logger.online=True \
#     --logger.project_id="$project_id" \
#     --logger.experiment_id="$experiment_id" \
#     --logger.experiment_note="$experiment_note" \
#     --logger.output_dir="$output_dir" \
#     --logger.wandb_dir="$HOME/experiment_output/$project_id"





# export SCRIPT_DIR="$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )"
# export PROJECT_DIR="$( cd -- "$( dirname -- "$SCRIPT_DIR" )" &> /dev/null && pwd )"
# cd $PROJECT_DIR
# export PYTHONPATH="$PYTHONPATH:$PROJECT_DIR"
# export LIBTPU_INIT_ARGS="--xla_tpu_megacore_fusion_allow_ags=false --xla_enable_async_collective_permute=true --xla_tpu_enable_ag_backward_pipelining=true --xla_tpu_enable_data_parallel_all_reduce_opt=true --xla_tpu_data_parallel_opt_different_sized_ops=true --xla_tpu_enable_async_collective_fusion=true --xla_tpu_enable_async_collective_fusion_multiple_steps=true --xla_tpu_overlap_compute_collective_tc=true --xla_enable_async_all_gather=true"

# export llama_tokenizer_path="llava/checkpoints/lwm_checkpoints/tokenizer.model"
# export dataset_path="data/lang_table_separate_whole_shuffled.jsonl"
# export eval_dataset_path="data/lang_table_whole_ver1_processed_val.jsonl"

# export output_dir="llava/checkpoints/debug"

# export project_id='lwm'
# export experiment_note='world-model'
# export experiment_id='lwm_separate_bf16_from_whole_delta32_batch128_real_36K_step'

# # mesh_dim: dp, fsdp, tp, sp
# CUDA_VISIBLE_DEVICES=4,5,6,7 python3 -u -m lwm.train \
#     --modality='vision,text,delta' \
#     --mesh_dim='!-1,4,1,1' \
#     --dtype='bf16' \
#     --total_steps=1301 \
#     --log_freq=1 \
#     --delta_tokens=1 \
#     --eval_steps=1 \
#     --save_model_freq=1300 \
#     --eval_log_freq=100 \
#     --save_milestone_freq=0 \
#     --load_llama_config='7b' \
#     --load_checkpoint="params::/home/World-Model/llava/checkpoints/debug/cross_attn_nsvq_code32_real_whole_no_mix_batch128/streaming_params_36448"\
#     --update_llama_config="dict(delta_vocab_size=32,theta=50000000,max_sequence_length=2048,use_flash_attention=True,scan_attention=True,scan_query_chunk_size=512,scan_key_chunk_size=1024,remat_attention='nothing_saveable',scan_mlp=True,scan_mlp_chunk_size=8192,remat_mlp='nothing_saveable',remat_block='nothing_saveable',scan_layers=True)" \
#     --tokenizer.vocab_file="$llama_tokenizer_path" \
#     --optimizer.type='adamw' \
#     --llama.delta_vocab_size=32 \
#     --optimizer.accumulate_gradient_steps=1 \
#     --optimizer.adamw_optimizer.weight_decay=0.1 \
#     --optimizer.adamw_optimizer.lr=2e-5 \
#     --optimizer.adamw_optimizer.end_lr=2e-5 \
#     --optimizer.adamw_optimizer.lr_warmup_steps=40 \
#     --optimizer.adamw_optimizer.lr_decay_steps=1300 \
#     --use_data_sharded_loader=True \
#     --train_dataset.type='json_vision_delta' \
#     --train_dataset.delta_vision_text_processor.fields_from_example='fields' \
#     --train_dataset.delta_vision_text_processor.n_tokens_per_delta=1 \
#     --train_dataset.delta_vision_text_processor.max_n_frames=1 \
#     --train_dataset.json_delta_dataset.mode="pad" \
#     --train_dataset.json_delta_dataset.path="$dataset_path" \
#     --train_dataset.json_delta_dataset.seq_length=384 \
#     --train_dataset.json_delta_dataset.batch_size=128 \
#     --train_dataset.json_delta_dataset.tokenizer_processes=1 \
#     --train_dataset.json_delta_dataset.tokenizer_parallel_chunk_size=128 \
#     --train_dataset.json_delta_dataset.tokenizer_parallel_batch_size=128 \
#     --train_dataset.json_delta_dataset.use_data_sharded_loader=True \
#     --eval_dataset.type='json_vision_delta' \
#     --eval_dataset.delta_vision_text_processor.fields_from_example='fields' \
#     --eval_dataset.delta_vision_text_processor.n_tokens_per_delta=1 \
#     --eval_dataset.delta_vision_text_processor.max_n_frames=1 \
#     --eval_dataset.json_delta_dataset.mode="pad" \
#     --eval_dataset.json_delta_dataset.path="$eval_dataset_path" \
#     --eval_dataset.json_delta_dataset.seq_length=384 \
#     --eval_dataset.json_delta_dataset.batch_size=128 \
#     --eval_dataset.json_delta_dataset.tokenizer_processes=1 \
#     --eval_dataset.json_delta_dataset.tokenizer_parallel_chunk_size=128 \
#     --eval_dataset.json_delta_dataset.tokenizer_parallel_batch_size=128 \
#     --eval_dataset.json_delta_dataset.use_data_sharded_loader=True \
#     --checkpointer.save_optimizer_state=False \
#     --autoresume=False \
#     --logger.append_uuid=False \
#     --logger.online=True \
#     --logger.project_id="$project_id" \
#     --logger.experiment_id="$experiment_id" \
#     --logger.experiment_note="$experiment_note" \
#     --logger.output_dir="$output_dir" \
#     --logger.wandb_dir="$HOME/experiment_output/$project_id"

# cd lang_table

# # Assuming CUDA_VISIBLE_DEVICES is used to specify the GPU
# CUDA_VISIBLE_DEVICES=4 ./eval_langtable.sh &
# PID1=$!

# CUDA_VISIBLE_DEVICES=5 ./eval_langtable2.sh &
# PID2=$!

# CUDA_VISIBLE_DEVICES=6 ./eval_langtable3.sh &
# PID3=$!

# CUDA_VISIBLE_DEVICES=7 ./eval_langtable4.sh &
# PID4=$!


# wait $PID1 $PID2 $PID3 $PID4 

# echo "All post-training scripts completed."

# cd ..





# export SCRIPT_DIR="$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )"
# export PROJECT_DIR="$( cd -- "$( dirname -- "$SCRIPT_DIR" )" &> /dev/null && pwd )"
# cd $PROJECT_DIR
# export PYTHONPATH="$PYTHONPATH:$PROJECT_DIR"
# export LIBTPU_INIT_ARGS="--xla_tpu_megacore_fusion_allow_ags=false --xla_enable_async_collective_permute=true --xla_tpu_enable_ag_backward_pipelining=true --xla_tpu_enable_data_parallel_all_reduce_opt=true --xla_tpu_data_parallel_opt_different_sized_ops=true --xla_tpu_enable_async_collective_fusion=true --xla_tpu_enable_async_collective_fusion_multiple_steps=true --xla_tpu_overlap_compute_collective_tc=true --xla_enable_async_all_gather=true"


# export llama_tokenizer_path="/mnt/default/lwm/data/checkpoints/lwm_checkpoints/tokenizer.model"
# export dataset_path='/mnt/default/lwm/data/unipi/vpt_preprocessed/lang_table_whole_sim_total_vpt.jsonl'

# export output_dir='/mnt/default/lwm/data/unipi/lwm_vpt'
# export WANDB_API_KEY="0d0155a751a804873eedf37c29060146b377edb8"


# export project_id='lwm'
# export experiment_note='world-model'
# # export experiment_id='sthv2_bf16_from_whole_delta8_batch128_seq4_layer8_100K_re2'
# export experiment_id='lang_table_vpt_sim_whole'


# # mesh_dim: dp, fsdp, tp, sp
# python3 -u -m lwm.train \
#     --modality='vision,text' \
#     --mesh_dim='!-1,8,1,1' \
#     --dtype='bf16' \
#     --total_steps=36450 \
#     --log_freq=1 \
#     --save_model_freq=36448 \
#     --save_milestone_freq=36448 \
#     --load_llama_config='7b' \
#     --load_checkpoint="params::/mnt/default/lwm/data/checkpoints/lwm_checkpoints/params" \
#     --update_llama_config="dict(theta=50000000,max_sequence_length=2048,use_flash_attention=True,scan_attention=True,scan_query_chunk_size=512,scan_key_chunk_size=1024,remat_attention='nothing_saveable',scan_mlp=True,scan_mlp_chunk_size=8192,remat_mlp='nothing_saveable',remat_block='nothing_saveable',scan_layers=True)" \
#     --tokenizer.vocab_file="$llama_tokenizer_path" \
#     --optimizer.type='adamw' \
#     --optimizer.accumulate_gradient_steps=1 \
#     --optimizer.adamw_optimizer.weight_decay=0 \
#     --optimizer.adamw_optimizer.lr=2e-5 \
#     --optimizer.adamw_optimizer.end_lr=2e-5 \
#     --optimizer.adamw_optimizer.lr_warmup_steps=0 \
#     --optimizer.adamw_optimizer.lr_decay_steps=36450 \
#     --use_data_sharded_loader=True \
#     --train_dataset.type='json_vision' \
#     --train_dataset.vision_text_processor.fields_from_example='fields' \
#     --train_dataset.vision_text_processor.max_n_frames=1 \
#     --train_dataset.json_vision_dataset.mode="pad" \
#     --train_dataset.json_vision_dataset.path="$dataset_path" \
#     --train_dataset.json_vision_dataset.seq_length=384 \
#     --train_dataset.json_vision_dataset.batch_size=128 \
#     --train_dataset.json_vision_dataset.tokenizer_processes=1 \
#     --train_dataset.json_vision_dataset.tokenizer_parallel_chunk_size=128 \
#     --train_dataset.json_vision_dataset.tokenizer_parallel_batch_size=128 \
#     --train_dataset.json_vision_dataset.use_data_sharded_loader=True \
#     --checkpointer.save_optimizer_state=False \
#     --autoresume=False \
#     --logger.append_uuid=False \
#     --logger.online=False \
#     --logger.project_id="$project_id" \
#     --logger.experiment_id="$experiment_id" \
#     --logger.experiment_note="$experiment_note" \
#     --logger.output_dir="$output_dir" \
#     --logger.wandb_dir="$HOME/experiment_output/$project_id"





export SCRIPT_DIR="$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )"
export PROJECT_DIR="$( cd -- "$( dirname -- "$SCRIPT_DIR" )" &> /dev/null && pwd )"
cd $PROJECT_DIR
export PYTHONPATH="$PYTHONPATH:$PROJECT_DIR"
export LIBTPU_INIT_ARGS="--xla_tpu_megacore_fusion_allow_ags=false --xla_enable_async_collective_permute=true --xla_tpu_enable_ag_backward_pipelining=true --xla_tpu_enable_data_parallel_all_reduce_opt=true --xla_tpu_data_parallel_opt_different_sized_ops=true --xla_tpu_enable_async_collective_fusion=true --xla_tpu_enable_async_collective_fusion_multiple_steps=true --xla_tpu_overlap_compute_collective_tc=true --xla_enable_async_all_gather=true"


export llama_tokenizer_path="/mnt/default/lwm/data/checkpoints/lwm_checkpoints/tokenizer.model"
export dataset_path='/mnt/default/lwm/data/unipi/vpt_preprocessed_real/lang_table_whole_real_total_vpt.jsonl'

export output_dir='/mnt/default/lwm/data/unipi/lwm_vpt'
export WANDB_API_KEY="0d0155a751a804873eedf37c29060146b377edb8"


export project_id='lwm'
export experiment_note='world-model'
# export experiment_id='sthv2_bf16_from_whole_delta8_batch128_seq4_layer8_100K_re2'
export experiment_id='lang_table_vpt_real_whole'


# mesh_dim: dp, fsdp, tp, sp
python3 -u -m lwm.train \
    --modality='vision,text' \
    --mesh_dim='!-1,8,1,1' \
    --dtype='bf16' \
    --total_steps=13765 \
    --log_freq=1 \
    --save_model_freq=0 \
    --save_milestone_freq=13760 \
    --load_llama_config='7b' \
    --load_checkpoint="params::/mnt/default/lwm/data/checkpoints/lwm_checkpoints/params" \
    --update_llama_config="dict(theta=50000000,max_sequence_length=2048,use_flash_attention=True,scan_attention=True,scan_query_chunk_size=512,scan_key_chunk_size=1024,remat_attention='nothing_saveable',scan_mlp=True,scan_mlp_chunk_size=8192,remat_mlp='nothing_saveable',remat_block='nothing_saveable',scan_layers=True)" \
    --tokenizer.vocab_file="$llama_tokenizer_path" \
    --optimizer.type='adamw' \
    --optimizer.accumulate_gradient_steps=1 \
    --optimizer.adamw_optimizer.weight_decay=0 \
    --optimizer.adamw_optimizer.lr=2e-5 \
    --optimizer.adamw_optimizer.end_lr=2e-5 \
    --optimizer.adamw_optimizer.lr_warmup_steps=0 \
    --optimizer.adamw_optimizer.lr_decay_steps=36450 \
    --use_data_sharded_loader=True \
    --train_dataset.type='json_vision' \
    --train_dataset.vision_text_processor.fields_from_example='fields' \
    --train_dataset.vision_text_processor.max_n_frames=1 \
    --train_dataset.json_vision_dataset.mode="pad" \
    --train_dataset.json_vision_dataset.path="$dataset_path" \
    --train_dataset.json_vision_dataset.seq_length=384 \
    --train_dataset.json_vision_dataset.batch_size=128 \
    --train_dataset.json_vision_dataset.tokenizer_processes=1 \
    --train_dataset.json_vision_dataset.tokenizer_parallel_chunk_size=128 \
    --train_dataset.json_vision_dataset.tokenizer_parallel_batch_size=128 \
    --train_dataset.json_vision_dataset.use_data_sharded_loader=True \
    --checkpointer.save_optimizer_state=False \
    --autoresume=False \
    --logger.append_uuid=False \
    --logger.online=False \
    --logger.project_id="$project_id" \
    --logger.experiment_id="$experiment_id" \
    --logger.experiment_note="$experiment_note" \
    --logger.output_dir="$output_dir" \
    --logger.wandb_dir="$HOME/experiment_output/$project_id"





# export SCRIPT_DIR="$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )"
# export PROJECT_DIR="$( cd -- "$( dirname -- "$SCRIPT_DIR" )" &> /dev/null && pwd )"
# cd $PROJECT_DIR
# export PYTHONPATH="$PYTHONPATH:$PROJECT_DIR"
# export LIBTPU_INIT_ARGS="--xla_tpu_megacore_fusion_allow_ags=false --xla_enable_async_collective_permute=true --xla_tpu_enable_ag_backward_pipelining=true --xla_tpu_enable_data_parallel_all_reduce_opt=true --xla_tpu_data_parallel_opt_different_sized_ops=true --xla_tpu_enable_async_collective_fusion=true --xla_tpu_enable_async_collective_fusion_multiple_steps=true --xla_tpu_overlap_compute_collective_tc=true --xla_enable_async_all_gather=true"


# export llama_tokenizer_path="/mnt/default/lwm/data/checkpoints/lwm_checkpoints/tokenizer.model"

# export dataset_path='/mnt/default/lwm/data/lang_table/lang_table_whole_ver1_processed_1000traj.jsonl'

# export output_dir='/mnt/default/lwm/data/unipi/lwm_vpt'
# export WANDB_API_KEY="0d0155a751a804873eedf37c29060146b377edb8"


# export project_id='lwm'
# export experiment_note='world-model'
# # export experiment_id='sthv2_bf16_from_whole_delta8_batch128_seq4_layer8_100K_re2'
# export experiment_id='lang_table_vpt_sim_finetuned_1000'


# # mesh_dim: dp, fsdp, tp, sp
# python3 -u -m lwm.train \
#     --modality='vision,text' \
#     --mesh_dim='!-1,8,1,1' \
#     --dtype='bf16' \
#     --total_steps=505 \
#     --log_freq=1 \
#     --save_model_freq=0 \
#     --save_milestone_freq=500 \
#     --load_llama_config='7b' \
#     --load_checkpoint="params::/mnt/default/lwm/data/unipi/lwm_vpt/lang_table_vpt_sim_whole/streaming_params_36448" \
#     --update_llama_config="dict(theta=50000000,max_sequence_length=2048,use_flash_attention=True,scan_attention=True,scan_query_chunk_size=512,scan_key_chunk_size=1024,remat_attention='nothing_saveable',scan_mlp=True,scan_mlp_chunk_size=8192,remat_mlp='nothing_saveable',remat_block='nothing_saveable',scan_layers=True)" \
#     --tokenizer.vocab_file="$llama_tokenizer_path" \
#     --optimizer.type='adamw' \
#     --optimizer.accumulate_gradient_steps=1 \
#     --optimizer.adamw_optimizer.weight_decay=0 \
#     --optimizer.adamw_optimizer.lr=2e-5 \
#     --optimizer.adamw_optimizer.end_lr=2e-5 \
#     --optimizer.adamw_optimizer.lr_warmup_steps=0 \
#     --optimizer.adamw_optimizer.lr_decay_steps=36450 \
#     --use_data_sharded_loader=True \
#     --train_dataset.type='json_vision' \
#     --train_dataset.vision_text_processor.fields_from_example='fields' \
#     --train_dataset.vision_text_processor.max_n_frames=1 \
#     --train_dataset.json_vision_dataset.mode="pad" \
#     --train_dataset.json_vision_dataset.path="$dataset_path" \
#     --train_dataset.json_vision_dataset.seq_length=384 \
#     --train_dataset.json_vision_dataset.batch_size=128 \
#     --train_dataset.json_vision_dataset.tokenizer_processes=1 \
#     --train_dataset.json_vision_dataset.tokenizer_parallel_chunk_size=128 \
#     --train_dataset.json_vision_dataset.tokenizer_parallel_batch_size=128 \
#     --train_dataset.json_vision_dataset.use_data_sharded_loader=True \
#     --checkpointer.save_optimizer_state=False \
#     --autoresume=False \
#     --logger.append_uuid=False \
#     --logger.online=False \
#     --logger.project_id="$project_id" \
#     --logger.experiment_id="$experiment_id" \
#     --logger.experiment_note="$experiment_note" \
#     --logger.output_dir="$output_dir" \
#     --logger.wandb_dir="$HOME/experiment_output/$project_id"





# export SCRIPT_DIR="$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )"
# export PROJECT_DIR="$( cd -- "$( dirname -- "$SCRIPT_DIR" )" &> /dev/null && pwd )"
# cd $PROJECT_DIR
# export PYTHONPATH="$PYTHONPATH:$PROJECT_DIR"
# export LIBTPU_INIT_ARGS="--xla_tpu_megacore_fusion_allow_ags=false --xla_enable_async_collective_permute=true --xla_tpu_enable_ag_backward_pipelining=true --xla_tpu_enable_data_parallel_all_reduce_opt=true --xla_tpu_data_parallel_opt_different_sized_ops=true --xla_tpu_enable_async_collective_fusion=true --xla_tpu_enable_async_collective_fusion_multiple_steps=true --xla_tpu_overlap_compute_collective_tc=true --xla_enable_async_all_gather=true"


# export llama_tokenizer_path="/mnt/default/lwm/data/checkpoints/lwm_checkpoints/tokenizer.model"

# export dataset_path='/mnt/default/lwm/data/lang_table/lang_table_whole_ver1_processed_1000traj.jsonl'

# export output_dir='/mnt/default/lwm/data/unipi/lwm_vpt'
# export WANDB_API_KEY="0d0155a751a804873eedf37c29060146b377edb8"


# export project_id='lwm'
# export experiment_note='world-model'
# # export experiment_id='sthv2_bf16_from_whole_delta8_batch128_seq4_layer8_100K_re2'
# export experiment_id='lang_table_vpt_real_finetuned_1000'


# # mesh_dim: dp, fsdp, tp, sp
# python3 -u -m lwm.train \
#     --modality='vision,text' \
#     --mesh_dim='!-1,8,1,1' \
#     --dtype='bf16' \
#     --total_steps=505 \
#     --log_freq=1 \
#     --save_model_freq=0 \
#     --save_milestone_freq=500 \
#     --load_llama_config='7b' \
#     --load_checkpoint="params::/mnt/default/lwm/data/unipi/lwm_vpt/lang_table_vpt_real_whole/streaming_params_13760" \
#     --update_llama_config="dict(theta=50000000,max_sequence_length=2048,use_flash_attention=True,scan_attention=True,scan_query_chunk_size=512,scan_key_chunk_size=1024,remat_attention='nothing_saveable',scan_mlp=True,scan_mlp_chunk_size=8192,remat_mlp='nothing_saveable',remat_block='nothing_saveable',scan_layers=True)" \
#     --tokenizer.vocab_file="$llama_tokenizer_path" \
#     --optimizer.type='adamw' \
#     --optimizer.accumulate_gradient_steps=1 \
#     --optimizer.adamw_optimizer.weight_decay=0 \
#     --optimizer.adamw_optimizer.lr=2e-5 \
#     --optimizer.adamw_optimizer.end_lr=2e-5 \
#     --optimizer.adamw_optimizer.lr_warmup_steps=0 \
#     --optimizer.adamw_optimizer.lr_decay_steps=36450 \
#     --use_data_sharded_loader=True \
#     --train_dataset.type='json_vision' \
#     --train_dataset.vision_text_processor.fields_from_example='fields' \
#     --train_dataset.vision_text_processor.max_n_frames=1 \
#     --train_dataset.json_vision_dataset.mode="pad" \
#     --train_dataset.json_vision_dataset.path="$dataset_path" \
#     --train_dataset.json_vision_dataset.seq_length=384 \
#     --train_dataset.json_vision_dataset.batch_size=128 \
#     --train_dataset.json_vision_dataset.tokenizer_processes=1 \
#     --train_dataset.json_vision_dataset.tokenizer_parallel_chunk_size=128 \
#     --train_dataset.json_vision_dataset.tokenizer_parallel_batch_size=128 \
#     --train_dataset.json_vision_dataset.use_data_sharded_loader=True \
#     --checkpointer.save_optimizer_state=False \
#     --autoresume=False \
#     --logger.append_uuid=False \
#     --logger.online=False \
#     --logger.project_id="$project_id" \
#     --logger.experiment_id="$experiment_id" \
#     --logger.experiment_note="$experiment_note" \
#     --logger.output_dir="$output_dir" \
#     --logger.wandb_dir="$HOME/experiment_output/$project_id"




# export SCRIPT_DIR="$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )"
# export PROJECT_DIR="$( cd -- "$( dirname -- "$SCRIPT_DIR" )" &> /dev/null && pwd )"
# cd $PROJECT_DIR
# export PYTHONPATH="$PYTHONPATH:$PROJECT_DIR"
# export LIBTPU_INIT_ARGS="--xla_tpu_megacore_fusion_allow_ags=false --xla_enable_async_collective_permute=true --xla_tpu_enable_ag_backward_pipelining=true --xla_tpu_enable_data_parallel_all_reduce_opt=true --xla_tpu_data_parallel_opt_different_sized_ops=true --xla_tpu_enable_async_collective_fusion=true --xla_tpu_enable_async_collective_fusion_multiple_steps=true --xla_tpu_overlap_compute_collective_tc=true --xla_enable_async_all_gather=true"


# export llama_tokenizer_path="/mnt/default/lwm/data/checkpoints/lwm_checkpoints/tokenizer.model"

# export dataset_path='/mnt/default/lwm/data/lang_table/lang_table_whole_ver1_processed.jsonl'

# export output_dir='/mnt/default/lwm/data/unipi/lwm_vpt'
# export WANDB_API_KEY="0d0155a751a804873eedf37c29060146b377edb8"


# export project_id='lwm'
# export experiment_note='world-model'
# # export experiment_id='sthv2_bf16_from_whole_delta8_batch128_seq4_layer8_100K_re2'
# export experiment_id='lang_table_action_real_whole'


# # mesh_dim: dp, fsdp, tp, sp
# python3 -u -m lwm.train \
#     --modality='vision,text' \
#     --mesh_dim='!-1,8,1,1' \
#     --dtype='bf16' \
#     --total_steps=13761 \
#     --log_freq=1 \
#     --save_model_freq=0 \
#     --save_milestone_freq=13760 \
#     --load_llama_config='7b' \
#     --load_checkpoint="params::/mnt/default/lwm/data/checkpoints/lwm_checkpoints/params" \
#     --update_llama_config="dict(theta=50000000,max_sequence_length=2048,use_flash_attention=True,scan_attention=True,scan_query_chunk_size=512,scan_key_chunk_size=1024,remat_attention='nothing_saveable',scan_mlp=True,scan_mlp_chunk_size=8192,remat_mlp='nothing_saveable',remat_block='nothing_saveable',scan_layers=True)" \
#     --tokenizer.vocab_file="$llama_tokenizer_path" \
#     --optimizer.type='adamw' \
#     --optimizer.accumulate_gradient_steps=1 \
#     --optimizer.adamw_optimizer.weight_decay=0 \
#     --optimizer.adamw_optimizer.lr=2e-5 \
#     --optimizer.adamw_optimizer.end_lr=2e-5 \
#     --optimizer.adamw_optimizer.lr_warmup_steps=0 \
#     --optimizer.adamw_optimizer.lr_decay_steps=36450 \
#     --use_data_sharded_loader=True \
#     --train_dataset.type='json_vision' \
#     --train_dataset.vision_text_processor.fields_from_example='fields' \
#     --train_dataset.vision_text_processor.max_n_frames=1 \
#     --train_dataset.json_vision_dataset.mode="pad" \
#     --train_dataset.json_vision_dataset.path="$dataset_path" \
#     --train_dataset.json_vision_dataset.seq_length=384 \
#     --train_dataset.json_vision_dataset.batch_size=128 \
#     --train_dataset.json_vision_dataset.tokenizer_processes=1 \
#     --train_dataset.json_vision_dataset.tokenizer_parallel_chunk_size=128 \
#     --train_dataset.json_vision_dataset.tokenizer_parallel_batch_size=128 \
#     --train_dataset.json_vision_dataset.use_data_sharded_loader=True \
#     --checkpointer.save_optimizer_state=False \
#     --autoresume=False \
#     --logger.append_uuid=False \
#     --logger.online=False \
#     --logger.project_id="$project_id" \
#     --logger.experiment_id="$experiment_id" \
#     --logger.experiment_note="$experiment_note" \
#     --logger.output_dir="$output_dir" \
#     --logger.wandb_dir="$HOME/experiment_output/$project_id"



# export SCRIPT_DIR="$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )"
# export PROJECT_DIR="$( cd -- "$( dirname -- "$SCRIPT_DIR" )" &> /dev/null && pwd )"
# cd $PROJECT_DIR
# export PYTHONPATH="$PYTHONPATH:$PROJECT_DIR"
# export LIBTPU_INIT_ARGS="--xla_tpu_megacore_fusion_allow_ags=false --xla_enable_async_collective_permute=true --xla_tpu_enable_ag_backward_pipelining=true --xla_tpu_enable_data_parallel_all_reduce_opt=true --xla_tpu_data_parallel_opt_different_sized_ops=true --xla_tpu_enable_async_collective_fusion=true --xla_tpu_enable_async_collective_fusion_multiple_steps=true --xla_tpu_overlap_compute_collective_tc=true --xla_enable_async_all_gather=true"


# export llama_tokenizer_path="/mnt/default/lwm/data/checkpoints/lwm_checkpoints/tokenizer.model"

# export dataset_path='/mnt/default/lwm/data/lang_table/lang_table_whole_ver1_processed_1000traj.jsonl'

# export output_dir='/mnt/default/lwm/data/unipi/lwm_vpt'
# export WANDB_API_KEY="0d0155a751a804873eedf37c29060146b377edb8"


# export project_id='lwm'
# export experiment_note='world-model'
# # export experiment_id='sthv2_bf16_from_whole_delta8_batch128_seq4_layer8_100K_re2'
# export experiment_id='lang_table_action_real_finetuned_1000_sim'


# # mesh_dim: dp, fsdp, tp, sp
# python3 -u -m lwm.train \
#     --modality='vision,text' \
#     --mesh_dim='!-1,8,1,1' \
#     --dtype='bf16' \
#     --total_steps=505 \
#     --log_freq=1 \
#     --save_model_freq=0 \
#     --save_milestone_freq=500 \
#     --load_llama_config='7b' \
#     --load_checkpoint="params::/mnt/default/lwm/data/unipi/lwm_vpt/lang_table_action_real_whole/streaming_params_13760" \
#     --update_llama_config="dict(theta=50000000,max_sequence_length=2048,use_flash_attention=True,scan_attention=True,scan_query_chunk_size=512,scan_key_chunk_size=1024,remat_attention='nothing_saveable',scan_mlp=True,scan_mlp_chunk_size=8192,remat_mlp='nothing_saveable',remat_block='nothing_saveable',scan_layers=True)" \
#     --tokenizer.vocab_file="$llama_tokenizer_path" \
#     --optimizer.type='adamw' \
#     --optimizer.accumulate_gradient_steps=1 \
#     --optimizer.adamw_optimizer.weight_decay=0 \
#     --optimizer.adamw_optimizer.lr=2e-5 \
#     --optimizer.adamw_optimizer.end_lr=2e-5 \
#     --optimizer.adamw_optimizer.lr_warmup_steps=0 \
#     --optimizer.adamw_optimizer.lr_decay_steps=36450 \
#     --use_data_sharded_loader=True \
#     --train_dataset.type='json_vision' \
#     --train_dataset.vision_text_processor.fields_from_example='fields' \
#     --train_dataset.vision_text_processor.max_n_frames=1 \
#     --train_dataset.json_vision_dataset.mode="pad" \
#     --train_dataset.json_vision_dataset.path="$dataset_path" \
#     --train_dataset.json_vision_dataset.seq_length=384 \
#     --train_dataset.json_vision_dataset.batch_size=128 \
#     --train_dataset.json_vision_dataset.tokenizer_processes=1 \
#     --train_dataset.json_vision_dataset.tokenizer_parallel_chunk_size=128 \
#     --train_dataset.json_vision_dataset.tokenizer_parallel_batch_size=128 \
#     --train_dataset.json_vision_dataset.use_data_sharded_loader=True \
#     --checkpointer.save_optimizer_state=False \
#     --autoresume=False \
#     --logger.append_uuid=False \
#     --logger.online=False \
#     --logger.project_id="$project_id" \
#     --logger.experiment_id="$experiment_id" \
#     --logger.experiment_note="$experiment_note" \
#     --logger.output_dir="$output_dir" \
#     --logger.wandb_dir="$HOME/experiment_output/$project_id"




# export SCRIPT_DIR="$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )"
# export PROJECT_DIR="$( cd -- "$( dirname -- "$SCRIPT_DIR" )" &> /dev/null && pwd )"
# cd $PROJECT_DIR
# export PYTHONPATH="$PYTHONPATH:$PROJECT_DIR"
# export LIBTPU_INIT_ARGS="--xla_tpu_megacore_fusion_allow_ags=false --xla_enable_async_collective_permute=true --xla_tpu_enable_ag_backward_pipelining=true --xla_tpu_enable_data_parallel_all_reduce_opt=true --xla_tpu_data_parallel_opt_different_sized_ops=true --xla_tpu_enable_async_collective_fusion=true --xla_tpu_enable_async_collective_fusion_multiple_steps=true --xla_tpu_overlap_compute_collective_tc=true --xla_enable_async_all_gather=true"


# export llama_tokenizer_path="/root/checkpoints/lwm_checkpoints/tokenizer.model"
# export dataset_path='/root/World-Model/vpt_preprocessed/lang_table_whole_sim_sep_total_vpt_action.jsonl'

# export output_dir='/root/data/unipi/lwm_vpt'
# export WANDB_API_KEY="0d0155a751a804873eedf37c29060146b377edb8"


# export project_id='lwm'
# export experiment_note='world-model'
# # export experiment_id='sthv2_bf16_from_whole_delta8_batch128_seq4_layer8_100K_re2'
# export experiment_id='lang_table_vpt_sim_sep_whole'


# # mesh_dim: dp, fsdp, tp, sp
# CUDA_VISIBLE_DEVICES=4,5,6,7 python3 -u -m lwm.train \
#     --modality='vision,text' \
#     --mesh_dim='!-1,4,1,1' \
#     --dtype='bf16' \
#     --total_steps=36450 \
#     --log_freq=1 \
#     --eval_steps=0 \
#     --save_model_freq=36448 \
#     --save_milestone_freq=36448 \
#     --load_llama_config='7b' \
#     --load_checkpoint="params::/root/checkpoints/lwm_checkpoints/params" \
#     --update_llama_config="dict(theta=50000000,max_sequence_length=2048,use_flash_attention=True,scan_attention=True,scan_query_chunk_size=512,scan_key_chunk_size=1024,remat_attention='nothing_saveable',scan_mlp=True,scan_mlp_chunk_size=8192,remat_mlp='nothing_saveable',remat_block='nothing_saveable',scan_layers=True)" \
#     --tokenizer.vocab_file="$llama_tokenizer_path" \
#     --optimizer.type='adamw' \
#     --optimizer.accumulate_gradient_steps=1 \
#     --optimizer.adamw_optimizer.weight_decay=0 \
#     --optimizer.adamw_optimizer.lr=2e-5 \
#     --optimizer.adamw_optimizer.end_lr=2e-5 \
#     --optimizer.adamw_optimizer.lr_warmup_steps=0 \
#     --optimizer.adamw_optimizer.lr_decay_steps=36450 \
#     --use_data_sharded_loader=True \
#     --train_dataset.type='json_vision' \
#     --train_dataset.vision_text_processor.fields_from_example='fields' \
#     --train_dataset.vision_text_processor.max_n_frames=1 \
#     --train_dataset.json_vision_dataset.mode="pad" \
#     --train_dataset.json_vision_dataset.path="$dataset_path" \
#     --train_dataset.json_vision_dataset.seq_length=384 \
#     --train_dataset.json_vision_dataset.batch_size=128 \
#     --train_dataset.json_vision_dataset.tokenizer_processes=1 \
#     --train_dataset.json_vision_dataset.tokenizer_parallel_chunk_size=128 \
#     --train_dataset.json_vision_dataset.tokenizer_parallel_batch_size=128 \
#     --train_dataset.json_vision_dataset.use_data_sharded_loader=True \
#     --checkpointer.save_optimizer_state=False \
#     --autoresume=False \
#     --logger.append_uuid=False \
#     --logger.online=False \
#     --logger.project_id="$project_id" \
#     --logger.experiment_id="$experiment_id" \
#     --logger.experiment_note="$experiment_note" \
#     --logger.output_dir="$output_dir" \
#     --logger.wandb_dir="$HOME/experiment_output/$project_id"




export SCRIPT_DIR="$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )"
export PROJECT_DIR="$( cd -- "$( dirname -- "$SCRIPT_DIR" )" &> /dev/null && pwd )"
cd $PROJECT_DIR
export PYTHONPATH="$PYTHONPATH:$PROJECT_DIR"
export LIBTPU_INIT_ARGS="--xla_tpu_megacore_fusion_allow_ags=false --xla_enable_async_collective_permute=true --xla_tpu_enable_ag_backward_pipelining=true --xla_tpu_enable_data_parallel_all_reduce_opt=true --xla_tpu_data_parallel_opt_different_sized_ops=true --xla_tpu_enable_async_collective_fusion=true --xla_tpu_enable_async_collective_fusion_multiple_steps=true --xla_tpu_overlap_compute_collective_tc=true --xla_enable_async_all_gather=true"


export llama_tokenizer_path="/root/checkpoints/lwm_checkpoints/tokenizer.model"
export dataset_path='/root/data/data_0919/data/lang_table_separate_whole_shuffled.jsonl'

export output_dir='/root/data/unipi/lwm_vpt'
export WANDB_API_KEY="0d0155a751a804873eedf37c29060146b377edb8"


export project_id='lwm'
export experiment_note='world-model'
export experiment_id='lang_table_vpt_sim_sep_whole_finetuned'

# mesh_dim: dp, fsdp, tp, sp
CUDA_VISIBLE_DEVICES=4,5,6,7 python3 -u -m lwm.train \
    --modality='vision,text' \
    --mesh_dim='!-1,4,1,1' \
    --dtype='bf16' \
    --total_steps=1303 \
    --log_freq=1 \
    --eval_steps=0 \
    --save_model_freq=0 \
    --save_milestone_freq=1300 \
    --load_llama_config='7b' \
    --load_checkpoint="params::/root/data/unipi/lwm_vpt/lang_table_vpt_sim_sep_whole/streaming_params_36448" \
    --update_llama_config="dict(theta=50000000,max_sequence_length=2048,use_flash_attention=True,scan_attention=True,scan_query_chunk_size=512,scan_key_chunk_size=1024,remat_attention='nothing_saveable',scan_mlp=True,scan_mlp_chunk_size=8192,remat_mlp='nothing_saveable',remat_block='nothing_saveable',scan_layers=True)" \
    --tokenizer.vocab_file="$llama_tokenizer_path" \
    --optimizer.type='adamw' \
    --optimizer.accumulate_gradient_steps=1 \
    --optimizer.adamw_optimizer.weight_decay=0 \
    --optimizer.adamw_optimizer.lr=2e-5 \
    --optimizer.adamw_optimizer.end_lr=2e-5 \
    --optimizer.adamw_optimizer.lr_warmup_steps=0 \
    --optimizer.adamw_optimizer.lr_decay_steps=36450 \
    --use_data_sharded_loader=True \
    --train_dataset.type='json_vision' \
    --train_dataset.vision_text_processor.fields_from_example='fields' \
    --train_dataset.vision_text_processor.max_n_frames=1 \
    --train_dataset.json_vision_dataset.mode="pad" \
    --train_dataset.json_vision_dataset.path="$dataset_path" \
    --train_dataset.json_vision_dataset.seq_length=384 \
    --train_dataset.json_vision_dataset.batch_size=128 \
    --train_dataset.json_vision_dataset.tokenizer_processes=1 \
    --train_dataset.json_vision_dataset.tokenizer_parallel_chunk_size=128 \
    --train_dataset.json_vision_dataset.tokenizer_parallel_batch_size=128 \
    --train_dataset.json_vision_dataset.use_data_sharded_loader=True \
    --checkpointer.save_optimizer_state=False \
    --autoresume=False \
    --logger.append_uuid=False \
    --logger.online=True \
    --logger.project_id="$project_id" \
    --logger.experiment_id="$experiment_id" \
    --logger.experiment_note="$experiment_note" \
    --logger.output_dir="$output_dir" \
    --logger.wandb_dir="$HOME/experiment_output/$project_id"



