# export SCRIPT_DIR="$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )"
# export PROJECT_DIR="$( cd -- "$( dirname -- "$SCRIPT_DIR" )" &> /dev/null && pwd )"
# cd $PROJECT_DIR
# export PYTHONPATH="$PYTHONPATH:$PROJECT_DIR"
# export LIBTPU_INIT_ARGS="--xla_tpu_megacore_fusion_allow_ags=false --xla_enable_async_collective_permute=true --xla_tpu_enable_ag_backward_pipelining=true --xla_tpu_enable_data_parallel_all_reduce_opt=true --xla_tpu_data_parallel_opt_different_sized_ops=true --xla_tpu_enable_async_collective_fusion=true --xla_tpu_enable_async_collective_fusion_multiple_steps=true --xla_tpu_overlap_compute_collective_tc=true --xla_enable_async_all_gather=true"

# export llama_tokenizer_path="llava/checkpoints/lwm_checkpoints/tokenizer.model"
# export dataset_path="/home/World-Model/data/bridge_cross_attn_nsvq_code32_layer8_seq1_train.jsonl"
# export eval_dataset_path="/home/World-Model/data/bridge_cross_attn_nsvq_code32_layer8_seq1_val.jsonl"

# export output_dir="llava/checkpoints/debug"

# export project_id='lwm'
# export experiment_note='world-model'
# export experiment_id='bridge_cross_attn_nsvq_code32_layer8_seq1'

# # mesh_dim: dp, fsdp, tp, sp
# CUDA_VISIBLE_DEVICES=4,5,6,7 python3 -u -m lwm.train \
#     --modality='vision,text,delta' \
#     --mesh_dim='!-1,4,1,1' \
#     --dtype='bf16' \
#     --total_steps=15701 \
#     --log_freq=1 \
#     --delta_tokens=1 \
#     --eval_steps=1 \
#     --save_model_freq=15700 \
#     --eval_log_freq=100 \
#     --save_milestone_freq=15700 \
#     --load_llama_config='7b' \
#     --load_checkpoint="params::/home/World-Model/llava/checkpoints/lwm_checkpoints/params"\
#     --update_llama_config="dict(delta_vocab_size=32,theta=50000000,max_sequence_length=2048,use_flash_attention=True,scan_attention=True,scan_query_chunk_size=512,scan_key_chunk_size=1024,remat_attention='nothing_saveable',scan_mlp=True,scan_mlp_chunk_size=8192,remat_mlp='nothing_saveable',remat_block='nothing_saveable',scan_layers=True)" \
#     --tokenizer.vocab_file="$llama_tokenizer_path" \
#     --optimizer.type='adamw' \
#     --llama.delta_vocab_size=32 \
#     --optimizer.accumulate_gradient_steps=1 \
#     --optimizer.adamw_optimizer.weight_decay=0 \
#     --optimizer.adamw_optimizer.lr=2e-5 \
#     --optimizer.adamw_optimizer.end_lr=0 \
#     --optimizer.adamw_optimizer.lr_warmup_steps=1000 \
#     --optimizer.adamw_optimizer.lr_decay_steps=15700 \
#     --use_data_sharded_loader=True \
#     --train_dataset.type='json_vision_delta' \
#     --train_dataset.delta_vision_text_processor.fields_from_example='fields' \
#     --train_dataset.delta_vision_text_processor.n_tokens_per_delta=1 \
#     --train_dataset.delta_vision_text_processor.max_n_frames=1 \
#     --train_dataset.json_delta_dataset.mode="pad" \
#     --train_dataset.json_delta_dataset.path="$dataset_path" \
#     --train_dataset.json_delta_dataset.seq_length=384 \
#     --train_dataset.json_delta_dataset.batch_size=128 \
#     --train_dataset.json_delta_dataset.tokenizer_processes=1 \
#     --train_dataset.json_delta_dataset.tokenizer_parallel_chunk_size=128 \
#     --train_dataset.json_delta_dataset.tokenizer_parallel_batch_size=128 \
#     --train_dataset.json_delta_dataset.use_data_sharded_loader=True \
#     --eval_dataset.type='json_vision_delta' \
#     --eval_dataset.delta_vision_text_processor.fields_from_example='fields' \
#     --eval_dataset.delta_vision_text_processor.n_tokens_per_delta=1 \
#     --eval_dataset.delta_vision_text_processor.max_n_frames=1 \
#     --eval_dataset.json_delta_dataset.mode="pad" \
#     --eval_dataset.json_delta_dataset.path="$eval_dataset_path" \
#     --eval_dataset.json_delta_dataset.seq_length=384 \
#     --eval_dataset.json_delta_dataset.batch_size=128 \
#     --eval_dataset.json_delta_dataset.tokenizer_processes=1 \
#     --eval_dataset.json_delta_dataset.tokenizer_parallel_chunk_size=128 \
#     --eval_dataset.json_delta_dataset.tokenizer_parallel_batch_size=128 \
#     --eval_dataset.json_delta_dataset.use_data_sharded_loader=True \
#     --checkpointer.save_optimizer_state=False \
#     --autoresume=False \
#     --logger.append_uuid=False \
#     --logger.online=True \
#     --logger.project_id="$project_id" \
#     --logger.experiment_id="$experiment_id" \
#     --logger.experiment_note="$experiment_note" \
#     --logger.output_dir="$output_dir" \
#     --logger.wandb_dir="$HOME/experiment_output/$project_id"




# export SCRIPT_DIR="$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )"
# export PROJECT_DIR="$( cd -- "$( dirname -- "$SCRIPT_DIR" )" &> /dev/null && pwd )"
# cd $PROJECT_DIR
# export PYTHONPATH="$PYTHONPATH:$PROJECT_DIR"
# export LIBTPU_INIT_ARGS="--xla_tpu_megacore_fusion_allow_ags=false --xla_enable_async_collective_permute=true --xla_tpu_enable_ag_backward_pipelining=true --xla_tpu_enable_data_parallel_all_reduce_opt=true --xla_tpu_data_parallel_opt_different_sized_ops=true --xla_tpu_enable_async_collective_fusion=true --xla_tpu_enable_async_collective_fusion_multiple_steps=true --xla_tpu_overlap_compute_collective_tc=true --xla_enable_async_all_gather=true"

# export llama_tokenizer_path="llava/checkpoints/lwm_checkpoints/tokenizer.model"
# export dataset_path="/home/World-Model/data/bridge_cross_attn_nsvq_code2_layer8_seq4_train.jsonl"
# export eval_dataset_path="/home/World-Model/data/bridge_cross_attn_nsvq_code2_layer8_seq4_val.jsonl"

# export output_dir="llava/checkpoints/debug"

# export project_id='lwm'
# export experiment_note='world-model'
# export experiment_id='bridge_cross_attn_nsvq_code2_layer8_seq4'

# # mesh_dim: dp, fsdp, tp, sp
# CUDA_VISIBLE_DEVICES=4,5,6,7 python3 -u -m lwm.train \
#     --modality='vision,text,delta' \
#     --mesh_dim='!-1,4,1,1' \
#     --dtype='bf16' \
#     --total_steps=15701 \
#     --log_freq=1 \
#     --delta_tokens=4 \
#     --eval_steps=1 \
#     --save_model_freq=15700 \
#     --eval_log_freq=100 \
#     --save_milestone_freq=15700 \
#     --load_llama_config='7b' \
#     --load_checkpoint="params::/home/World-Model/llava/checkpoints/lwm_checkpoints/params"\
#     --update_llama_config="dict(delta_vocab_size=2,theta=50000000,max_sequence_length=2048,use_flash_attention=True,scan_attention=True,scan_query_chunk_size=512,scan_key_chunk_size=1024,remat_attention='nothing_saveable',scan_mlp=True,scan_mlp_chunk_size=8192,remat_mlp='nothing_saveable',remat_block='nothing_saveable',scan_layers=True)" \
#     --tokenizer.vocab_file="$llama_tokenizer_path" \
#     --optimizer.type='adamw' \
#     --llama.delta_vocab_size=2 \
#     --optimizer.accumulate_gradient_steps=1 \
#     --optimizer.adamw_optimizer.weight_decay=0 \
#     --optimizer.adamw_optimizer.lr=2e-5 \
#     --optimizer.adamw_optimizer.end_lr=0 \
#     --optimizer.adamw_optimizer.lr_warmup_steps=1000 \
#     --optimizer.adamw_optimizer.lr_decay_steps=15700 \
#     --use_data_sharded_loader=True \
#     --train_dataset.type='json_vision_delta' \
#     --train_dataset.delta_vision_text_processor.fields_from_example='fields' \
#     --train_dataset.delta_vision_text_processor.n_tokens_per_delta=4 \
#     --train_dataset.delta_vision_text_processor.max_n_frames=1 \
#     --train_dataset.json_delta_dataset.mode="pad" \
#     --train_dataset.json_delta_dataset.path="$dataset_path" \
#     --train_dataset.json_delta_dataset.seq_length=384 \
#     --train_dataset.json_delta_dataset.batch_size=128 \
#     --train_dataset.json_delta_dataset.tokenizer_processes=1 \
#     --train_dataset.json_delta_dataset.tokenizer_parallel_chunk_size=128 \
#     --train_dataset.json_delta_dataset.tokenizer_parallel_batch_size=128 \
#     --train_dataset.json_delta_dataset.use_data_sharded_loader=True \
#     --eval_dataset.type='json_vision_delta' \
#     --eval_dataset.delta_vision_text_processor.fields_from_example='fields' \
#     --eval_dataset.delta_vision_text_processor.n_tokens_per_delta=4 \
#     --eval_dataset.delta_vision_text_processor.max_n_frames=1 \
#     --eval_dataset.json_delta_dataset.mode="pad" \
#     --eval_dataset.json_delta_dataset.path="$eval_dataset_path" \
#     --eval_dataset.json_delta_dataset.seq_length=384 \
#     --eval_dataset.json_delta_dataset.batch_size=128 \
#     --eval_dataset.json_delta_dataset.tokenizer_processes=1 \
#     --eval_dataset.json_delta_dataset.tokenizer_parallel_chunk_size=128 \
#     --eval_dataset.json_delta_dataset.tokenizer_parallel_batch_size=128 \
#     --eval_dataset.json_delta_dataset.use_data_sharded_loader=True \
#     --checkpointer.save_optimizer_state=False \
#     --autoresume=False \
#     --logger.append_uuid=False \
#     --logger.online=True \
#     --logger.project_id="$project_id" \
#     --logger.experiment_id="$experiment_id" \
#     --logger.experiment_note="$experiment_note" \
#     --logger.output_dir="$output_dir" \
#     --logger.wandb_dir="$HOME/experiment_output/$project_id"





# export SCRIPT_DIR="$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )"
# export PROJECT_DIR="$( cd -- "$( dirname -- "$SCRIPT_DIR" )" &> /dev/null && pwd )"
# cd $PROJECT_DIR
# export PYTHONPATH="$PYTHONPATH:$PROJECT_DIR"
# export LIBTPU_INIT_ARGS="--xla_tpu_megacore_fusion_allow_ags=false --xla_enable_async_collective_permute=true --xla_tpu_enable_ag_backward_pipelining=true --xla_tpu_enable_data_parallel_all_reduce_opt=true --xla_tpu_data_parallel_opt_different_sized_ops=true --xla_tpu_enable_async_collective_fusion=true --xla_tpu_enable_async_collective_fusion_multiple_steps=true --xla_tpu_overlap_compute_collective_tc=true --xla_enable_async_all_gather=true"

# export llama_tokenizer_path="llava/checkpoints/lwm_checkpoints/tokenizer.model"
# export dataset_path="/home/World-Model/data/bridge_singleview_total_action_processed_filtered_1000traj.jsonl"
# export eval_dataset_path="/home/World-Model/data/bridge_singleview_total_action_processed_val.jsonl"

# export output_dir="llava/checkpoints/debug"

# export project_id='lwm'
# export experiment_note='world-model'
# export experiment_id='lwm_1000traj_bridge_bf16_from_whole_delta8_batch128_seq1_600'

# # mesh_dim: dp, fsdp, tp, sp
# CUDA_VISIBLE_DEVICES=0,1,2,3 python3 -u -m lwm.train \
#     --modality='vision,text,delta' \
#     --mesh_dim='!-1,4,1,1' \
#     --dtype='bf16' \
#     --total_steps=603 \
#     --log_freq=1 \
#     --delta_tokens=1 \
#     --eval_steps=1 \
#     --save_model_freq=600 \
#     --eval_log_freq=100 \
#     --save_milestone_freq=0 \
#     --load_llama_config='7b' \
#     --load_checkpoint="params::/home/World-Model/llava/checkpoints/debug/bridge_cross_attn_nsvq_code8_layer8_seq1/streaming_params_15700"\
#     --update_llama_config="dict(delta_vocab_size=8,theta=50000000,max_sequence_length=2048,use_flash_attention=True,scan_attention=True,scan_query_chunk_size=512,scan_key_chunk_size=1024,remat_attention='nothing_saveable',scan_mlp=True,scan_mlp_chunk_size=8192,remat_mlp='nothing_saveable',remat_block='nothing_saveable',scan_layers=True)" \
#     --tokenizer.vocab_file="$llama_tokenizer_path" \
#     --optimizer.type='adamw' \
#     --llama.delta_vocab_size=8 \
#     --optimizer.accumulate_gradient_steps=1 \
#     --optimizer.adamw_optimizer.weight_decay=0 \
#     --optimizer.adamw_optimizer.lr=2e-5 \
#     --optimizer.adamw_optimizer.end_lr=2e-5 \
#     --optimizer.adamw_optimizer.lr_warmup_steps=0 \
#     --optimizer.adamw_optimizer.lr_decay_steps=3000 \
#     --use_data_sharded_loader=True \
#     --train_dataset.type='json_vision_delta' \
#     --train_dataset.delta_vision_text_processor.fields_from_example='fields' \
#     --train_dataset.delta_vision_text_processor.n_tokens_per_delta=1 \
#     --train_dataset.delta_vision_text_processor.max_n_frames=1 \
#     --train_dataset.json_delta_dataset.mode="pad" \
#     --train_dataset.json_delta_dataset.path="$dataset_path" \
#     --train_dataset.json_delta_dataset.seq_length=384 \
#     --train_dataset.json_delta_dataset.batch_size=128 \
#     --train_dataset.json_delta_dataset.tokenizer_processes=1 \
#     --train_dataset.json_delta_dataset.tokenizer_parallel_chunk_size=128 \
#     --train_dataset.json_delta_dataset.tokenizer_parallel_batch_size=128 \
#     --train_dataset.json_delta_dataset.use_data_sharded_loader=True \
#     --eval_dataset.type='json_vision_delta' \
#     --eval_dataset.delta_vision_text_processor.fields_from_example='fields' \
#     --eval_dataset.delta_vision_text_processor.n_tokens_per_delta=1 \
#     --eval_dataset.delta_vision_text_processor.max_n_frames=1 \
#     --eval_dataset.json_delta_dataset.mode="pad" \
#     --eval_dataset.json_delta_dataset.path="$eval_dataset_path" \
#     --eval_dataset.json_delta_dataset.seq_length=384 \
#     --eval_dataset.json_delta_dataset.batch_size=128 \
#     --eval_dataset.json_delta_dataset.tokenizer_processes=1 \
#     --eval_dataset.json_delta_dataset.tokenizer_parallel_chunk_size=128 \
#     --eval_dataset.json_delta_dataset.tokenizer_parallel_batch_size=128 \
#     --eval_dataset.json_delta_dataset.use_data_sharded_loader=True \
#     --checkpointer.save_optimizer_state=False \
#     --autoresume=False \
#     --logger.append_uuid=False \
#     --logger.online=True \
#     --logger.project_id="$project_id" \
#     --logger.experiment_id="$experiment_id" \
#     --logger.experiment_note="$experiment_note" \
#     --logger.output_dir="$output_dir" \
#     --logger.wandb_dir="$HOME/experiment_output/$project_id"




# export SCRIPT_DIR="$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )"
# export PROJECT_DIR="$( cd -- "$( dirname -- "$SCRIPT_DIR" )" &> /dev/null && pwd )"
# cd $PROJECT_DIR
# export PYTHONPATH="$PYTHONPATH:$PROJECT_DIR"
# export LIBTPU_INIT_ARGS="--xla_tpu_megacore_fusion_allow_ags=false --xla_enable_async_collective_permute=true --xla_tpu_enable_ag_backward_pipelining=true --xla_tpu_enable_data_parallel_all_reduce_opt=true --xla_tpu_data_parallel_opt_different_sized_ops=true --xla_tpu_enable_async_collective_fusion=true --xla_tpu_enable_async_collective_fusion_multiple_steps=true --xla_tpu_overlap_compute_collective_tc=true --xla_enable_async_all_gather=true"

# export llama_tokenizer_path="llava/checkpoints/lwm_checkpoints/tokenizer.model"
# export dataset_path="/home/World-Model/data/bridge_singleview_total_action_processed_filtered_1000traj.jsonl"
# export eval_dataset_path="/home/World-Model/data/bridge_singleview_total_action_processed_val.jsonl"

# export output_dir="llava/checkpoints/debug"

# export project_id='lwm'
# export experiment_note='world-model'
# export experiment_id='lwm_1000traj_bridge_bf16_from_whole_delta32_batch128_seq1_filtered'

# # mesh_dim: dp, fsdp, tp, sp
# CUDA_VISIBLE_DEVICES=0,1,2,3 python3 -u -m lwm.train \
#     --modality='vision,text,delta' \
#     --mesh_dim='!-1,4,1,1' \
#     --dtype='bf16' \
#     --total_steps=603 \
#     --log_freq=1 \
#     --delta_tokens=1 \
#     --eval_steps=1 \
#     --save_model_freq=600 \
#     --eval_log_freq=100 \
#     --save_milestone_freq=0 \
#     --load_llama_config='7b' \
#     --load_checkpoint="params::/home/World-Model/llava/checkpoints/debug/bridge_cross_attn_nsvq_code32_layer8_seq1_filtered_epoch3/streaming_params_15287"\
#     --update_llama_config="dict(delta_vocab_size=32,theta=50000000,max_sequence_length=2048,use_flash_attention=True,scan_attention=True,scan_query_chunk_size=512,scan_key_chunk_size=1024,remat_attention='nothing_saveable',scan_mlp=True,scan_mlp_chunk_size=8192,remat_mlp='nothing_saveable',remat_block='nothing_saveable',scan_layers=True)" \
#     --tokenizer.vocab_file="$llama_tokenizer_path" \
#     --optimizer.type='adamw' \
#     --llama.delta_vocab_size=32 \
#     --optimizer.accumulate_gradient_steps=1 \
#     --optimizer.adamw_optimizer.weight_decay=0 \
#     --optimizer.adamw_optimizer.lr=2e-5 \
#     --optimizer.adamw_optimizer.end_lr=2e-5 \
#     --optimizer.adamw_optimizer.lr_warmup_steps=0 \
#     --optimizer.adamw_optimizer.lr_decay_steps=3000 \
#     --use_data_sharded_loader=True \
#     --train_dataset.type='json_vision_delta' \
#     --train_dataset.delta_vision_text_processor.fields_from_example='fields' \
#     --train_dataset.delta_vision_text_processor.n_tokens_per_delta=1 \
#     --train_dataset.delta_vision_text_processor.max_n_frames=1 \
#     --train_dataset.json_delta_dataset.mode="pad" \
#     --train_dataset.json_delta_dataset.path="$dataset_path" \
#     --train_dataset.json_delta_dataset.seq_length=384 \
#     --train_dataset.json_delta_dataset.batch_size=128 \
#     --train_dataset.json_delta_dataset.tokenizer_processes=1 \
#     --train_dataset.json_delta_dataset.tokenizer_parallel_chunk_size=128 \
#     --train_dataset.json_delta_dataset.tokenizer_parallel_batch_size=128 \
#     --train_dataset.json_delta_dataset.use_data_sharded_loader=True \
#     --eval_dataset.type='json_vision_delta' \
#     --eval_dataset.delta_vision_text_processor.fields_from_example='fields' \
#     --eval_dataset.delta_vision_text_processor.n_tokens_per_delta=1 \
#     --eval_dataset.delta_vision_text_processor.max_n_frames=1 \
#     --eval_dataset.json_delta_dataset.mode="pad" \
#     --eval_dataset.json_delta_dataset.path="$eval_dataset_path" \
#     --eval_dataset.json_delta_dataset.seq_length=384 \
#     --eval_dataset.json_delta_dataset.batch_size=128 \
#     --eval_dataset.json_delta_dataset.tokenizer_processes=1 \
#     --eval_dataset.json_delta_dataset.tokenizer_parallel_chunk_size=128 \
#     --eval_dataset.json_delta_dataset.tokenizer_parallel_batch_size=128 \
#     --eval_dataset.json_delta_dataset.use_data_sharded_loader=True \
#     --checkpointer.save_optimizer_state=False \
#     --autoresume=False \
#     --logger.append_uuid=False \
#     --logger.online=True \
#     --logger.project_id="$project_id" \
#     --logger.experiment_id="$experiment_id" \
#     --logger.experiment_note="$experiment_note" \
#     --logger.output_dir="$output_dir" \
#     --logger.wandb_dir="$HOME/experiment_output/$project_id"

# conda deactivate
# conda activate simpler_env
# cd SimplerEnv 
# ./scripts/lwm_bridge.sh 
# cd ..
# conda deactivate
# conda activate lwm



# export SCRIPT_DIR="$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )"
# export PROJECT_DIR="$( cd -- "$( dirname -- "$SCRIPT_DIR" )" &> /dev/null && pwd )"
# cd $PROJECT_DIR
# export PYTHONPATH="$PYTHONPATH:$PROJECT_DIR"
# export LIBTPU_INIT_ARGS="--xla_tpu_megacore_fusion_allow_ags=false --xla_enable_async_collective_permute=true --xla_tpu_enable_ag_backward_pipelining=true --xla_tpu_enable_data_parallel_all_reduce_opt=true --xla_tpu_data_parallel_opt_different_sized_ops=true --xla_tpu_enable_async_collective_fusion=true --xla_tpu_enable_async_collective_fusion_multiple_steps=true --xla_tpu_overlap_compute_collective_tc=true --xla_enable_async_all_gather=true"

# export llama_tokenizer_path="llava/checkpoints/lwm_checkpoints/tokenizer.model"
# export dataset_path="/home/World-Model/data/bridge_singleview_total_action_processed_filtered_1000traj.jsonl"
# export eval_dataset_path="/home/World-Model/data/bridge_singleview_total_action_processed_val.jsonl"

# export output_dir="llava/checkpoints/debug"

# export project_id='lwm'
# export experiment_note='world-model'
# export experiment_id='lwm_1000traj_bridge_bf16_from_whole_delta32_batch128_seq1_filtered_long'

# # mesh_dim: dp, fsdp, tp, sp
# CUDA_VISIBLE_DEVICES=0,1,2,3 python3 -u -m lwm.train \
#     --modality='vision,text,delta' \
#     --mesh_dim='!-1,4,1,1' \
#     --dtype='bf16' \
#     --total_steps=5005 \
#     --log_freq=1 \
#     --delta_tokens=1 \
#     --eval_steps=1 \
#     --save_model_freq=5000 \
#     --eval_log_freq=100 \
#     --save_milestone_freq=0 \
#     --load_llama_config='7b' \
#     --load_checkpoint="params::/home/World-Model/llava/checkpoints/debug/bridge_cross_attn_nsvq_code32_layer8_seq1_filtered_epoch3/streaming_params_15287"\
#     --update_llama_config="dict(delta_vocab_size=32,theta=50000000,max_sequence_length=2048,use_flash_attention=True,scan_attention=True,scan_query_chunk_size=512,scan_key_chunk_size=1024,remat_attention='nothing_saveable',scan_mlp=True,scan_mlp_chunk_size=8192,remat_mlp='nothing_saveable',remat_block='nothing_saveable',scan_layers=True)" \
#     --tokenizer.vocab_file="$llama_tokenizer_path" \
#     --optimizer.type='adamw' \
#     --llama.delta_vocab_size=32 \
#     --optimizer.accumulate_gradient_steps=1 \
#     --optimizer.adamw_optimizer.weight_decay=0 \
#     --optimizer.adamw_optimizer.lr=2e-5 \
#     --optimizer.adamw_optimizer.end_lr=2e-5 \
#     --optimizer.adamw_optimizer.lr_warmup_steps=0 \
#     --optimizer.adamw_optimizer.lr_decay_steps=3000 \
#     --use_data_sharded_loader=True \
#     --train_dataset.type='json_vision_delta' \
#     --train_dataset.delta_vision_text_processor.fields_from_example='fields' \
#     --train_dataset.delta_vision_text_processor.n_tokens_per_delta=1 \
#     --train_dataset.delta_vision_text_processor.max_n_frames=1 \
#     --train_dataset.json_delta_dataset.mode="pad" \
#     --train_dataset.json_delta_dataset.path="$dataset_path" \
#     --train_dataset.json_delta_dataset.seq_length=384 \
#     --train_dataset.json_delta_dataset.batch_size=128 \
#     --train_dataset.json_delta_dataset.tokenizer_processes=1 \
#     --train_dataset.json_delta_dataset.tokenizer_parallel_chunk_size=128 \
#     --train_dataset.json_delta_dataset.tokenizer_parallel_batch_size=128 \
#     --train_dataset.json_delta_dataset.use_data_sharded_loader=True \
#     --eval_dataset.type='json_vision_delta' \
#     --eval_dataset.delta_vision_text_processor.fields_from_example='fields' \
#     --eval_dataset.delta_vision_text_processor.n_tokens_per_delta=1 \
#     --eval_dataset.delta_vision_text_processor.max_n_frames=1 \
#     --eval_dataset.json_delta_dataset.mode="pad" \
#     --eval_dataset.json_delta_dataset.path="$eval_dataset_path" \
#     --eval_dataset.json_delta_dataset.seq_length=384 \
#     --eval_dataset.json_delta_dataset.batch_size=128 \
#     --eval_dataset.json_delta_dataset.tokenizer_processes=1 \
#     --eval_dataset.json_delta_dataset.tokenizer_parallel_chunk_size=128 \
#     --eval_dataset.json_delta_dataset.tokenizer_parallel_batch_size=128 \
#     --eval_dataset.json_delta_dataset.use_data_sharded_loader=True \
#     --checkpointer.save_optimizer_state=False \
#     --autoresume=False \
#     --logger.append_uuid=False \
#     --logger.online=True \
#     --logger.project_id="$project_id" \
#     --logger.experiment_id="$experiment_id" \
#     --logger.experiment_note="$experiment_note" \
#     --logger.output_dir="$output_dir" \
#     --logger.wandb_dir="$HOME/experiment_output/$project_id"


# export SCRIPT_DIR="$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )"
# export PROJECT_DIR="$( cd -- "$( dirname -- "$SCRIPT_DIR" )" &> /dev/null && pwd )"
# cd $PROJECT_DIR
# export PYTHONPATH="$PYTHONPATH:$PROJECT_DIR"
# export LIBTPU_INIT_ARGS="--xla_tpu_megacore_fusion_allow_ags=false --xla_enable_async_collective_permute=true --xla_tpu_enable_ag_backward_pipelining=true --xla_tpu_enable_data_parallel_all_reduce_opt=true --xla_tpu_data_parallel_opt_different_sized_ops=true --xla_tpu_enable_async_collective_fusion=true --xla_tpu_enable_async_collective_fusion_multiple_steps=true --xla_tpu_overlap_compute_collective_tc=true --xla_enable_async_all_gather=true"

# export llama_tokenizer_path="llava/checkpoints/lwm_checkpoints/tokenizer.model"
# export dataset_path="/home/World-Model/data/bridge_singleview_total_action_processed_filtered_1000traj_action.jsonl"
# export eval_dataset_path="/home/World-Model/data/bridge_singleview_total_action_processed_action_val.jsonl"

# export output_dir="llava/checkpoints/debug"

# export project_id='lwm'
# export experiment_note='world-model'
# export experiment_id='lwm_1000traj_bridge_bf16_from_whole_delta32_batch128_seq1_filtered_epoch2_action'

# # mesh_dim: dp, fsdp, tp, sp
# CUDA_VISIBLE_DEVICES=0,1,2,3 python3 -u -m lwm.train \
#     --modality='vision,action,delta' \
#     --mesh_dim='!-1,4,1,1' \
#     --dtype='bf16' \
#     --total_steps=603 \
#     --log_freq=1 \
#     --delta_tokens=1 \
#     --eval_steps=1 \
#     --save_model_freq=600 \
#     --eval_log_freq=100 \
#     --save_milestone_freq=0 \
#     --load_llama_config='7b' \
#     --load_checkpoint="params::/home/World-Model/llava/checkpoints/debug/bridge_cross_attn_nsvq_code32_layer8_seq1_filtered_epoch3/streaming_params_30574"\
#     --update_llama_config="dict(action_vocab_size=245,delta_vocab_size=32,theta=50000000,max_sequence_length=2048,use_flash_attention=True,scan_attention=True,scan_query_chunk_size=512,scan_key_chunk_size=1024,remat_attention='nothing_saveable',scan_mlp=True,scan_mlp_chunk_size=8192,remat_mlp='nothing_saveable',remat_block='nothing_saveable',scan_layers=True)" \
#     --tokenizer.vocab_file="$llama_tokenizer_path" \
#     --optimizer.type='adamw' \
#     --llama.delta_vocab_size=32 \
#     --llama.action_vocab_size=245 \
#     --optimizer.accumulate_gradient_steps=1 \
#     --optimizer.adamw_optimizer.weight_decay=0 \
#     --optimizer.adamw_optimizer.lr=2e-5 \
#     --optimizer.adamw_optimizer.end_lr=2e-5 \
#     --optimizer.adamw_optimizer.lr_warmup_steps=0 \
#     --optimizer.adamw_optimizer.lr_decay_steps=3000 \
#     --use_data_sharded_loader=True \
#     --train_dataset.type='json_vision_delta_action' \
#     --train_dataset.delta_vision_action_processor.fields_from_example='fields' \
#     --train_dataset.delta_vision_action_processor.n_tokens_per_delta=1 \
#     --train_dataset.delta_vision_action_processor.n_tokens_per_action=7 \
#     --train_dataset.delta_vision_action_processor.max_n_frames=1 \
#     --train_dataset.json_delta_action_dataset.mode="pad" \
#     --train_dataset.json_delta_action_dataset.path="$dataset_path" \
#     --train_dataset.json_delta_action_dataset.seq_length=384 \
#     --train_dataset.json_delta_action_dataset.batch_size=128 \
#     --train_dataset.json_delta_action_dataset.tokenizer_processes=1 \
#     --train_dataset.json_delta_action_dataset.tokenizer_parallel_chunk_size=128 \
#     --train_dataset.json_delta_action_dataset.tokenizer_parallel_batch_size=128 \
#     --train_dataset.json_delta_action_dataset.use_data_sharded_loader=True \
#     --eval_dataset.type='json_vision_delta_action' \
#     --eval_dataset.delta_vision_action_processor.fields_from_example='fields' \
#     --eval_dataset.delta_vision_action_processor.n_tokens_per_delta=1 \
#     --eval_dataset.delta_vision_action_processor.n_tokens_per_action=7 \
#     --eval_dataset.delta_vision_action_processor.max_n_frames=1 \
#     --eval_dataset.json_delta_action_dataset.mode="pad" \
#     --eval_dataset.json_delta_action_dataset.path="$eval_dataset_path" \
#     --eval_dataset.json_delta_action_dataset.seq_length=384 \
#     --eval_dataset.json_delta_action_dataset.batch_size=128 \
#     --eval_dataset.json_delta_action_dataset.tokenizer_processes=1 \
#     --eval_dataset.json_delta_action_dataset.tokenizer_parallel_chunk_size=128 \
#     --eval_dataset.json_delta_action_dataset.tokenizer_parallel_batch_size=128 \
#     --eval_dataset.json_delta_action_dataset.use_data_sharded_loader=True \
#     --checkpointer.save_optimizer_state=False \
#     --autoresume=False \
#     --logger.append_uuid=False \
#     --logger.online=True \
#     --logger.project_id="$project_id" \
#     --logger.experiment_id="$experiment_id" \
#     --logger.experiment_note="$experiment_note" \
#     --logger.output_dir="$output_dir" \
#     --logger.wandb_dir="$HOME/experiment_output/$project_id"



# export SCRIPT_DIR="$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )"
# export PROJECT_DIR="$( cd -- "$( dirname -- "$SCRIPT_DIR" )" &> /dev/null && pwd )"
# cd $PROJECT_DIR
# export PYTHONPATH="$PYTHONPATH:$PROJECT_DIR"
# export LIBTPU_INIT_ARGS="--xla_tpu_megacore_fusion_allow_ags=false --xla_enable_async_collective_permute=true --xla_tpu_enable_ag_backward_pipelining=true --xla_tpu_enable_data_parallel_all_reduce_opt=true --xla_tpu_data_parallel_opt_different_sized_ops=true --xla_tpu_enable_async_collective_fusion=true --xla_tpu_enable_async_collective_fusion_multiple_steps=true --xla_tpu_overlap_compute_collective_tc=true --xla_enable_async_all_gather=true"

# export llama_tokenizer_path="llava/checkpoints/lwm_checkpoints/tokenizer.model"
# export dataset_path="/home/World-Model/data/bridge_singleview_total_action_processed_filtered_1000traj_action.jsonl"
# export eval_dataset_path="/home/World-Model/data/bridge_singleview_total_action_processed_action_val.jsonl"

# export output_dir="llava/checkpoints/debug"

# export project_id='lwm'
# export experiment_note='world-model'
# export experiment_id='lwm_whole_traj_bridge_bf16_from_whole_delta128_batch32_seq1_filtered_epoch1_noinst_action'

# # mesh_dim: dp, fsdp, tp, sp
# CUDA_VISIBLE_DEVICES=0,1,2,3 python3 -u -m lwm.train \
#     --modality='vision,action,delta' \
#     --mesh_dim='!-1,4,1,1' \
#     --dtype='bf16' \
#     --total_steps=4004 \
#     --log_freq=1 \
#     --delta_tokens=1 \
#     --eval_steps=1 \
#     --save_model_freq=0 \
#     --eval_log_freq=100 \
#     --save_milestone_freq=1000 \
#     --load_llama_config='7b' \
#     --load_checkpoint="params::/home/World-Model/llava/checkpoints/debug/bridge_cross_attn_nsvq_code32_layer8_seq1_filtered_epoch3/streaming_params_15287"\
#     --update_llama_config="dict(action_vocab_size=245,delta_vocab_size=32,theta=50000000,max_sequence_length=2048,use_flash_attention=True,scan_attention=True,scan_query_chunk_size=512,scan_key_chunk_size=1024,remat_attention='nothing_saveable',scan_mlp=True,scan_mlp_chunk_size=8192,remat_mlp='nothing_saveable',remat_block='nothing_saveable',scan_layers=True)" \
#     --tokenizer.vocab_file="$llama_tokenizer_path" \
#     --optimizer.type='adamw' \
#     --llama.delta_vocab_size=32 \
#     --llama.action_vocab_size=245 \
#     --optimizer.accumulate_gradient_steps=1 \
#     --optimizer.adamw_optimizer.weight_decay=0 \
#     --optimizer.adamw_optimizer.lr=2e-5 \
#     --optimizer.adamw_optimizer.end_lr=2e-5 \
#     --optimizer.adamw_optimizer.lr_warmup_steps=0 \
#     --optimizer.adamw_optimizer.lr_decay_steps=3000 \
#     --use_data_sharded_loader=True \
#     --train_dataset.type='json_vision_delta_action' \
#     --train_dataset.delta_vision_action_processor.fields_from_example='fields' \
#     --train_dataset.delta_vision_action_processor.n_tokens_per_delta=1 \
#     --train_dataset.delta_vision_action_processor.n_tokens_per_action=7 \
#     --train_dataset.delta_vision_action_processor.max_n_frames=1 \
#     --train_dataset.json_delta_action_dataset.mode="pad" \
#     --train_dataset.json_delta_action_dataset.path="$dataset_path" \
#     --train_dataset.json_delta_action_dataset.seq_length=384 \
#     --train_dataset.json_delta_action_dataset.batch_size=32 \
#     --train_dataset.json_delta_action_dataset.tokenizer_processes=1 \
#     --train_dataset.json_delta_action_dataset.tokenizer_parallel_chunk_size=32 \
#     --train_dataset.json_delta_action_dataset.tokenizer_parallel_batch_size=32 \
#     --train_dataset.json_delta_action_dataset.use_data_sharded_loader=True \
#     --eval_dataset.type='json_vision_delta_action' \
#     --eval_dataset.delta_vision_action_processor.fields_from_example='fields' \
#     --eval_dataset.delta_vision_action_processor.n_tokens_per_delta=1 \
#     --eval_dataset.delta_vision_action_processor.n_tokens_per_action=7 \
#     --eval_dataset.delta_vision_action_processor.max_n_frames=1 \
#     --eval_dataset.json_delta_action_dataset.mode="pad" \
#     --eval_dataset.json_delta_action_dataset.path="$eval_dataset_path" \
#     --eval_dataset.json_delta_action_dataset.seq_length=384 \
#     --eval_dataset.json_delta_action_dataset.batch_size=32 \
#     --eval_dataset.json_delta_action_dataset.tokenizer_processes=1 \
#     --eval_dataset.json_delta_action_dataset.tokenizer_parallel_chunk_size=32 \
#     --eval_dataset.json_delta_action_dataset.tokenizer_parallel_batch_size=32 \
#     --eval_dataset.json_delta_action_dataset.use_data_sharded_loader=True \
#     --checkpointer.save_optimizer_state=False \
#     --autoresume=False \
#     --logger.append_uuid=False \
#     --logger.online=True \
#     --logger.project_id="$project_id" \
#     --logger.experiment_id="$experiment_id" \
#     --logger.experiment_note="$experiment_note" \
#     --logger.output_dir="$output_dir" \
#     --logger.wandb_dir="$HOME/experiment_output/$project_id"



# export SCRIPT_DIR="$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )"
# export PROJECT_DIR="$( cd -- "$( dirname -- "$SCRIPT_DIR" )" &> /dev/null && pwd )"
# cd $PROJECT_DIR
# export PYTHONPATH="$PYTHONPATH:$PROJECT_DIR"
# export LIBTPU_INIT_ARGS="--xla_tpu_megacore_fusion_allow_ags=false --xla_enable_async_collective_permute=true --xla_tpu_enable_ag_backward_pipelining=true --xla_tpu_enable_data_parallel_all_reduce_opt=true --xla_tpu_data_parallel_opt_different_sized_ops=true --xla_tpu_enable_async_collective_fusion=true --xla_tpu_enable_async_collective_fusion_multiple_steps=true --xla_tpu_overlap_compute_collective_tc=true --xla_enable_async_all_gather=true"

# export llama_tokenizer_path="llava/checkpoints/lwm_checkpoints/tokenizer.model"
# export dataset_path="/home/World-Model/data/bridge_singleview_total_action_processed_filtered_action_no_inst_train.jsonl"
# export eval_dataset_path="/home/World-Model/data/bridge_singleview_total_action_processed_action_val.jsonl"

# export output_dir="llava/checkpoints/debug"

# export project_id='lwm'
# export experiment_note='world-model'
# export experiment_id='lwm_whole_traj_bridge_bf16_from_whole_delta128_batch32_seq1_filtered_epoch1_noinst_action'

# # mesh_dim: dp, fsdp, tp, sp
# CUDA_VISIBLE_DEVICES=0,1,2,3 python3 -u -m lwm.train \
#     --modality='vision,action,delta' \
#     --mesh_dim='!-1,4,1,1' \
#     --dtype='bf16' \
#     --total_steps=10450 \
#     --log_freq=1 \
#     --delta_tokens=1 \
#     --eval_steps=1 \
#     --save_model_freq=10430 \
#     --eval_log_freq=100 \
#     --save_milestone_freq=10430 \
#     --load_llama_config='7b' \
#     --load_checkpoint="params::/home/World-Model/llava/checkpoints/debug/bridge_cross_attn_nsvq_code32_layer8_seq1_filtered_epoch3/streaming_params_15287"\
#     --update_llama_config="dict(action_vocab_size=245,delta_vocab_size=32,theta=50000000,max_sequence_length=2048,use_flash_attention=True,scan_attention=True,scan_query_chunk_size=512,scan_key_chunk_size=1024,remat_attention='nothing_saveable',scan_mlp=True,scan_mlp_chunk_size=8192,remat_mlp='nothing_saveable',remat_block='nothing_saveable',scan_layers=True)" \
#     --tokenizer.vocab_file="$llama_tokenizer_path" \
#     --optimizer.type='adamw' \
#     --llama.delta_vocab_size=32 \
#     --llama.action_vocab_size=245 \
#     --optimizer.accumulate_gradient_steps=1 \
#     --optimizer.adamw_optimizer.weight_decay=0 \
#     --optimizer.adamw_optimizer.lr=2e-5 \
#     --optimizer.adamw_optimizer.end_lr=2e-5 \
#     --optimizer.adamw_optimizer.lr_warmup_steps=0 \
#     --optimizer.adamw_optimizer.lr_decay_steps=3000 \
#     --use_data_sharded_loader=True \
#     --train_dataset.type='json_vision_delta_action' \
#     --train_dataset.delta_vision_action_processor.fields_from_example='fields' \
#     --train_dataset.delta_vision_action_processor.n_tokens_per_delta=1 \
#     --train_dataset.delta_vision_action_processor.n_tokens_per_action=7 \
#     --train_dataset.delta_vision_action_processor.max_n_frames=1 \
#     --train_dataset.json_delta_action_dataset.mode="pad" \
#     --train_dataset.json_delta_action_dataset.path="$dataset_path" \
#     --train_dataset.json_delta_action_dataset.seq_length=384 \
#     --train_dataset.json_delta_action_dataset.batch_size=128 \
#     --train_dataset.json_delta_action_dataset.tokenizer_processes=1 \
#     --train_dataset.json_delta_action_dataset.tokenizer_parallel_chunk_size=128 \
#     --train_dataset.json_delta_action_dataset.tokenizer_parallel_batch_size=128 \
#     --train_dataset.json_delta_action_dataset.use_data_sharded_loader=True \
#     --eval_dataset.type='json_vision_delta_action' \
#     --eval_dataset.delta_vision_action_processor.fields_from_example='fields' \
#     --eval_dataset.delta_vision_action_processor.n_tokens_per_delta=1 \
#     --eval_dataset.delta_vision_action_processor.n_tokens_per_action=7 \
#     --eval_dataset.delta_vision_action_processor.max_n_frames=1 \
#     --eval_dataset.json_delta_action_dataset.mode="pad" \
#     --eval_dataset.json_delta_action_dataset.path="$eval_dataset_path" \
#     --eval_dataset.json_delta_action_dataset.seq_length=384 \
#     --eval_dataset.json_delta_action_dataset.batch_size=128 \
#     --eval_dataset.json_delta_action_dataset.tokenizer_processes=1 \
#     --eval_dataset.json_delta_action_dataset.tokenizer_parallel_chunk_size=128 \
#     --eval_dataset.json_delta_action_dataset.tokenizer_parallel_batch_size=128 \
#     --eval_dataset.json_delta_action_dataset.use_data_sharded_loader=True \
#     --checkpointer.save_optimizer_state=False \
#     --autoresume=False \
#     --logger.append_uuid=False \
#     --logger.online=True \
#     --logger.project_id="$project_id" \
#     --logger.experiment_id="$experiment_id" \
#     --logger.experiment_note="$experiment_note" \
#     --logger.output_dir="$output_dir" \
#     --logger.wandb_dir="$HOME/experiment_output/$project_id"




# export SCRIPT_DIR="$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )"
# export PROJECT_DIR="$( cd -- "$( dirname -- "$SCRIPT_DIR" )" &> /dev/null && pwd )"
# cd $PROJECT_DIR
# export PYTHONPATH="$PYTHONPATH:$PROJECT_DIR"
# export LIBTPU_INIT_ARGS="--xla_tpu_megacore_fusion_allow_ags=false --xla_enable_async_collective_permute=true --xla_tpu_enable_ag_backward_pipelining=true --xla_tpu_enable_data_parallel_all_reduce_opt=true --xla_tpu_data_parallel_opt_different_sized_ops=true --xla_tpu_enable_async_collective_fusion=true --xla_tpu_enable_async_collective_fusion_multiple_steps=true --xla_tpu_overlap_compute_collective_tc=true --xla_enable_async_all_gather=true"

# export llama_tokenizer_path="llava/checkpoints/lwm_checkpoints/tokenizer.model"
# export dataset_path='/home/World-Model/data/bridge_simplerenv_tabletop_total_256_filtered_no_inst_train.jsonl'
# export eval_dataset_path='/home/World-Model/data/bridge_simplerenv_tabletop_total_256_filtered_no_inst_val.jsonl'

# export output_dir="llava/checkpoints/debug"

# export project_id='lwm'
# export experiment_note='world-model'
# export experiment_id='lwm_1000traj_bridge_bf16_from_whole_delta32_batch32_seq1_only_tabletop_total_re'

# # mesh_dim: dp, fsdp, tp, sp
# CUDA_VISIBLE_DEVICES=4,5,6,7 python3 -u -m lwm.train \
#     --modality='vision,action,delta' \
#     --mesh_dim='!-1,4,1,1' \
#     --dtype='bf16' \
#     --total_steps=10005 \
#     --log_freq=1 \
#     --delta_tokens=1 \
#     --eval_steps=1 \
#     --save_model_freq=2000 \
#     --eval_log_freq=100 \
#     --save_milestone_freq=2000 \
#     --load_llama_config='7b' \
#     --load_checkpoint="params::/home/World-Model/llava/checkpoints/debug/bridge_cross_attn_nsvq_code32_layer8_seq1_filtered_epoch3/streaming_params_15287"\
#     --update_llama_config="dict(action_vocab_size=238,delta_vocab_size=32,theta=50000000,max_sequence_length=2048,use_flash_attention=True,scan_attention=True,scan_query_chunk_size=512,scan_key_chunk_size=1024,remat_attention='nothing_saveable',scan_mlp=True,scan_mlp_chunk_size=8192,remat_mlp='nothing_saveable',remat_block='nothing_saveable',scan_layers=True)" \
#     --tokenizer.vocab_file="$llama_tokenizer_path" \
#     --optimizer.type='adamw' \
#     --llama.delta_vocab_size=32 \
#     --llama.action_vocab_size=238 \
#     --optimizer.accumulate_gradient_steps=1 \
#     --optimizer.adamw_optimizer.weight_decay=0 \
#     --optimizer.adamw_optimizer.lr=2e-5 \
#     --optimizer.adamw_optimizer.end_lr=2e-5 \
#     --optimizer.adamw_optimizer.lr_warmup_steps=0 \
#     --optimizer.adamw_optimizer.lr_decay_steps=3000 \
#     --use_data_sharded_loader=True \
#     --train_dataset.type='json_vision_delta_action' \
#     --train_dataset.delta_vision_action_processor.fields_from_example='fields' \
#     --train_dataset.delta_vision_action_processor.n_tokens_per_delta=1 \
#     --train_dataset.delta_vision_action_processor.n_tokens_per_action=7 \
#     --train_dataset.delta_vision_action_processor.max_n_frames=1 \
#     --train_dataset.json_delta_action_dataset.mode="pad" \
#     --train_dataset.json_delta_action_dataset.path="$dataset_path" \
#     --train_dataset.json_delta_action_dataset.seq_length=384 \
#     --train_dataset.json_delta_action_dataset.batch_size=32 \
#     --train_dataset.json_delta_action_dataset.tokenizer_processes=1 \
#     --train_dataset.json_delta_action_dataset.tokenizer_parallel_chunk_size=32 \
#     --train_dataset.json_delta_action_dataset.tokenizer_parallel_batch_size=32 \
#     --train_dataset.json_delta_action_dataset.use_data_sharded_loader=True \
#     --eval_dataset.type='json_vision_delta_action' \
#     --eval_dataset.delta_vision_action_processor.fields_from_example='fields' \
#     --eval_dataset.delta_vision_action_processor.n_tokens_per_delta=1 \
#     --eval_dataset.delta_vision_action_processor.n_tokens_per_action=7 \
#     --eval_dataset.delta_vision_action_processor.max_n_frames=1 \
#     --eval_dataset.json_delta_action_dataset.mode="pad" \
#     --eval_dataset.json_delta_action_dataset.path="$eval_dataset_path" \
#     --eval_dataset.json_delta_action_dataset.seq_length=384 \
#     --eval_dataset.json_delta_action_dataset.batch_size=32 \
#     --eval_dataset.json_delta_action_dataset.tokenizer_processes=1 \
#     --eval_dataset.json_delta_action_dataset.tokenizer_parallel_chunk_size=32 \
#     --eval_dataset.json_delta_action_dataset.tokenizer_parallel_batch_size=32 \
#     --eval_dataset.json_delta_action_dataset.use_data_sharded_loader=True \
#     --checkpointer.save_optimizer_state=False \
#     --autoresume=False \
#     --logger.append_uuid=False \
#     --logger.online=True \
#     --logger.project_id="$project_id" \
#     --logger.experiment_id="$experiment_id" \
#     --logger.experiment_note="$experiment_note" \
#     --logger.output_dir="$output_dir" \
#     --logger.wandb_dir="$HOME/experiment_output/$project_id"




# export SCRIPT_DIR="$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )"
# export PROJECT_DIR="$( cd -- "$( dirname -- "$SCRIPT_DIR" )" &> /dev/null && pwd )"
# cd $PROJECT_DIR
# export PYTHONPATH="$PYTHONPATH:$PROJECT_DIR"
# export LIBTPU_INIT_ARGS="--xla_tpu_megacore_fusion_allow_ags=false --xla_enable_async_collective_permute=true --xla_tpu_enable_ag_backward_pipelining=true --xla_tpu_enable_data_parallel_all_reduce_opt=true --xla_tpu_data_parallel_opt_different_sized_ops=true --xla_tpu_enable_async_collective_fusion=true --xla_tpu_enable_async_collective_fusion_multiple_steps=true --xla_tpu_overlap_compute_collective_tc=true --xla_enable_async_all_gather=true"

# export llama_tokenizer_path="llava/checkpoints/lwm_checkpoints/tokenizer.model"
# export dataset_path='/home/World-Model/data/bridge_simplerenv_tabletop_total_256_filtered_no_inst_train.jsonl'
# export eval_dataset_path='/home/World-Model/data/bridge_simplerenv_tabletop_total_256_filtered_no_inst_val.jsonl'

# export output_dir="llava/checkpoints/debug"

# export project_id='lwm'
# export experiment_note='world-model'
# export experiment_id='lwm_1000traj_bridge_bf16_from_whole_delta32_batch32_seq1_only_tabletop_total_re_continued'

# # mesh_dim: dp, fsdp, tp, sp
# CUDA_VISIBLE_DEVICES=4,5,6,7 python3 -u -m lwm.train \
#     --modality='vision,action,delta' \
#     --mesh_dim='!-1,4,1,1' \
#     --dtype='bf16' \
#     --total_steps=32005 \
#     --log_freq=1 \
#     --delta_tokens=1 \
#     --eval_steps=1 \
#     --save_model_freq=0 \
#     --eval_log_freq=100 \
#     --save_milestone_freq=4000 \
#     --load_llama_config='7b' \
#     --load_checkpoint="params::/home/World-Model/llava/checkpoints/debug/lwm_1000traj_bridge_bf16_from_whole_delta32_batch32_seq1_only_tabletop_total_re/streaming_params_10000"\
#     --update_llama_config="dict(action_vocab_size=238,delta_vocab_size=32,theta=50000000,max_sequence_length=2048,use_flash_attention=True,scan_attention=True,scan_query_chunk_size=512,scan_key_chunk_size=1024,remat_attention='nothing_saveable',scan_mlp=True,scan_mlp_chunk_size=8192,remat_mlp='nothing_saveable',remat_block='nothing_saveable',scan_layers=True)" \
#     --tokenizer.vocab_file="$llama_tokenizer_path" \
#     --optimizer.type='adamw' \
#     --llama.delta_vocab_size=32 \
#     --llama.action_vocab_size=238 \
#     --optimizer.accumulate_gradient_steps=1 \
#     --optimizer.adamw_optimizer.weight_decay=0 \
#     --optimizer.adamw_optimizer.lr=2e-5 \
#     --optimizer.adamw_optimizer.end_lr=2e-5 \
#     --optimizer.adamw_optimizer.lr_warmup_steps=0 \
#     --optimizer.adamw_optimizer.lr_decay_steps=3000 \
#     --use_data_sharded_loader=True \
#     --train_dataset.type='json_vision_delta_action' \
#     --train_dataset.delta_vision_action_processor.fields_from_example='fields' \
#     --train_dataset.delta_vision_action_processor.n_tokens_per_delta=1 \
#     --train_dataset.delta_vision_action_processor.n_tokens_per_action=7 \
#     --train_dataset.delta_vision_action_processor.max_n_frames=1 \
#     --train_dataset.json_delta_action_dataset.mode="pad" \
#     --train_dataset.json_delta_action_dataset.path="$dataset_path" \
#     --train_dataset.json_delta_action_dataset.seq_length=384 \
#     --train_dataset.json_delta_action_dataset.batch_size=32 \
#     --train_dataset.json_delta_action_dataset.tokenizer_processes=1 \
#     --train_dataset.json_delta_action_dataset.tokenizer_parallel_chunk_size=32 \
#     --train_dataset.json_delta_action_dataset.tokenizer_parallel_batch_size=32 \
#     --train_dataset.json_delta_action_dataset.use_data_sharded_loader=True \
#     --eval_dataset.type='json_vision_delta_action' \
#     --eval_dataset.delta_vision_action_processor.fields_from_example='fields' \
#     --eval_dataset.delta_vision_action_processor.n_tokens_per_delta=1 \
#     --eval_dataset.delta_vision_action_processor.n_tokens_per_action=7 \
#     --eval_dataset.delta_vision_action_processor.max_n_frames=1 \
#     --eval_dataset.json_delta_action_dataset.mode="pad" \
#     --eval_dataset.json_delta_action_dataset.path="$eval_dataset_path" \
#     --eval_dataset.json_delta_action_dataset.seq_length=384 \
#     --eval_dataset.json_delta_action_dataset.batch_size=32 \
#     --eval_dataset.json_delta_action_dataset.tokenizer_processes=1 \
#     --eval_dataset.json_delta_action_dataset.tokenizer_parallel_chunk_size=32 \
#     --eval_dataset.json_delta_action_dataset.tokenizer_parallel_batch_size=32 \
#     --eval_dataset.json_delta_action_dataset.use_data_sharded_loader=True \
#     --checkpointer.save_optimizer_state=False \
#     --autoresume=False \
#     --logger.append_uuid=False \
#     --logger.online=True \
#     --logger.project_id="$project_id" \
#     --logger.experiment_id="$experiment_id" \
#     --logger.experiment_note="$experiment_note" \
#     --logger.output_dir="$output_dir" \
#     --logger.wandb_dir="$HOME/experiment_output/$project_id"



# export SCRIPT_DIR="$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )"
# export PROJECT_DIR="$( cd -- "$( dirname -- "$SCRIPT_DIR" )" &> /dev/null && pwd )"
# cd $PROJECT_DIR
# export PYTHONPATH="$PYTHONPATH:$PROJECT_DIR"
# export LIBTPU_INIT_ARGS="--xla_tpu_megacore_fusion_allow_ags=false --xla_enable_async_collective_permute=true --xla_tpu_enable_ag_backward_pipelining=true --xla_tpu_enable_data_parallel_all_reduce_opt=true --xla_tpu_data_parallel_opt_different_sized_ops=true --xla_tpu_enable_async_collective_fusion=true --xla_tpu_enable_async_collective_fusion_multiple_steps=true --xla_tpu_overlap_compute_collective_tc=true --xla_enable_async_all_gather=true"

# export llama_tokenizer_path="llava/checkpoints/lwm_checkpoints/tokenizer.model"
# export dataset_path='/home/World-Model/data/bridge_simplerenv_tabletop_total_256_filtered_no_inst_train.jsonl'
# export eval_dataset_path='/home/World-Model/data/bridge_simplerenv_tabletop_total_256_filtered_no_inst_val.jsonl'

# export output_dir="llava/checkpoints/debug"

# export project_id='lwm'
# export experiment_note='world-model'
# export experiment_id='lwm_1000traj_bridge_bf16_rt2_only_tabletop_total_re'

# # mesh_dim: dp, fsdp, tp, sp
# CUDA_VISIBLE_DEVICES=4,5,6,7 python3 -u -m lwm.train \
#     --modality='vision,action' \
#     --mesh_dim='!-1,4,1,1' \
#     --dtype='bf16' \
#     --total_steps=10005 \
#     --log_freq=1 \
#     --delta_tokens=1 \
#     --eval_steps=1 \
#     --save_model_freq=0 \
#     --eval_log_freq=100 \
#     --save_milestone_freq=5000 \
#     --load_llama_config='7b' \
#     --load_checkpoint="params::/home/World-Model/llava/checkpoints/lwm_checkpoints/params"\
#     --update_llama_config="dict(action_vocab_size=238,theta=50000000,max_sequence_length=2048,use_flash_attention=True,scan_attention=True,scan_query_chunk_size=512,scan_key_chunk_size=1024,remat_attention='nothing_saveable',scan_mlp=True,scan_mlp_chunk_size=8192,remat_mlp='nothing_saveable',remat_block='nothing_saveable',scan_layers=True)" \
#     --tokenizer.vocab_file="$llama_tokenizer_path" \
#     --optimizer.type='adamw' \
#     --llama.action_vocab_size=238 \
#     --optimizer.accumulate_gradient_steps=1 \
#     --optimizer.adamw_optimizer.weight_decay=0 \
#     --optimizer.adamw_optimizer.lr=2e-5 \
#     --optimizer.adamw_optimizer.end_lr=2e-5 \
#     --optimizer.adamw_optimizer.lr_warmup_steps=0 \
#     --optimizer.adamw_optimizer.lr_decay_steps=3000 \
#     --use_data_sharded_loader=True \
#     --train_dataset.type='json_vision_action' \
#     --train_dataset.vision_action_processor.fields_from_example='fields' \
#     --train_dataset.vision_action_processor.n_tokens_per_action=7 \
#     --train_dataset.vision_action_processor.max_n_frames=1 \
#     --train_dataset.json_vision_action_dataset.mode="pad" \
#     --train_dataset.json_vision_action_dataset.path="$dataset_path" \
#     --train_dataset.json_vision_action_dataset.seq_length=384 \
#     --train_dataset.json_vision_action_dataset.batch_size=32 \
#     --train_dataset.json_vision_action_dataset.tokenizer_processes=1 \
#     --train_dataset.json_vision_action_dataset.tokenizer_parallel_chunk_size=32 \
#     --train_dataset.json_vision_action_dataset.tokenizer_parallel_batch_size=32 \
#     --train_dataset.json_vision_action_dataset.use_data_sharded_loader=True \
#     --eval_dataset.type='json_vision_action' \
#     --eval_dataset.vision_action_processor.fields_from_example='fields' \
#     --eval_dataset.vision_action_processor.n_tokens_per_action=7 \
#     --eval_dataset.vision_action_processor.max_n_frames=1 \
#     --eval_dataset.json_vision_action_dataset.mode="pad" \
#     --eval_dataset.json_vision_action_dataset.path="$eval_dataset_path" \
#     --eval_dataset.json_vision_action_dataset.seq_length=384 \
#     --eval_dataset.json_vision_action_dataset.batch_size=32 \
#     --eval_dataset.json_vision_action_dataset.tokenizer_processes=1 \
#     --eval_dataset.json_vision_action_dataset.tokenizer_parallel_chunk_size=32 \
#     --eval_dataset.json_vision_action_dataset.tokenizer_parallel_batch_size=32 \
#     --eval_dataset.json_vision_action_dataset.use_data_sharded_loader=True \
#     --checkpointer.save_optimizer_state=False \
#     --autoresume=False \
#     --logger.append_uuid=False \
#     --logger.online=True \
#     --logger.project_id="$project_id" \
#     --logger.experiment_id="$experiment_id" \
#     --logger.experiment_note="$experiment_note" \
#     --logger.output_dir="$output_dir" \
#     --logger.wandb_dir="$HOME/experiment_output/$project_id"





# export SCRIPT_DIR="$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )"
# export PROJECT_DIR="$( cd -- "$( dirname -- "$SCRIPT_DIR" )" &> /dev/null && pwd )"
# cd $PROJECT_DIR
# export PYTHONPATH="$PYTHONPATH:$PROJECT_DIR"
# export LIBTPU_INIT_ARGS="--xla_tpu_megacore_fusion_allow_ags=false --xla_enable_async_collective_permute=true --xla_tpu_enable_ag_backward_pipelining=true --xla_tpu_enable_data_parallel_all_reduce_opt=true --xla_tpu_data_parallel_opt_different_sized_ops=true --xla_tpu_enable_async_collective_fusion=true --xla_tpu_enable_async_collective_fusion_multiple_steps=true --xla_tpu_overlap_compute_collective_tc=true --xla_enable_async_all_gather=true"

# export llama_tokenizer_path="llava/checkpoints/lwm_checkpoints/tokenizer.model"
# export dataset_path='/home/World-Model/data/bridge_simplerenv_tabletop_256_filtered_no_inst_train.jsonl'
# export eval_dataset_path='/home/World-Model/data/bridge_simplerenv_tabletop_256_filtered_no_inst_val.jsonl'

# export output_dir="llava/checkpoints/debug"

# export project_id='lwm'
# export experiment_note='world-model'
# export experiment_id='lwm_1000traj_bridge_bf16_from_whole_delta32_batch128_seq1_only_tabletop_total'

# # mesh_dim: dp, fsdp, tp, sp
# CUDA_VISIBLE_DEVICES=0,1,2,3 python3 -u -m lwm.train \
#     --modality='vision,action,delta' \
#     --mesh_dim='!-1,4,1,1' \
#     --dtype='bf16' \
#     --total_steps=2005 \
#     --log_freq=1 \
#     --delta_tokens=1 \
#     --eval_steps=1 \
#     --save_model_freq=500 \
#     --eval_log_freq=100 \
#     --save_milestone_freq=500 \
#     --load_llama_config='7b' \
#     --load_checkpoint="params::/home/World-Model/llava/checkpoints/debug/bridge_cross_attn_nsvq_code32_layer8_seq1_filtered_epoch3/streaming_params_15287"\
#     --update_llama_config="dict(action_vocab_size=238,delta_vocab_size=32,theta=50000000,max_sequence_length=2048,use_flash_attention=True,scan_attention=True,scan_query_chunk_size=512,scan_key_chunk_size=1024,remat_attention='nothing_saveable',scan_mlp=True,scan_mlp_chunk_size=8192,remat_mlp='nothing_saveable',remat_block='nothing_saveable',scan_layers=True)" \
#     --tokenizer.vocab_file="$llama_tokenizer_path" \
#     --optimizer.type='adamw' \
#     --llama.delta_vocab_size=32 \
#     --llama.action_vocab_size=238 \
#     --optimizer.accumulate_gradient_steps=1 \
#     --optimizer.adamw_optimizer.weight_decay=0 \
#     --optimizer.adamw_optimizer.lr=2e-5 \
#     --optimizer.adamw_optimizer.end_lr=2e-5 \
#     --optimizer.adamw_optimizer.lr_warmup_steps=0 \
#     --optimizer.adamw_optimizer.lr_decay_steps=3000 \
#     --use_data_sharded_loader=True \
#     --train_dataset.type='json_vision_delta_action' \
#     --train_dataset.delta_vision_action_processor.fields_from_example='fields' \
#     --train_dataset.delta_vision_action_processor.n_tokens_per_delta=1 \
#     --train_dataset.delta_vision_action_processor.n_tokens_per_action=7 \
#     --train_dataset.delta_vision_action_processor.max_n_frames=1 \
#     --train_dataset.json_delta_action_dataset.mode="pad" \
#     --train_dataset.json_delta_action_dataset.path="$dataset_path" \
#     --train_dataset.json_delta_action_dataset.seq_length=384 \
#     --train_dataset.json_delta_action_dataset.batch_size=128 \
#     --train_dataset.json_delta_action_dataset.tokenizer_processes=1 \
#     --train_dataset.json_delta_action_dataset.tokenizer_parallel_chunk_size=128 \
#     --train_dataset.json_delta_action_dataset.tokenizer_parallel_batch_size=128 \
#     --train_dataset.json_delta_action_dataset.use_data_sharded_loader=True \
#     --eval_dataset.type='json_vision_delta_action' \
#     --eval_dataset.delta_vision_action_processor.fields_from_example='fields' \
#     --eval_dataset.delta_vision_action_processor.n_tokens_per_delta=1 \
#     --eval_dataset.delta_vision_action_processor.n_tokens_per_action=7 \
#     --eval_dataset.delta_vision_action_processor.max_n_frames=1 \
#     --eval_dataset.json_delta_action_dataset.mode="pad" \
#     --eval_dataset.json_delta_action_dataset.path="$eval_dataset_path" \
#     --eval_dataset.json_delta_action_dataset.seq_length=384 \
#     --eval_dataset.json_delta_action_dataset.batch_size=128 \
#     --eval_dataset.json_delta_action_dataset.tokenizer_processes=1 \
#     --eval_dataset.json_delta_action_dataset.tokenizer_parallel_chunk_size=128 \
#     --eval_dataset.json_delta_action_dataset.tokenizer_parallel_batch_size=128 \
#     --eval_dataset.json_delta_action_dataset.use_data_sharded_loader=True \
#     --checkpointer.save_optimizer_state=False \
#     --autoresume=False \
#     --logger.append_uuid=False \
#     --logger.online=True \
#     --logger.project_id="$project_id" \
#     --logger.experiment_id="$experiment_id" \
#     --logger.experiment_note="$experiment_note" \
#     --logger.output_dir="$output_dir" \
#     --logger.wandb_dir="$HOME/experiment_output/$project_id"




# export SCRIPT_DIR="$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )"
# export PROJECT_DIR="$( cd -- "$( dirname -- "$SCRIPT_DIR" )" &> /dev/null && pwd )"
# cd $PROJECT_DIR
# export PYTHONPATH="$PYTHONPATH:$PROJECT_DIR"
# export LIBTPU_INIT_ARGS="--xla_tpu_megacore_fusion_allow_ags=false --xla_enable_async_collective_permute=true --xla_tpu_enable_ag_backward_pipelining=true --xla_tpu_enable_data_parallel_all_reduce_opt=true --xla_tpu_data_parallel_opt_different_sized_ops=true --xla_tpu_enable_async_collective_fusion=true --xla_tpu_enable_async_collective_fusion_multiple_steps=true --xla_tpu_overlap_compute_collective_tc=true --xla_enable_async_all_gather=true"

# export llama_tokenizer_path="llava/checkpoints/lwm_checkpoints/tokenizer.model"
# export dataset_path="/home/World-Model/data/bridge_singleview_total_action_processed_filtered_1000traj.jsonl"
# export eval_dataset_path="/home/World-Model/data/bridge_singleview_total_action_processed_val.jsonl"

# export output_dir="llava/checkpoints/debug"

# export project_id='lwm'
# export experiment_note='world-model'
# export experiment_id='lwm_1000traj_bridge_bf16_from_whole_delta4_batch128_seq2'

# # mesh_dim: dp, fsdp, tp, sp
# CUDA_VISIBLE_DEVICES=0,1,2,3 python3 -u -m lwm.train \
#     --modality='vision,text,delta' \
#     --mesh_dim='!-1,4,1,1' \
#     --dtype='bf16' \
#     --total_steps=603 \
#     --log_freq=1 \
#     --delta_tokens=2 \
#     --eval_steps=1 \
#     --save_model_freq=600 \
#     --eval_log_freq=100 \
#     --save_milestone_freq=0 \
#     --load_llama_config='7b' \
#     --load_checkpoint="params::/home/World-Model/llava/checkpoints/debug/bridge_cross_attn_nsvq_code4_layer8_seq2/streaming_params_15700"\
#     --update_llama_config="dict(delta_vocab_size=4,theta=50000000,max_sequence_length=2048,use_flash_attention=True,scan_attention=True,scan_query_chunk_size=512,scan_key_chunk_size=1024,remat_attention='nothing_saveable',scan_mlp=True,scan_mlp_chunk_size=8192,remat_mlp='nothing_saveable',remat_block='nothing_saveable',scan_layers=True)" \
#     --tokenizer.vocab_file="$llama_tokenizer_path" \
#     --optimizer.type='adamw' \
#     --llama.delta_vocab_size=4 \
#     --optimizer.accumulate_gradient_steps=1 \
#     --optimizer.adamw_optimizer.weight_decay=0 \
#     --optimizer.adamw_optimizer.lr=2e-5 \
#     --optimizer.adamw_optimizer.end_lr=2e-5 \
#     --optimizer.adamw_optimizer.lr_warmup_steps=0 \
#     --optimizer.adamw_optimizer.lr_decay_steps=3000 \
#     --use_data_sharded_loader=True \
#     --train_dataset.type='json_vision_delta' \
#     --train_dataset.delta_vision_text_processor.fields_from_example='fields' \
#     --train_dataset.delta_vision_text_processor.n_tokens_per_delta=2 \
#     --train_dataset.delta_vision_text_processor.max_n_frames=1 \
#     --train_dataset.json_delta_dataset.mode="pad" \
#     --train_dataset.json_delta_dataset.path="$dataset_path" \
#     --train_dataset.json_delta_dataset.seq_length=384 \
#     --train_dataset.json_delta_dataset.batch_size=128 \
#     --train_dataset.json_delta_dataset.tokenizer_processes=1 \
#     --train_dataset.json_delta_dataset.tokenizer_parallel_chunk_size=128 \
#     --train_dataset.json_delta_dataset.tokenizer_parallel_batch_size=128 \
#     --train_dataset.json_delta_dataset.use_data_sharded_loader=True \
#     --eval_dataset.type='json_vision_delta' \
#     --eval_dataset.delta_vision_text_processor.fields_from_example='fields' \
#     --eval_dataset.delta_vision_text_processor.n_tokens_per_delta=2 \
#     --eval_dataset.delta_vision_text_processor.max_n_frames=1 \
#     --eval_dataset.json_delta_dataset.mode="pad" \
#     --eval_dataset.json_delta_dataset.path="$eval_dataset_path" \
#     --eval_dataset.json_delta_dataset.seq_length=384 \
#     --eval_dataset.json_delta_dataset.batch_size=128 \
#     --eval_dataset.json_delta_dataset.tokenizer_processes=1 \
#     --eval_dataset.json_delta_dataset.tokenizer_parallel_chunk_size=128 \
#     --eval_dataset.json_delta_dataset.tokenizer_parallel_batch_size=128 \
#     --eval_dataset.json_delta_dataset.use_data_sharded_loader=True \
#     --checkpointer.save_optimizer_state=False \
#     --autoresume=False \
#     --logger.append_uuid=False \
#     --logger.online=True \
#     --logger.project_id="$project_id" \
#     --logger.experiment_id="$experiment_id" \
#     --logger.experiment_note="$experiment_note" \
#     --logger.output_dir="$output_dir" \
#     --logger.wandb_dir="$HOME/experiment_output/$project_id"




# export SCRIPT_DIR="$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )"
# export PROJECT_DIR="$( cd -- "$( dirname -- "$SCRIPT_DIR" )" &> /dev/null && pwd )"
# cd $PROJECT_DIR
# export PYTHONPATH="$PYTHONPATH:$PROJECT_DIR"
# export LIBTPU_INIT_ARGS="--xla_tpu_megacore_fusion_allow_ags=false --xla_enable_async_collective_permute=true --xla_tpu_enable_ag_backward_pipelining=true --xla_tpu_enable_data_parallel_all_reduce_opt=true --xla_tpu_data_parallel_opt_different_sized_ops=true --xla_tpu_enable_async_collective_fusion=true --xla_tpu_enable_async_collective_fusion_multiple_steps=true --xla_tpu_overlap_compute_collective_tc=true --xla_enable_async_all_gather=true"

# export llama_tokenizer_path="llava/checkpoints/lwm_checkpoints/tokenizer.model"
# export dataset_path="/home/World-Model/data/bridge_singleview_total_action_processed_filtered_1000traj.jsonl"
# export eval_dataset_path="/home/World-Model/data/bridge_singleview_total_action_processed_val.jsonl"

# export output_dir="llava/checkpoints/debug"

# export project_id='lwm'
# export experiment_note='world-model'
# export experiment_id='lwm_1000traj_bridge_bf16_from_whole_delta2_batch128_seq4'

# # mesh_dim: dp, fsdp, tp, sp
# CUDA_VISIBLE_DEVICES=0,1,2,3 python3 -u -m lwm.train \
#     --modality='vision,text,delta' \
#     --mesh_dim='!-1,4,1,1' \
#     --dtype='bf16' \
#     --total_steps=603 \
#     --log_freq=1 \
#     --delta_tokens=4 \
#     --eval_steps=1 \
#     --save_model_freq=600 \
#     --eval_log_freq=100 \
#     --save_milestone_freq=0 \
#     --load_llama_config='7b' \
#     --load_checkpoint="params::/home/World-Model/llava/checkpoints/debug/bridge_cross_attn_nsvq_code2_layer8_seq4/streaming_params_15700"\
#     --update_llama_config="dict(delta_vocab_size=2,theta=50000000,max_sequence_length=2048,use_flash_attention=True,scan_attention=True,scan_query_chunk_size=512,scan_key_chunk_size=1024,remat_attention='nothing_saveable',scan_mlp=True,scan_mlp_chunk_size=8192,remat_mlp='nothing_saveable',remat_block='nothing_saveable',scan_layers=True)" \
#     --tokenizer.vocab_file="$llama_tokenizer_path" \
#     --optimizer.type='adamw' \
#     --llama.delta_vocab_size=2 \
#     --optimizer.accumulate_gradient_steps=1 \
#     --optimizer.adamw_optimizer.weight_decay=0 \
#     --optimizer.adamw_optimizer.lr=2e-5 \
#     --optimizer.adamw_optimizer.end_lr=2e-5 \
#     --optimizer.adamw_optimizer.lr_warmup_steps=0 \
#     --optimizer.adamw_optimizer.lr_decay_steps=3000 \
#     --use_data_sharded_loader=True \
#     --train_dataset.type='json_vision_delta' \
#     --train_dataset.delta_vision_text_processor.fields_from_example='fields' \
#     --train_dataset.delta_vision_text_processor.n_tokens_per_delta=4 \
#     --train_dataset.delta_vision_text_processor.max_n_frames=1 \
#     --train_dataset.json_delta_dataset.mode="pad" \
#     --train_dataset.json_delta_dataset.path="$dataset_path" \
#     --train_dataset.json_delta_dataset.seq_length=384 \
#     --train_dataset.json_delta_dataset.batch_size=128 \
#     --train_dataset.json_delta_dataset.tokenizer_processes=1 \
#     --train_dataset.json_delta_dataset.tokenizer_parallel_chunk_size=128 \
#     --train_dataset.json_delta_dataset.tokenizer_parallel_batch_size=128 \
#     --train_dataset.json_delta_dataset.use_data_sharded_loader=True \
#     --eval_dataset.type='json_vision_delta' \
#     --eval_dataset.delta_vision_text_processor.fields_from_example='fields' \
#     --eval_dataset.delta_vision_text_processor.n_tokens_per_delta=4 \
#     --eval_dataset.delta_vision_text_processor.max_n_frames=1 \
#     --eval_dataset.json_delta_dataset.mode="pad" \
#     --eval_dataset.json_delta_dataset.path="$eval_dataset_path" \
#     --eval_dataset.json_delta_dataset.seq_length=384 \
#     --eval_dataset.json_delta_dataset.batch_size=128 \
#     --eval_dataset.json_delta_dataset.tokenizer_processes=1 \
#     --eval_dataset.json_delta_dataset.tokenizer_parallel_chunk_size=128 \
#     --eval_dataset.json_delta_dataset.tokenizer_parallel_batch_size=128 \
#     --eval_dataset.json_delta_dataset.use_data_sharded_loader=True \
#     --checkpointer.save_optimizer_state=False \
#     --autoresume=False \
#     --logger.append_uuid=False \
#     --logger.online=True \
#     --logger.project_id="$project_id" \
#     --logger.experiment_id="$experiment_id" \
#     --logger.experiment_note="$experiment_note" \
#     --logger.output_dir="$output_dir" \
#     --logger.wandb_dir="$HOME/experiment_output/$project_id"


# export SCRIPT_DIR="$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )"
# export PROJECT_DIR="$( cd -- "$( dirname -- "$SCRIPT_DIR" )" &> /dev/null && pwd )"
# cd $PROJECT_DIR
# export PYTHONPATH="$PYTHONPATH:$PROJECT_DIR"
# export LIBTPU_INIT_ARGS="--xla_tpu_megacore_fusion_allow_ags=false --xla_enable_async_collective_permute=true --xla_tpu_enable_ag_backward_pipelining=true --xla_tpu_enable_data_parallel_all_reduce_opt=true --xla_tpu_data_parallel_opt_different_sized_ops=true --xla_tpu_enable_async_collective_fusion=true --xla_tpu_enable_async_collective_fusion_multiple_steps=true --xla_tpu_overlap_compute_collective_tc=true --xla_enable_async_all_gather=true"

# export llama_tokenizer_path="llava/checkpoints/lwm_checkpoints/tokenizer.model"
# export dataset_path="/home/World-Model/data/bridge_cross_attn_nsvq_code32_layer8_seq1_train_no_inst.jsonl"
# export eval_dataset_path="/home/World-Model/data/bridge_cross_attn_nsvq_code32_layer8_seq1_val.jsonl"

# export output_dir="llava/checkpoints/debug"

# export project_id='lwm'
# export experiment_note='world-model'
# export experiment_id='bridge_cross_attn_nsvq_code32_layer8_seq1_filtered_epoch3_no_inst_re'

# # mesh_dim: dp, fsdp, tp, sp
# CUDA_VISIBLE_DEVICES=4,5,6,7 python3 -u -m lwm.train \
#     --modality='vision,action,delta' \
#     --mesh_dim='!-1,4,1,1' \
#     --dtype='bf16' \
#     --total_steps=32200 \
#     --log_freq=1 \
#     --delta_tokens=1 \
#     --eval_steps=1 \
#     --save_model_freq=10725 \
#     --eval_log_freq=100 \
#     --save_milestone_freq=10725 \
#     --load_llama_config='7b' \
#     --load_checkpoint="params::/home/World-Model/llava/checkpoints/lwm_checkpoints/params"\
#     --update_llama_config="dict(delta_vocab_size=32,action_vocab_size=238,theta=50000000,max_sequence_length=2048,use_flash_attention=True,scan_attention=True,scan_query_chunk_size=512,scan_key_chunk_size=1024,remat_attention='nothing_saveable',scan_mlp=True,scan_mlp_chunk_size=8192,remat_mlp='nothing_saveable',remat_block='nothing_saveable',scan_layers=True)" \
#     --tokenizer.vocab_file="$llama_tokenizer_path" \
#     --optimizer.type='adamw' \
#     --llama.delta_vocab_size=32 \
#     --llama.action_vocab_size=238 \
#     --optimizer.accumulate_gradient_steps=1 \
#     --optimizer.adamw_optimizer.weight_decay=0 \
#     --optimizer.adamw_optimizer.lr=2e-5 \
#     --optimizer.adamw_optimizer.end_lr=2e-5 \
#     --optimizer.adamw_optimizer.lr_warmup_steps=0 \
#     --optimizer.adamw_optimizer.lr_decay_steps=15700 \
#     --use_data_sharded_loader=True \
#     --train_dataset.type='json_vision_delta_action' \
#     --train_dataset.delta_vision_action_processor.fields_from_example='fields' \
#     --train_dataset.delta_vision_action_processor.n_tokens_per_delta=1 \
#     --train_dataset.delta_vision_action_processor.max_n_frames=1 \
#     --train_dataset.json_delta_action_dataset.mode="pad" \
#     --train_dataset.json_delta_action_dataset.path="$dataset_path" \
#     --train_dataset.json_delta_action_dataset.seq_length=384 \
#     --train_dataset.json_delta_action_dataset.batch_size=128 \
#     --train_dataset.json_delta_action_dataset.tokenizer_processes=1 \
#     --train_dataset.json_delta_action_dataset.tokenizer_parallel_chunk_size=128 \
#     --train_dataset.json_delta_action_dataset.tokenizer_parallel_batch_size=128 \
#     --train_dataset.json_delta_action_dataset.use_data_sharded_loader=True \
#     --eval_dataset.type='json_vision_delta_action' \
#     --eval_dataset.delta_vision_action_processor.fields_from_example='fields' \
#     --eval_dataset.delta_vision_action_processor.n_tokens_per_delta=1 \
#     --eval_dataset.delta_vision_action_processor.max_n_frames=1 \
#     --eval_dataset.json_delta_action_dataset.mode="pad" \
#     --eval_dataset.json_delta_action_dataset.path="$eval_dataset_path" \
#     --eval_dataset.json_delta_action_dataset.seq_length=384 \
#     --eval_dataset.json_delta_action_dataset.batch_size=128 \
#     --eval_dataset.json_delta_action_dataset.tokenizer_processes=1 \
#     --eval_dataset.json_delta_action_dataset.tokenizer_parallel_chunk_size=128 \
#     --eval_dataset.json_delta_action_dataset.tokenizer_parallel_batch_size=128 \
#     --eval_dataset.json_delta_action_dataset.use_data_sharded_loader=True \
#     --checkpointer.save_optimizer_state=False \
#     --autoresume=False \
#     --logger.append_uuid=False \
#     --logger.online=True \
#     --logger.project_id="$project_id" \
#     --logger.experiment_id="$experiment_id" \
#     --logger.experiment_note="$experiment_note" \
#     --logger.output_dir="$output_dir" \
#     --logger.wandb_dir="$HOME/experiment_output/$project_id"




# export SCRIPT_DIR="$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )"
# export PROJECT_DIR="$( cd -- "$( dirname -- "$SCRIPT_DIR" )" &> /dev/null && pwd )"
# cd $PROJECT_DIR
# export PYTHONPATH="$PYTHONPATH:$PROJECT_DIR"
# export LIBTPU_INIT_ARGS="--xla_tpu_megacore_fusion_allow_ags=false --xla_enable_async_collective_permute=true --xla_tpu_enable_ag_backward_pipelining=true --xla_tpu_enable_data_parallel_all_reduce_opt=true --xla_tpu_data_parallel_opt_different_sized_ops=true --xla_tpu_enable_async_collective_fusion=true --xla_tpu_enable_async_collective_fusion_multiple_steps=true --xla_tpu_overlap_compute_collective_tc=true --xla_enable_async_all_gather=true"

# export llama_tokenizer_path="llava/checkpoints/lwm_checkpoints/tokenizer.model"
# export dataset_path='/home/World-Model/data/bridge_simplerenv_tabletop_total_256_filtered_no_inst_train.jsonl'
# export eval_dataset_path='/home/World-Model/data/bridge_simplerenv_tabletop_total_256_filtered_no_inst_val.jsonl'


# export output_dir="llava/checkpoints/debug"

# export project_id='lwm'
# export experiment_note='world-model'
# export experiment_id='lwm_1000traj_bridge_bf16_from_whole_delta32_batch32_seq1_only_tabletop_total_no_inst_nolang_epoch3'

# # mesh_dim: dp, fsdp, tp, sp
# CUDA_VISIBLE_DEVICES=4,5,6,7 python3 -u -m lwm.train \
#     --modality='vision,action,delta' \
#     --mesh_dim='!-1,4,1,1' \
#     --dtype='bf16' \
#     --total_steps=5005 \
#     --log_freq=1 \
#     --delta_tokens=1 \
#     --eval_steps=1 \
#     --save_model_freq=0 \
#     --eval_log_freq=100 \
#     --save_milestone_freq=5000 \
#     --load_llama_config='7b' \
#     --load_checkpoint="params::/home/World-Model/llava/checkpoints/debug/bridge_cross_attn_nsvq_code32_layer8_seq1_filtered_epoch3_no_inst_re/streaming_params_32175"\
#     --update_llama_config="dict(action_vocab_size=238,delta_vocab_size=32,theta=50000000,max_sequence_length=2048,use_flash_attention=True,scan_attention=True,scan_query_chunk_size=512,scan_key_chunk_size=1024,remat_attention='nothing_saveable',scan_mlp=True,scan_mlp_chunk_size=8192,remat_mlp='nothing_saveable',remat_block='nothing_saveable',scan_layers=True)" \
#     --tokenizer.vocab_file="$llama_tokenizer_path" \
#     --optimizer.type='adamw' \
#     --llama.delta_vocab_size=32 \
#     --llama.action_vocab_size=238 \
#     --optimizer.accumulate_gradient_steps=1 \
#     --optimizer.adamw_optimizer.weight_decay=0 \
#     --optimizer.adamw_optimizer.lr=2e-5 \
#     --optimizer.adamw_optimizer.end_lr=2e-5 \
#     --optimizer.adamw_optimizer.lr_warmup_steps=0 \
#     --optimizer.adamw_optimizer.lr_decay_steps=3000 \
#     --use_data_sharded_loader=True \
#     --train_dataset.type='json_vision_delta_action' \
#     --train_dataset.delta_vision_action_processor.fields_from_example='fields' \
#     --train_dataset.delta_vision_action_processor.n_tokens_per_delta=1 \
#     --train_dataset.delta_vision_action_processor.n_tokens_per_action=7 \
#     --train_dataset.delta_vision_action_processor.max_n_frames=1 \
#     --train_dataset.json_delta_action_dataset.mode="pad" \
#     --train_dataset.json_delta_action_dataset.path="$dataset_path" \
#     --train_dataset.json_delta_action_dataset.seq_length=384 \
#     --train_dataset.json_delta_action_dataset.batch_size=32 \
#     --train_dataset.json_delta_action_dataset.tokenizer_processes=1 \
#     --train_dataset.json_delta_action_dataset.tokenizer_parallel_chunk_size=32 \
#     --train_dataset.json_delta_action_dataset.tokenizer_parallel_batch_size=32 \
#     --train_dataset.json_delta_action_dataset.use_data_sharded_loader=True \
#     --eval_dataset.type='json_vision_delta_action' \
#     --eval_dataset.delta_vision_action_processor.fields_from_example='fields' \
#     --eval_dataset.delta_vision_action_processor.n_tokens_per_delta=1 \
#     --eval_dataset.delta_vision_action_processor.n_tokens_per_action=7 \
#     --eval_dataset.delta_vision_action_processor.max_n_frames=1 \
#     --eval_dataset.json_delta_action_dataset.mode="pad" \
#     --eval_dataset.json_delta_action_dataset.path="$eval_dataset_path" \
#     --eval_dataset.json_delta_action_dataset.seq_length=384 \
#     --eval_dataset.json_delta_action_dataset.batch_size=32 \
#     --eval_dataset.json_delta_action_dataset.tokenizer_processes=1 \
#     --eval_dataset.json_delta_action_dataset.tokenizer_parallel_chunk_size=32 \
#     --eval_dataset.json_delta_action_dataset.tokenizer_parallel_batch_size=32 \
#     --eval_dataset.json_delta_action_dataset.use_data_sharded_loader=True \
#     --checkpointer.save_optimizer_state=False \
#     --autoresume=False \
#     --logger.append_uuid=False \
#     --logger.online=True \
#     --logger.project_id="$project_id" \
#     --logger.experiment_id="$experiment_id" \
#     --logger.experiment_note="$experiment_note" \
#     --logger.output_dir="$output_dir" \
#     --logger.wandb_dir="$HOME/experiment_output/$project_id"



# export SCRIPT_DIR="$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )"
# export PROJECT_DIR="$( cd -- "$( dirname -- "$SCRIPT_DIR" )" &> /dev/null && pwd )"
# cd $PROJECT_DIR
# export PYTHONPATH="$PYTHONPATH:$PROJECT_DIR"
# export LIBTPU_INIT_ARGS="--xla_tpu_megacore_fusion_allow_ags=false --xla_enable_async_collective_permute=true --xla_tpu_enable_ag_backward_pipelining=true --xla_tpu_enable_data_parallel_all_reduce_opt=true --xla_tpu_data_parallel_opt_different_sized_ops=true --xla_tpu_enable_async_collective_fusion=true --xla_tpu_enable_async_collective_fusion_multiple_steps=true --xla_tpu_overlap_compute_collective_tc=true --xla_enable_async_all_gather=true"

# export llama_tokenizer_path="llava/checkpoints/lwm_checkpoints/tokenizer.model"
# export dataset_path="/home/World-Model/data/bridge_singleview_total_action_processed_filtered_action_no_inst_train.jsonl"
# export eval_dataset_path="/home/World-Model/data/bridge_singleview_total_action_processed_action_val.jsonl"

# export output_dir="llava/checkpoints/debug"

# export project_id='lwm'
# export experiment_note='world-model'
# export experiment_id='bridge_rt2_epoch_filtered_no_inst_epoch3_continued'

# # mesh_dim: dp, fsdp, tp, sp
# CUDA_VISIBLE_DEVICES=4,5,6,7 python3 -u -m lwm.train \
#     --modality='vision,action' \
#     --mesh_dim='!-1,4,1,1' \
#     --dtype='bf16' \
#     --total_steps=20870 \
#     --log_freq=1 \
#     --eval_steps=1 \
#     --save_model_freq=0 \
#     --eval_log_freq=100 \
#     --save_milestone_freq=10430 \
#     --load_llama_config='7b' \
#     --load_checkpoint="params::/home/World-Model/llava/checkpoints/debug/bridge_rt2_epoch_filtered_no_inst/streaming_params_10430"\
#     --update_llama_config="dict(action_vocab_size=245,theta=50000000,max_sequence_length=2048,use_flash_attention=True,scan_attention=True,scan_query_chunk_size=512,scan_key_chunk_size=1024,remat_attention='nothing_saveable',scan_mlp=True,scan_mlp_chunk_size=8192,remat_mlp='nothing_saveable',remat_block='nothing_saveable',scan_layers=True)" \
#     --tokenizer.vocab_file="$llama_tokenizer_path" \
#     --optimizer.type='adamw' \
#     --llama.action_vocab_size=245 \
#     --optimizer.accumulate_gradient_steps=1 \
#     --optimizer.adamw_optimizer.weight_decay=0 \
#     --optimizer.adamw_optimizer.lr=2e-5 \
#     --optimizer.adamw_optimizer.end_lr=2e-5 \
#     --optimizer.adamw_optimizer.lr_warmup_steps=0 \
#     --optimizer.adamw_optimizer.lr_decay_steps=45861 \
#     --use_data_sharded_loader=True \
#     --train_dataset.type='json_vision_action' \
#     --train_dataset.vision_action_processor.fields_from_example='fields' \
#     --train_dataset.vision_action_processor.n_tokens_per_action=7 \
#     --train_dataset.vision_action_processor.max_n_frames=1 \
#     --train_dataset.json_vision_action_dataset.mode="pad" \
#     --train_dataset.json_vision_action_dataset.path="$dataset_path" \
#     --train_dataset.json_vision_action_dataset.seq_length=384 \
#     --train_dataset.json_vision_action_dataset.batch_size=128 \
#     --train_dataset.json_vision_action_dataset.tokenizer_processes=1 \
#     --train_dataset.json_vision_action_dataset.tokenizer_parallel_chunk_size=128 \
#     --train_dataset.json_vision_action_dataset.tokenizer_parallel_batch_size=128 \
#     --train_dataset.json_vision_action_dataset.use_data_sharded_loader=True \
#     --eval_dataset.type='json_vision_action' \
#     --eval_dataset.vision_action_processor.fields_from_example='fields' \
#     --eval_dataset.vision_action_processor.n_tokens_per_action=7 \
#     --eval_dataset.vision_action_processor.max_n_frames=1 \
#     --eval_dataset.json_vision_action_dataset.mode="pad" \
#     --eval_dataset.json_vision_action_dataset.path="$eval_dataset_path" \
#     --eval_dataset.json_vision_action_dataset.seq_length=384 \
#     --eval_dataset.json_vision_action_dataset.batch_size=128 \
#     --eval_dataset.json_vision_action_dataset.tokenizer_processes=1 \
#     --eval_dataset.json_vision_action_dataset.tokenizer_parallel_chunk_size=128 \
#     --eval_dataset.json_vision_action_dataset.tokenizer_parallel_batch_size=128 \
#     --eval_dataset.json_vision_action_dataset.use_data_sharded_loader=True \
#     --checkpointer.save_optimizer_state=False \
#     --autoresume=False \
#     --logger.append_uuid=False \
#     --logger.online=True \
#     --logger.project_id="$project_id" \
#     --logger.experiment_id="$experiment_id" \
#     --logger.experiment_note="$experiment_note" \
#     --logger.output_dir="$output_dir" \
#     --logger.wandb_dir="$HOME/experiment_output/$project_id"





# export SCRIPT_DIR="$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )"
# export PROJECT_DIR="$( cd -- "$( dirname -- "$SCRIPT_DIR" )" &> /dev/null && pwd )"
# cd $PROJECT_DIR
# export PYTHONPATH="$PYTHONPATH:$PROJECT_DIR"
# export LIBTPU_INIT_ARGS="--xla_tpu_megacore_fusion_allow_ags=false --xla_enable_async_collective_permute=true --xla_tpu_enable_ag_backward_pipelining=true --xla_tpu_enable_data_parallel_all_reduce_opt=true --xla_tpu_data_parallel_opt_different_sized_ops=true --xla_tpu_enable_async_collective_fusion=true --xla_tpu_enable_async_collective_fusion_multiple_steps=true --xla_tpu_overlap_compute_collective_tc=true --xla_enable_async_all_gather=true"

# export llama_tokenizer_path="llava/checkpoints/lwm_checkpoints/tokenizer.model"
# export dataset_path="/home/World-Model/data/bridge_cross_attn_nsvq_code2_layer8_seq4_train_no_inst.jsonl"
# export eval_dataset_path="/home/World-Model/data/bridge_cross_attn_nsvq_code2_layer8_seq4_val.jsonl"

# export output_dir="llava/checkpoints/debug"

# export project_id='lwm'
# export experiment_note='world-model'
# export experiment_id='bridge_cross_attn_nsvq_code2_layer8_seq4_no_inst_noact_re'

# # mesh_dim: dp, fsdp, tp, sp
# CUDA_VISIBLE_DEVICES=4,5,6,7 python3 -u -m lwm.train \
#     --modality='vision,text,delta' \
#     --mesh_dim='!-1,4,1,1' \
#     --dtype='bf16' \
#     --total_steps=10730 \
#     --log_freq=1 \
#     --delta_tokens=4 \
#     --eval_steps=1 \
#     --save_model_freq=0 \
#     --eval_log_freq=100 \
#     --save_milestone_freq=10725 \
#     --load_llama_config='7b' \
#     --load_checkpoint="params::/home/World-Model/llava/checkpoints/lwm_checkpoints/params"\
#     --update_llama_config="dict(delta_vocab_size=2,theta=50000000,max_sequence_length=2048,use_flash_attention=True,scan_attention=True,scan_query_chunk_size=512,scan_key_chunk_size=1024,remat_attention='nothing_saveable',scan_mlp=True,scan_mlp_chunk_size=8192,remat_mlp='nothing_saveable',remat_block='nothing_saveable',scan_layers=True)" \
#     --tokenizer.vocab_file="$llama_tokenizer_path" \
#     --optimizer.type='adamw' \
#     --llama.delta_vocab_size=2 \
#     --optimizer.accumulate_gradient_steps=1 \
#     --optimizer.adamw_optimizer.weight_decay=0 \
#     --optimizer.adamw_optimizer.lr=2e-5 \
#     --optimizer.adamw_optimizer.end_lr=0 \
#     --optimizer.adamw_optimizer.lr_warmup_steps=1000 \
#     --optimizer.adamw_optimizer.lr_decay_steps=15700 \
#     --use_data_sharded_loader=True \
#     --train_dataset.type='json_vision_delta' \
#     --train_dataset.delta_vision_text_processor.fields_from_example='fields' \
#     --train_dataset.delta_vision_text_processor.n_tokens_per_delta=4 \
#     --train_dataset.delta_vision_text_processor.max_n_frames=1 \
#     --train_dataset.json_delta_dataset.mode="pad" \
#     --train_dataset.json_delta_dataset.path="$dataset_path" \
#     --train_dataset.json_delta_dataset.seq_length=384 \
#     --train_dataset.json_delta_dataset.batch_size=128 \
#     --train_dataset.json_delta_dataset.tokenizer_processes=1 \
#     --train_dataset.json_delta_dataset.tokenizer_parallel_chunk_size=128 \
#     --train_dataset.json_delta_dataset.tokenizer_parallel_batch_size=128 \
#     --train_dataset.json_delta_dataset.use_data_sharded_loader=True \
#     --eval_dataset.type='json_vision_delta' \
#     --eval_dataset.delta_vision_text_processor.fields_from_example='fields' \
#     --eval_dataset.delta_vision_text_processor.n_tokens_per_delta=4 \
#     --eval_dataset.delta_vision_text_processor.max_n_frames=1 \
#     --eval_dataset.json_delta_dataset.mode="pad" \
#     --eval_dataset.json_delta_dataset.path="$eval_dataset_path" \
#     --eval_dataset.json_delta_dataset.seq_length=384 \
#     --eval_dataset.json_delta_dataset.batch_size=128 \
#     --eval_dataset.json_delta_dataset.tokenizer_processes=1 \
#     --eval_dataset.json_delta_dataset.tokenizer_parallel_chunk_size=128 \
#     --eval_dataset.json_delta_dataset.tokenizer_parallel_batch_size=128 \
#     --eval_dataset.json_delta_dataset.use_data_sharded_loader=True \
#     --checkpointer.save_optimizer_state=False \
#     --autoresume=False \
#     --logger.append_uuid=False \
#     --logger.online=True \
#     --logger.project_id="$project_id" \
#     --logger.experiment_id="$experiment_id" \
#     --logger.experiment_note="$experiment_note" \
#     --logger.output_dir="$output_dir" \
#     --logger.wandb_dir="$HOME/experiment_output/$project_id"



# export SCRIPT_DIR="$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )"
# export PROJECT_DIR="$( cd -- "$( dirname -- "$SCRIPT_DIR" )" &> /dev/null && pwd )"
# cd $PROJECT_DIR
# export PYTHONPATH="$PYTHONPATH:$PROJECT_DIR"
# export LIBTPU_INIT_ARGS="--xla_tpu_megacore_fusion_allow_ags=false --xla_enable_async_collective_permute=true --xla_tpu_enable_ag_backward_pipelining=true --xla_tpu_enable_data_parallel_all_reduce_opt=true --xla_tpu_data_parallel_opt_different_sized_ops=true --xla_tpu_enable_async_collective_fusion=true --xla_tpu_enable_async_collective_fusion_multiple_steps=true --xla_tpu_overlap_compute_collective_tc=true --xla_enable_async_all_gather=true"

# export llama_tokenizer_path="llava/checkpoints/lwm_checkpoints/tokenizer.model"
# export dataset_path='/home/World-Model/data/bridge_simplerenv_tabletop_total_256_filtered_no_inst_train.jsonl'
# export eval_dataset_path='/home/World-Model/data/bridge_simplerenv_tabletop_total_256_filtered_no_inst_val.jsonl'


# export output_dir="llava/checkpoints/debug"

# export project_id='lwm'
# export experiment_note='world-model'
# export experiment_id='lwm_1000traj_bridge_bf16_from_whole_delta2_batch32_seq4_only_tabletop_total_no_inst_nolang_freeze'

# # mesh_dim: dp, fsdp, tp, sp
# CUDA_VISIBLE_DEVICES=0,1,2,3 python3 -u -m lwm.train \
#     --modality='vision,action,delta' \
#     --mesh_dim='!-1,4,1,1' \
#     --dtype='bf16' \
#     --total_steps=10005 \
#     --log_freq=1 \
#     --delta_tokens=1 \
#     --eval_steps=1 \
#     --save_model_freq=0 \
#     --eval_log_freq=100 \
#     --save_milestone_freq=5000 \
#     --load_llama_config='7b' \
#     --load_checkpoint="params::/home/World-Model/llava/checkpoints/debug/bridge_cross_attn_nsvq_code2_layer8_seq4_no_inst_noact_re/streaming_params_10725"\
#     --update_llama_config="dict(action_vocab_size=238,delta_vocab_size=2,theta=50000000,max_sequence_length=2048,use_flash_attention=True,scan_attention=True,scan_query_chunk_size=512,scan_key_chunk_size=1024,remat_attention='nothing_saveable',scan_mlp=True,scan_mlp_chunk_size=8192,remat_mlp='nothing_saveable',remat_block='nothing_saveable',scan_layers=True)" \
#     --tokenizer.vocab_file="$llama_tokenizer_path" \
#     --optimizer.type='adamw' \
#     --freeze=1 \
#     --llama.delta_vocab_size=2 \
#     --llama.action_vocab_size=238 \
#     --optimizer.accumulate_gradient_steps=1 \
#     --optimizer.adamw_optimizer.weight_decay=0 \
#     --optimizer.adamw_optimizer.lr=1e-4 \
#     --optimizer.adamw_optimizer.end_lr=1e-4 \
#     --optimizer.adamw_optimizer.lr_warmup_steps=0 \
#     --optimizer.adamw_optimizer.lr_decay_steps=3000 \
#     --use_data_sharded_loader=True \
#     --train_dataset.type='json_vision_delta_action' \
#     --train_dataset.delta_vision_action_processor.fields_from_example='fields' \
#     --train_dataset.delta_vision_action_processor.n_tokens_per_delta=4 \
#     --train_dataset.delta_vision_action_processor.n_tokens_per_action=7 \
#     --train_dataset.delta_vision_action_processor.max_n_frames=1 \
#     --train_dataset.json_delta_action_dataset.mode="pad" \
#     --train_dataset.json_delta_action_dataset.path="$dataset_path" \
#     --train_dataset.json_delta_action_dataset.seq_length=384 \
#     --train_dataset.json_delta_action_dataset.batch_size=32 \
#     --train_dataset.json_delta_action_dataset.tokenizer_processes=1 \
#     --train_dataset.json_delta_action_dataset.tokenizer_parallel_chunk_size=32 \
#     --train_dataset.json_delta_action_dataset.tokenizer_parallel_batch_size=32 \
#     --train_dataset.json_delta_action_dataset.use_data_sharded_loader=True \
#     --eval_dataset.type='json_vision_delta_action' \
#     --eval_dataset.delta_vision_action_processor.fields_from_example='fields' \
#     --eval_dataset.delta_vision_action_processor.n_tokens_per_delta=4 \
#     --eval_dataset.delta_vision_action_processor.n_tokens_per_action=7 \
#     --eval_dataset.delta_vision_action_processor.max_n_frames=1 \
#     --eval_dataset.json_delta_action_dataset.mode="pad" \
#     --eval_dataset.json_delta_action_dataset.path="$eval_dataset_path" \
#     --eval_dataset.json_delta_action_dataset.seq_length=384 \
#     --eval_dataset.json_delta_action_dataset.batch_size=32 \
#     --eval_dataset.json_delta_action_dataset.tokenizer_processes=1 \
#     --eval_dataset.json_delta_action_dataset.tokenizer_parallel_chunk_size=32 \
#     --eval_dataset.json_delta_action_dataset.tokenizer_parallel_batch_size=32 \
#     --eval_dataset.json_delta_action_dataset.use_data_sharded_loader=True \
#     --checkpointer.save_optimizer_state=False \
#     --autoresume=False \
#     --logger.append_uuid=False \
#     --logger.online=True \
#     --logger.project_id="$project_id" \
#     --logger.experiment_id="$experiment_id" \
#     --logger.experiment_note="$experiment_note" \
#     --logger.output_dir="$output_dir" \
#     --logger.wandb_dir="$HOME/experiment_output/$project_id"


# export SCRIPT_DIR="$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )"
# export PROJECT_DIR="$( cd -- "$( dirname -- "$SCRIPT_DIR" )" &> /dev/null && pwd )"
# cd $PROJECT_DIR
# export PYTHONPATH="$PYTHONPATH:$PROJECT_DIR"
# export LIBTPU_INIT_ARGS="--xla_tpu_megacore_fusion_allow_ags=false --xla_enable_async_collective_permute=true --xla_tpu_enable_ag_backward_pipelining=true --xla_tpu_enable_data_parallel_all_reduce_opt=true --xla_tpu_data_parallel_opt_different_sized_ops=true --xla_tpu_enable_async_collective_fusion=true --xla_tpu_enable_async_collective_fusion_multiple_steps=true --xla_tpu_overlap_compute_collective_tc=true --xla_enable_async_all_gather=true"

# export llama_tokenizer_path="llava/checkpoints/lwm_checkpoints/tokenizer.model"
# export dataset_path='/home/World-Model/data//bridge_singleview_total_action_processed_filtered_action_no_inst_train_carrot.jsonl'
# export eval_dataset_path="/home/World-Model/data/bridge_singleview_total_action_processed_action_val.jsonl"


# export output_dir="llava/checkpoints/debug"

# export project_id='lwm'
# export experiment_note='world-model'
# export experiment_id='lwm_1000traj_bridge_bf16_from_whole_delta2_batch32_seq4_only_tabletop_total_no_inst_nolang_carrot_10000'

# # mesh_dim: dp, fsdp, tp, sp
# CUDA_VISIBLE_DEVICES=0,1,2,3 python3 -u -m lwm.train \
#     --modality='vision,action,delta' \
#     --mesh_dim='!-1,4,1,1' \
#     --dtype='bf16' \
#     --total_steps=10005 \
#     --log_freq=1 \
#     --delta_tokens=1 \
#     --eval_steps=1 \
#     --save_model_freq=0 \
#     --eval_log_freq=100 \
#     --save_milestone_freq=1000 \
#     --load_llama_config='7b' \
#     --load_checkpoint="params::/home/World-Model/llava/checkpoints/debug/bridge_cross_attn_nsvq_code2_layer8_seq4_no_inst_noact_re/streaming_params_10725"\
#     --update_llama_config="dict(action_vocab_size=245,delta_vocab_size=2,theta=50000000,max_sequence_length=2048,use_flash_attention=True,scan_attention=True,scan_query_chunk_size=512,scan_key_chunk_size=1024,remat_attention='nothing_saveable',scan_mlp=True,scan_mlp_chunk_size=8192,remat_mlp='nothing_saveable',remat_block='nothing_saveable',scan_layers=True)" \
#     --tokenizer.vocab_file="$llama_tokenizer_path" \
#     --optimizer.type='adamw' \
#     --llama.delta_vocab_size=2 \
#     --llama.action_vocab_size=245 \
#     --optimizer.accumulate_gradient_steps=1 \
#     --optimizer.adamw_optimizer.weight_decay=0 \
#     --optimizer.adamw_optimizer.lr=2e-5 \
#     --optimizer.adamw_optimizer.end_lr=2e-5 \
#     --optimizer.adamw_optimizer.lr_warmup_steps=0 \
#     --optimizer.adamw_optimizer.lr_decay_steps=3000 \
#     --use_data_sharded_loader=True \
#     --train_dataset.type='json_vision_delta_action' \
#     --train_dataset.delta_vision_action_processor.fields_from_example='fields' \
#     --train_dataset.delta_vision_action_processor.n_tokens_per_delta=4 \
#     --train_dataset.delta_vision_action_processor.n_tokens_per_action=7 \
#     --train_dataset.delta_vision_action_processor.max_n_frames=1 \
#     --train_dataset.json_delta_action_dataset.mode="pad" \
#     --train_dataset.json_delta_action_dataset.path="$dataset_path" \
#     --train_dataset.json_delta_action_dataset.seq_length=384 \
#     --train_dataset.json_delta_action_dataset.batch_size=32 \
#     --train_dataset.json_delta_action_dataset.tokenizer_processes=1 \
#     --train_dataset.json_delta_action_dataset.tokenizer_parallel_chunk_size=32 \
#     --train_dataset.json_delta_action_dataset.tokenizer_parallel_batch_size=32 \
#     --train_dataset.json_delta_action_dataset.use_data_sharded_loader=True \
#     --eval_dataset.type='json_vision_delta_action' \
#     --eval_dataset.delta_vision_action_processor.fields_from_example='fields' \
#     --eval_dataset.delta_vision_action_processor.n_tokens_per_delta=4 \
#     --eval_dataset.delta_vision_action_processor.n_tokens_per_action=7 \
#     --eval_dataset.delta_vision_action_processor.max_n_frames=1 \
#     --eval_dataset.json_delta_action_dataset.mode="pad" \
#     --eval_dataset.json_delta_action_dataset.path="$eval_dataset_path" \
#     --eval_dataset.json_delta_action_dataset.seq_length=384 \
#     --eval_dataset.json_delta_action_dataset.batch_size=32 \
#     --eval_dataset.json_delta_action_dataset.tokenizer_processes=1 \
#     --eval_dataset.json_delta_action_dataset.tokenizer_parallel_chunk_size=32 \
#     --eval_dataset.json_delta_action_dataset.tokenizer_parallel_batch_size=32 \
#     --eval_dataset.json_delta_action_dataset.use_data_sharded_loader=True \
#     --checkpointer.save_optimizer_state=False \
#     --autoresume=False \
#     --logger.append_uuid=False \
#     --logger.online=True \
#     --logger.project_id="$project_id" \
#     --logger.experiment_id="$experiment_id" \
#     --logger.experiment_note="$experiment_note" \
#     --logger.output_dir="$output_dir" \
#     --logger.wandb_dir="$HOME/experiment_output/$project_id"

# export SCRIPT_DIR="$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )"
# export PROJECT_DIR="$( cd -- "$( dirname -- "$SCRIPT_DIR" )" &> /dev/null && pwd )"
# cd $PROJECT_DIR
# export PYTHONPATH="$PYTHONPATH:$PROJECT_DIR"
# export LIBTPU_INIT_ARGS="--xla_tpu_megacore_fusion_allow_ags=false --xla_enable_async_collective_permute=true --xla_tpu_enable_ag_backward_pipelining=true --xla_tpu_enable_data_parallel_all_reduce_opt=true --xla_tpu_data_parallel_opt_different_sized_ops=true --xla_tpu_enable_async_collective_fusion=true --xla_tpu_enable_async_collective_fusion_multiple_steps=true --xla_tpu_overlap_compute_collective_tc=true --xla_enable_async_all_gather=true"

# export llama_tokenizer_path="llava/checkpoints/lwm_checkpoints/tokenizer.model"
# export dataset_path="/home/World-Model/data/bridge_singleview_total_action_processed_filtered_action_no_inst_train.jsonl"
# export eval_dataset_path="/home/World-Model/data/bridge_singleview_total_action_processed_action_val.jsonl"

# export output_dir="llava/checkpoints/debug"

# export project_id='lwm'
# export experiment_note='world-model'
# export experiment_id='lwm_total_bridge_bf16_from_whole_delta2_batch32_seq4_filtered_no_inst_epoch3'

# # mesh_dim: dp, fsdp, tp, sp
# CUDA_VISIBLE_DEVICES=4,5,6,7 python3 -u -m lwm.train \
#     --modality='vision,action,delta' \
#     --mesh_dim='!-1,4,1,1' \
#     --dtype='bf16' \
#     --total_steps=20870 \
#     --log_freq=1 \
#     --eval_steps=1 \
#     --save_model_freq=0 \
#     --eval_log_freq=100 \
#     --save_milestone_freq=10430 \
#     --load_llama_config='7b' \
#     --load_checkpoint="params::/home/World-Model/llava/checkpoints/debug/lwm_total_bridge_bf16_from_whole_delta2_batch32_seq4_filtered_no_inst/streaming_params_10430"\
#     --update_llama_config="dict(action_vocab_size=245,delta_vocab_size=2,theta=50000000,max_sequence_length=2048,use_flash_attention=True,scan_attention=True,scan_query_chunk_size=512,scan_key_chunk_size=1024,remat_attention='nothing_saveable',scan_mlp=True,scan_mlp_chunk_size=8192,remat_mlp='nothing_saveable',remat_block='nothing_saveable',scan_layers=True)" \
#     --tokenizer.vocab_file="$llama_tokenizer_path" \
#     --optimizer.type='adamw' \
#     --llama.delta_vocab_size=2 \
#     --llama.action_vocab_size=245 \
#     --optimizer.accumulate_gradient_steps=1 \
#     --optimizer.adamw_optimizer.weight_decay=0 \
#     --optimizer.adamw_optimizer.lr=2e-5 \
#     --optimizer.adamw_optimizer.end_lr=2e-5 \
#     --optimizer.adamw_optimizer.lr_warmup_steps=0 \
#     --optimizer.adamw_optimizer.lr_decay_steps=45861 \
#     --use_data_sharded_loader=True \
#     --train_dataset.type='json_vision_delta_action' \
#     --train_dataset.delta_vision_action_processor.fields_from_example='fields' \
#     --train_dataset.delta_vision_action_processor.n_tokens_per_delta=4 \
#     --train_dataset.delta_vision_action_processor.n_tokens_per_action=7 \
#     --train_dataset.delta_vision_action_processor.max_n_frames=1 \
#     --train_dataset.json_delta_action_dataset.mode="pad" \
#     --train_dataset.json_delta_action_dataset.path="$dataset_path" \
#     --train_dataset.json_delta_action_dataset.seq_length=384 \
#     --train_dataset.json_delta_action_dataset.batch_size=128 \
#     --train_dataset.json_delta_action_dataset.tokenizer_processes=1 \
#     --train_dataset.json_delta_action_dataset.tokenizer_parallel_chunk_size=128 \
#     --train_dataset.json_delta_action_dataset.tokenizer_parallel_batch_size=128 \
#     --train_dataset.json_delta_action_dataset.use_data_sharded_loader=True \
#     --eval_dataset.type='json_vision_delta_action' \
#     --eval_dataset.delta_vision_action_processor.fields_from_example='fields' \
#     --eval_dataset.delta_vision_action_processor.n_tokens_per_delta=4 \
#     --eval_dataset.delta_vision_action_processor.n_tokens_per_action=7 \
#     --eval_dataset.delta_vision_action_processor.max_n_frames=1 \
#     --eval_dataset.json_delta_action_dataset.mode="pad" \
#     --eval_dataset.json_delta_action_dataset.path="$eval_dataset_path" \
#     --eval_dataset.json_delta_action_dataset.seq_length=384 \
#     --eval_dataset.json_delta_action_dataset.batch_size=128 \
#     --eval_dataset.json_delta_action_dataset.tokenizer_processes=1 \
#     --eval_dataset.json_delta_action_dataset.tokenizer_parallel_chunk_size=128 \
#     --eval_dataset.json_delta_action_dataset.tokenizer_parallel_batch_size=128 \
#     --eval_dataset.json_delta_action_dataset.use_data_sharded_loader=True \
#     --checkpointer.save_optimizer_state=False \
#     --autoresume=False \
#     --logger.append_uuid=False \
#     --logger.online=True \
#     --logger.project_id="$project_id" \
#     --logger.experiment_id="$experiment_id" \
#     --logger.experiment_note="$experiment_note" \
#     --logger.output_dir="$output_dir" \
#     --logger.wandb_dir="$HOME/experiment_output/$project_id"





# export SCRIPT_DIR="$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )"
# export PROJECT_DIR="$( cd -- "$( dirname -- "$SCRIPT_DIR" )" &> /dev/null && pwd )"
# cd $PROJECT_DIR
# export PYTHONPATH="$PYTHONPATH:$PROJECT_DIR"
# export LIBTPU_INIT_ARGS="--xla_tpu_megacore_fusion_allow_ags=false --xla_enable_async_collective_permute=true --xla_tpu_enable_ag_backward_pipelining=true --xla_tpu_enable_data_parallel_all_reduce_opt=true --xla_tpu_data_parallel_opt_different_sized_ops=true --xla_tpu_enable_async_collective_fusion=true --xla_tpu_enable_async_collective_fusion_multiple_steps=true --xla_tpu_overlap_compute_collective_tc=true --xla_enable_async_all_gather=true"

# export llama_tokenizer_path="llava/checkpoints/lwm_checkpoints/tokenizer.model"
# export dataset_path='/home/World-Model/data/bridge_singleview_total_action_processed_filtered_action_no_inst_train_1000traj.jsonl'
# export eval_dataset_path='/home/World-Model/data/bridge_simplerenv_tabletop_total_256_filtered_no_inst_val_1000traj.jsonl'


# export output_dir="llava/checkpoints/debug"

# export project_id='lwm'
# export experiment_note='world-model'
# export experiment_id='lwm_1000traj_bridge_bf16_from_whole_delta2_batch32_seq4_only_tabletop_total_no_inst_nolang_1000traj'

# # mesh_dim: dp, fsdp, tp, sp
# CUDA_VISIBLE_DEVICES=0,1,2,3 python3 -u -m lwm.train \
#     --modality='vision,action,delta' \
#     --mesh_dim='!-1,4,1,1' \
#     --dtype='bf16' \
#     --total_steps=5005 \
#     --log_freq=1 \
#     --delta_tokens=1 \
#     --eval_steps=1 \
#     --save_model_freq=0 \
#     --eval_log_freq=100 \
#     --save_milestone_freq=5000 \
#     --load_llama_config='7b' \
#     --load_checkpoint="params::/home/World-Model/llava/checkpoints/debug/bridge_cross_attn_nsvq_code2_layer8_seq4_no_inst_noact_re/streaming_params_10725"\
#     --update_llama_config="dict(action_vocab_size=250,delta_vocab_size=2,theta=50000000,max_sequence_length=2048,use_flash_attention=True,scan_attention=True,scan_query_chunk_size=512,scan_key_chunk_size=1024,remat_attention='nothing_saveable',scan_mlp=True,scan_mlp_chunk_size=8192,remat_mlp='nothing_saveable',remat_block='nothing_saveable',scan_layers=True)" \
#     --tokenizer.vocab_file="$llama_tokenizer_path" \
#     --optimizer.type='adamw' \
#     --llama.delta_vocab_size=2 \
#     --llama.action_vocab_size=250 \
#     --optimizer.accumulate_gradient_steps=1 \
#     --optimizer.adamw_optimizer.weight_decay=0 \
#     --optimizer.adamw_optimizer.lr=2e-5 \
#     --optimizer.adamw_optimizer.end_lr=2e-5 \
#     --optimizer.adamw_optimizer.lr_warmup_steps=0 \
#     --optimizer.adamw_optimizer.lr_decay_steps=3000 \
#     --use_data_sharded_loader=True \
#     --train_dataset.type='json_vision_delta_action' \
#     --train_dataset.delta_vision_action_processor.fields_from_example='fields' \
#     --train_dataset.delta_vision_action_processor.n_tokens_per_delta=4 \
#     --train_dataset.delta_vision_action_processor.n_tokens_per_action=7 \
#     --train_dataset.delta_vision_action_processor.max_n_frames=1 \
#     --train_dataset.json_delta_action_dataset.mode="pad" \
#     --train_dataset.json_delta_action_dataset.path="$dataset_path" \
#     --train_dataset.json_delta_action_dataset.seq_length=384 \
#     --train_dataset.json_delta_action_dataset.batch_size=32 \
#     --train_dataset.json_delta_action_dataset.tokenizer_processes=1 \
#     --train_dataset.json_delta_action_dataset.tokenizer_parallel_chunk_size=32 \
#     --train_dataset.json_delta_action_dataset.tokenizer_parallel_batch_size=32 \
#     --train_dataset.json_delta_action_dataset.use_data_sharded_loader=True \
#     --eval_dataset.type='json_vision_delta_action' \
#     --eval_dataset.delta_vision_action_processor.fields_from_example='fields' \
#     --eval_dataset.delta_vision_action_processor.n_tokens_per_delta=4 \
#     --eval_dataset.delta_vision_action_processor.n_tokens_per_action=7 \
#     --eval_dataset.delta_vision_action_processor.max_n_frames=1 \
#     --eval_dataset.json_delta_action_dataset.mode="pad" \
#     --eval_dataset.json_delta_action_dataset.path="$eval_dataset_path" \
#     --eval_dataset.json_delta_action_dataset.seq_length=384 \
#     --eval_dataset.json_delta_action_dataset.batch_size=32 \
#     --eval_dataset.json_delta_action_dataset.tokenizer_processes=1 \
#     --eval_dataset.json_delta_action_dataset.tokenizer_parallel_chunk_size=32 \
#     --eval_dataset.json_delta_action_dataset.tokenizer_parallel_batch_size=32 \
#     --eval_dataset.json_delta_action_dataset.use_data_sharded_loader=True \
#     --checkpointer.save_optimizer_state=False \
#     --autoresume=False \
#     --logger.append_uuid=False \
#     --logger.online=True \
#     --logger.project_id="$project_id" \
#     --logger.experiment_id="$experiment_id" \
#     --logger.experiment_note="$experiment_note" \
#     --logger.output_dir="$output_dir" \
#     --logger.wandb_dir="$HOME/experiment_output/$project_id"



# export SCRIPT_DIR="$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )"
# export PROJECT_DIR="$( cd -- "$( dirname -- "$SCRIPT_DIR" )" &> /dev/null && pwd )"
# cd $PROJECT_DIR
# export PYTHONPATH="$PYTHONPATH:$PROJECT_DIR"
# export LIBTPU_INIT_ARGS="--xla_tpu_megacore_fusion_allow_ags=false --xla_enable_async_collective_permute=true --xla_tpu_enable_ag_backward_pipelining=true --xla_tpu_enable_data_parallel_all_reduce_opt=true --xla_tpu_data_parallel_opt_different_sized_ops=true --xla_tpu_enable_async_collective_fusion=true --xla_tpu_enable_async_collective_fusion_multiple_steps=true --xla_tpu_overlap_compute_collective_tc=true --xla_enable_async_all_gather=true"

# export llama_tokenizer_path="llava/checkpoints/lwm_checkpoints/tokenizer.model"
# export dataset_path='/home/World-Model/data/bridge_simplerenv_tabletop_total_256_filtered_no_inst_delta_train.jsonl'
# export eval_dataset_path='/home/World-Model/data/bridge_simplerenv_tabletop_total_256_filtered_no_inst_delta_val.jsonl'


# export output_dir="llava/checkpoints/debug"

# export project_id='lwm'
# export experiment_note='world-model'
# export experiment_id='lwm_1000traj_bridge_bf16_from_whole_delta4_batch32_seq2_only_tabletop_total_no_inst_nolang_mse_unfreeze_re'
# # export experiment_id='debug'

# # mesh_dim: dp, fsdp, tp, sp
# CUDA_VISIBLE_DEVICES=4,5,6,7 python3 -u -m lwm.train \
#     --modality='vision,contaction,delta' \
#     --mesh_dim='!-1,4,1,1' \
#     --dtype='bf16' \
#     --total_steps=5002 \
#     --freeze=0 \
#     --mse_loss=1 \
#     --log_freq=1 \
#     --eval_steps=1 \
#     --save_model_freq=0 \
#     --eval_log_freq=100 \
#     --save_milestone_freq=2500 \
#     --load_llama_config='7b' \
#     --load_checkpoint='params::/home/World-Model/llava/checkpoints/debug/bridge_cross_attn_nsvq_code4_layer8_seq2_no_inst_noact_re/streaming_params_10725' \
#     --update_llama_config="dict(delta_vocab_size=4,theta=50000000,max_sequence_length=2048,use_flash_attention=True,scan_attention=True,scan_query_chunk_size=512,scan_key_chunk_size=1024,remat_attention='nothing_saveable',scan_mlp=True,scan_mlp_chunk_size=8192,remat_mlp='nothing_saveable',remat_block='nothing_saveable',scan_layers=True)" \
#     --tokenizer.vocab_file="$llama_tokenizer_path" \
#     --optimizer.type='adamw' \
#     --llama.delta_vocab_size=4 \
#     --optimizer.accumulate_gradient_steps=1 \
#     --optimizer.adamw_optimizer.weight_decay=0.1 \
#     --optimizer.adamw_optimizer.lr=2e-5 \
#     --optimizer.adamw_optimizer.end_lr=2e-5 \
#     --optimizer.adamw_optimizer.lr_warmup_steps=0 \
#     --optimizer.adamw_optimizer.lr_decay_steps=3000 \
#     --use_data_sharded_loader=True \
#     --train_dataset.type='json_vision_delta_cont_action' \
#     --train_dataset.delta_vision_cont_action_processor.fields_from_example='fields' \
#     --train_dataset.delta_vision_cont_action_processor.n_tokens_per_delta=2 \
#     --train_dataset.delta_vision_cont_action_processor.n_tokens_per_action=7 \
#     --train_dataset.delta_vision_cont_action_processor.max_n_frames=1 \
#     --train_dataset.json_delta_cont_action_dataset.mode="pad" \
#     --train_dataset.json_delta_cont_action_dataset.path="$dataset_path" \
#     --train_dataset.json_delta_cont_action_dataset.seq_length=384 \
#     --train_dataset.json_delta_cont_action_dataset.batch_size=32 \
#     --train_dataset.json_delta_cont_action_dataset.tokenizer_processes=1 \
#     --train_dataset.json_delta_cont_action_dataset.tokenizer_parallel_chunk_size=32 \
#     --train_dataset.json_delta_cont_action_dataset.tokenizer_parallel_batch_size=32 \
#     --train_dataset.json_delta_cont_action_dataset.use_data_sharded_loader=True \
#     --eval_dataset.type='json_vision_delta_cont_action' \
#     --eval_dataset.delta_vision_cont_action_processor.fields_from_example='fields' \
#     --eval_dataset.delta_vision_cont_action_processor.n_tokens_per_delta=2 \
#     --eval_dataset.delta_vision_cont_action_processor.n_tokens_per_action=7 \
#     --eval_dataset.delta_vision_cont_action_processor.max_n_frames=1 \
#     --eval_dataset.json_delta_cont_action_dataset.mode="pad" \
#     --eval_dataset.json_delta_cont_action_dataset.path="$eval_dataset_path" \
#     --eval_dataset.json_delta_cont_action_dataset.seq_length=384 \
#     --eval_dataset.json_delta_cont_action_dataset.batch_size=32 \
#     --eval_dataset.json_delta_cont_action_dataset.tokenizer_processes=1 \
#     --eval_dataset.json_delta_cont_action_dataset.tokenizer_parallel_chunk_size=32 \
#     --eval_dataset.json_delta_cont_action_dataset.tokenizer_parallel_batch_size=32 \
#     --eval_dataset.json_delta_cont_action_dataset.use_data_sharded_loader=True \
#     --checkpointer.save_optimizer_state=False \
#     --autoresume=False \
#     --logger.append_uuid=False \
#     --logger.online=True \
#     --logger.project_id="$project_id" \
#     --logger.experiment_id="$experiment_id" \
#     --logger.experiment_note="$experiment_note" \
#     --logger.output_dir="$output_dir" \
#     --logger.wandb_dir="$HOME/experiment_output/$project_id"




# export SCRIPT_DIR="$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )"
# export PROJECT_DIR="$( cd -- "$( dirname -- "$SCRIPT_DIR" )" &> /dev/null && pwd )"
# cd $PROJECT_DIR
# export PYTHONPATH="$PYTHONPATH:$PROJECT_DIR"
# export LIBTPU_INIT_ARGS="--xla_tpu_megacore_fusion_allow_ags=false --xla_enable_async_collective_permute=true --xla_tpu_enable_ag_backward_pipelining=true --xla_tpu_enable_data_parallel_all_reduce_opt=true --xla_tpu_data_parallel_opt_different_sized_ops=true --xla_tpu_enable_async_collective_fusion=true --xla_tpu_enable_async_collective_fusion_multiple_steps=true --xla_tpu_overlap_compute_collective_tc=true --xla_enable_async_all_gather=true"

# export llama_tokenizer_path="llava/checkpoints/lwm_checkpoints/tokenizer.model"
# export dataset_path='/home/World-Model/data/bridge_simplerenv_tabletop_total_256_filtered_no_inst_delta_train.jsonl'
# export eval_dataset_path='/home/World-Model/data/bridge_simplerenv_tabletop_total_256_filtered_no_inst_delta_val.jsonl'


# export output_dir="llava/checkpoints/debug"

# export project_id='lwm'
# export experiment_note='world-model'
# export experiment_id='lwm_1000traj_bridge_bf16_from_whole_delta4_batch32_seq2_only_tabletop_total_no_inst_nolang_mse_freeze'
# # export experiment_id='debug'

# # mesh_dim: dp, fsdp, tp, sp
# CUDA_VISIBLE_DEVICES=4,5,6,7 python3 -u -m lwm.train \
#     --modality='vision,contaction,delta' \
#     --mesh_dim='!-1,4,1,1' \
#     --dtype='bf16' \
#     --total_steps=5002 \
#     --freeze=1 \
#     --mse_loss=1 \
#     --log_freq=1 \
#     --eval_steps=1 \
#     --save_model_freq=0 \
#     --eval_log_freq=100 \
#     --save_milestone_freq=2500 \
#     --load_llama_config='7b' \
#     --load_checkpoint='params::/home/World-Model/llava/checkpoints/debug/bridge_cross_attn_nsvq_code4_layer8_seq2_no_inst_noact_re/streaming_params_10725' \
#     --update_llama_config="dict(delta_vocab_size=4,theta=50000000,max_sequence_length=2048,use_flash_attention=True,scan_attention=True,scan_query_chunk_size=512,scan_key_chunk_size=1024,remat_attention='nothing_saveable',scan_mlp=True,scan_mlp_chunk_size=8192,remat_mlp='nothing_saveable',remat_block='nothing_saveable',scan_layers=True)" \
#     --tokenizer.vocab_file="$llama_tokenizer_path" \
#     --optimizer.type='adamw' \
#     --llama.delta_vocab_size=4 \
#     --optimizer.accumulate_gradient_steps=1 \
#     --optimizer.adamw_optimizer.weight_decay=0.1 \
#     --optimizer.adamw_optimizer.lr=2e-5 \
#     --optimizer.adamw_optimizer.end_lr=2e-5 \
#     --optimizer.adamw_optimizer.lr_warmup_steps=0 \
#     --optimizer.adamw_optimizer.lr_decay_steps=3000 \
#     --use_data_sharded_loader=True \
#     --train_dataset.type='json_vision_delta_cont_action' \
#     --train_dataset.delta_vision_cont_action_processor.fields_from_example='fields' \
#     --train_dataset.delta_vision_cont_action_processor.n_tokens_per_delta=2 \
#     --train_dataset.delta_vision_cont_action_processor.n_tokens_per_action=7 \
#     --train_dataset.delta_vision_cont_action_processor.max_n_frames=1 \
#     --train_dataset.json_delta_cont_action_dataset.mode="pad" \
#     --train_dataset.json_delta_cont_action_dataset.path="$dataset_path" \
#     --train_dataset.json_delta_cont_action_dataset.seq_length=384 \
#     --train_dataset.json_delta_cont_action_dataset.batch_size=32 \
#     --train_dataset.json_delta_cont_action_dataset.tokenizer_processes=1 \
#     --train_dataset.json_delta_cont_action_dataset.tokenizer_parallel_chunk_size=32 \
#     --train_dataset.json_delta_cont_action_dataset.tokenizer_parallel_batch_size=32 \
#     --train_dataset.json_delta_cont_action_dataset.use_data_sharded_loader=True \
#     --eval_dataset.type='json_vision_delta_cont_action' \
#     --eval_dataset.delta_vision_cont_action_processor.fields_from_example='fields' \
#     --eval_dataset.delta_vision_cont_action_processor.n_tokens_per_delta=2 \
#     --eval_dataset.delta_vision_cont_action_processor.n_tokens_per_action=7 \
#     --eval_dataset.delta_vision_cont_action_processor.max_n_frames=1 \
#     --eval_dataset.json_delta_cont_action_dataset.mode="pad" \
#     --eval_dataset.json_delta_cont_action_dataset.path="$eval_dataset_path" \
#     --eval_dataset.json_delta_cont_action_dataset.seq_length=384 \
#     --eval_dataset.json_delta_cont_action_dataset.batch_size=32 \
#     --eval_dataset.json_delta_cont_action_dataset.tokenizer_processes=1 \
#     --eval_dataset.json_delta_cont_action_dataset.tokenizer_parallel_chunk_size=32 \
#     --eval_dataset.json_delta_cont_action_dataset.tokenizer_parallel_batch_size=32 \
#     --eval_dataset.json_delta_cont_action_dataset.use_data_sharded_loader=True \
#     --checkpointer.save_optimizer_state=False \
#     --autoresume=False \
#     --logger.append_uuid=False \
#     --logger.online=True \
#     --logger.project_id="$project_id" \
#     --logger.experiment_id="$experiment_id" \
#     --logger.experiment_note="$experiment_note" \
#     --logger.output_dir="$output_dir" \
#     --logger.wandb_dir="$HOME/experiment_output/$project_id"



# export SCRIPT_DIR="$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )"
# export PROJECT_DIR="$( cd -- "$( dirname -- "$SCRIPT_DIR" )" &> /dev/null && pwd )"
# cd $PROJECT_DIR
# export PYTHONPATH="$PYTHONPATH:$PROJECT_DIR"
# export LIBTPU_INIT_ARGS="--xla_tpu_megacore_fusion_allow_ags=false --xla_enable_async_collective_permute=true --xla_tpu_enable_ag_backward_pipelining=true --xla_tpu_enable_data_parallel_all_reduce_opt=true --xla_tpu_data_parallel_opt_different_sized_ops=true --xla_tpu_enable_async_collective_fusion=true --xla_tpu_enable_async_collective_fusion_multiple_steps=true --xla_tpu_overlap_compute_collective_tc=true --xla_enable_async_all_gather=true"

# export llama_tokenizer_path="llava/checkpoints/lwm_checkpoints/tokenizer.model"
# export dataset_path='/home/World-Model/data/bridge_simplerenv_tabletop_total_256_filtered_no_inst_delta_train.jsonl'
# export eval_dataset_path='/home/World-Model/data/bridge_simplerenv_tabletop_total_256_filtered_no_inst_delta_val.jsonl'


# export output_dir="llava/checkpoints/debug"

# export project_id='lwm'
# export experiment_note='world-model'
# export experiment_id='lwm_1000traj_bridge_bf16_from_whole_delta4_batch32_seq2_only_tabletop_total_no_inst_nolang_normal_freeze'
# # export experiment_id='debug'

# # mesh_dim: dp, fsdp, tp, sp
# CUDA_VISIBLE_DEVICES=4,5,6,7 python3 -u -m lwm.train \
#     --modality='vision,contaction,delta' \
#     --mesh_dim='!-1,4,1,1' \
#     --dtype='bf16' \
#     --total_steps=5002 \
#     --freeze=1 \
#     --mse_loss=0 \
#     --log_freq=1 \
#     --eval_steps=1 \
#     --save_model_freq=0 \
#     --eval_log_freq=100 \
#     --save_milestone_freq=2500 \
#     --load_llama_config='7b' \
#     --load_checkpoint='params::/home/World-Model/llava/checkpoints/debug/bridge_cross_attn_nsvq_code4_layer8_seq2_no_inst_noact_re/streaming_params_10725' \
#     --update_llama_config="dict(delta_vocab_size=4,theta=50000000,max_sequence_length=2048,use_flash_attention=True,scan_attention=True,scan_query_chunk_size=512,scan_key_chunk_size=1024,remat_attention='nothing_saveable',scan_mlp=True,scan_mlp_chunk_size=8192,remat_mlp='nothing_saveable',remat_block='nothing_saveable',scan_layers=True)" \
#     --tokenizer.vocab_file="$llama_tokenizer_path" \
#     --optimizer.type='adamw' \
#     --llama.delta_vocab_size=4 \
#     --optimizer.accumulate_gradient_steps=1 \
#     --optimizer.adamw_optimizer.weight_decay=0.1 \
#     --optimizer.adamw_optimizer.lr=2e-5 \
#     --optimizer.adamw_optimizer.end_lr=2e-5 \
#     --optimizer.adamw_optimizer.lr_warmup_steps=0 \
#     --optimizer.adamw_optimizer.lr_decay_steps=3000 \
#     --use_data_sharded_loader=True \
#     --train_dataset.type='json_vision_delta_cont_action' \
#     --train_dataset.delta_vision_cont_action_processor.fields_from_example='fields' \
#     --train_dataset.delta_vision_cont_action_processor.n_tokens_per_delta=2 \
#     --train_dataset.delta_vision_cont_action_processor.n_tokens_per_action=7 \
#     --train_dataset.delta_vision_cont_action_processor.max_n_frames=1 \
#     --train_dataset.json_delta_cont_action_dataset.mode="pad" \
#     --train_dataset.json_delta_cont_action_dataset.path="$dataset_path" \
#     --train_dataset.json_delta_cont_action_dataset.seq_length=384 \
#     --train_dataset.json_delta_cont_action_dataset.batch_size=32 \
#     --train_dataset.json_delta_cont_action_dataset.tokenizer_processes=1 \
#     --train_dataset.json_delta_cont_action_dataset.tokenizer_parallel_chunk_size=32 \
#     --train_dataset.json_delta_cont_action_dataset.tokenizer_parallel_batch_size=32 \
#     --train_dataset.json_delta_cont_action_dataset.use_data_sharded_loader=True \
#     --eval_dataset.type='json_vision_delta_cont_action' \
#     --eval_dataset.delta_vision_cont_action_processor.fields_from_example='fields' \
#     --eval_dataset.delta_vision_cont_action_processor.n_tokens_per_delta=2 \
#     --eval_dataset.delta_vision_cont_action_processor.n_tokens_per_action=7 \
#     --eval_dataset.delta_vision_cont_action_processor.max_n_frames=1 \
#     --eval_dataset.json_delta_cont_action_dataset.mode="pad" \
#     --eval_dataset.json_delta_cont_action_dataset.path="$eval_dataset_path" \
#     --eval_dataset.json_delta_cont_action_dataset.seq_length=384 \
#     --eval_dataset.json_delta_cont_action_dataset.batch_size=32 \
#     --eval_dataset.json_delta_cont_action_dataset.tokenizer_processes=1 \
#     --eval_dataset.json_delta_cont_action_dataset.tokenizer_parallel_chunk_size=32 \
#     --eval_dataset.json_delta_cont_action_dataset.tokenizer_parallel_batch_size=32 \
#     --eval_dataset.json_delta_cont_action_dataset.use_data_sharded_loader=True \
#     --checkpointer.save_optimizer_state=False \
#     --autoresume=False \
#     --logger.append_uuid=False \
#     --logger.online=True \
#     --logger.project_id="$project_id" \
#     --logger.experiment_id="$experiment_id" \
#     --logger.experiment_note="$experiment_note" \
#     --logger.output_dir="$output_dir" \
#     --logger.wandb_dir="$HOME/experiment_output/$project_id"





# export SCRIPT_DIR="$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )"
# export PROJECT_DIR="$( cd -- "$( dirname -- "$SCRIPT_DIR" )" &> /dev/null && pwd )"
# cd $PROJECT_DIR
# export PYTHONPATH="$PYTHONPATH:$PROJECT_DIR"
# export LIBTPU_INIT_ARGS="--xla_tpu_megacore_fusion_allow_ags=false --xla_enable_async_collective_permute=true --xla_tpu_enable_ag_backward_pipelining=true --xla_tpu_enable_data_parallel_all_reduce_opt=true --xla_tpu_data_parallel_opt_different_sized_ops=true --xla_tpu_enable_async_collective_fusion=true --xla_tpu_enable_async_collective_fusion_multiple_steps=true --xla_tpu_overlap_compute_collective_tc=true --xla_enable_async_all_gather=true"

# export llama_tokenizer_path="llava/checkpoints/lwm_checkpoints/tokenizer.model"
# export dataset_path='/home/World-Model/data/bridge_simplerenv_tabletop_total_256_filtered_no_inst_train.jsonl'
# export eval_dataset_path='/home/World-Model/data/bridge_simplerenv_tabletop_total_256_filtered_no_inst_val.jsonl'


# export output_dir="llava/checkpoints/debug"

# export project_id='lwm'
# export experiment_note='world-model'
# export experiment_id='lwm_1000traj_bridge_bf16_from_whole_delta2_batch32_seq4_only_tabletop_total_no_inst_nolang_20000'

# # mesh_dim: dp, fsdp, tp, sp
# CUDA_VISIBLE_DEVICES=4,5,6,7 python3 -u -m lwm.train \
#     --modality='vision,action,delta' \
#     --mesh_dim='!-1,4,1,1' \
#     --dtype='bf16' \
#     --total_steps=10005 \
#     --log_freq=1 \
#     --delta_tokens=1 \
#     --eval_steps=1 \
#     --save_model_freq=0 \
#     --eval_log_freq=100 \
#     --save_milestone_freq=5000 \
#     --load_llama_config='7b' \
#     --load_checkpoint="params::/home/World-Model/llava/checkpoints/debug/lwm_1000traj_bridge_bf16_from_whole_delta2_batch32_seq4_only_tabletop_total_no_inst_nolang_10000/streaming_params_5000"\
#     --update_llama_config="dict(action_vocab_size=238,delta_vocab_size=2,theta=50000000,max_sequence_length=2048,use_flash_attention=True,scan_attention=True,scan_query_chunk_size=512,scan_key_chunk_size=1024,remat_attention='nothing_saveable',scan_mlp=True,scan_mlp_chunk_size=8192,remat_mlp='nothing_saveable',remat_block='nothing_saveable',scan_layers=True)" \
#     --tokenizer.vocab_file="$llama_tokenizer_path" \
#     --optimizer.type='adamw' \
#     --llama.delta_vocab_size=2 \
#     --llama.action_vocab_size=238 \
#     --optimizer.accumulate_gradient_steps=1 \
#     --optimizer.adamw_optimizer.weight_decay=0 \
#     --optimizer.adamw_optimizer.lr=2e-5 \
#     --optimizer.adamw_optimizer.end_lr=2e-5 \
#     --optimizer.adamw_optimizer.lr_warmup_steps=0 \
#     --optimizer.adamw_optimizer.lr_decay_steps=3000 \
#     --use_data_sharded_loader=True \
#     --train_dataset.type='json_vision_delta_action' \
#     --train_dataset.delta_vision_action_processor.fields_from_example='fields' \
#     --train_dataset.delta_vision_action_processor.n_tokens_per_delta=4 \
#     --train_dataset.delta_vision_action_processor.n_tokens_per_action=7 \
#     --train_dataset.delta_vision_action_processor.max_n_frames=1 \
#     --train_dataset.json_delta_action_dataset.mode="pad" \
#     --train_dataset.json_delta_action_dataset.path="$dataset_path" \
#     --train_dataset.json_delta_action_dataset.seq_length=384 \
#     --train_dataset.json_delta_action_dataset.batch_size=32 \
#     --train_dataset.json_delta_action_dataset.tokenizer_processes=1 \
#     --train_dataset.json_delta_action_dataset.tokenizer_parallel_chunk_size=32 \
#     --train_dataset.json_delta_action_dataset.tokenizer_parallel_batch_size=32 \
#     --train_dataset.json_delta_action_dataset.use_data_sharded_loader=True \
#     --eval_dataset.type='json_vision_delta_action' \
#     --eval_dataset.delta_vision_action_processor.fields_from_example='fields' \
#     --eval_dataset.delta_vision_action_processor.n_tokens_per_delta=4 \
#     --eval_dataset.delta_vision_action_processor.n_tokens_per_action=7 \
#     --eval_dataset.delta_vision_action_processor.max_n_frames=1 \
#     --eval_dataset.json_delta_action_dataset.mode="pad" \
#     --eval_dataset.json_delta_action_dataset.path="$eval_dataset_path" \
#     --eval_dataset.json_delta_action_dataset.seq_length=384 \
#     --eval_dataset.json_delta_action_dataset.batch_size=32 \
#     --eval_dataset.json_delta_action_dataset.tokenizer_processes=1 \
#     --eval_dataset.json_delta_action_dataset.tokenizer_parallel_chunk_size=32 \
#     --eval_dataset.json_delta_action_dataset.tokenizer_parallel_batch_size=32 \
#     --eval_dataset.json_delta_action_dataset.use_data_sharded_loader=True \
#     --checkpointer.save_optimizer_state=False \
#     --autoresume=False \
#     --logger.append_uuid=False \
#     --logger.online=True \
#     --logger.project_id="$project_id" \
#     --logger.experiment_id="$experiment_id" \
#     --logger.experiment_note="$experiment_note" \
#     --logger.output_dir="$output_dir" \
#     --logger.wandb_dir="$HOME/experiment_output/$project_id"





# export SCRIPT_DIR="$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )"
# export PROJECT_DIR="$( cd -- "$( dirname -- "$SCRIPT_DIR" )" &> /dev/null && pwd )"
# cd $PROJECT_DIR
# export PYTHONPATH="$PYTHONPATH:$PROJECT_DIR"
# export LIBTPU_INIT_ARGS="--xla_tpu_megacore_fusion_allow_ags=false --xla_enable_async_collective_permute=true --xla_tpu_enable_ag_backward_pipelining=true --xla_tpu_enable_data_parallel_all_reduce_opt=true --xla_tpu_data_parallel_opt_different_sized_ops=true --xla_tpu_enable_async_collective_fusion=true --xla_tpu_enable_async_collective_fusion_multiple_steps=true --xla_tpu_overlap_compute_collective_tc=true --xla_enable_async_all_gather=true"

# export llama_tokenizer_path="llava/checkpoints/lwm_checkpoints/tokenizer.model"
# export dataset_path='/home/World-Model/data/bridge_simplerenv_tabletop_total_256_filtered_no_inst_train.jsonl'
# export eval_dataset_path='/home/World-Model/data/bridge_simplerenv_tabletop_total_256_filtered_no_inst_val.jsonl'

# export output_dir="llava/checkpoints/debug"

# export project_id='lwm'
# export experiment_note='world-model'
# export experiment_id='lwm_1000traj_bridge_bf16_rt2_only_tabletop_total_re'

# # mesh_dim: dp, fsdp, tp, sp
# CUDA_VISIBLE_DEVICES=4,5,6,7 python3 -u -m lwm.train \
#     --modality='vision,action' \
#     --mesh_dim='!-1,4,1,1' \
#     --dtype='bf16' \
#     --total_steps=10005 \
#     --log_freq=1 \
#     --delta_tokens=1 \
#     --eval_steps=1 \
#     --save_model_freq=0 \
#     --eval_log_freq=100 \
#     --save_milestone_freq=5000 \
#     --load_llama_config='7b' \
#     --load_checkpoint="params::/home/World-Model/llava/checkpoints/lwm_checkpoints/params"\
#     --update_llama_config="dict(action_vocab_size=238,theta=50000000,max_sequence_length=2048,use_flash_attention=True,scan_attention=True,scan_query_chunk_size=512,scan_key_chunk_size=1024,remat_attention='nothing_saveable',scan_mlp=True,scan_mlp_chunk_size=8192,remat_mlp='nothing_saveable',remat_block='nothing_saveable',scan_layers=True)" \
#     --tokenizer.vocab_file="$llama_tokenizer_path" \
#     --optimizer.type='adamw' \
#     --llama.action_vocab_size=238 \
#     --optimizer.accumulate_gradient_steps=1 \
#     --optimizer.adamw_optimizer.weight_decay=0 \
#     --optimizer.adamw_optimizer.lr=2e-5 \
#     --optimizer.adamw_optimizer.end_lr=2e-5 \
#     --optimizer.adamw_optimizer.lr_warmup_steps=0 \
#     --optimizer.adamw_optimizer.lr_decay_steps=3000 \
#     --use_data_sharded_loader=True \
#     --train_dataset.type='json_vision_action' \
#     --train_dataset.vision_action_processor.fields_from_example='fields' \
#     --train_dataset.vision_action_processor.n_tokens_per_action=7 \
#     --train_dataset.vision_action_processor.max_n_frames=1 \
#     --train_dataset.json_vision_action_dataset.mode="pad" \
#     --train_dataset.json_vision_action_dataset.path="$dataset_path" \
#     --train_dataset.json_vision_action_dataset.seq_length=384 \
#     --train_dataset.json_vision_action_dataset.batch_size=32 \
#     --train_dataset.json_vision_action_dataset.tokenizer_processes=1 \
#     --train_dataset.json_vision_action_dataset.tokenizer_parallel_chunk_size=32 \
#     --train_dataset.json_vision_action_dataset.tokenizer_parallel_batch_size=32 \
#     --train_dataset.json_vision_action_dataset.use_data_sharded_loader=True \
#     --eval_dataset.type='json_vision_action' \
#     --eval_dataset.vision_action_processor.fields_from_example='fields' \
#     --eval_dataset.vision_action_processor.n_tokens_per_action=7 \
#     --eval_dataset.vision_action_processor.max_n_frames=1 \
#     --eval_dataset.json_vision_action_dataset.mode="pad" \
#     --eval_dataset.json_vision_action_dataset.path="$eval_dataset_path" \
#     --eval_dataset.json_vision_action_dataset.seq_length=384 \
#     --eval_dataset.json_vision_action_dataset.batch_size=32 \
#     --eval_dataset.json_vision_action_dataset.tokenizer_processes=1 \
#     --eval_dataset.json_vision_action_dataset.tokenizer_parallel_chunk_size=32 \
#     --eval_dataset.json_vision_action_dataset.tokenizer_parallel_batch_size=32 \
#     --eval_dataset.json_vision_action_dataset.use_data_sharded_loader=True \
#     --checkpointer.save_optimizer_state=False \
#     --autoresume=False \
#     --logger.append_uuid=False \
#     --logger.online=True \
#     --logger.project_id="$project_id" \
#     --logger.experiment_id="$experiment_id" \
#     --logger.experiment_note="$experiment_note" \
#     --logger.output_dir="$output_dir" \
#     --logger.wandb_dir="$HOME/experiment_output/$project_id"




# export SCRIPT_DIR="$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )"
# export PROJECT_DIR="$( cd -- "$( dirname -- "$SCRIPT_DIR" )" &> /dev/null && pwd )"
# cd $PROJECT_DIR
# export PYTHONPATH="$PYTHONPATH:$PROJECT_DIR"
# export LIBTPU_INIT_ARGS="--xla_tpu_megacore_fusion_allow_ags=false --xla_enable_async_collective_permute=true --xla_tpu_enable_ag_backward_pipelining=true --xla_tpu_enable_data_parallel_all_reduce_opt=true --xla_tpu_data_parallel_opt_different_sized_ops=true --xla_tpu_enable_async_collective_fusion=true --xla_tpu_enable_async_collective_fusion_multiple_steps=true --xla_tpu_overlap_compute_collective_tc=true --xla_enable_async_all_gather=true"

# export llama_tokenizer_path="llava/checkpoints/lwm_checkpoints/tokenizer.model"
# export dataset_path='/home/World-Model/data/bridge_rollout_carrot_256.jsonl'


# export output_dir="llava/checkpoints/debug"

# export project_id='lwm'
# export experiment_note='world-model'
# export experiment_id='lwm_100traj_bridge_bf16_from_whole_delta2_batch32_seq4_carrot_rollout'

# # mesh_dim: dp, fsdp, tp, sp
# CUDA_VISIBLE_DEVICES=0,1,2,3 python3 -u -m lwm.train \
#     --modality='vision,action,delta' \
#     --mesh_dim='!-1,4,1,1' \
#     --dtype='bf16' \
#     --total_steps=7501 \
#     --log_freq=1 \
#     --delta_tokens=1 \
#     --eval_steps=0\
#     --save_model_freq=0 \
#     --eval_log_freq=0 \
#     --save_milestone_freq=2500 \
#     --load_llama_config='7b' \
#     --load_checkpoint="params::/home/World-Model/llava/checkpoints/debug/lwm_100traj_bridge_bf16_from_whole_delta2_batch32_seq4_carrot_rollout/streaming_params_2500"\
#     --update_llama_config="dict(action_vocab_size=215,delta_vocab_size=2,theta=50000000,max_sequence_length=2048,use_flash_attention=True,scan_attention=True,scan_query_chunk_size=512,scan_key_chunk_size=1024,remat_attention='nothing_saveable',scan_mlp=True,scan_mlp_chunk_size=8192,remat_mlp='nothing_saveable',remat_block='nothing_saveable',scan_layers=True)" \
#     --tokenizer.vocab_file="$llama_tokenizer_path" \
#     --optimizer.type='adamw' \
#     --llama.delta_vocab_size=2 \
#     --llama.action_vocab_size=215 \
#     --optimizer.accumulate_gradient_steps=1 \
#     --optimizer.adamw_optimizer.weight_decay=0 \
#     --optimizer.adamw_optimizer.lr=2e-5 \
#     --optimizer.adamw_optimizer.end_lr=2e-5 \
#     --optimizer.adamw_optimizer.lr_warmup_steps=0 \
#     --optimizer.adamw_optimizer.lr_decay_steps=3000 \
#     --use_data_sharded_loader=True \
#     --train_dataset.type='json_vision_delta_action' \
#     --train_dataset.delta_vision_action_processor.fields_from_example='fields' \
#     --train_dataset.delta_vision_action_processor.n_tokens_per_delta=4 \
#     --train_dataset.delta_vision_action_processor.n_tokens_per_action=7 \
#     --train_dataset.delta_vision_action_processor.max_n_frames=1 \
#     --train_dataset.json_delta_action_dataset.mode="pad" \
#     --train_dataset.json_delta_action_dataset.path="$dataset_path" \
#     --train_dataset.json_delta_action_dataset.seq_length=384 \
#     --train_dataset.json_delta_action_dataset.batch_size=32 \
#     --train_dataset.json_delta_action_dataset.tokenizer_processes=1 \
#     --train_dataset.json_delta_action_dataset.tokenizer_parallel_chunk_size=32 \
#     --train_dataset.json_delta_action_dataset.tokenizer_parallel_batch_size=32 \
#     --train_dataset.json_delta_action_dataset.use_data_sharded_loader=True \
#     --checkpointer.save_optimizer_state=False \
#     --autoresume=False \
#     --logger.append_uuid=False \
#     --logger.online=True \
#     --logger.project_id="$project_id" \
#     --logger.experiment_id="$experiment_id" \
#     --logger.experiment_note="$experiment_note" \
#     --logger.output_dir="$output_dir" \
#     --logger.wandb_dir="$HOME/experiment_output/$project_id"



# export SCRIPT_DIR="$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )"
# export PROJECT_DIR="$( cd -- "$( dirname -- "$SCRIPT_DIR" )" &> /dev/null && pwd )"
# cd $PROJECT_DIR
# export PYTHONPATH="$PYTHONPATH:$PROJECT_DIR"
# export LIBTPU_INIT_ARGS="--xla_tpu_megacore_fusion_allow_ags=false --xla_enable_async_collective_permute=true --xla_tpu_enable_ag_backward_pipelining=true --xla_tpu_enable_data_parallel_all_reduce_opt=true --xla_tpu_data_parallel_opt_different_sized_ops=true --xla_tpu_enable_async_collective_fusion=true --xla_tpu_enable_async_collective_fusion_multiple_steps=true --xla_tpu_overlap_compute_collective_tc=true --xla_enable_async_all_gather=true"


# export llama_tokenizer_path="/mnt/default/lwm/data/checkpoints/lwm_checkpoints/tokenizer.model"

# export dataset_path='/mnt/default/lwm/data/train_data/bridge_window3_codebook8_codeseqlen9_shuffled_filtered_action_preprocessed.jsonl'
# export eval_dataset_path="/mnt/default/lwm/data/train_data/bridge_window3_codebook8_codeseqlen9_shuffled_filtered_action_preprocessed_validation.jsonl"

# export output_dir=$AMLT_OUTPUT_DIR
# export WANDB_API_KEY="0d0155a751a804873eedf37c29060146b377edb8"


# export project_id='lwm'
# export experiment_note='world-model'
# export experiment_id='bridge_rt2_epoch_filtered_no_inst_epoch3_256'

# # mesh_dim: dp, fsdp, tp, sp
# python3 -u -m lwm.train \
#     --modality='vision,action' \
#     --mesh_dim='!-1,4,1,1' \
#     --dtype='bf16' \
#     --total_steps=31520 \
#     --log_freq=1 \
#     --eval_steps=1 \
#     --save_model_freq=0 \
#     --eval_log_freq=100 \
#     --save_milestone_freq=10503 \
#     --load_llama_config='7b' \
#     --load_checkpoint="params::/mnt/default/lwm/data/checkpoints/lwm_checkpoints/params" \
#     --update_llama_config="dict(action_vocab_size=256,theta=50000000,max_sequence_length=2048,use_flash_attention=True,scan_attention=True,scan_query_chunk_size=512,scan_key_chunk_size=1024,remat_attention='nothing_saveable',scan_mlp=True,scan_mlp_chunk_size=8192,remat_mlp='nothing_saveable',remat_block='nothing_saveable',scan_layers=True)" \
#     --tokenizer.vocab_file="$llama_tokenizer_path" \
#     --optimizer.type='adamw' \
#     --llama.action_vocab_size=256 \
#     --optimizer.accumulate_gradient_steps=1 \
#     --optimizer.adamw_optimizer.weight_decay=0 \
#     --optimizer.adamw_optimizer.lr=2e-5 \
#     --optimizer.adamw_optimizer.end_lr=2e-5 \
#     --optimizer.adamw_optimizer.lr_warmup_steps=0 \
#     --optimizer.adamw_optimizer.lr_decay_steps=45861 \
#     --use_data_sharded_loader=True \
#     --train_dataset.vision_action_processor.fields_from_example='fields' \
#     --train_dataset.vision_action_processor.n_tokens_per_action=7 \
#     --train_dataset.vision_action_processor.max_n_frames=1 \
#     --train_dataset.json_vision_action_dataset.mode="pad" \
#     --train_dataset.json_vision_action_dataset.path="$dataset_path" \
#     --train_dataset.json_vision_action_dataset.seq_length=384 \
#     --train_dataset.json_vision_action_dataset.batch_size=128 \
#     --train_dataset.json_vision_action_dataset.tokenizer_processes=1 \
#     --train_dataset.json_vision_action_dataset.tokenizer_parallel_chunk_size=128 \
#     --train_dataset.json_vision_action_dataset.tokenizer_parallel_batch_size=128 \
#     --train_dataset.json_vision_action_dataset.use_data_sharded_loader=True \
#     --eval_dataset.type='json_vision_action' \
#     --eval_dataset.vision_action_processor.fields_from_example='fields' \
#     --eval_dataset.vision_action_processor.n_tokens_per_action=7 \
#     --eval_dataset.vision_action_processor.max_n_frames=1 \
#     --eval_dataset.json_vision_action_dataset.mode="pad" \
#     --eval_dataset.json_vision_action_dataset.path="$eval_dataset_path" \
#     --eval_dataset.json_vision_action_dataset.seq_length=384 \
#     --eval_dataset.json_vision_action_dataset.batch_size=128 \
#     --eval_dataset.json_vision_action_dataset.tokenizer_processes=1 \
#     --eval_dataset.json_vision_action_dataset.tokenizer_parallel_chunk_size=128 \
#     --eval_dataset.json_vision_action_dataset.tokenizer_parallel_batch_size=128 \
#     --eval_dataset.json_vision_action_dataset.use_data_sharded_loader=True \
#     --checkpointer.save_optimizer_state=False \
#     --autoresume=False \
#     --logger.append_uuid=False \
#     --logger.online=True \
#     --logger.project_id="$project_id" \
#     --logger.experiment_id="$experiment_id" \
#     --logger.experiment_note="$experiment_note" \
#     --logger.output_dir="$output_dir" \
#     --logger.wandb_dir="$HOME/experiment_output/$project_id"


export WANDB_API_KEY="0d0155a751a804873eedf37c29060146b377edb8" 

export SCRIPT_DIR="$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )"
export PROJECT_DIR="$( cd -- "$( dirname -- "$SCRIPT_DIR" )" &> /dev/null && pwd )"
cd $PROJECT_DIR
export PYTHONPATH="$PYTHONPATH:$PROJECT_DIR"
export LIBTPU_INIT_ARGS="--xla_tpu_megacore_fusion_allow_ags=false --xla_enable_async_collective_permute=true --xla_tpu_enable_ag_backward_pipelining=true --xla_tpu_enable_data_parallel_all_reduce_opt=true --xla_tpu_data_parallel_opt_different_sized_ops=true --xla_tpu_enable_async_collective_fusion=true --xla_tpu_enable_async_collective_fusion_multiple_steps=true --xla_tpu_overlap_compute_collective_tc=true --xla_enable_async_all_gather=true"

export llama_tokenizer_path="/mnt/default/lwm/data/checkpoints/lwm_checkpoints/tokenizer.model"
export dataset_path='/mnt/default/lwm/data/unipi/vpt_preprocessed_bridge/bridge_total_vpt.jsonl'

export output_dir='/mnt/default/lwm/data/unipi/lwm_vpt'

export project_id='lwm'
export experiment_note='world-model'
export experiment_id='bridge_vpt_whole'

# mesh_dim: dp, fsdp, tp, sp
python3 -u -m lwm.train \
    --modality='vision,action' \
    --mesh_dim='!-1,8,1,1' \
    --dtype='bf16' \
    --total_steps=45865 \
    --log_freq=1 \
    --eval_steps=0 \
    --save_model_freq=15287 \
    --eval_log_freq=100 \
    --save_milestone_freq=15287 \
    --load_llama_config='7b' \
    --load_checkpoint="params::/mnt/default/lwm/data/checkpoints/lwm_checkpoints/params" \
    --update_llama_config="dict(action_vocab_size=245,theta=50000000,max_sequence_length=2048,use_flash_attention=True,scan_attention=True,scan_query_chunk_size=512,scan_key_chunk_size=1024,remat_attention='nothing_saveable',scan_mlp=True,scan_mlp_chunk_size=8192,remat_mlp='nothing_saveable',remat_block='nothing_saveable',scan_layers=True)" \
    --tokenizer.vocab_file="$llama_tokenizer_path" \
    --optimizer.type='adamw' \
    --llama.action_vocab_size=245 \
    --optimizer.accumulate_gradient_steps=1 \
    --optimizer.adamw_optimizer.weight_decay=0 \
    --optimizer.adamw_optimizer.lr=2e-5 \
    --optimizer.adamw_optimizer.end_lr=2e-5 \
    --optimizer.adamw_optimizer.lr_warmup_steps=0 \
    --optimizer.adamw_optimizer.lr_decay_steps=45861 \
    --use_data_sharded_loader=True \
    --train_dataset.type='json_vision_action' \
    --train_dataset.vision_action_processor.fields_from_example='fields' \
    --train_dataset.vision_action_processor.n_tokens_per_action=7 \
    --train_dataset.vision_action_processor.max_n_frames=1 \
    --train_dataset.json_vision_action_dataset.mode="pad" \
    --train_dataset.json_vision_action_dataset.path="$dataset_path" \
    --train_dataset.json_vision_action_dataset.seq_length=384 \
    --train_dataset.json_vision_action_dataset.batch_size=128 \
    --train_dataset.json_vision_action_dataset.tokenizer_processes=1 \
    --train_dataset.json_vision_action_dataset.tokenizer_parallel_chunk_size=128 \
    --train_dataset.json_vision_action_dataset.tokenizer_parallel_batch_size=128 \
    --train_dataset.json_vision_action_dataset.use_data_sharded_loader=True \
    --checkpointer.save_optimizer_state=False \
    --autoresume=False \
    --logger.append_uuid=False \
    --logger.online=True \
    --logger.project_id="$project_id" \
    --logger.experiment_id="$experiment_id" \
    --logger.experiment_note="$experiment_note" \
    --logger.output_dir="$output_dir" \
    --logger.wandb_dir="$HOME/experiment_output/$project_id"
