
# export SCRIPT_DIR="$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )"
# export PROJECT_DIR="$( cd -- "$( dirname -- "$SCRIPT_DIR" )" &> /dev/null && pwd )"
# cd $PROJECT_DIR
# export PYTHONPATH="$PYTHONPATH:$PROJECT_DIR"
# export LIBTPU_INIT_ARGS="--xla_tpu_megacore_fusion_allow_ags=false --xla_enable_async_collective_permute=true --xla_tpu_enable_ag_backward_pipelining=true --xla_tpu_enable_data_parallel_all_reduce_opt=true --xla_tpu_data_parallel_opt_different_sized_ops=true --xla_tpu_enable_async_collective_fusion=true --xla_tpu_enable_async_collective_fusion_multiple_steps=true --xla_tpu_overlap_compute_collective_tc=true --xla_enable_async_all_gather=true"

# export llama_tokenizer_path="/mnt/default/lwm/data/checkpoints/lwm_checkpoints/tokenizer.model"

# # export dataset_path='data/0808_multiobject_sink_llava_5hz_action_preprocessed.jsonl'
# export dataset_path='data/0809_multiobject_sink_llava_256.jsonl'
# # export eval_dataset_path="data/0731_multi_200_5hz_val.jsonl"

# export output_dir=$AMLT_OUTPUT_DIR
# export WANDB_API_KEY="0d0155a751a804873eedf37c29060146b377edb8"

# export project_id='lwm'
# export experiment_note='world-model'
# export experiment_id='latent_finetune_150traj_multitask_to_sink_re'

# # mesh_dim: dp, fsdp, tp, sp
# python3 -u -m lwm.train \
#     --modality='vision,action,delta' \
#     --mesh_dim='!-1,4,1,1' \
#     --dtype='bf16' \
#     --total_steps=2005 \
#     --log_freq=1 \
#     --delta_tokens=1 \
#     --eval_steps=0 \
#     --save_model_freq=0 \
#     --eval_log_freq=100 \
#     --save_milestone_freq=100 \
#     --load_llama_config='7b' \
#     --load_checkpoint="params::/mnt/default/projects/lwm/amlt-results/7277524497.23794-ffa900cf-6610-43f0-934f-c10572ffca61/lwm_total_bridge_bf16_from_whole_delta8_batch32_seq4_filtered_no_inst_epoch3/streaming_params_10503" \
#     --update_llama_config="dict(action_vocab_size=256,delta_vocab_size=8,theta=50000000,max_sequence_length=2048,use_flash_attention=True,scan_attention=True,scan_query_chunk_size=512,scan_key_chunk_size=1024,remat_attention='nothing_saveable',scan_mlp=True,scan_mlp_chunk_size=8192,remat_mlp='nothing_saveable',remat_block='nothing_saveable',scan_layers=True)" \
#     --tokenizer.vocab_file="$llama_tokenizer_path" \
#     --optimizer.type='adamw' \
#     --llama.action_vocab_size=256 \
#     --llama.delta_vocab_size=8 \
#     --optimizer.accumulate_gradient_steps=1 \
#     --optimizer.adamw_optimizer.weight_decay=0 \
#     --optimizer.adamw_optimizer.lr=2e-5 \
#     --optimizer.adamw_optimizer.end_lr=2e-5 \
#     --optimizer.adamw_optimizer.lr_warmup_steps=0 \
#     --optimizer.adamw_optimizer.lr_decay_steps=3000 \
#     --use_data_sharded_loader=True \
#     --train_dataset.type='json_vision_delta_action' \
#     --train_dataset.delta_vision_action_processor.fields_from_example='fields' \
#     --train_dataset.delta_vision_action_processor.n_tokens_per_action=7 \
#     --train_dataset.delta_vision_action_processor.n_tokens_per_delta=4 \
#     --train_dataset.delta_vision_action_processor.max_n_frames=1 \
#     --train_dataset.json_delta_action_dataset.mode="pad" \
#     --train_dataset.json_delta_action_dataset.path="$dataset_path" \
#     --train_dataset.json_delta_action_dataset.seq_length=384 \
#     --train_dataset.json_delta_action_dataset.batch_size=128 \
#     --train_dataset.json_delta_action_dataset.tokenizer_processes=1 \
#     --train_dataset.json_delta_action_dataset.tokenizer_parallel_chunk_size=128 \
#     --train_dataset.json_delta_action_dataset.tokenizer_parallel_batch_size=128 \
#     --train_dataset.json_delta_action_dataset.use_data_sharded_loader=True \
#     --checkpointer.save_optimizer_state=False \
#     --autoresume=False \
#     --logger.append_uuid=False \
#     --logger.online=True \
#     --logger.project_id="$project_id" \
#     --logger.experiment_id="$experiment_id" \
#     --logger.experiment_note="$experiment_note" \
#     --logger.output_dir="$output_dir" \
#     --logger.wandb_dir="$HOME/experiment_output/$project_id"



# # --eval_dataset.type='json_vision_delta_action' \
#     # --eval_dataset.delta_vision_action_processor.fields_from_example='fields' \
#     # --eval_dataset.delta_vision_action_processor.n_tokens_per_action=7 \
#     # --eval_dataset.delta_vision_action_processor.n_tokens_per_delta=4 \
#     # --eval_dataset.vision_action_processor.max_n_frames=1 \
#     # --eval_dataset.json_delta_action_dataset.mode="pad" \
#     # --eval_dataset.json_delta_action_dataset.path="$eval_dataset_path" \
#     # --eval_dataset.json_delta_action_dataset.seq_length=384 \
#     # --eval_dataset.json_delta_action_dataset.batch_size=128 \
#     # --eval_dataset.json_delta_action_dataset.tokenizer_processes=1 \
#     # --eval_dataset.json_delta_action_dataset.tokenizer_parallel_chunk_size=128 \
#     # --eval_dataset.json_delta_action_dataset.tokenizer_parallel_batch_size=128 \
#     # --eval_dataset.json_delta_action_dataset.use_data_sharded_loader=True \




# export SCRIPT_DIR="$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )"
# export PROJECT_DIR="$( cd -- "$( dirname -- "$SCRIPT_DIR" )" &> /dev/null && pwd )"
# cd $PROJECT_DIR
# export PYTHONPATH="$PYTHONPATH:$PROJECT_DIR"
# export LIBTPU_INIT_ARGS="--xla_tpu_megacore_fusion_allow_ags=false --xla_enable_async_collective_permute=true --xla_tpu_enable_ag_backward_pipelining=true --xla_tpu_enable_data_parallel_all_reduce_opt=true --xla_tpu_data_parallel_opt_different_sized_ops=true --xla_tpu_enable_async_collective_fusion=true --xla_tpu_enable_async_collective_fusion_multiple_steps=true --xla_tpu_overlap_compute_collective_tc=true --xla_enable_async_all_gather=true"

# export llama_tokenizer_path="/mnt/default/lwm/data/checkpoints/lwm_checkpoints/tokenizer.model"

# # export dataset_path='data/0808_multiobject_sink_llava_5hz_action_preprocessed.jsonl'
# export dataset_path='data/0809_multiobject_sink_llava_256.jsonl'
# # export eval_dataset_path="data/0731_multi_200_5hz_val.jsonl"

# export output_dir=$AMLT_OUTPUT_DIR
# export WANDB_API_KEY="0d0155a751a804873eedf37c29060146b377edb8"

# export project_id='lwm'
# export experiment_note='world-model'
# export experiment_id='latent_finetune_100traj_multitask2sink_15hz_imgaug_re'

# # mesh_dim: dp, fsdp, tp, sp
# python3 -u -m lwm.train \
#     --modality='vision,action,delta' \
#     --mesh_dim='!-1,4,1,1' \
#     --dtype='bf16' \
#     --total_steps=4005 \
#     --log_freq=1 \
#     --delta_tokens=1 \
#     --eval_steps=0 \
#     --save_model_freq=0 \
#     --eval_log_freq=100 \
#     --save_milestone_freq=100 \
#     --load_llama_config='7b' \
#     --load_checkpoint="params::/mnt/default/projects/lwm/amlt-results/7277524497.23794-ffa900cf-6610-43f0-934f-c10572ffca61/lwm_total_bridge_bf16_from_whole_delta8_batch32_seq4_filtered_no_inst_epoch3/streaming_params_10503" \
#     --update_llama_config="dict(action_vocab_size=256,delta_vocab_size=8,theta=50000000,max_sequence_length=2048,use_flash_attention=True,scan_attention=True,scan_query_chunk_size=512,scan_key_chunk_size=1024,remat_attention='nothing_saveable',scan_mlp=True,scan_mlp_chunk_size=8192,remat_mlp='nothing_saveable',remat_block='nothing_saveable',scan_layers=True)" \
#     --tokenizer.vocab_file="$llama_tokenizer_path" \
#     --optimizer.type='adamw' \
#     --llama.action_vocab_size=256 \
#     --llama.delta_vocab_size=8 \
#     --optimizer.accumulate_gradient_steps=1 \
#     --optimizer.adamw_optimizer.weight_decay=0 \
#     --optimizer.adamw_optimizer.lr=2e-5 \
#     --optimizer.adamw_optimizer.end_lr=2e-5 \
#     --optimizer.adamw_optimizer.lr_warmup_steps=0 \
#     --optimizer.adamw_optimizer.lr_decay_steps=3000 \
#     --use_data_sharded_loader=True \
#     --train_dataset.type='json_vision_delta_action' \
#     --train_dataset.delta_vision_action_processor.fields_from_example='fields' \
#     --train_dataset.delta_vision_action_processor.n_tokens_per_action=7 \
#     --train_dataset.delta_vision_action_processor.n_tokens_per_delta=4 \
#     --train_dataset.delta_vision_action_processor.img_aug=True \
#     --train_dataset.delta_vision_action_processor.vqgan_checkpoint_path='/mnt/default/lwm/data/checkpoints/lwm_checkpoints/vqgan'\
#     --train_dataset.delta_vision_action_processor.image_absolute_path='/mnt/default/lwm/data/finetune_data/'\
#     --train_dataset.delta_vision_action_processor.max_n_frames=1 \
#     --train_dataset.json_delta_action_dataset.mode="pad" \
#     --train_dataset.json_delta_action_dataset.path="$dataset_path" \
#     --train_dataset.json_delta_action_dataset.seq_length=384 \
#     --train_dataset.json_delta_action_dataset.batch_size=128 \
#     --train_dataset.json_delta_action_dataset.tokenizer_processes=1 \
#     --train_dataset.json_delta_action_dataset.tokenizer_parallel_chunk_size=128 \
#     --train_dataset.json_delta_action_dataset.tokenizer_parallel_batch_size=128 \
#     --train_dataset.json_delta_action_dataset.use_data_sharded_loader=True \
#     --checkpointer.save_optimizer_state=False \
#     --autoresume=False \
#     --logger.append_uuid=False \
#     --logger.online=True \
#     --logger.project_id="$project_id" \
#     --logger.experiment_id="$experiment_id" \
#     --logger.experiment_note="$experiment_note" \
#     --logger.output_dir="$output_dir" \
#     --logger.wandb_dir="$HOME/experiment_output/$project_id"


# export SCRIPT_DIR="$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )"
# export PROJECT_DIR="$( cd -- "$( dirname -- "$SCRIPT_DIR" )" &> /dev/null && pwd )"
# cd $PROJECT_DIR
# export PYTHONPATH="$PYTHONPATH:$PROJECT_DIR"
# export LIBTPU_INIT_ARGS="--xla_tpu_megacore_fusion_allow_ags=false --xla_enable_async_collective_permute=true --xla_tpu_enable_ag_backward_pipelining=true --xla_tpu_enable_data_parallel_all_reduce_opt=true --xla_tpu_data_parallel_opt_different_sized_ops=true --xla_tpu_enable_async_collective_fusion=true --xla_tpu_enable_async_collective_fusion_multiple_steps=true --xla_tpu_overlap_compute_collective_tc=true --xla_enable_async_all_gather=true"

# export llama_tokenizer_path="/mnt/default/lwm/data/checkpoints/lwm_checkpoints/tokenizer.model"

# # export dataset_path='data/0808_multiobject_sink_llava_5hz_action_preprocessed.jsonl'
# export dataset_path='data/0809_multiobject_sink_llava_5hz_256.jsonl'
# # export eval_dataset_path="data/0731_multi_200_5hz_val.jsonl"

# export output_dir=$AMLT_OUTPUT_DIR
# export WANDB_API_KEY="0d0155a751a804873eedf37c29060146b377edb8"

# export project_id='lwm'
# export experiment_note='world-model'
# export experiment_id='latent_finetune_100traj_multitask2sink_5hz_imgaug'

# # mesh_dim: dp, fsdp, tp, sp
# python3 -u -m lwm.train \
#     --modality='vision,action,delta' \
#     --mesh_dim='!-1,4,1,1' \
#     --dtype='bf16' \
#     --total_steps=2005 \
#     --log_freq=1 \
#     --delta_tokens=1 \
#     --eval_steps=0 \
#     --save_model_freq=0 \
#     --eval_log_freq=100 \
#     --save_milestone_freq=100 \
#     --load_llama_config='7b' \
#     --load_checkpoint="params::/mnt/default/projects/lwm/amlt-results/7277524497.23794-ffa900cf-6610-43f0-934f-c10572ffca61/lwm_total_bridge_bf16_from_whole_delta8_batch32_seq4_filtered_no_inst_epoch3/streaming_params_10503" \
#     --update_llama_config="dict(action_vocab_size=256,delta_vocab_size=8,theta=50000000,max_sequence_length=2048,use_flash_attention=True,scan_attention=True,scan_query_chunk_size=512,scan_key_chunk_size=1024,remat_attention='nothing_saveable',scan_mlp=True,scan_mlp_chunk_size=8192,remat_mlp='nothing_saveable',remat_block='nothing_saveable',scan_layers=True)" \
#     --tokenizer.vocab_file="$llama_tokenizer_path" \
#     --optimizer.type='adamw' \
#     --llama.action_vocab_size=256 \
#     --llama.delta_vocab_size=8 \
#     --optimizer.accumulate_gradient_steps=1 \
#     --optimizer.adamw_optimizer.weight_decay=0 \
#     --optimizer.adamw_optimizer.lr=2e-5 \
#     --optimizer.adamw_optimizer.end_lr=2e-5 \
#     --optimizer.adamw_optimizer.lr_warmup_steps=0 \
#     --optimizer.adamw_optimizer.lr_decay_steps=3000 \
#     --use_data_sharded_loader=True \
#     --train_dataset.type='json_vision_delta_action' \
#     --train_dataset.delta_vision_action_processor.fields_from_example='fields' \
#     --train_dataset.delta_vision_action_processor.n_tokens_per_action=7 \
#     --train_dataset.delta_vision_action_processor.n_tokens_per_delta=4 \
#     --train_dataset.delta_vision_action_processor.img_aug=True \
#     --train_dataset.delta_vision_action_processor.vqgan_checkpoint_path='/mnt/default/lwm/data/checkpoints/lwm_checkpoints/vqgan'\
#     --train_dataset.delta_vision_action_processor.image_absolute_path='/mnt/default/lwm/data/finetune_data/'\
#     --train_dataset.delta_vision_action_processor.max_n_frames=1 \
#     --train_dataset.json_delta_action_dataset.mode="pad" \
#     --train_dataset.json_delta_action_dataset.path="$dataset_path" \
#     --train_dataset.json_delta_action_dataset.seq_length=384 \
#     --train_dataset.json_delta_action_dataset.batch_size=128 \
#     --train_dataset.json_delta_action_dataset.tokenizer_processes=1 \
#     --train_dataset.json_delta_action_dataset.tokenizer_parallel_chunk_size=128 \
#     --train_dataset.json_delta_action_dataset.tokenizer_parallel_batch_size=128 \
#     --train_dataset.json_delta_action_dataset.use_data_sharded_loader=True \
#     --checkpointer.save_optimizer_state=False \
#     --autoresume=False \
#     --logger.append_uuid=False \
#     --logger.online=True \
#     --logger.project_id="$project_id" \
#     --logger.experiment_id="$experiment_id" \
#     --logger.experiment_note="$experiment_note" \
#     --logger.output_dir="$output_dir" \
#     --logger.wandb_dir="$HOME/experiment_output/$project_id"




# export SCRIPT_DIR="$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )"
# export PROJECT_DIR="$( cd -- "$( dirname -- "$SCRIPT_DIR" )" &> /dev/null && pwd )"
# cd $PROJECT_DIR
# export PYTHONPATH="$PYTHONPATH:$PROJECT_DIR"
# export LIBTPU_INIT_ARGS="--xla_tpu_megacore_fusion_allow_ags=false --xla_enable_async_collective_permute=true --xla_tpu_enable_ag_backward_pipelining=true --xla_tpu_enable_data_parallel_all_reduce_opt=true --xla_tpu_data_parallel_opt_different_sized_ops=true --xla_tpu_enable_async_collective_fusion=true --xla_tpu_enable_async_collective_fusion_multiple_steps=true --xla_tpu_overlap_compute_collective_tc=true --xla_enable_async_all_gather=true"

# export llama_tokenizer_path="/mnt/default/lwm/data/checkpoints/lwm_checkpoints/tokenizer.model"

# # export dataset_path='data/0808_multiobject_sink_llava_5hz_action_preprocessed.jsonl'
# # export dataset_path='data/0809_multiobject_sink_llava_256.jsonl'
# export dataset_path='data/0809_multiobject_sink_llava_5hz_256_filtered.jsonl'
# # export eval_dataset_path="data/0731_multi_200_5hz_val.jsonl"

# export output_dir=$AMLT_OUTPUT_DIR
# export WANDB_API_KEY="0d0155a751a804873eedf37c29060146b377edb8"

# export project_id='lwm'
# export experiment_note='world-model'
# export experiment_id='latent_finetune_100traj_multitask2sink_5hz_imgaug_filtered_epoch3'

# # mesh_dim: dp, fsdp, tp, sp
# python3 -u -m lwm.train \
#     --modality='vision,action,delta' \
#     --mesh_dim='!-1,4,1,1' \
#     --dtype='bf16' \
#     --total_steps=2005 \
#     --log_freq=1 \
#     --delta_tokens=1 \
#     --eval_steps=0 \
#     --save_model_freq=0 \
#     --eval_log_freq=100 \
#     --save_milestone_freq=100 \
#     --load_llama_config='7b' \
#     --load_checkpoint="params::/mnt/default/projects/lwm/amlt-results/7277371979.87448-e016daf0-ac38-43ff-bb99-d2a50e63f89c/lwm_total_bridge_bf16_from_whole_delta8_batch32_seq4_filtered_no_inst_epoch3/streaming_params_10503" \
#     --update_llama_config="dict(action_vocab_size=256,delta_vocab_size=8,theta=50000000,max_sequence_length=2048,use_flash_attention=True,scan_attention=True,scan_query_chunk_size=512,scan_key_chunk_size=1024,remat_attention='nothing_saveable',scan_mlp=True,scan_mlp_chunk_size=8192,remat_mlp='nothing_saveable',remat_block='nothing_saveable',scan_layers=True)" \
#     --tokenizer.vocab_file="$llama_tokenizer_path" \
#     --optimizer.type='adamw' \
#     --llama.action_vocab_size=256 \
#     --llama.delta_vocab_size=8 \
#     --optimizer.accumulate_gradient_steps=1 \
#     --optimizer.adamw_optimizer.weight_decay=0 \
#     --optimizer.adamw_optimizer.lr=2e-5 \
#     --optimizer.adamw_optimizer.end_lr=2e-5 \
#     --optimizer.adamw_optimizer.lr_warmup_steps=0 \
#     --optimizer.adamw_optimizer.lr_decay_steps=3000 \
#     --use_data_sharded_loader=True \
#     --train_dataset.type='json_vision_delta_action' \
#     --train_dataset.delta_vision_action_processor.fields_from_example='fields' \
#     --train_dataset.delta_vision_action_processor.n_tokens_per_action=7 \
#     --train_dataset.delta_vision_action_processor.n_tokens_per_delta=4 \
#     --train_dataset.delta_vision_action_processor.img_aug=True \
#     --train_dataset.delta_vision_action_processor.vqgan_checkpoint_path='/mnt/default/lwm/data/checkpoints/lwm_checkpoints/vqgan'\
#     --train_dataset.delta_vision_action_processor.image_absolute_path='/mnt/default/lwm/data/finetune_data/'\
#     --train_dataset.delta_vision_action_processor.max_n_frames=1 \
#     --train_dataset.json_delta_action_dataset.mode="pad" \
#     --train_dataset.json_delta_action_dataset.path="$dataset_path" \
#     --train_dataset.json_delta_action_dataset.seq_length=384 \
#     --train_dataset.json_delta_action_dataset.batch_size=128 \
#     --train_dataset.json_delta_action_dataset.tokenizer_processes=1 \
#     --train_dataset.json_delta_action_dataset.tokenizer_parallel_chunk_size=128 \
#     --train_dataset.json_delta_action_dataset.tokenizer_parallel_batch_size=128 \
#     --train_dataset.json_delta_action_dataset.use_data_sharded_loader=True \
#     --checkpointer.save_optimizer_state=False \
#     --autoresume=False \
#     --logger.append_uuid=False \
#     --logger.online=True \
#     --logger.project_id="$project_id" \
#     --logger.experiment_id="$experiment_id" \
#     --logger.experiment_note="$experiment_note" \
#     --logger.output_dir="$output_dir" \
#     --logger.wandb_dir="$HOME/experiment_output/$project_id"




# export SCRIPT_DIR="$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )"
# export PROJECT_DIR="$( cd -- "$( dirname -- "$SCRIPT_DIR" )" &> /dev/null && pwd )"
# cd $PROJECT_DIR
# export PYTHONPATH="$PYTHONPATH:$PROJECT_DIR"
# export LIBTPU_INIT_ARGS="--xla_tpu_megacore_fusion_allow_ags=false --xla_enable_async_collective_permute=true --xla_tpu_enable_ag_backward_pipelining=true --xla_tpu_enable_data_parallel_all_reduce_opt=true --xla_tpu_data_parallel_opt_different_sized_ops=true --xla_tpu_enable_async_collective_fusion=true --xla_tpu_enable_async_collective_fusion_multiple_steps=true --xla_tpu_overlap_compute_collective_tc=true --xla_enable_async_all_gather=true"

# export llama_tokenizer_path="/mnt/default/lwm/data/checkpoints/lwm_checkpoints/tokenizer.model"

# # export dataset_path='data/0808_multiobject_sink_llava_5hz_action_preprocessed.jsonl'
# # export dataset_path='data/0809_multiobject_sink_llava_256.jsonl'
# export dataset_path='data/0809_multiobject_sink_llava_256_filtered.jsonl'
# # export eval_dataset_path="data/0731_multi_200_5hz_val.jsonl"

# export output_dir=$AMLT_OUTPUT_DIR
# export WANDB_API_KEY="0d0155a751a804873eedf37c29060146b377edb8"

# export project_id='lwm'
# export experiment_note='world-model'
# export experiment_id='latent_finetune_100traj_multitask2sink_15hz_imgaug_filtered'

# # mesh_dim: dp, fsdp, tp, sp
# python3 -u -m lwm.train \
#     --modality='vision,action,delta' \
#     --mesh_dim='!-1,4,1,1' \
#     --dtype='bf16' \
#     --total_steps=4005 \
#     --log_freq=1 \
#     --delta_tokens=1 \
#     --eval_steps=0 \
#     --save_model_freq=0 \
#     --eval_log_freq=100 \
#     --save_milestone_freq=100 \
#     --load_llama_config='7b' \
#     --load_checkpoint="params::/mnt/default/projects/lwm/amlt-results/7277524497.23794-ffa900cf-6610-43f0-934f-c10572ffca61/lwm_total_bridge_bf16_from_whole_delta8_batch32_seq4_filtered_no_inst_epoch3/streaming_params_10503" \
#     --update_llama_config="dict(action_vocab_size=256,delta_vocab_size=8,theta=50000000,max_sequence_length=2048,use_flash_attention=True,scan_attention=True,scan_query_chunk_size=512,scan_key_chunk_size=1024,remat_attention='nothing_saveable',scan_mlp=True,scan_mlp_chunk_size=8192,remat_mlp='nothing_saveable',remat_block='nothing_saveable',scan_layers=True)" \
#     --tokenizer.vocab_file="$llama_tokenizer_path" \
#     --optimizer.type='adamw' \
#     --llama.action_vocab_size=256 \
#     --llama.delta_vocab_size=8 \
#     --optimizer.accumulate_gradient_steps=1 \
#     --optimizer.adamw_optimizer.weight_decay=0 \
#     --optimizer.adamw_optimizer.lr=2e-5 \
#     --optimizer.adamw_optimizer.end_lr=2e-5 \
#     --optimizer.adamw_optimizer.lr_warmup_steps=0 \
#     --optimizer.adamw_optimizer.lr_decay_steps=3000 \
#     --use_data_sharded_loader=True \
#     --train_dataset.type='json_vision_delta_action' \
#     --train_dataset.delta_vision_action_processor.fields_from_example='fields' \
#     --train_dataset.delta_vision_action_processor.n_tokens_per_action=7 \
#     --train_dataset.delta_vision_action_processor.n_tokens_per_delta=4 \
#     --train_dataset.delta_vision_action_processor.img_aug=True \
#     --train_dataset.delta_vision_action_processor.vqgan_checkpoint_path='/mnt/default/lwm/data/checkpoints/lwm_checkpoints/vqgan'\
#     --train_dataset.delta_vision_action_processor.image_absolute_path='/mnt/default/lwm/data/finetune_data/'\
#     --train_dataset.delta_vision_action_processor.max_n_frames=1 \
#     --train_dataset.json_delta_action_dataset.mode="pad" \
#     --train_dataset.json_delta_action_dataset.path="$dataset_path" \
#     --train_dataset.json_delta_action_dataset.seq_length=384 \
#     --train_dataset.json_delta_action_dataset.batch_size=128 \
#     --train_dataset.json_delta_action_dataset.tokenizer_processes=1 \
#     --train_dataset.json_delta_action_dataset.tokenizer_parallel_chunk_size=128 \
#     --train_dataset.json_delta_action_dataset.tokenizer_parallel_batch_size=128 \
#     --train_dataset.json_delta_action_dataset.use_data_sharded_loader=True \
#     --checkpointer.save_optimizer_state=False \
#     --autoresume=False \
#     --logger.append_uuid=False \
#     --logger.online=True \
#     --logger.project_id="$project_id" \
#     --logger.experiment_id="$experiment_id" \
#     --logger.experiment_note="$experiment_note" \
#     --logger.output_dir="$output_dir" \
#     --logger.wandb_dir="$HOME/experiment_output/$project_id"



# export SCRIPT_DIR="$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )"
# export PROJECT_DIR="$( cd -- "$( dirname -- "$SCRIPT_DIR" )" &> /dev/null && pwd )"
# cd $PROJECT_DIR
# export PYTHONPATH="$PYTHONPATH:$PROJECT_DIR"
# export LIBTPU_INIT_ARGS="--xla_tpu_megacore_fusion_allow_ags=false --xla_enable_async_collective_permute=true --xla_tpu_enable_ag_backward_pipelining=true --xla_tpu_enable_data_parallel_all_reduce_opt=true --xla_tpu_data_parallel_opt_different_sized_ops=true --xla_tpu_enable_async_collective_fusion=true --xla_tpu_enable_async_collective_fusion_multiple_steps=true --xla_tpu_overlap_compute_collective_tc=true --xla_enable_async_all_gather=true"

# export llama_tokenizer_path="/mnt/default/lwm/data/checkpoints/lwm_checkpoints/tokenizer.model"

# # export dataset_path='data/0808_multiobject_sink_llava_5hz_action_preprocessed.jsonl'
# # export dataset_path='data/0809_multiobject_sink_llava_256.jsonl'
# # export dataset_path='data/0811_multiobject_sink_llava_train_filtered.jsonl'
# export dataset_path='data/817_multiobject_sink_rand_location_llava_train_filtered.jsonl'
# # export dataset_path='data/0812_play_data_llava_train_filtered.jsonl'
# # export dataset_path='data/0812_play_data_llava_concat_multi2sink_train_filtered.jsonl'
# export eval_dataset_path='data/0811_multiobject_sink_llava_seen_val_filtered.jsonl'
# export unseen_eval_dataset_path='data/0811_multiobject_sink_llava_unseen_val_filtered.jsonl'
# # export eval_dataset_path="data/0731_multi_200_5hz_val.jsonl"

# export output_dir=$AMLT_OUTPUT_DIR
# export WANDB_API_KEY="0d0155a751a804873eedf37c29060146b377edb8"

# export project_id='lwm'
# export experiment_note='world-model'
# # export experiment_id='817_latent_c8_s9_epoch3_multiobject_sink_rand_location_llava_train_filtered_4000'
# export experiment_id='817_latent_c8_s4_epoch1_multiobject_sink_rand_location_llava_train_filtered_4000'
# # export experiment_id='0811_latent_finetune_100traj_multitask2sink_15hz_imgaug_filtered_bridge_5epoch_h100'
# # export experiment_id='0811_latent_finetune_100traj_multitask2sink_15hz_imgaug_filtered_codeboook8_seq9_bridge_epoch3'

# # export experiment_id='0815_multiobject_sink_rand_location_seen_llava_train_filtered_codebook8_seq9_bridge_e3'
# # export experiment_id='0812_latent_play_data_multi2sink_mixture_finetune_15hz_imgaug_filtered'
# # export experiment_id='0811_multiobject_sink_llava_train_filtered_bridge_10percent_codebook8_seq1'

# # mesh_dim: dp, fsdp, tp, sp
# python3 -u -m lwm.train \
#     --modality='vision,action,delta' \
#     --mesh_dim='!-1,4,1,1' \
#     --dtype='bf16' \
#     --total_steps=4005 \
#     --log_freq=1 \
#     --delta_tokens=1 \
#     --eval_steps=5 \
#     --save_model_freq=0 \
#     --eval_log_freq=100 \
#     --save_milestone_freq=1000 \
#     --load_llama_config='7b' \
#     --load_checkpoint="params::/mnt/default/projects/lwm/amlt-results/7277524497.23794-ffa900cf-6610-43f0-934f-c10572ffca61/lwm_total_bridge_bf16_from_whole_delta8_batch32_seq4_filtered_no_inst_epoch3/streaming_params_10503" \
#     --update_llama_config="dict(action_vocab_size=256,delta_vocab_size=8,theta=50000000,max_sequence_length=2048,use_flash_attention=True,scan_attention=True,scan_query_chunk_size=512,scan_key_chunk_size=1024,remat_attention='nothing_saveable',scan_mlp=True,scan_mlp_chunk_size=8192,remat_mlp='nothing_saveable',remat_block='nothing_saveable',scan_layers=True)" \
#     --tokenizer.vocab_file="$llama_tokenizer_path" \
#     --optimizer.type='adamw' \
#     --llama.action_vocab_size=256 \
#     --llama.delta_vocab_size=8 \
#     --optimizer.accumulate_gradient_steps=1 \
#     --optimizer.adamw_optimizer.weight_decay=0 \
#     --optimizer.adamw_optimizer.lr=2e-5 \
#     --optimizer.adamw_optimizer.end_lr=2e-5 \
#     --optimizer.adamw_optimizer.lr_warmup_steps=0 \
#     --optimizer.adamw_optimizer.lr_decay_steps=3000 \
#     --use_data_sharded_loader=True \
#     --train_dataset.type='json_vision_delta_action' \
#     --train_dataset.delta_vision_action_processor.fields_from_example='fields' \
#     --train_dataset.delta_vision_action_processor.n_tokens_per_action=7 \
#     --train_dataset.delta_vision_action_processor.n_tokens_per_delta=4 \
#     --train_dataset.delta_vision_action_processor.img_aug=True \
#     --train_dataset.delta_vision_action_processor.vqgan_checkpoint_path='/mnt/default/lwm/data/checkpoints/lwm_checkpoints/vqgan'\
#     --train_dataset.delta_vision_action_processor.image_absolute_path='/mnt/default/lwm/data/finetune_data/'\
#     --train_dataset.delta_vision_action_processor.max_n_frames=1 \
#     --train_dataset.json_delta_action_dataset.mode="pad" \
#     --train_dataset.json_delta_action_dataset.path="$dataset_path" \
#     --train_dataset.json_delta_action_dataset.seq_length=384 \
#     --train_dataset.json_delta_action_dataset.batch_size=128 \
#     --train_dataset.json_delta_action_dataset.tokenizer_processes=1 \
#     --train_dataset.json_delta_action_dataset.tokenizer_parallel_chunk_size=128 \
#     --train_dataset.json_delta_action_dataset.tokenizer_parallel_batch_size=128 \
#     --train_dataset.json_delta_action_dataset.use_data_sharded_loader=True \
#     --eval_dataset.type='json_vision_delta_action' \
#     --eval_dataset.delta_vision_action_processor.fields_from_example='fields' \
#     --eval_dataset.delta_vision_action_processor.n_tokens_per_action=7 \
#     --eval_dataset.delta_vision_action_processor.n_tokens_per_delta=4 \
#     --eval_dataset.delta_vision_action_processor.img_aug=True \
#     --eval_dataset.delta_vision_action_processor.vqgan_checkpoint_path='/mnt/default/lwm/data/checkpoints/lwm_checkpoints/vqgan'\
#     --eval_dataset.delta_vision_action_processor.image_absolute_path='/mnt/default/lwm/data/finetune_data/'\
#     --eval_dataset.vision_action_processor.max_n_frames=1 \
#     --eval_dataset.json_delta_action_dataset.mode="pad" \
#     --eval_dataset.json_delta_action_dataset.path="$eval_dataset_path" \
#     --eval_dataset.json_delta_action_dataset.seq_length=384 \
#     --eval_dataset.json_delta_action_dataset.batch_size=128 \
#     --eval_dataset.json_delta_action_dataset.tokenizer_processes=1 \
#     --eval_dataset.json_delta_action_dataset.tokenizer_parallel_chunk_size=128 \
#     --eval_dataset.json_delta_action_dataset.tokenizer_parallel_batch_size=128 \
#     --eval_dataset.json_delta_action_dataset.use_data_sharded_loader=True \
#     --unseen_eval_dataset.type='json_vision_delta_action' \
#     --unseen_eval_dataset.delta_vision_action_processor.fields_from_example='fields' \
#     --unseen_eval_dataset.delta_vision_action_processor.n_tokens_per_action=7 \
#     --unseen_eval_dataset.delta_vision_action_processor.n_tokens_per_delta=4 \
#     --unseen_eval_dataset.delta_vision_action_processor.img_aug=True \
#     --unseen_eval_dataset.delta_vision_action_processor.vqgan_checkpoint_path='/mnt/default/lwm/data/checkpoints/lwm_checkpoints/vqgan'\
#     --unseen_eval_dataset.delta_vision_action_processor.image_absolute_path='/mnt/default/lwm/data/finetune_data/'\
#     --unseen_eval_dataset.vision_action_processor.max_n_frames=1 \
#     --unseen_eval_dataset.json_delta_action_dataset.mode="pad" \
#     --unseen_eval_dataset.json_delta_action_dataset.path="$unseen_eval_dataset_path" \
#     --unseen_eval_dataset.json_delta_action_dataset.seq_length=384 \
#     --unseen_eval_dataset.json_delta_action_dataset.batch_size=128 \
#     --unseen_eval_dataset.json_delta_action_dataset.tokenizer_processes=1 \
#     --unseen_eval_dataset.json_delta_action_dataset.tokenizer_parallel_chunk_size=128 \
#     --unseen_eval_dataset.json_delta_action_dataset.tokenizer_parallel_batch_size=128 \
#     --unseen_eval_dataset.json_delta_action_dataset.use_data_sharded_loader=True \
#     --checkpointer.save_optimizer_state=False \
#     --autoresume=False \
#     --logger.append_uuid=False \
#     --logger.online=True \
#     --logger.project_id="$project_id" \
#     --logger.experiment_id="$experiment_id" \
#     --logger.experiment_note="$experiment_note" \
#     --logger.output_dir="$output_dir" \
#     --logger.wandb_dir="$HOME/experiment_output/$project_id"

    # --load_checkpoint="params::/mnt/default/projects/lwm/amlt-results/7277524497.23794-ffa900cf-6610-43f0-934f-c10572ffca61/lwm_total_bridge_bf16_from_whole_delta8_batch32_seq4_filtered_no_inst_epoch3/streaming_params_10503" \
    # --load_checkpoint="params::/mnt/default/projects/lwm/amlt-results/7276778139.80181-f4720d54-07a5-4829-a1e4-6fda178ae02a/lwm_total_bridge_bf16_from_whole_delta8_batch32_seq9_filtered_no_inst_epoch3/streaming_params_31509" \




# export SCRIPT_DIR="$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )"
# export PROJECT_DIR="$( cd -- "$( dirname -- "$SCRIPT_DIR" )" &> /dev/null && pwd )"
# cd $PROJECT_DIR
# export PYTHONPATH="$PYTHONPATH:$PROJECT_DIR"
# export LIBTPU_INIT_ARGS="--xla_tpu_megacore_fusion_allow_ags=false --xla_enable_async_collective_permute=true --xla_tpu_enable_ag_backward_pipelining=true --xla_tpu_enable_data_parallel_all_reduce_opt=true --xla_tpu_data_parallel_opt_different_sized_ops=true --xla_tpu_enable_async_collective_fusion=true --xla_tpu_enable_async_collective_fusion_multiple_steps=true --xla_tpu_overlap_compute_collective_tc=true --xla_enable_async_all_gather=true"

# export llama_tokenizer_path="/mnt/default/lwm/data/checkpoints/lwm_checkpoints/tokenizer.model"

# # export dataset_path='data/0808_multiobject_sink_llava_5hz_action_preprocessed.jsonl'
# # export dataset_path='data/0809_multiobject_sink_llava_256.jsonl'
# export dataset_path='data/0811_multiobject_sink_llava_train_filtered.jsonl'
# # export dataset_path='data/817_multiobject_sink_rand_location_llava_train_filtered.jsonl'
# # export dataset_path='data/0812_play_data_llava_train_filtered.jsonl'
# # export dataset_path='data/0812_play_data_llava_concat_multi2sink_train_filtered.jsonl'
# export eval_dataset_path='data/0811_multiobject_sink_llava_seen_val_filtered.jsonl'
# export unseen_eval_dataset_path='data/0811_multiobject_sink_llava_unseen_val_filtered.jsonl'
# # export eval_dataset_path="data/0731_multi_200_5hz_val.jsonl"

# export output_dir=$AMLT_OUTPUT_DIR
# export WANDB_API_KEY="0d0155a751a804873eedf37c29060146b377edb8"

# export project_id='lwm'
# export experiment_note='world-model'
# # export experiment_id='817_latent_c8_s9_epoch3_multiobject_sink_rand_location_llava_train_filtered_4000'
# # export experiment_id='821_latent_c8_s4_l2_40k_epoch1_multiobject_sink_rand_location_llava_train_filtered_2000_stage2_10k_4000'
# export experiment_id='0811_latent_c8_s4_l2_40k_epoch1_multiobject_sink_location_llava_train_filtered_stage2_10k'
# # export experiment_id='0811_latent_finetune_100traj_multitask2sink_15hz_imgaug_filtered_bridge_5epoch_h100'
# # export experiment_id='0811_latent_finetune_100traj_multitask2sink_15hz_imgaug_filtered_codeboook8_seq9_bridge_epoch3'

# # export experiment_id='0815_multiobject_sink_rand_location_seen_llava_train_filtered_codebook8_seq9_bridge_e3'
# # export experiment_id='0812_latent_play_data_multi2sink_mixture_finetune_15hz_imgaug_filtered'
# # export experiment_id='0811_multiobject_sink_llava_train_filtered_bridge_10percent_codebook8_seq1'

# # mesh_dim: dp, fsdp, tp, sp
# python3 -u -m lwm.train \
#     --modality='vision,action,delta' \
#     --mesh_dim='!-1,8,1,1' \
#     --dtype='bf16' \
#     --total_steps=1005 \
#     --log_freq=1 \
#     --delta_tokens=1 \
#     --eval_steps=5 \
#     --save_model_freq=0 \
#     --eval_log_freq=100 \
#     --save_milestone_freq=1000 \
#     --load_llama_config='7b' \
#     --load_checkpoint="params::/mnt/default/projects/lwm/amlt-results/7275617329.34871-f7049ed3-da94-4e55-8346-7d6e8ff29739/lwm_human_bf16_from_whole_delta8_batch32_seq4_layer2_filtered_no_inst_imgaug_821_40Ktrain_re/streaming_params_10000" \
#     --update_llama_config="dict(action_vocab_size=256,delta_vocab_size=8,theta=50000000,max_sequence_length=2048,use_flash_attention=True,scan_attention=True,scan_query_chunk_size=512,scan_key_chunk_size=1024,remat_attention='nothing_saveable',scan_mlp=True,scan_mlp_chunk_size=8192,remat_mlp='nothing_saveable',remat_block='nothing_saveable',scan_layers=True)" \
#     --tokenizer.vocab_file="$llama_tokenizer_path" \
#     --optimizer.type='adamw' \
#     --llama.action_vocab_size=256 \
#     --llama.delta_vocab_size=8 \
#     --optimizer.accumulate_gradient_steps=1 \
#     --optimizer.adamw_optimizer.weight_decay=0 \
#     --optimizer.adamw_optimizer.lr=2e-5 \
#     --optimizer.adamw_optimizer.end_lr=2e-5 \
#     --optimizer.adamw_optimizer.lr_warmup_steps=0 \
#     --optimizer.adamw_optimizer.lr_decay_steps=3000 \
#     --use_data_sharded_loader=True \
#     --train_dataset.type='json_vision_delta_action' \
#     --train_dataset.delta_vision_action_processor.fields_from_example='fields' \
#     --train_dataset.delta_vision_action_processor.n_tokens_per_action=7 \
#     --train_dataset.delta_vision_action_processor.n_tokens_per_delta=4 \
#     --train_dataset.delta_vision_action_processor.img_aug=True \
#     --train_dataset.delta_vision_action_processor.vqgan_checkpoint_path='/mnt/default/lwm/data/checkpoints/lwm_checkpoints/vqgan'\
#     --train_dataset.delta_vision_action_processor.image_absolute_path='/mnt/default/lwm/data/finetune_data/'\
#     --train_dataset.delta_vision_action_processor.max_n_frames=1 \
#     --train_dataset.json_delta_action_dataset.mode="pad" \
#     --train_dataset.json_delta_action_dataset.path="$dataset_path" \
#     --train_dataset.json_delta_action_dataset.seq_length=384 \
#     --train_dataset.json_delta_action_dataset.batch_size=128 \
#     --train_dataset.json_delta_action_dataset.tokenizer_processes=1 \
#     --train_dataset.json_delta_action_dataset.tokenizer_parallel_chunk_size=128 \
#     --train_dataset.json_delta_action_dataset.tokenizer_parallel_batch_size=128 \
#     --train_dataset.json_delta_action_dataset.use_data_sharded_loader=True \
#     --eval_dataset.type='json_vision_delta_action' \
#     --eval_dataset.delta_vision_action_processor.fields_from_example='fields' \
#     --eval_dataset.delta_vision_action_processor.n_tokens_per_action=7 \
#     --eval_dataset.delta_vision_action_processor.n_tokens_per_delta=4 \
#     --eval_dataset.delta_vision_action_processor.img_aug=True \
#     --eval_dataset.delta_vision_action_processor.vqgan_checkpoint_path='/mnt/default/lwm/data/checkpoints/lwm_checkpoints/vqgan'\
#     --eval_dataset.delta_vision_action_processor.image_absolute_path='/mnt/default/lwm/data/finetune_data/'\
#     --eval_dataset.vision_action_processor.max_n_frames=1 \
#     --eval_dataset.json_delta_action_dataset.mode="pad" \
#     --eval_dataset.json_delta_action_dataset.path="$eval_dataset_path" \
#     --eval_dataset.json_delta_action_dataset.seq_length=384 \
#     --eval_dataset.json_delta_action_dataset.batch_size=128 \
#     --eval_dataset.json_delta_action_dataset.tokenizer_processes=1 \
#     --eval_dataset.json_delta_action_dataset.tokenizer_parallel_chunk_size=128 \
#     --eval_dataset.json_delta_action_dataset.tokenizer_parallel_batch_size=128 \
#     --eval_dataset.json_delta_action_dataset.use_data_sharded_loader=True \
#     --unseen_eval_dataset.type='json_vision_delta_action' \
#     --unseen_eval_dataset.delta_vision_action_processor.fields_from_example='fields' \
#     --unseen_eval_dataset.delta_vision_action_processor.n_tokens_per_action=7 \
#     --unseen_eval_dataset.delta_vision_action_processor.n_tokens_per_delta=4 \
#     --unseen_eval_dataset.delta_vision_action_processor.img_aug=True \
#     --unseen_eval_dataset.delta_vision_action_processor.vqgan_checkpoint_path='/mnt/default/lwm/data/checkpoints/lwm_checkpoints/vqgan'\
#     --unseen_eval_dataset.delta_vision_action_processor.image_absolute_path='/mnt/default/lwm/data/finetune_data/'\
#     --unseen_eval_dataset.vision_action_processor.max_n_frames=1 \
#     --unseen_eval_dataset.json_delta_action_dataset.mode="pad" \
#     --unseen_eval_dataset.json_delta_action_dataset.path="$unseen_eval_dataset_path" \
#     --unseen_eval_dataset.json_delta_action_dataset.seq_length=384 \
#     --unseen_eval_dataset.json_delta_action_dataset.batch_size=128 \
#     --unseen_eval_dataset.json_delta_action_dataset.tokenizer_processes=1 \
#     --unseen_eval_dataset.json_delta_action_dataset.tokenizer_parallel_chunk_size=128 \
#     --unseen_eval_dataset.json_delta_action_dataset.tokenizer_parallel_batch_size=128 \
#     --unseen_eval_dataset.json_delta_action_dataset.use_data_sharded_loader=True \
#     --checkpointer.save_optimizer_state=False \
#     --autoresume=False \
#     --logger.append_uuid=False \
#     --logger.online=True \
#     --logger.project_id="$project_id" \
#     --logger.experiment_id="$experiment_id" \
#     --logger.experiment_note="$experiment_note" \
#     --logger.output_dir="$output_dir" \
#     --logger.wandb_dir="$HOME/experiment_output/$project_id"


# export SCRIPT_DIR="$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )"
# export PROJECT_DIR="$( cd -- "$( dirname -- "$SCRIPT_DIR" )" &> /dev/null && pwd )"
# cd $PROJECT_DIR
# export PYTHONPATH="$PYTHONPATH:$PROJECT_DIR"
# export LIBTPU_INIT_ARGS="--xla_tpu_megacore_fusion_allow_ags=false --xla_enable_async_collective_permute=true --xla_tpu_enable_ag_backward_pipelining=true --xla_tpu_enable_data_parallel_all_reduce_opt=true --xla_tpu_data_parallel_opt_different_sized_ops=true --xla_tpu_enable_async_collective_fusion=true --xla_tpu_enable_async_collective_fusion_multiple_steps=true --xla_tpu_overlap_compute_collective_tc=true --xla_enable_async_all_gather=true"

# export llama_tokenizer_path="/mnt/default/lwm/data/checkpoints/lwm_checkpoints/tokenizer.model"

# # export dataset_path='data/0808_multiobject_sink_llava_5hz_action_preprocessed.jsonl'
# # export dataset_path='data/0809_multiobject_sink_llava_256.jsonl'
# export dataset_path='data/0811_multiobject_sink_llava_train_filtered.jsonl'
# # export dataset_path='data/0815_multiobject_sink_rand_location_seen_llava_train_filtered.jsonl'
# # export dataset_path='data/0812_play_data_llava_train_filtered.jsonl'
# # export dataset_path='data/0812_play_data_llava_concat_multi2sink_train_filtered.jsonl'
# export eval_dataset_path='data/0811_multiobject_sink_llava_seen_val_filtered.jsonl'
# export unseen_eval_dataset_path='data/0811_multiobject_sink_llava_unseen_val_filtered.jsonl'
# # export eval_dataset_path="data/0731_multi_200_5hz_val.jsonl"

# export output_dir=$AMLT_OUTPUT_DIR
# export WANDB_API_KEY="0d0155a751a804873eedf37c29060146b377edb8"

# export project_id='lwm'
# export experiment_note='world-model'
# # export experiment_id='0811_latent_finetune_100traj_multitask2sink_15hz_imgaug_filtered_bridge_5epoch_h100'
# # export experiment_id='0811_latent_finetune_100traj_multitask2sink_15hz_imgaug_filtered_codeboook8_seq9_bridge_epoch3'

# # export experiment_id='0815_multiobject_sink_rand_location_seen_llava_train_filtered_codebook8_seq9_bridge_e3'
# # export experiment_id='0812_latent_play_data_multi2sink_mixture_finetune_15hz_imgaug_filtered'

# # export experiment_id='0811_multiobject_sink_llava_train_filtered_bridge_10percent_human_seen_mix_codebook8_seq1_epoch3'
# export experiment_id='0811_multiobject_sink_llava_train_filtered_human_seen_delta8_seq1_epoch1'
# # mesh_dim: dp, fsdp, tp, sp
# python3 -u -m lwm.train \
#     --modality='vision,action,delta' \
#     --mesh_dim='!-1,4,1,1' \
#     --dtype='bf16' \
#     --total_steps=2005 \
#     --log_freq=1 \
#     --delta_tokens=1 \
#     --eval_steps=5 \
#     --save_model_freq=0 \
#     --eval_log_freq=100 \
#     --save_milestone_freq=1000 \
#     --load_llama_config='7b' \
#     --load_checkpoint="params::/mnt/default/projects/lwm/amlt-results/7276144719.35290-ac86dab9-f4c2-4f5e-9df5-a2cbd25e5d82/lwm_human_seen_bf16_from_whole_delta8_batch32_seq1_filtered_no_inst/streaming_params_310" \
#     --update_llama_config="dict(action_vocab_size=256,delta_vocab_size=8,theta=50000000,max_sequence_length=2048,use_flash_attention=True,scan_attention=True,scan_query_chunk_size=512,scan_key_chunk_size=1024,remat_attention='nothing_saveable',scan_mlp=True,scan_mlp_chunk_size=8192,remat_mlp='nothing_saveable',remat_block='nothing_saveable',scan_layers=True)" \
#     --tokenizer.vocab_file="$llama_tokenizer_path" \
#     --optimizer.type='adamw' \
#     --llama.action_vocab_size=256 \
#     --llama.delta_vocab_size=8 \
#     --optimizer.accumulate_gradient_steps=1 \
#     --optimizer.adamw_optimizer.weight_decay=0 \
#     --optimizer.adamw_optimizer.lr=2e-5 \
#     --optimizer.adamw_optimizer.end_lr=2e-5 \
#     --optimizer.adamw_optimizer.lr_warmup_steps=0 \
#     --optimizer.adamw_optimizer.lr_decay_steps=3000 \
#     --use_data_sharded_loader=True \
#     --train_dataset.type='json_vision_delta_action' \
#     --train_dataset.delta_vision_action_processor.fields_from_example='fields' \
#     --train_dataset.delta_vision_action_processor.n_tokens_per_action=7 \
#     --train_dataset.delta_vision_action_processor.n_tokens_per_delta=1 \
#     --train_dataset.delta_vision_action_processor.img_aug=True \
#     --train_dataset.delta_vision_action_processor.vqgan_checkpoint_path='/mnt/default/lwm/data/checkpoints/lwm_checkpoints/vqgan'\
#     --train_dataset.delta_vision_action_processor.image_absolute_path='/mnt/default/lwm/data/finetune_data/'\
#     --train_dataset.delta_vision_action_processor.max_n_frames=1 \
#     --train_dataset.json_delta_action_dataset.mode="pad" \
#     --train_dataset.json_delta_action_dataset.path="$dataset_path" \
#     --train_dataset.json_delta_action_dataset.seq_length=384 \
#     --train_dataset.json_delta_action_dataset.batch_size=128 \
#     --train_dataset.json_delta_action_dataset.tokenizer_processes=1 \
#     --train_dataset.json_delta_action_dataset.tokenizer_parallel_chunk_size=128 \
#     --train_dataset.json_delta_action_dataset.tokenizer_parallel_batch_size=128 \
#     --train_dataset.json_delta_action_dataset.use_data_sharded_loader=True \
#     --eval_dataset.type='json_vision_delta_action' \
#     --eval_dataset.delta_vision_action_processor.fields_from_example='fields' \
#     --eval_dataset.delta_vision_action_processor.n_tokens_per_action=7 \
#     --eval_dataset.delta_vision_action_processor.n_tokens_per_delta=1 \
#     --eval_dataset.delta_vision_action_processor.img_aug=True \
#     --eval_dataset.delta_vision_action_processor.vqgan_checkpoint_path='/mnt/default/lwm/data/checkpoints/lwm_checkpoints/vqgan'\
#     --eval_dataset.delta_vision_action_processor.image_absolute_path='/mnt/default/lwm/data/finetune_data/'\
#     --eval_dataset.vision_action_processor.max_n_frames=1 \
#     --eval_dataset.json_delta_action_dataset.mode="pad" \
#     --eval_dataset.json_delta_action_dataset.path="$eval_dataset_path" \
#     --eval_dataset.json_delta_action_dataset.seq_length=384 \
#     --eval_dataset.json_delta_action_dataset.batch_size=128 \
#     --eval_dataset.json_delta_action_dataset.tokenizer_processes=1 \
#     --eval_dataset.json_delta_action_dataset.tokenizer_parallel_chunk_size=128 \
#     --eval_dataset.json_delta_action_dataset.tokenizer_parallel_batch_size=128 \
#     --eval_dataset.json_delta_action_dataset.use_data_sharded_loader=True \
#     --unseen_eval_dataset.type='json_vision_delta_action' \
#     --unseen_eval_dataset.delta_vision_action_processor.fields_from_example='fields' \
#     --unseen_eval_dataset.delta_vision_action_processor.n_tokens_per_action=7 \
#     --unseen_eval_dataset.delta_vision_action_processor.n_tokens_per_delta=1 \
#     --unseen_eval_dataset.delta_vision_action_processor.img_aug=True \
#     --unseen_eval_dataset.delta_vision_action_processor.vqgan_checkpoint_path='/mnt/default/lwm/data/checkpoints/lwm_checkpoints/vqgan'\
#     --unseen_eval_dataset.delta_vision_action_processor.image_absolute_path='/mnt/default/lwm/data/finetune_data/'\
#     --unseen_eval_dataset.vision_action_processor.max_n_frames=1 \
#     --unseen_eval_dataset.json_delta_action_dataset.mode="pad" \
#     --unseen_eval_dataset.json_delta_action_dataset.path="$unseen_eval_dataset_path" \
#     --unseen_eval_dataset.json_delta_action_dataset.seq_length=384 \
#     --unseen_eval_dataset.json_delta_action_dataset.batch_size=128 \
#     --unseen_eval_dataset.json_delta_action_dataset.tokenizer_processes=1 \
#     --unseen_eval_dataset.json_delta_action_dataset.tokenizer_parallel_chunk_size=128 \
#     --unseen_eval_dataset.json_delta_action_dataset.tokenizer_parallel_batch_size=128 \
#     --unseen_eval_dataset.json_delta_action_dataset.use_data_sharded_loader=True \
#     --checkpointer.save_optimizer_state=False \
#     --autoresume=False \
#     --logger.append_uuid=False \
#     --logger.online=True \
#     --logger.project_id="$project_id" \
#     --logger.experiment_id="$experiment_id" \
#     --logger.experiment_note="$experiment_note" \
#     --logger.output_dir="$output_dir" \
#     --logger.wandb_dir="$HOME/experiment_output/$project_id"


# # --load_checkpoint="params::/mnt/default/projects/lwm/amlt-results/7276144104.44081-77cab47d-2a7c-4d2f-9195-8abd53f96505/lwm_bridge_10percent_human_seen_mix_bf16_from_whole_delta8_batch32_seq1_filtered_no_inst/streaming_params_13530" \



# export SCRIPT_DIR="$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )"
# export PROJECT_DIR="$( cd -- "$( dirname -- "$SCRIPT_DIR" )" &> /dev/null && pwd )"
# cd $PROJECT_DIR
# export PYTHONPATH="$PYTHONPATH:$PROJECT_DIR"
# export LIBTPU_INIT_ARGS="--xla_tpu_megacore_fusion_allow_ags=false --xla_enable_async_collective_permute=true --xla_tpu_enable_ag_backward_pipelining=true --xla_tpu_enable_data_parallel_all_reduce_opt=true --xla_tpu_data_parallel_opt_different_sized_ops=true --xla_tpu_enable_async_collective_fusion=true --xla_tpu_enable_async_collective_fusion_multiple_steps=true --xla_tpu_overlap_compute_collective_tc=true --xla_enable_async_all_gather=true"

# export llama_tokenizer_path="/mnt/default/lwm/data/checkpoints/lwm_checkpoints/tokenizer.model"

# # export dataset_path='data/0808_multiobject_sink_llava_5hz_action_preprocessed.jsonl'
# # export dataset_path='data/0809_multiobject_sink_llava_256.jsonl'
# # export dataset_path='data/0811_multiobject_sink_llava_train_filtered.jsonl'
# export dataset_path='data/817_multiobject_sink_rand_location_llava_train_filtered.jsonl'
# # export dataset_path='data/0812_play_data_llava_train_filtered.jsonl'
# # export dataset_path='data/0812_play_data_llava_concat_multi2sink_train_filtered.jsonl'
# export eval_dataset_path='data/0811_multiobject_sink_llava_seen_val_filtered.jsonl'
# export unseen_eval_dataset_path='data/0811_multiobject_sink_llava_unseen_val_filtered.jsonl'
# # export eval_dataset_path="data/0731_multi_200_5hz_val.jsonl"

# export output_dir=$AMLT_OUTPUT_DIR
# export WANDB_API_KEY="0d0155a751a804873eedf37c29060146b377edb8"

# export project_id='lwm'
# export experiment_note='world-model'
# # export experiment_id='817_latent_c8_s9_epoch3_multiobject_sink_rand_location_llava_train_filtered_4000'
# export experiment_id='821_latent_c8_s4_l4_40k_multiobject_sink_rand_location_llava_train_filtered_2000_stage2_5k'
# # export experiment_id='0811_latent_finetune_100traj_multitask2sink_15hz_imgaug_filtered_bridge_5epoch_h100'
# # export experiment_id='0811_latent_finetune_100traj_multitask2sink_15hz_imgaug_filtered_codeboook8_seq9_bridge_epoch3'

# # export experiment_id='0815_multiobject_sink_rand_location_seen_llava_train_filtered_codebook8_seq9_bridge_e3'
# # export experiment_id='0812_latent_play_data_multi2sink_mixture_finetune_15hz_imgaug_filtered'
# # export experiment_id='0811_multiobject_sink_llava_train_filtered_bridge_10percent_codebook8_seq1'

# # mesh_dim: dp, fsdp, tp, sp
# python3 -u -m lwm.train \
#     --modality='vision,action,delta' \
#     --mesh_dim='!-1,8,1,1' \
#     --dtype='bf16' \
#     --total_steps=2005 \
#     --log_freq=1 \
#     --delta_tokens=1 \
#     --eval_steps=5 \
#     --save_model_freq=0 \
#     --eval_log_freq=100 \
#     --save_milestone_freq=1000 \
#     --load_llama_config='7b' \
#     --load_checkpoint="params::/mnt/default/projects/lwm/amlt-results/7275566570.87539-24bb832a-3031-404b-8339-5fe92f838119/lwm_human_bf16_from_whole_delta8_batch32_seq4_layer4_filtered_no_inst_imgaug_821_40Ktrain/streaming_params_5000"\
#     --update_llama_config="dict(action_vocab_size=256,delta_vocab_size=8,theta=50000000,max_sequence_length=2048,use_flash_attention=True,scan_attention=True,scan_query_chunk_size=512,scan_key_chunk_size=1024,remat_attention='nothing_saveable',scan_mlp=True,scan_mlp_chunk_size=8192,remat_mlp='nothing_saveable',remat_block='nothing_saveable',scan_layers=True)" \
#     --tokenizer.vocab_file="$llama_tokenizer_path" \
#     --optimizer.type='adamw' \
#     --llama.action_vocab_size=256 \
#     --llama.delta_vocab_size=8 \
#     --optimizer.accumulate_gradient_steps=1 \
#     --optimizer.adamw_optimizer.weight_decay=0 \
#     --optimizer.adamw_optimizer.lr=2e-5 \
#     --optimizer.adamw_optimizer.end_lr=2e-5 \
#     --optimizer.adamw_optimizer.lr_warmup_steps=0 \
#     --optimizer.adamw_optimizer.lr_decay_steps=3000 \
#     --use_data_sharded_loader=True \
#     --train_dataset.type='json_vision_delta_action' \
#     --train_dataset.delta_vision_action_processor.fields_from_example='fields' \
#     --train_dataset.delta_vision_action_processor.n_tokens_per_action=7 \
#     --train_dataset.delta_vision_action_processor.n_tokens_per_delta=4 \
#     --train_dataset.delta_vision_action_processor.img_aug=True \
#     --train_dataset.delta_vision_action_processor.vqgan_checkpoint_path='/mnt/default/lwm/data/checkpoints/lwm_checkpoints/vqgan'\
#     --train_dataset.delta_vision_action_processor.image_absolute_path='/mnt/default/lwm/data/finetune_data/'\
#     --train_dataset.delta_vision_action_processor.max_n_frames=1 \
#     --train_dataset.json_delta_action_dataset.mode="pad" \
#     --train_dataset.json_delta_action_dataset.path="$dataset_path" \
#     --train_dataset.json_delta_action_dataset.seq_length=384 \
#     --train_dataset.json_delta_action_dataset.batch_size=128 \
#     --train_dataset.json_delta_action_dataset.tokenizer_processes=1 \
#     --train_dataset.json_delta_action_dataset.tokenizer_parallel_chunk_size=128 \
#     --train_dataset.json_delta_action_dataset.tokenizer_parallel_batch_size=128 \
#     --train_dataset.json_delta_action_dataset.use_data_sharded_loader=True \
#     --eval_dataset.type='json_vision_delta_action' \
#     --eval_dataset.delta_vision_action_processor.fields_from_example='fields' \
#     --eval_dataset.delta_vision_action_processor.n_tokens_per_action=7 \
#     --eval_dataset.delta_vision_action_processor.n_tokens_per_delta=4 \
#     --eval_dataset.delta_vision_action_processor.img_aug=True \
#     --eval_dataset.delta_vision_action_processor.vqgan_checkpoint_path='/mnt/default/lwm/data/checkpoints/lwm_checkpoints/vqgan'\
#     --eval_dataset.delta_vision_action_processor.image_absolute_path='/mnt/default/lwm/data/finetune_data/'\
#     --eval_dataset.vision_action_processor.max_n_frames=1 \
#     --eval_dataset.json_delta_action_dataset.mode="pad" \
#     --eval_dataset.json_delta_action_dataset.path="$eval_dataset_path" \
#     --eval_dataset.json_delta_action_dataset.seq_length=384 \
#     --eval_dataset.json_delta_action_dataset.batch_size=128 \
#     --eval_dataset.json_delta_action_dataset.tokenizer_processes=1 \
#     --eval_dataset.json_delta_action_dataset.tokenizer_parallel_chunk_size=128 \
#     --eval_dataset.json_delta_action_dataset.tokenizer_parallel_batch_size=128 \
#     --eval_dataset.json_delta_action_dataset.use_data_sharded_loader=True \
#     --unseen_eval_dataset.type='json_vision_delta_action' \
#     --unseen_eval_dataset.delta_vision_action_processor.fields_from_example='fields' \
#     --unseen_eval_dataset.delta_vision_action_processor.n_tokens_per_action=7 \
#     --unseen_eval_dataset.delta_vision_action_processor.n_tokens_per_delta=4 \
#     --unseen_eval_dataset.delta_vision_action_processor.img_aug=True \
#     --unseen_eval_dataset.delta_vision_action_processor.vqgan_checkpoint_path='/mnt/default/lwm/data/checkpoints/lwm_checkpoints/vqgan'\
#     --unseen_eval_dataset.delta_vision_action_processor.image_absolute_path='/mnt/default/lwm/data/finetune_data/'\
#     --unseen_eval_dataset.vision_action_processor.max_n_frames=1 \
#     --unseen_eval_dataset.json_delta_action_dataset.mode="pad" \
#     --unseen_eval_dataset.json_delta_action_dataset.path="$unseen_eval_dataset_path" \
#     --unseen_eval_dataset.json_delta_action_dataset.seq_length=384 \
#     --unseen_eval_dataset.json_delta_action_dataset.batch_size=128 \
#     --unseen_eval_dataset.json_delta_action_dataset.tokenizer_processes=1 \
#     --unseen_eval_dataset.json_delta_action_dataset.tokenizer_parallel_chunk_size=128 \
#     --unseen_eval_dataset.json_delta_action_dataset.tokenizer_parallel_batch_size=128 \
#     --unseen_eval_dataset.json_delta_action_dataset.use_data_sharded_loader=True \
#     --checkpointer.save_optimizer_state=False \
#     --autoresume=False \
#     --logger.append_uuid=False \
#     --logger.online=True \
#     --logger.project_id="$project_id" \
#     --logger.experiment_id="$experiment_id" \
#     --logger.experiment_note="$experiment_note" \
#     --logger.output_dir="$output_dir" \
#     --logger.wandb_dir="$HOME/experiment_output/$project_id"




# export SCRIPT_DIR="$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )"
# export PROJECT_DIR="$( cd -- "$( dirname -- "$SCRIPT_DIR" )" &> /dev/null && pwd )"
# cd $PROJECT_DIR
# export PYTHONPATH="$PYTHONPATH:$PROJECT_DIR"
# export LIBTPU_INIT_ARGS="--xla_tpu_megacore_fusion_allow_ags=false --xla_enable_async_collective_permute=true --xla_tpu_enable_ag_backward_pipelining=true --xla_tpu_enable_data_parallel_all_reduce_opt=true --xla_tpu_data_parallel_opt_different_sized_ops=true --xla_tpu_enable_async_collective_fusion=true --xla_tpu_enable_async_collective_fusion_multiple_steps=true --xla_tpu_overlap_compute_collective_tc=true --xla_enable_async_all_gather=true"

# export llama_tokenizer_path="/mnt/default/lwm/data/checkpoints/lwm_checkpoints/tokenizer.model"

# # export dataset_path='data/0808_multiobject_sink_llava_5hz_action_preprocessed.jsonl'
# # export dataset_path='data/0809_multiobject_sink_llava_256.jsonl'
# # export dataset_path='data/0811_multiobject_sink_llava_train_filtered.jsonl'
# export dataset_path='data/817_multiobject_sink_rand_location_llava_train_filtered.jsonl'
# # export dataset_path='data/0812_play_data_llava_train_filtered.jsonl'
# # export dataset_path='data/0812_play_data_llava_concat_multi2sink_train_filtered.jsonl'
# export eval_dataset_path='data/0811_multiobject_sink_llava_seen_val_filtered.jsonl'
# export unseen_eval_dataset_path='data/0811_multiobject_sink_llava_unseen_val_filtered.jsonl'
# # export eval_dataset_path="data/0731_multi_200_5hz_val.jsonl"

# export output_dir=$AMLT_OUTPUT_DIR
# export WANDB_API_KEY="0d0155a751a804873eedf37c29060146b377edb8"

# export project_id='lwm'
# export experiment_note='world-model'
# # export experiment_id='817_latent_c8_s9_epoch3_multiobject_sink_rand_location_llava_train_filtered_4000'
# export experiment_id='821_latent_c8_s4_l4_20k_multiobject_sink_rand_location_llava_train_filtered_2000_stage2_5k'
# # export experiment_id='0811_latent_finetune_100traj_multitask2sink_15hz_imgaug_filtered_bridge_5epoch_h100'
# # export experiment_id='0811_latent_finetune_100traj_multitask2sink_15hz_imgaug_filtered_codeboook8_seq9_bridge_epoch3'

# # export experiment_id='0815_multiobject_sink_rand_location_seen_llava_train_filtered_codebook8_seq9_bridge_e3'
# # export experiment_id='0812_latent_play_data_multi2sink_mixture_finetune_15hz_imgaug_filtered'
# # export experiment_id='0811_multiobject_sink_llava_train_filtered_bridge_10percent_codebook8_seq1'

# # mesh_dim: dp, fsdp, tp, sp
# python3 -u -m lwm.train \
#     --modality='vision,action,delta' \
#     --mesh_dim='!-1,8,1,1' \
#     --dtype='bf16' \
#     --total_steps=2005 \
#     --log_freq=1 \
#     --delta_tokens=1 \
#     --eval_steps=5 \
#     --save_model_freq=0 \
#     --eval_log_freq=100 \
#     --save_milestone_freq=1000 \
#     --load_llama_config='7b' \
#     --load_checkpoint="params::/mnt/default/projects/lwm/amlt-results/7275549870.27458-879d4c3e-a522-4a7e-b6bb-15b1d12c2a2b/lwm_human_bf16_from_whole_delta8_batch32_seq4_layer4_filtered_no_inst_imgaug_821_20Ktrain/streaming_params_5000" \
#     --update_llama_config="dict(action_vocab_size=256,delta_vocab_size=8,theta=50000000,max_sequence_length=2048,use_flash_attention=True,scan_attention=True,scan_query_chunk_size=512,scan_key_chunk_size=1024,remat_attention='nothing_saveable',scan_mlp=True,scan_mlp_chunk_size=8192,remat_mlp='nothing_saveable',remat_block='nothing_saveable',scan_layers=True)" \
#     --tokenizer.vocab_file="$llama_tokenizer_path" \
#     --optimizer.type='adamw' \
#     --llama.action_vocab_size=256 \
#     --llama.delta_vocab_size=8 \
#     --optimizer.accumulate_gradient_steps=1 \
#     --optimizer.adamw_optimizer.weight_decay=0 \
#     --optimizer.adamw_optimizer.lr=2e-5 \
#     --optimizer.adamw_optimizer.end_lr=2e-5 \
#     --optimizer.adamw_optimizer.lr_warmup_steps=0 \
#     --optimizer.adamw_optimizer.lr_decay_steps=3000 \
#     --use_data_sharded_loader=True \
#     --train_dataset.type='json_vision_delta_action' \
#     --train_dataset.delta_vision_action_processor.fields_from_example='fields' \
#     --train_dataset.delta_vision_action_processor.n_tokens_per_action=7 \
#     --train_dataset.delta_vision_action_processor.n_tokens_per_delta=4 \
#     --train_dataset.delta_vision_action_processor.img_aug=True \
#     --train_dataset.delta_vision_action_processor.vqgan_checkpoint_path='/mnt/default/lwm/data/checkpoints/lwm_checkpoints/vqgan'\
#     --train_dataset.delta_vision_action_processor.image_absolute_path='/mnt/default/lwm/data/finetune_data/'\
#     --train_dataset.delta_vision_action_processor.max_n_frames=1 \
#     --train_dataset.json_delta_action_dataset.mode="pad" \
#     --train_dataset.json_delta_action_dataset.path="$dataset_path" \
#     --train_dataset.json_delta_action_dataset.seq_length=384 \
#     --train_dataset.json_delta_action_dataset.batch_size=128 \
#     --train_dataset.json_delta_action_dataset.tokenizer_processes=1 \
#     --train_dataset.json_delta_action_dataset.tokenizer_parallel_chunk_size=128 \
#     --train_dataset.json_delta_action_dataset.tokenizer_parallel_batch_size=128 \
#     --train_dataset.json_delta_action_dataset.use_data_sharded_loader=True \
#     --eval_dataset.type='json_vision_delta_action' \
#     --eval_dataset.delta_vision_action_processor.fields_from_example='fields' \
#     --eval_dataset.delta_vision_action_processor.n_tokens_per_action=7 \
#     --eval_dataset.delta_vision_action_processor.n_tokens_per_delta=4 \
#     --eval_dataset.delta_vision_action_processor.img_aug=True \
#     --eval_dataset.delta_vision_action_processor.vqgan_checkpoint_path='/mnt/default/lwm/data/checkpoints/lwm_checkpoints/vqgan'\
#     --eval_dataset.delta_vision_action_processor.image_absolute_path='/mnt/default/lwm/data/finetune_data/'\
#     --eval_dataset.vision_action_processor.max_n_frames=1 \
#     --eval_dataset.json_delta_action_dataset.mode="pad" \
#     --eval_dataset.json_delta_action_dataset.path="$eval_dataset_path" \
#     --eval_dataset.json_delta_action_dataset.seq_length=384 \
#     --eval_dataset.json_delta_action_dataset.batch_size=128 \
#     --eval_dataset.json_delta_action_dataset.tokenizer_processes=1 \
#     --eval_dataset.json_delta_action_dataset.tokenizer_parallel_chunk_size=128 \
#     --eval_dataset.json_delta_action_dataset.tokenizer_parallel_batch_size=128 \
#     --eval_dataset.json_delta_action_dataset.use_data_sharded_loader=True \
#     --unseen_eval_dataset.type='json_vision_delta_action' \
#     --unseen_eval_dataset.delta_vision_action_processor.fields_from_example='fields' \
#     --unseen_eval_dataset.delta_vision_action_processor.n_tokens_per_action=7 \
#     --unseen_eval_dataset.delta_vision_action_processor.n_tokens_per_delta=4 \
#     --unseen_eval_dataset.delta_vision_action_processor.img_aug=True \
#     --unseen_eval_dataset.delta_vision_action_processor.vqgan_checkpoint_path='/mnt/default/lwm/data/checkpoints/lwm_checkpoints/vqgan'\
#     --unseen_eval_dataset.delta_vision_action_processor.image_absolute_path='/mnt/default/lwm/data/finetune_data/'\
#     --unseen_eval_dataset.vision_action_processor.max_n_frames=1 \
#     --unseen_eval_dataset.json_delta_action_dataset.mode="pad" \
#     --unseen_eval_dataset.json_delta_action_dataset.path="$unseen_eval_dataset_path" \
#     --unseen_eval_dataset.json_delta_action_dataset.seq_length=384 \
#     --unseen_eval_dataset.json_delta_action_dataset.batch_size=128 \
#     --unseen_eval_dataset.json_delta_action_dataset.tokenizer_processes=1 \
#     --unseen_eval_dataset.json_delta_action_dataset.tokenizer_parallel_chunk_size=128 \
#     --unseen_eval_dataset.json_delta_action_dataset.tokenizer_parallel_batch_size=128 \
#     --unseen_eval_dataset.json_delta_action_dataset.use_data_sharded_loader=True \
#     --checkpointer.save_optimizer_state=False \
#     --autoresume=False \
#     --logger.append_uuid=False \
#     --logger.online=True \
#     --logger.project_id="$project_id" \
#     --logger.experiment_id="$experiment_id" \
#     --logger.experiment_note="$experiment_note" \
#     --logger.output_dir="$output_dir" \
#     --logger.wandb_dir="$HOME/experiment_output/$project_id"




# export SCRIPT_DIR="$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )"
# export PROJECT_DIR="$( cd -- "$( dirname -- "$SCRIPT_DIR" )" &> /dev/null && pwd )"
# cd $PROJECT_DIR
# export PYTHONPATH="$PYTHONPATH:$PROJECT_DIR"
# export LIBTPU_INIT_ARGS="--xla_tpu_megacore_fusion_allow_ags=false --xla_enable_async_collective_permute=true --xla_tpu_enable_ag_backward_pipelining=true --xla_tpu_enable_data_parallel_all_reduce_opt=true --xla_tpu_data_parallel_opt_different_sized_ops=true --xla_tpu_enable_async_collective_fusion=true --xla_tpu_enable_async_collective_fusion_multiple_steps=true --xla_tpu_overlap_compute_collective_tc=true --xla_enable_async_all_gather=true"

# export llama_tokenizer_path="/mnt/default/lwm/data/checkpoints/lwm_checkpoints/tokenizer.model"

# # export dataset_path='data/0808_multiobject_sink_llava_5hz_action_preprocessed.jsonl'
# # export dataset_path='data/0809_multiobject_sink_llava_256.jsonl'
# # export dataset_path='data/0811_multiobject_sink_llava_train_filtered.jsonl'
# export dataset_path='data/817_multiobject_sink_rand_location_llava_train_filtered.jsonl'
# # export dataset_path='data/0812_play_data_llava_train_filtered.jsonl'
# # export dataset_path='data/0812_play_data_llava_concat_multi2sink_train_filtered.jsonl'
# export eval_dataset_path='data/0811_multiobject_sink_llava_seen_val_filtered.jsonl'
# export unseen_eval_dataset_path='data/0811_multiobject_sink_llava_unseen_val_filtered.jsonl'
# # export eval_dataset_path="data/0731_multi_200_5hz_val.jsonl"

# export output_dir=$AMLT_OUTPUT_DIR
# export WANDB_API_KEY="0d0155a751a804873eedf37c29060146b377edb8"

# export project_id='lwm'
# export experiment_note='world-model'
# # export experiment_id='817_latent_c8_s9_epoch3_multiobject_sink_rand_location_llava_train_filtered_4000'
# export experiment_id='821_latent_c8_s4_l4_20k_multiobject_sink_rand_location_llava_train_filtered_2000_stage2_5k_domain_rand'
# # export experiment_id='0811_latent_finetune_100traj_multitask2sink_15hz_imgaug_filtered_bridge_5epoch_h100'
# # export experiment_id='0811_latent_finetune_100traj_multitask2sink_15hz_imgaug_filtered_codeboook8_seq9_bridge_epoch3'

# # export experiment_id='0815_multiobject_sink_rand_location_seen_llava_train_filtered_codebook8_seq9_bridge_e3'
# # export experiment_id='0812_latent_play_data_multi2sink_mixture_finetune_15hz_imgaug_filtered'
# # export experiment_id='0811_multiobject_sink_llava_train_filtered_bridge_10percent_codebook8_seq1'

# # mesh_dim: dp, fsdp, tp, sp
# python3 -u -m lwm.train \
#     --modality='vision,action,delta' \
#     --mesh_dim='!-1,8,1,1' \
#     --dtype='bf16' \
#     --total_steps=2005 \
#     --log_freq=1 \
#     --delta_tokens=1 \
#     --eval_steps=5 \
#     --save_model_freq=0 \
#     --eval_log_freq=100 \
#     --save_milestone_freq=1000 \
#     --load_llama_config='7b' \
#     --load_checkpoint="params::/mnt/default/projects/lwm/amlt-results/7275530656.79753-a06252f0-da2a-482f-a785-d918d16c636b/lwm_human_bf16_from_whole_delta8_batch32_seq4_layer4_filtered_no_inst_imgaug_821_20Ktrain_domain_rand/streaming_params_5000" \
#     --update_llama_config="dict(action_vocab_size=256,delta_vocab_size=8,theta=50000000,max_sequence_length=2048,use_flash_attention=True,scan_attention=True,scan_query_chunk_size=512,scan_key_chunk_size=1024,remat_attention='nothing_saveable',scan_mlp=True,scan_mlp_chunk_size=8192,remat_mlp='nothing_saveable',remat_block='nothing_saveable',scan_layers=True)" \
#     --tokenizer.vocab_file="$llama_tokenizer_path" \
#     --optimizer.type='adamw' \
#     --llama.action_vocab_size=256 \
#     --llama.delta_vocab_size=8 \
#     --optimizer.accumulate_gradient_steps=1 \
#     --optimizer.adamw_optimizer.weight_decay=0 \
#     --optimizer.adamw_optimizer.lr=2e-5 \
#     --optimizer.adamw_optimizer.end_lr=2e-5 \
#     --optimizer.adamw_optimizer.lr_warmup_steps=0 \
#     --optimizer.adamw_optimizer.lr_decay_steps=3000 \
#     --use_data_sharded_loader=True \
#     --train_dataset.type='json_vision_delta_action' \
#     --train_dataset.delta_vision_action_processor.fields_from_example='fields' \
#     --train_dataset.delta_vision_action_processor.n_tokens_per_action=7 \
#     --train_dataset.delta_vision_action_processor.n_tokens_per_delta=4 \
#     --train_dataset.delta_vision_action_processor.img_aug=True \
#     --train_dataset.delta_vision_action_processor.vqgan_checkpoint_path='/mnt/default/lwm/data/checkpoints/lwm_checkpoints/vqgan'\
#     --train_dataset.delta_vision_action_processor.image_absolute_path='/mnt/default/lwm/data/finetune_data/'\
#     --train_dataset.delta_vision_action_processor.max_n_frames=1 \
#     --train_dataset.json_delta_action_dataset.mode="pad" \
#     --train_dataset.json_delta_action_dataset.path="$dataset_path" \
#     --train_dataset.json_delta_action_dataset.seq_length=384 \
#     --train_dataset.json_delta_action_dataset.batch_size=128 \
#     --train_dataset.json_delta_action_dataset.tokenizer_processes=1 \
#     --train_dataset.json_delta_action_dataset.tokenizer_parallel_chunk_size=128 \
#     --train_dataset.json_delta_action_dataset.tokenizer_parallel_batch_size=128 \
#     --train_dataset.json_delta_action_dataset.use_data_sharded_loader=True \
#     --eval_dataset.type='json_vision_delta_action' \
#     --eval_dataset.delta_vision_action_processor.fields_from_example='fields' \
#     --eval_dataset.delta_vision_action_processor.n_tokens_per_action=7 \
#     --eval_dataset.delta_vision_action_processor.n_tokens_per_delta=4 \
#     --eval_dataset.delta_vision_action_processor.img_aug=True \
#     --eval_dataset.delta_vision_action_processor.vqgan_checkpoint_path='/mnt/default/lwm/data/checkpoints/lwm_checkpoints/vqgan'\
#     --eval_dataset.delta_vision_action_processor.image_absolute_path='/mnt/default/lwm/data/finetune_data/'\
#     --eval_dataset.vision_action_processor.max_n_frames=1 \
#     --eval_dataset.json_delta_action_dataset.mode="pad" \
#     --eval_dataset.json_delta_action_dataset.path="$eval_dataset_path" \
#     --eval_dataset.json_delta_action_dataset.seq_length=384 \
#     --eval_dataset.json_delta_action_dataset.batch_size=128 \
#     --eval_dataset.json_delta_action_dataset.tokenizer_processes=1 \
#     --eval_dataset.json_delta_action_dataset.tokenizer_parallel_chunk_size=128 \
#     --eval_dataset.json_delta_action_dataset.tokenizer_parallel_batch_size=128 \
#     --eval_dataset.json_delta_action_dataset.use_data_sharded_loader=True \
#     --unseen_eval_dataset.type='json_vision_delta_action' \
#     --unseen_eval_dataset.delta_vision_action_processor.fields_from_example='fields' \
#     --unseen_eval_dataset.delta_vision_action_processor.n_tokens_per_action=7 \
#     --unseen_eval_dataset.delta_vision_action_processor.n_tokens_per_delta=4 \
#     --unseen_eval_dataset.delta_vision_action_processor.img_aug=True \
#     --unseen_eval_dataset.delta_vision_action_processor.vqgan_checkpoint_path='/mnt/default/lwm/data/checkpoints/lwm_checkpoints/vqgan'\
#     --unseen_eval_dataset.delta_vision_action_processor.image_absolute_path='/mnt/default/lwm/data/finetune_data/'\
#     --unseen_eval_dataset.vision_action_processor.max_n_frames=1 \
#     --unseen_eval_dataset.json_delta_action_dataset.mode="pad" \
#     --unseen_eval_dataset.json_delta_action_dataset.path="$unseen_eval_dataset_path" \
#     --unseen_eval_dataset.json_delta_action_dataset.seq_length=384 \
#     --unseen_eval_dataset.json_delta_action_dataset.batch_size=128 \
#     --unseen_eval_dataset.json_delta_action_dataset.tokenizer_processes=1 \
#     --unseen_eval_dataset.json_delta_action_dataset.tokenizer_parallel_chunk_size=128 \
#     --unseen_eval_dataset.json_delta_action_dataset.tokenizer_parallel_batch_size=128 \
#     --unseen_eval_dataset.json_delta_action_dataset.use_data_sharded_loader=True \
#     --checkpointer.save_optimizer_state=False \
#     --autoresume=False \
#     --logger.append_uuid=False \
#     --logger.online=True \
#     --logger.project_id="$project_id" \
#     --logger.experiment_id="$experiment_id" \
#     --logger.experiment_note="$experiment_note" \
#     --logger.output_dir="$output_dir" \
#     --logger.wandb_dir="$HOME/experiment_output/$project_id"




# export SCRIPT_DIR="$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )"
# export PROJECT_DIR="$( cd -- "$( dirname -- "$SCRIPT_DIR" )" &> /dev/null && pwd )"
# cd $PROJECT_DIR
# export PYTHONPATH="$PYTHONPATH:$PROJECT_DIR"
# export LIBTPU_INIT_ARGS="--xla_tpu_megacore_fusion_allow_ags=false --xla_enable_async_collective_permute=true --xla_tpu_enable_ag_backward_pipelining=true --xla_tpu_enable_data_parallel_all_reduce_opt=true --xla_tpu_data_parallel_opt_different_sized_ops=true --xla_tpu_enable_async_collective_fusion=true --xla_tpu_enable_async_collective_fusion_multiple_steps=true --xla_tpu_overlap_compute_collective_tc=true --xla_enable_async_all_gather=true"

# export llama_tokenizer_path="/mnt/default/lwm/data/checkpoints/lwm_checkpoints/tokenizer.model"

# # export dataset_path='data/0808_multiobject_sink_llava_5hz_action_preprocessed.jsonl'
# # export dataset_path='data/0809_multiobject_sink_llava_256.jsonl'
# # export dataset_path='data/0811_multiobject_sink_llava_train_filtered.jsonl'
# export dataset_path='data/817_multiobject_sink_rand_location_llava_train_filtered.jsonl'
# # export dataset_path='data/0812_play_data_llava_train_filtered.jsonl'
# # export dataset_path='data/0812_play_data_llava_concat_multi2sink_train_filtered.jsonl'
# export eval_dataset_path='data/0811_multiobject_sink_llava_seen_val_filtered.jsonl'
# export unseen_eval_dataset_path='data/0811_multiobject_sink_llava_unseen_val_filtered.jsonl'
# # export eval_dataset_path="data/0731_multi_200_5hz_val.jsonl"

# export output_dir=$AMLT_OUTPUT_DIR
# export WANDB_API_KEY="0d0155a751a804873eedf37c29060146b377edb8"

# export project_id='lwm'
# export experiment_note='world-model'
# # export experiment_id='817_latent_c8_s9_epoch3_multiobject_sink_rand_location_llava_train_filtered_4000'
# export experiment_id='821_latent_c8_s4_l4_40k_multiobject_sink_rand_location_llava_train_filtered_2000_stage2_5k_domain_rand'
# # export experiment_id='0811_latent_finetune_100traj_multitask2sink_15hz_imgaug_filtered_bridge_5epoch_h100'
# # export experiment_id='0811_latent_finetune_100traj_multitask2sink_15hz_imgaug_filtered_codeboook8_seq9_bridge_epoch3'

# # export experiment_id='0815_multiobject_sink_rand_location_seen_llava_train_filtered_codebook8_seq9_bridge_e3'
# # export experiment_id='0812_latent_play_data_multi2sink_mixture_finetune_15hz_imgaug_filtered'
# # export experiment_id='0811_multiobject_sink_llava_train_filtered_bridge_10percent_codebook8_seq1'

# # mesh_dim: dp, fsdp, tp, sp
# python3 -u -m lwm.train \
#     --modality='vision,action,delta' \
#     --mesh_dim='!-1,8,1,1' \
#     --dtype='bf16' \
#     --total_steps=2005 \
#     --log_freq=1 \
#     --delta_tokens=1 \
#     --eval_steps=5 \
#     --save_model_freq=0 \
#     --eval_log_freq=100 \
#     --save_milestone_freq=1000 \
#     --load_llama_config='7b' \
#     --load_checkpoint="params::/mnt/default/projects/lwm/amlt-results/7275530793.43282-b4f131cb-7866-41c5-9c27-64b6aa1becb5/lwm_human_bf16_from_whole_delta8_batch32_seq4_layer4_filtered_no_inst_imgaug_821_40Ktrain_domain_rand/streaming_params_5000" \
#     --update_llama_config="dict(action_vocab_size=256,delta_vocab_size=8,theta=50000000,max_sequence_length=2048,use_flash_attention=True,scan_attention=True,scan_query_chunk_size=512,scan_key_chunk_size=1024,remat_attention='nothing_saveable',scan_mlp=True,scan_mlp_chunk_size=8192,remat_mlp='nothing_saveable',remat_block='nothing_saveable',scan_layers=True)" \
#     --tokenizer.vocab_file="$llama_tokenizer_path" \
#     --optimizer.type='adamw' \
#     --llama.action_vocab_size=256 \
#     --llama.delta_vocab_size=8 \
#     --optimizer.accumulate_gradient_steps=1 \
#     --optimizer.adamw_optimizer.weight_decay=0 \
#     --optimizer.adamw_optimizer.lr=2e-5 \
#     --optimizer.adamw_optimizer.end_lr=2e-5 \
#     --optimizer.adamw_optimizer.lr_warmup_steps=0 \
#     --optimizer.adamw_optimizer.lr_decay_steps=3000 \
#     --use_data_sharded_loader=True \
#     --train_dataset.type='json_vision_delta_action' \
#     --train_dataset.delta_vision_action_processor.fields_from_example='fields' \
#     --train_dataset.delta_vision_action_processor.n_tokens_per_action=7 \
#     --train_dataset.delta_vision_action_processor.n_tokens_per_delta=4 \
#     --train_dataset.delta_vision_action_processor.img_aug=True \
#     --train_dataset.delta_vision_action_processor.vqgan_checkpoint_path='/mnt/default/lwm/data/checkpoints/lwm_checkpoints/vqgan'\
#     --train_dataset.delta_vision_action_processor.image_absolute_path='/mnt/default/lwm/data/finetune_data/'\
#     --train_dataset.delta_vision_action_processor.max_n_frames=1 \
#     --train_dataset.json_delta_action_dataset.mode="pad" \
#     --train_dataset.json_delta_action_dataset.path="$dataset_path" \
#     --train_dataset.json_delta_action_dataset.seq_length=384 \
#     --train_dataset.json_delta_action_dataset.batch_size=128 \
#     --train_dataset.json_delta_action_dataset.tokenizer_processes=1 \
#     --train_dataset.json_delta_action_dataset.tokenizer_parallel_chunk_size=128 \
#     --train_dataset.json_delta_action_dataset.tokenizer_parallel_batch_size=128 \
#     --train_dataset.json_delta_action_dataset.use_data_sharded_loader=True \
#     --eval_dataset.type='json_vision_delta_action' \
#     --eval_dataset.delta_vision_action_processor.fields_from_example='fields' \
#     --eval_dataset.delta_vision_action_processor.n_tokens_per_action=7 \
#     --eval_dataset.delta_vision_action_processor.n_tokens_per_delta=4 \
#     --eval_dataset.delta_vision_action_processor.img_aug=True \
#     --eval_dataset.delta_vision_action_processor.vqgan_checkpoint_path='/mnt/default/lwm/data/checkpoints/lwm_checkpoints/vqgan'\
#     --eval_dataset.delta_vision_action_processor.image_absolute_path='/mnt/default/lwm/data/finetune_data/'\
#     --eval_dataset.vision_action_processor.max_n_frames=1 \
#     --eval_dataset.json_delta_action_dataset.mode="pad" \
#     --eval_dataset.json_delta_action_dataset.path="$eval_dataset_path" \
#     --eval_dataset.json_delta_action_dataset.seq_length=384 \
#     --eval_dataset.json_delta_action_dataset.batch_size=128 \
#     --eval_dataset.json_delta_action_dataset.tokenizer_processes=1 \
#     --eval_dataset.json_delta_action_dataset.tokenizer_parallel_chunk_size=128 \
#     --eval_dataset.json_delta_action_dataset.tokenizer_parallel_batch_size=128 \
#     --eval_dataset.json_delta_action_dataset.use_data_sharded_loader=True \
#     --unseen_eval_dataset.type='json_vision_delta_action' \
#     --unseen_eval_dataset.delta_vision_action_processor.fields_from_example='fields' \
#     --unseen_eval_dataset.delta_vision_action_processor.n_tokens_per_action=7 \
#     --unseen_eval_dataset.delta_vision_action_processor.n_tokens_per_delta=4 \
#     --unseen_eval_dataset.delta_vision_action_processor.img_aug=True \
#     --unseen_eval_dataset.delta_vision_action_processor.vqgan_checkpoint_path='/mnt/default/lwm/data/checkpoints/lwm_checkpoints/vqgan'\
#     --unseen_eval_dataset.delta_vision_action_processor.image_absolute_path='/mnt/default/lwm/data/finetune_data/'\
#     --unseen_eval_dataset.vision_action_processor.max_n_frames=1 \
#     --unseen_eval_dataset.json_delta_action_dataset.mode="pad" \
#     --unseen_eval_dataset.json_delta_action_dataset.path="$unseen_eval_dataset_path" \
#     --unseen_eval_dataset.json_delta_action_dataset.seq_length=384 \
#     --unseen_eval_dataset.json_delta_action_dataset.batch_size=128 \
#     --unseen_eval_dataset.json_delta_action_dataset.tokenizer_processes=1 \
#     --unseen_eval_dataset.json_delta_action_dataset.tokenizer_parallel_chunk_size=128 \
#     --unseen_eval_dataset.json_delta_action_dataset.tokenizer_parallel_batch_size=128 \
#     --unseen_eval_dataset.json_delta_action_dataset.use_data_sharded_loader=True \
#     --checkpointer.save_optimizer_state=False \
#     --autoresume=False \
#     --logger.append_uuid=False \
#     --logger.online=True \
#     --logger.project_id="$project_id" \
#     --logger.experiment_id="$experiment_id" \
#     --logger.experiment_note="$experiment_note" \
#     --logger.output_dir="$output_dir" \
#     --logger.wandb_dir="$HOME/experiment_output/$project_id"




# export SCRIPT_DIR="$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )"
# export PROJECT_DIR="$( cd -- "$( dirname -- "$SCRIPT_DIR" )" &> /dev/null && pwd )"
# cd $PROJECT_DIR
# export PYTHONPATH="$PYTHONPATH:$PROJECT_DIR"
# export LIBTPU_INIT_ARGS="--xla_tpu_megacore_fusion_allow_ags=false --xla_enable_async_collective_permute=true --xla_tpu_enable_ag_backward_pipelining=true --xla_tpu_enable_data_parallel_all_reduce_opt=true --xla_tpu_data_parallel_opt_different_sized_ops=true --xla_tpu_enable_async_collective_fusion=true --xla_tpu_enable_async_collective_fusion_multiple_steps=true --xla_tpu_overlap_compute_collective_tc=true --xla_enable_async_all_gather=true"

# export llama_tokenizer_path="/mnt/default/lwm/data/checkpoints/lwm_checkpoints/tokenizer.model"

# # export dataset_path='data/0808_multiobject_sink_llava_5hz_action_preprocessed.jsonl'
# # export dataset_path='data/0809_multiobject_sink_llava_256.jsonl'
# export dataset_path='data/0811_multiobject_sink_llava_train_filtered.jsonl'
# # export dataset_path='data/817_multiobject_sink_rand_location_llava_train_filtered.jsonl'
# # export dataset_path='data/0812_play_data_llava_train_filtered.jsonl'
# # export dataset_path='data/0812_play_data_llava_concat_multi2sink_train_filtered.jsonl'
# export eval_dataset_path='data/0811_multiobject_sink_llava_seen_val_filtered.jsonl'
# export unseen_eval_dataset_path='data/0811_multiobject_sink_llava_unseen_val_filtered.jsonl'
# # export eval_dataset_path="data/0731_multi_200_5hz_val.jsonl"

# export output_dir=$AMLT_OUTPUT_DIR
# export WANDB_API_KEY="0d0155a751a804873eedf37c29060146b377edb8"

# export project_id='lwm'
# export experiment_note='world-model'
# # export experiment_id='817_latent_c8_s9_epoch3_multiobject_sink_rand_location_llava_train_filtered_4000'
# # export experiment_id='821_latent_bridge10_c8_s4_l4_40k_multiobject_sink_rand_location_llava_train_filtered_10k'

# export experiment_id='0811_latent_bridge10_c8_s4_l4_40k_multiobject_sink_llava_train_filtered_4k'

# # mesh_dim: dp, fsdp, tp, sp
# python3 -u -m lwm.train \
#     --modality='vision,action,delta' \
#     --mesh_dim='!-1,8,1,1' \
#     --dtype='bf16' \
#     --total_steps=1005 \
#     --log_freq=1 \
#     --delta_tokens=1 \
#     --eval_steps=5 \
#     --save_model_freq=0 \
#     --eval_log_freq=100 \
#     --save_milestone_freq=1000 \
#     --load_llama_config='7b' \
#     --load_checkpoint="params::/mnt/default/projects/lwm/amlt-results/7275559807.38266-3a465a47-cb9e-4d09-8827-26da7927e7b7/lwm_bridge_10percent_delta8_batch32_seq4_layer4_filtered_no_inst_40K_re2_8gpu/streaming_params_5000" \
#     --update_llama_config="dict(action_vocab_size=256,delta_vocab_size=8,theta=50000000,max_sequence_length=2048,use_flash_attention=True,scan_attention=True,scan_query_chunk_size=512,scan_key_chunk_size=1024,remat_attention='nothing_saveable',scan_mlp=True,scan_mlp_chunk_size=8192,remat_mlp='nothing_saveable',remat_block='nothing_saveable',scan_layers=True)" \
#     --tokenizer.vocab_file="$llama_tokenizer_path" \
#     --optimizer.type='adamw' \
#     --llama.action_vocab_size=256 \
#     --llama.delta_vocab_size=8 \
#     --optimizer.accumulate_gradient_steps=1 \
#     --optimizer.adamw_optimizer.weight_decay=0 \
#     --optimizer.adamw_optimizer.lr=2e-5 \
#     --optimizer.adamw_optimizer.end_lr=2e-5 \
#     --optimizer.adamw_optimizer.lr_warmup_steps=0 \
#     --optimizer.adamw_optimizer.lr_decay_steps=3000 \
#     --use_data_sharded_loader=True \
#     --train_dataset.type='json_vision_delta_action' \
#     --train_dataset.delta_vision_action_processor.fields_from_example='fields' \
#     --train_dataset.delta_vision_action_processor.n_tokens_per_action=7 \
#     --train_dataset.delta_vision_action_processor.n_tokens_per_delta=4 \
#     --train_dataset.delta_vision_action_processor.img_aug=True \
#     --train_dataset.delta_vision_action_processor.vqgan_checkpoint_path='/mnt/default/lwm/data/checkpoints/lwm_checkpoints/vqgan'\
#     --train_dataset.delta_vision_action_processor.image_absolute_path='/mnt/default/lwm/data/finetune_data/'\
#     --train_dataset.delta_vision_action_processor.max_n_frames=1 \
#     --train_dataset.json_delta_action_dataset.mode="pad" \
#     --train_dataset.json_delta_action_dataset.path="$dataset_path" \
#     --train_dataset.json_delta_action_dataset.seq_length=384 \
#     --train_dataset.json_delta_action_dataset.batch_size=128 \
#     --train_dataset.json_delta_action_dataset.tokenizer_processes=1 \
#     --train_dataset.json_delta_action_dataset.tokenizer_parallel_chunk_size=128 \
#     --train_dataset.json_delta_action_dataset.tokenizer_parallel_batch_size=128 \
#     --train_dataset.json_delta_action_dataset.use_data_sharded_loader=True \
#     --eval_dataset.type='json_vision_delta_action' \
#     --eval_dataset.delta_vision_action_processor.fields_from_example='fields' \
#     --eval_dataset.delta_vision_action_processor.n_tokens_per_action=7 \
#     --eval_dataset.delta_vision_action_processor.n_tokens_per_delta=4 \
#     --eval_dataset.delta_vision_action_processor.img_aug=True \
#     --eval_dataset.delta_vision_action_processor.vqgan_checkpoint_path='/mnt/default/lwm/data/checkpoints/lwm_checkpoints/vqgan'\
#     --eval_dataset.delta_vision_action_processor.image_absolute_path='/mnt/default/lwm/data/finetune_data/'\
#     --eval_dataset.vision_action_processor.max_n_frames=1 \
#     --eval_dataset.json_delta_action_dataset.mode="pad" \
#     --eval_dataset.json_delta_action_dataset.path="$eval_dataset_path" \
#     --eval_dataset.json_delta_action_dataset.seq_length=384 \
#     --eval_dataset.json_delta_action_dataset.batch_size=128 \
#     --eval_dataset.json_delta_action_dataset.tokenizer_processes=1 \
#     --eval_dataset.json_delta_action_dataset.tokenizer_parallel_chunk_size=128 \
#     --eval_dataset.json_delta_action_dataset.tokenizer_parallel_batch_size=128 \
#     --eval_dataset.json_delta_action_dataset.use_data_sharded_loader=True \
#     --unseen_eval_dataset.type='json_vision_delta_action' \
#     --unseen_eval_dataset.delta_vision_action_processor.fields_from_example='fields' \
#     --unseen_eval_dataset.delta_vision_action_processor.n_tokens_per_action=7 \
#     --unseen_eval_dataset.delta_vision_action_processor.n_tokens_per_delta=4 \
#     --unseen_eval_dataset.delta_vision_action_processor.img_aug=True \
#     --unseen_eval_dataset.delta_vision_action_processor.vqgan_checkpoint_path='/mnt/default/lwm/data/checkpoints/lwm_checkpoints/vqgan'\
#     --unseen_eval_dataset.delta_vision_action_processor.image_absolute_path='/mnt/default/lwm/data/finetune_data/'\
#     --unseen_eval_dataset.vision_action_processor.max_n_frames=1 \
#     --unseen_eval_dataset.json_delta_action_dataset.mode="pad" \
#     --unseen_eval_dataset.json_delta_action_dataset.path="$unseen_eval_dataset_path" \
#     --unseen_eval_dataset.json_delta_action_dataset.seq_length=384 \
#     --unseen_eval_dataset.json_delta_action_dataset.batch_size=128 \
#     --unseen_eval_dataset.json_delta_action_dataset.tokenizer_processes=1 \
#     --unseen_eval_dataset.json_delta_action_dataset.tokenizer_parallel_chunk_size=128 \
#     --unseen_eval_dataset.json_delta_action_dataset.tokenizer_parallel_batch_size=128 \
#     --unseen_eval_dataset.json_delta_action_dataset.use_data_sharded_loader=True \
#     --checkpointer.save_optimizer_state=False \
#     --autoresume=False \
#     --logger.append_uuid=False \
#     --logger.online=True \
#     --logger.project_id="$project_id" \
#     --logger.experiment_id="$experiment_id" \
#     --logger.experiment_note="$experiment_note" \
#     --logger.output_dir="$output_dir" \
#     --logger.wandb_dir="$HOME/experiment_output/$project_id"



# export SCRIPT_DIR="$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )"
# export PROJECT_DIR="$( cd -- "$( dirname -- "$SCRIPT_DIR" )" &> /dev/null && pwd )"
# cd $PROJECT_DIR
# export PYTHONPATH="$PYTHONPATH:$PROJECT_DIR"
# export LIBTPU_INIT_ARGS="--xla_tpu_megacore_fusion_allow_ags=false --xla_enable_async_collective_permute=true --xla_tpu_enable_ag_backward_pipelining=true --xla_tpu_enable_data_parallel_all_reduce_opt=true --xla_tpu_data_parallel_opt_different_sized_ops=true --xla_tpu_enable_async_collective_fusion=true --xla_tpu_enable_async_collective_fusion_multiple_steps=true --xla_tpu_overlap_compute_collective_tc=true --xla_enable_async_all_gather=true"

# export llama_tokenizer_path="/mnt/default/lwm/data/checkpoints/lwm_checkpoints/tokenizer.model"

# export dataset_path='data/817_multiobject_sink_rand_location_llava_train_filtered.jsonl'
# export eval_dataset_path='data/0811_multiobject_sink_llava_seen_val_filtered.jsonl'
# export unseen_eval_dataset_path='data/0811_multiobject_sink_llava_unseen_val_filtered.jsonl'

# export output_dir=$AMLT_OUTPUT_DIR
# export WANDB_API_KEY="0d0155a751a804873eedf37c29060146b377edb8"

# export project_id='lwm'
# export experiment_note='world-model'
# # export experiment_id='817_latent_c8_s9_epoch3_multiobject_sink_rand_location_llava_train_filtered_4000'
# # export experiment_id='821_latent_bridge10_c8_s4_l4_40k_multiobject_sink_rand_location_llava_train_filtered_10k'

# export experiment_id='817_latent_sthv2_10percent_c8_s4_l8_100k_multiobject_sink_rand_location_llava_train_filtered_14k_re'

# # mesh_dim: dp, fsdp, tp, sp
# python3 -u -m lwm.train \
#     --modality='vision,action,delta' \
#     --mesh_dim='!-1,8,1,1' \
#     --dtype='bf16' \
#     --total_steps=4005 \
#     --log_freq=1 \
#     --delta_tokens=1 \
#     --eval_steps=5 \
#     --save_model_freq=0 \
#     --eval_log_freq=100 \
#     --save_milestone_freq=1000 \
#     --load_llama_config='7b' \
#     --load_checkpoint="params::/mnt/default/projects/lwm/amlt-results/7275056058.06230-948e80e8-b4b8-4fe8-b547-2237a94c1d8f/sthv2_bf16_from_whole_delta8_batch128_seq4_layer8_100K_re2/streaming_params_14454" \
#     --update_llama_config="dict(action_vocab_size=256,delta_vocab_size=8,theta=50000000,max_sequence_length=2048,use_flash_attention=True,scan_attention=True,scan_query_chunk_size=512,scan_key_chunk_size=1024,remat_attention='nothing_saveable',scan_mlp=True,scan_mlp_chunk_size=8192,remat_mlp='nothing_saveable',remat_block='nothing_saveable',scan_layers=True)" \
#     --tokenizer.vocab_file="$llama_tokenizer_path" \
#     --optimizer.type='adamw' \
#     --llama.action_vocab_size=256 \
#     --llama.delta_vocab_size=8 \
#     --optimizer.accumulate_gradient_steps=1 \
#     --optimizer.adamw_optimizer.weight_decay=0 \
#     --optimizer.adamw_optimizer.lr=2e-5 \
#     --optimizer.adamw_optimizer.end_lr=2e-5 \
#     --optimizer.adamw_optimizer.lr_warmup_steps=0 \
#     --optimizer.adamw_optimizer.lr_decay_steps=3000 \
#     --use_data_sharded_loader=True \
#     --train_dataset.type='json_vision_delta_action' \
#     --train_dataset.delta_vision_action_processor.fields_from_example='fields' \
#     --train_dataset.delta_vision_action_processor.n_tokens_per_action=7 \
#     --train_dataset.delta_vision_action_processor.n_tokens_per_delta=4 \
#     --train_dataset.delta_vision_action_processor.img_aug=True \
#     --train_dataset.delta_vision_action_processor.vqgan_checkpoint_path='/mnt/default/lwm/data/checkpoints/lwm_checkpoints/vqgan'\
#     --train_dataset.delta_vision_action_processor.image_absolute_path='/mnt/default/lwm/data/finetune_data/'\
#     --train_dataset.delta_vision_action_processor.max_n_frames=1 \
#     --train_dataset.json_delta_action_dataset.mode="pad" \
#     --train_dataset.json_delta_action_dataset.path="$dataset_path" \
#     --train_dataset.json_delta_action_dataset.seq_length=384 \
#     --train_dataset.json_delta_action_dataset.batch_size=128 \
#     --train_dataset.json_delta_action_dataset.tokenizer_processes=1 \
#     --train_dataset.json_delta_action_dataset.tokenizer_parallel_chunk_size=128 \
#     --train_dataset.json_delta_action_dataset.tokenizer_parallel_batch_size=128 \
#     --train_dataset.json_delta_action_dataset.use_data_sharded_loader=True \
#     --eval_dataset.type='json_vision_delta_action' \
#     --eval_dataset.delta_vision_action_processor.fields_from_example='fields' \
#     --eval_dataset.delta_vision_action_processor.n_tokens_per_action=7 \
#     --eval_dataset.delta_vision_action_processor.n_tokens_per_delta=4 \
#     --eval_dataset.delta_vision_action_processor.img_aug=True \
#     --eval_dataset.delta_vision_action_processor.vqgan_checkpoint_path='/mnt/default/lwm/data/checkpoints/lwm_checkpoints/vqgan'\
#     --eval_dataset.delta_vision_action_processor.image_absolute_path='/mnt/default/lwm/data/finetune_data/'\
#     --eval_dataset.vision_action_processor.max_n_frames=1 \
#     --eval_dataset.json_delta_action_dataset.mode="pad" \
#     --eval_dataset.json_delta_action_dataset.path="$eval_dataset_path" \
#     --eval_dataset.json_delta_action_dataset.seq_length=384 \
#     --eval_dataset.json_delta_action_dataset.batch_size=128 \
#     --eval_dataset.json_delta_action_dataset.tokenizer_processes=1 \
#     --eval_dataset.json_delta_action_dataset.tokenizer_parallel_chunk_size=128 \
#     --eval_dataset.json_delta_action_dataset.tokenizer_parallel_batch_size=128 \
#     --eval_dataset.json_delta_action_dataset.use_data_sharded_loader=True \
#     --unseen_eval_dataset.type='json_vision_delta_action' \
#     --unseen_eval_dataset.delta_vision_action_processor.fields_from_example='fields' \
#     --unseen_eval_dataset.delta_vision_action_processor.n_tokens_per_action=7 \
#     --unseen_eval_dataset.delta_vision_action_processor.n_tokens_per_delta=4 \
#     --unseen_eval_dataset.delta_vision_action_processor.img_aug=True \
#     --unseen_eval_dataset.delta_vision_action_processor.vqgan_checkpoint_path='/mnt/default/lwm/data/checkpoints/lwm_checkpoints/vqgan'\
#     --unseen_eval_dataset.delta_vision_action_processor.image_absolute_path='/mnt/default/lwm/data/finetune_data/'\
#     --unseen_eval_dataset.vision_action_processor.max_n_frames=1 \
#     --unseen_eval_dataset.json_delta_action_dataset.mode="pad" \
#     --unseen_eval_dataset.json_delta_action_dataset.path="$unseen_eval_dataset_path" \
#     --unseen_eval_dataset.json_delta_action_dataset.seq_length=384 \
#     --unseen_eval_dataset.json_delta_action_dataset.batch_size=128 \
#     --unseen_eval_dataset.json_delta_action_dataset.tokenizer_processes=1 \
#     --unseen_eval_dataset.json_delta_action_dataset.tokenizer_parallel_chunk_size=128 \
#     --unseen_eval_dataset.json_delta_action_dataset.tokenizer_parallel_batch_size=128 \
#     --unseen_eval_dataset.json_delta_action_dataset.use_data_sharded_loader=True \
#     --checkpointer.save_optimizer_state=False \
#     --autoresume=False \
#     --logger.append_uuid=False \
#     --logger.online=True \
#     --logger.project_id="$project_id" \
#     --logger.experiment_id="$experiment_id" \
#     --logger.experiment_note="$experiment_note" \
#     --logger.output_dir="$output_dir" \
#     --logger.wandb_dir="$HOME/experiment_output/$project_id"



# export SCRIPT_DIR="$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )"
# export PROJECT_DIR="$( cd -- "$( dirname -- "$SCRIPT_DIR" )" &> /dev/null && pwd )"
# cd $PROJECT_DIR
# export PYTHONPATH="$PYTHONPATH:$PROJECT_DIR"
# export LIBTPU_INIT_ARGS="--xla_tpu_megacore_fusion_allow_ags=false --xla_enable_async_collective_permute=true --xla_tpu_enable_ag_backward_pipelining=true --xla_tpu_enable_data_parallel_all_reduce_opt=true --xla_tpu_data_parallel_opt_different_sized_ops=true --xla_tpu_enable_async_collective_fusion=true --xla_tpu_enable_async_collective_fusion_multiple_steps=true --xla_tpu_overlap_compute_collective_tc=true --xla_enable_async_all_gather=true"

# export llama_tokenizer_path="/mnt/default/lwm/data/checkpoints/lwm_checkpoints/tokenizer.model"

# # export dataset_path='data/0808_multiobject_sink_llava_5hz_action_preprocessed.jsonl'
# # export dataset_path='data/0809_multiobject_sink_llava_256.jsonl'
# export dataset_path='data/0811_multiobject_sink_llava_train_filtered.jsonl'
# # export dataset_path='data/817_multiobject_sink_rand_location_llava_train_filtered.jsonl'
# # export dataset_path='data/0812_play_data_llava_train_filtered.jsonl'
# # export dataset_path='data/0812_play_data_llava_concat_multi2sink_train_filtered.jsonl'
# export eval_dataset_path='data/0811_multiobject_sink_llava_seen_val_filtered.jsonl'
# export unseen_eval_dataset_path='data/0811_multiobject_sink_llava_unseen_val_filtered.jsonl'
# # export eval_dataset_path="data/0731_multi_200_5hz_val.jsonl"

# export output_dir=$AMLT_OUTPUT_DIR
# export WANDB_API_KEY="0d0155a751a804873eedf37c29060146b377edb8"

# export project_id='lwm'
# export experiment_note='world-model'
# # export experiment_id='817_latent_c8_s9_epoch3_multiobject_sink_rand_location_llava_train_filtered_4000'
# # export experiment_id='821_latent_bridge10_c8_s4_l4_40k_multiobject_sink_rand_location_llava_train_filtered_10k'

# export experiment_id='0811_latent_sthv2_10percent_c8_s4_l8_100k_multiobject_sink_llava_train_filtered_14k_re'

# # mesh_dim: dp, fsdp, tp, sp
# python3 -u -m lwm.train \
#     --modality='vision,action,delta' \
#     --mesh_dim='!-1,8,1,1' \
#     --dtype='bf16' \
#     --total_steps=1005 \
#     --log_freq=1 \
#     --delta_tokens=1 \
#     --eval_steps=5 \
#     --save_model_freq=0 \
#     --eval_log_freq=100 \
#     --save_milestone_freq=1000 \
#     --load_llama_config='7b' \
#     --load_checkpoint="params::/mnt/default/projects/lwm/amlt-results/7275056058.06230-948e80e8-b4b8-4fe8-b547-2237a94c1d8f/sthv2_bf16_from_whole_delta8_batch128_seq4_layer8_100K_re2/streaming_params_14454" \
#     --update_llama_config="dict(action_vocab_size=256,delta_vocab_size=8,theta=50000000,max_sequence_length=2048,use_flash_attention=True,scan_attention=True,scan_query_chunk_size=512,scan_key_chunk_size=1024,remat_attention='nothing_saveable',scan_mlp=True,scan_mlp_chunk_size=8192,remat_mlp='nothing_saveable',remat_block='nothing_saveable',scan_layers=True)" \
#     --tokenizer.vocab_file="$llama_tokenizer_path" \
#     --optimizer.type='adamw' \
#     --llama.action_vocab_size=256 \
#     --llama.delta_vocab_size=8 \
#     --optimizer.accumulate_gradient_steps=1 \
#     --optimizer.adamw_optimizer.weight_decay=0 \
#     --optimizer.adamw_optimizer.lr=2e-5 \
#     --optimizer.adamw_optimizer.end_lr=2e-5 \
#     --optimizer.adamw_optimizer.lr_warmup_steps=0 \
#     --optimizer.adamw_optimizer.lr_decay_steps=3000 \
#     --use_data_sharded_loader=True \
#     --train_dataset.type='json_vision_delta_action' \
#     --train_dataset.delta_vision_action_processor.fields_from_example='fields' \
#     --train_dataset.delta_vision_action_processor.n_tokens_per_action=7 \
#     --train_dataset.delta_vision_action_processor.n_tokens_per_delta=4 \
#     --train_dataset.delta_vision_action_processor.img_aug=True \
#     --train_dataset.delta_vision_action_processor.vqgan_checkpoint_path='/mnt/default/lwm/data/checkpoints/lwm_checkpoints/vqgan'\
#     --train_dataset.delta_vision_action_processor.image_absolute_path='/mnt/default/lwm/data/finetune_data/'\
#     --train_dataset.delta_vision_action_processor.max_n_frames=1 \
#     --train_dataset.json_delta_action_dataset.mode="pad" \
#     --train_dataset.json_delta_action_dataset.path="$dataset_path" \
#     --train_dataset.json_delta_action_dataset.seq_length=384 \
#     --train_dataset.json_delta_action_dataset.batch_size=128 \
#     --train_dataset.json_delta_action_dataset.tokenizer_processes=1 \
#     --train_dataset.json_delta_action_dataset.tokenizer_parallel_chunk_size=128 \
#     --train_dataset.json_delta_action_dataset.tokenizer_parallel_batch_size=128 \
#     --train_dataset.json_delta_action_dataset.use_data_sharded_loader=True \
#     --eval_dataset.type='json_vision_delta_action' \
#     --eval_dataset.delta_vision_action_processor.fields_from_example='fields' \
#     --eval_dataset.delta_vision_action_processor.n_tokens_per_action=7 \
#     --eval_dataset.delta_vision_action_processor.n_tokens_per_delta=4 \
#     --eval_dataset.delta_vision_action_processor.img_aug=True \
#     --eval_dataset.delta_vision_action_processor.vqgan_checkpoint_path='/mnt/default/lwm/data/checkpoints/lwm_checkpoints/vqgan'\
#     --eval_dataset.delta_vision_action_processor.image_absolute_path='/mnt/default/lwm/data/finetune_data/'\
#     --eval_dataset.vision_action_processor.max_n_frames=1 \
#     --eval_dataset.json_delta_action_dataset.mode="pad" \
#     --eval_dataset.json_delta_action_dataset.path="$eval_dataset_path" \
#     --eval_dataset.json_delta_action_dataset.seq_length=384 \
#     --eval_dataset.json_delta_action_dataset.batch_size=128 \
#     --eval_dataset.json_delta_action_dataset.tokenizer_processes=1 \
#     --eval_dataset.json_delta_action_dataset.tokenizer_parallel_chunk_size=128 \
#     --eval_dataset.json_delta_action_dataset.tokenizer_parallel_batch_size=128 \
#     --eval_dataset.json_delta_action_dataset.use_data_sharded_loader=True \
#     --unseen_eval_dataset.type='json_vision_delta_action' \
#     --unseen_eval_dataset.delta_vision_action_processor.fields_from_example='fields' \
#     --unseen_eval_dataset.delta_vision_action_processor.n_tokens_per_action=7 \
#     --unseen_eval_dataset.delta_vision_action_processor.n_tokens_per_delta=4 \
#     --unseen_eval_dataset.delta_vision_action_processor.img_aug=True \
#     --unseen_eval_dataset.delta_vision_action_processor.vqgan_checkpoint_path='/mnt/default/lwm/data/checkpoints/lwm_checkpoints/vqgan'\
#     --unseen_eval_dataset.delta_vision_action_processor.image_absolute_path='/mnt/default/lwm/data/finetune_data/'\
#     --unseen_eval_dataset.vision_action_processor.max_n_frames=1 \
#     --unseen_eval_dataset.json_delta_action_dataset.mode="pad" \
#     --unseen_eval_dataset.json_delta_action_dataset.path="$unseen_eval_dataset_path" \
#     --unseen_eval_dataset.json_delta_action_dataset.seq_length=384 \
#     --unseen_eval_dataset.json_delta_action_dataset.batch_size=128 \
#     --unseen_eval_dataset.json_delta_action_dataset.tokenizer_processes=1 \
#     --unseen_eval_dataset.json_delta_action_dataset.tokenizer_parallel_chunk_size=128 \
#     --unseen_eval_dataset.json_delta_action_dataset.tokenizer_parallel_batch_size=128 \
#     --unseen_eval_dataset.json_delta_action_dataset.use_data_sharded_loader=True \
#     --checkpointer.save_optimizer_state=False \
#     --autoresume=False \
#     --logger.append_uuid=False \
#     --logger.online=True \
#     --logger.project_id="$project_id" \
#     --logger.experiment_id="$experiment_id" \
#     --logger.experiment_note="$experiment_note" \
#     --logger.output_dir="$output_dir" \
#     --logger.wandb_dir="$HOME/experiment_output/$project_id"




# # export SCRIPT_DIR="$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )"
# # export PROJECT_DIR="$( cd -- "$( dirname -- "$SCRIPT_DIR" )" &> /dev/null && pwd )"
# # cd $PROJECT_DIR
# # export PYTHONPATH="$PYTHONPATH:$PROJECT_DIR"
# # export LIBTPU_INIT_ARGS="--xla_tpu_megacore_fusion_allow_ags=false --xla_enable_async_collective_permute=true --xla_tpu_enable_ag_backward_pipelining=true --xla_tpu_enable_data_parallel_all_reduce_opt=true --xla_tpu_data_parallel_opt_different_sized_ops=true --xla_tpu_enable_async_collective_fusion=true --xla_tpu_enable_async_collective_fusion_multiple_steps=true --xla_tpu_overlap_compute_collective_tc=true --xla_enable_async_all_gather=true"

# # export llama_tokenizer_path="/mnt/default/lwm/data/checkpoints/lwm_checkpoints/tokenizer.model"

# # # export dataset_path='data/0808_multiobject_sink_llava_5hz_action_preprocessed.jsonl'
# # # export dataset_path='data/0809_multiobject_sink_llava_256.jsonl'
# # export dataset_path='data/0811_multiobject_sink_llava_train_filtered.jsonl'
# # # export dataset_path='data/817_multiobject_sink_rand_location_llava_train_filtered.jsonl'
# # # export dataset_path='data/0812_play_data_llava_train_filtered.jsonl'
# # # export dataset_path='data/0812_play_data_llava_concat_multi2sink_train_filtered.jsonl'
# # export eval_dataset_path='data/0811_multiobject_sink_llava_seen_val_filtered.jsonl'
# # export unseen_eval_dataset_path='data/0811_multiobject_sink_llava_unseen_val_filtered.jsonl'
# # # export eval_dataset_path="data/0731_multi_200_5hz_val.jsonl"

# # export output_dir=$AMLT_OUTPUT_DIR
# # export WANDB_API_KEY="0d0155a751a804873eedf37c29060146b377edb8"

# # export project_id='lwm'
# # export experiment_note='world-model'
# # # export experiment_id='817_latent_c8_s9_epoch3_multiobject_sink_rand_location_llava_train_filtered_4000'
# # # export experiment_id='821_latent_bridge10_c8_s4_l4_40k_multiobject_sink_rand_location_llava_train_filtered_10k'

# # export experiment_id='0811_latent_human_c8_s4_l4_40k_multiobject_sink_llava_train_filtered_10k'

# # # mesh_dim: dp, fsdp, tp, sp
# # python3 -u -m lwm.train \
# #     --modality='vision,action,delta' \
# #     --mesh_dim='!-1,8,1,1' \
# #     --dtype='bf16' \
# #     --total_steps=1005 \
# #     --log_freq=1 \
# #     --delta_tokens=1 \
# #     --eval_steps=5 \
# #     --save_model_freq=0 \
# #     --eval_log_freq=100 \
# #     --save_milestone_freq=1000 \
# #     --load_llama_config='7b' \
# #     --load_checkpoint="params::/mnt/default/projects/lwm/amlt-results/7275566570.87539-24bb832a-3031-404b-8339-5fe92f838119/lwm_human_bf16_from_whole_delta8_batch32_seq4_layer4_filtered_no_inst_imgaug_821_40Ktrain/streaming_params_10000"\
# #     --update_llama_config="dict(action_vocab_size=256,delta_vocab_size=8,theta=50000000,max_sequence_length=2048,use_flash_attention=True,scan_attention=True,scan_query_chunk_size=512,scan_key_chunk_size=1024,remat_attention='nothing_saveable',scan_mlp=True,scan_mlp_chunk_size=8192,remat_mlp='nothing_saveable',remat_block='nothing_saveable',scan_layers=True)" \
# #     --tokenizer.vocab_file="$llama_tokenizer_path" \
# #     --optimizer.type='adamw' \
# #     --llama.action_vocab_size=256 \
# #     --llama.delta_vocab_size=8 \
# #     --optimizer.accumulate_gradient_steps=1 \
# #     --optimizer.adamw_optimizer.weight_decay=0 \
# #     --optimizer.adamw_optimizer.lr=2e-5 \
# #     --optimizer.adamw_optimizer.end_lr=2e-5 \
# #     --optimizer.adamw_optimizer.lr_warmup_steps=0 \
# #     --optimizer.adamw_optimizer.lr_decay_steps=3000 \
# #     --use_data_sharded_loader=True \
# #     --train_dataset.type='json_vision_delta_action' \
# #     --train_dataset.delta_vision_action_processor.fields_from_example='fields' \
# #     --train_dataset.delta_vision_action_processor.n_tokens_per_action=7 \
# #     --train_dataset.delta_vision_action_processor.n_tokens_per_delta=4 \
# #     --train_dataset.delta_vision_action_processor.img_aug=True \
# #     --train_dataset.delta_vision_action_processor.vqgan_checkpoint_path='/mnt/default/lwm/data/checkpoints/lwm_checkpoints/vqgan'\
# #     --train_dataset.delta_vision_action_processor.image_absolute_path='/mnt/default/lwm/data/finetune_data/'\
# #     --train_dataset.delta_vision_action_processor.max_n_frames=1 \
# #     --train_dataset.json_delta_action_dataset.mode="pad" \
# #     --train_dataset.json_delta_action_dataset.path="$dataset_path" \
# #     --train_dataset.json_delta_action_dataset.seq_length=384 \
# #     --train_dataset.json_delta_action_dataset.batch_size=128 \
# #     --train_dataset.json_delta_action_dataset.tokenizer_processes=1 \
# #     --train_dataset.json_delta_action_dataset.tokenizer_parallel_chunk_size=128 \
# #     --train_dataset.json_delta_action_dataset.tokenizer_parallel_batch_size=128 \
# #     --train_dataset.json_delta_action_dataset.use_data_sharded_loader=True \
# #     --eval_dataset.type='json_vision_delta_action' \
# #     --eval_dataset.delta_vision_action_processor.fields_from_example='fields' \
# #     --eval_dataset.delta_vision_action_processor.n_tokens_per_action=7 \
# #     --eval_dataset.delta_vision_action_processor.n_tokens_per_delta=4 \
# #     --eval_dataset.delta_vision_action_processor.img_aug=True \
# #     --eval_dataset.delta_vision_action_processor.vqgan_checkpoint_path='/mnt/default/lwm/data/checkpoints/lwm_checkpoints/vqgan'\
# #     --eval_dataset.delta_vision_action_processor.image_absolute_path='/mnt/default/lwm/data/finetune_data/'\
# #     --eval_dataset.vision_action_processor.max_n_frames=1 \
# #     --eval_dataset.json_delta_action_dataset.mode="pad" \
# #     --eval_dataset.json_delta_action_dataset.path="$eval_dataset_path" \
# #     --eval_dataset.json_delta_action_dataset.seq_length=384 \
# #     --eval_dataset.json_delta_action_dataset.batch_size=128 \
# #     --eval_dataset.json_delta_action_dataset.tokenizer_processes=1 \
# #     --eval_dataset.json_delta_action_dataset.tokenizer_parallel_chunk_size=128 \
# #     --eval_dataset.json_delta_action_dataset.tokenizer_parallel_batch_size=128 \
# #     --eval_dataset.json_delta_action_dataset.use_data_sharded_loader=True \
# #     --unseen_eval_dataset.type='json_vision_delta_action' \
# #     --unseen_eval_dataset.delta_vision_action_processor.fields_from_example='fields' \
# #     --unseen_eval_dataset.delta_vision_action_processor.n_tokens_per_action=7 \
# #     --unseen_eval_dataset.delta_vision_action_processor.n_tokens_per_delta=4 \
# #     --unseen_eval_dataset.delta_vision_action_processor.img_aug=True \
# #     --unseen_eval_dataset.delta_vision_action_processor.vqgan_checkpoint_path='/mnt/default/lwm/data/checkpoints/lwm_checkpoints/vqgan'\
# #     --unseen_eval_dataset.delta_vision_action_processor.image_absolute_path='/mnt/default/lwm/data/finetune_data/'\
# #     --unseen_eval_dataset.vision_action_processor.max_n_frames=1 \
# #     --unseen_eval_dataset.json_delta_action_dataset.mode="pad" \
# #     --unseen_eval_dataset.json_delta_action_dataset.path="$unseen_eval_dataset_path" \
# #     --unseen_eval_dataset.json_delta_action_dataset.seq_length=384 \
# #     --unseen_eval_dataset.json_delta_action_dataset.batch_size=128 \
# #     --unseen_eval_dataset.json_delta_action_dataset.tokenizer_processes=1 \
# #     --unseen_eval_dataset.json_delta_action_dataset.tokenizer_parallel_chunk_size=128 \
# #     --unseen_eval_dataset.json_delta_action_dataset.tokenizer_parallel_batch_size=128 \
# #     --unseen_eval_dataset.json_delta_action_dataset.use_data_sharded_loader=True \
# #     --checkpointer.save_optimizer_state=False \
# #     --autoresume=False \
# #     --logger.append_uuid=False \
# #     --logger.online=True \
# #     --logger.project_id="$project_id" \
# #     --logger.experiment_id="$experiment_id" \
# #     --logger.experiment_note="$experiment_note" \
# #     --logger.output_dir="$output_dir" \
# #     --logger.wandb_dir="$HOME/experiment_output/$project_id"




export SCRIPT_DIR="$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )"
export PROJECT_DIR="$( cd -- "$( dirname -- "$SCRIPT_DIR" )" &> /dev/null && pwd )"
cd $PROJECT_DIR
export PYTHONPATH="$PYTHONPATH:$PROJECT_DIR"
export LIBTPU_INIT_ARGS="--xla_tpu_megacore_fusion_allow_ags=false --xla_enable_async_collective_permute=true --xla_tpu_enable_ag_backward_pipelining=true --xla_tpu_enable_data_parallel_all_reduce_opt=true --xla_tpu_data_parallel_opt_different_sized_ops=true --xla_tpu_enable_async_collective_fusion=true --xla_tpu_enable_async_collective_fusion_multiple_steps=true --xla_tpu_overlap_compute_collective_tc=true --xla_enable_async_all_gather=true"

export llama_tokenizer_path="/mnt/default/lwm/data/checkpoints/lwm_checkpoints/tokenizer.model"

export dataset_path='data/0811_multiobject_sink_llava_train_filtered.jsonl'
export eval_dataset_path='data/0811_multiobject_sink_llava_seen_val_filtered.jsonl'
export unseen_eval_dataset_path='data/0811_multiobject_sink_llava_unseen_val_filtered.jsonl'
# export eval_dataset_path="data/0731_multi_200_5hz_val.jsonl"

export output_dir='/mnt/default/lwm/data/checkpoints/real_finetune'
export WANDB_API_KEY="0d0155a751a804873eedf37c29060146b377edb8"

export project_id='lwm'
export experiment_note='world-model'
export experiment_id='0811_latent_sthv2_100percent_c8_s9_l8_w10_100k_multiobject_sink_llava_train_filtered_22k'

# mesh_dim: dp, fsdp, tp, sp
python3 -u -m lwm.train \
    --modality='vision,action,delta' \
    --mesh_dim='!-1,8,1,1' \
    --dtype='bf16' \
    --total_steps=1005 \
    --log_freq=1 \
    --delta_tokens=1 \
    --eval_steps=5 \
    --save_model_freq=0 \
    --eval_log_freq=100 \
    --save_milestone_freq=1000 \
    --load_llama_config='7b' \
    --load_checkpoint="params::/mnt/default/lwm/data/checkpoints/debug/sthv2_bf16_from_whole_delta8_batch128_seq9_layer8_w10_100K_100percent_re2/streaming_params_22485" \
    --update_llama_config="dict(action_vocab_size=256,delta_vocab_size=8,theta=50000000,max_sequence_length=2048,use_flash_attention=True,scan_attention=True,scan_query_chunk_size=512,scan_key_chunk_size=1024,remat_attention='nothing_saveable',scan_mlp=True,scan_mlp_chunk_size=8192,remat_mlp='nothing_saveable',remat_block='nothing_saveable',scan_layers=True)" \
    --tokenizer.vocab_file="$llama_tokenizer_path" \
    --optimizer.type='adamw' \
    --llama.action_vocab_size=256 \
    --llama.delta_vocab_size=8 \
    --optimizer.accumulate_gradient_steps=1 \
    --optimizer.adamw_optimizer.weight_decay=0 \
    --optimizer.adamw_optimizer.lr=2e-5 \
    --optimizer.adamw_optimizer.end_lr=2e-5 \
    --optimizer.adamw_optimizer.lr_warmup_steps=0 \
    --optimizer.adamw_optimizer.lr_decay_steps=3000 \
    --use_data_sharded_loader=True \
    --train_dataset.type='json_vision_delta_action' \
    --train_dataset.delta_vision_action_processor.fields_from_example='fields' \
    --train_dataset.delta_vision_action_processor.n_tokens_per_action=7 \
    --train_dataset.delta_vision_action_processor.n_tokens_per_delta=9 \
    --train_dataset.delta_vision_action_processor.img_aug=True \
    --train_dataset.delta_vision_action_processor.vqgan_checkpoint_path='/mnt/default/lwm/data/checkpoints/lwm_checkpoints/vqgan'\
    --train_dataset.delta_vision_action_processor.image_absolute_path='/mnt/default/lwm/data/finetune_data/'\
    --train_dataset.delta_vision_action_processor.max_n_frames=1 \
    --train_dataset.json_delta_action_dataset.mode="pad" \
    --train_dataset.json_delta_action_dataset.path="$dataset_path" \
    --train_dataset.json_delta_action_dataset.seq_length=384 \
    --train_dataset.json_delta_action_dataset.batch_size=128 \
    --train_dataset.json_delta_action_dataset.tokenizer_processes=1 \
    --train_dataset.json_delta_action_dataset.tokenizer_parallel_chunk_size=128 \
    --train_dataset.json_delta_action_dataset.tokenizer_parallel_batch_size=128 \
    --train_dataset.json_delta_action_dataset.use_data_sharded_loader=True \
    --eval_dataset.type='json_vision_delta_action' \
    --eval_dataset.delta_vision_action_processor.fields_from_example='fields' \
    --eval_dataset.delta_vision_action_processor.n_tokens_per_action=7 \
    --eval_dataset.delta_vision_action_processor.n_tokens_per_delta=9 \
    --eval_dataset.delta_vision_action_processor.img_aug=True \
    --eval_dataset.delta_vision_action_processor.vqgan_checkpoint_path='/mnt/default/lwm/data/checkpoints/lwm_checkpoints/vqgan'\
    --eval_dataset.delta_vision_action_processor.image_absolute_path='/mnt/default/lwm/data/finetune_data/'\
    --eval_dataset.vision_action_processor.max_n_frames=1 \
    --eval_dataset.json_delta_action_dataset.mode="pad" \
    --eval_dataset.json_delta_action_dataset.path="$eval_dataset_path" \
    --eval_dataset.json_delta_action_dataset.seq_length=384 \
    --eval_dataset.json_delta_action_dataset.batch_size=128 \
    --eval_dataset.json_delta_action_dataset.tokenizer_processes=1 \
    --eval_dataset.json_delta_action_dataset.tokenizer_parallel_chunk_size=128 \
    --eval_dataset.json_delta_action_dataset.tokenizer_parallel_batch_size=128 \
    --eval_dataset.json_delta_action_dataset.use_data_sharded_loader=True \
    --unseen_eval_dataset.type='json_vision_delta_action' \
    --unseen_eval_dataset.delta_vision_action_processor.fields_from_example='fields' \
    --unseen_eval_dataset.delta_vision_action_processor.n_tokens_per_action=7 \
    --unseen_eval_dataset.delta_vision_action_processor.n_tokens_per_delta=9 \
    --unseen_eval_dataset.delta_vision_action_processor.img_aug=True \
    --unseen_eval_dataset.delta_vision_action_processor.vqgan_checkpoint_path='/mnt/default/lwm/data/checkpoints/lwm_checkpoints/vqgan'\
    --unseen_eval_dataset.delta_vision_action_processor.image_absolute_path='/mnt/default/lwm/data/finetune_data/'\
    --unseen_eval_dataset.vision_action_processor.max_n_frames=1 \
    --unseen_eval_dataset.json_delta_action_dataset.mode="pad" \
    --unseen_eval_dataset.json_delta_action_dataset.path="$unseen_eval_dataset_path" \
    --unseen_eval_dataset.json_delta_action_dataset.seq_length=384 \
    --unseen_eval_dataset.json_delta_action_dataset.batch_size=128 \
    --unseen_eval_dataset.json_delta_action_dataset.tokenizer_processes=1 \
    --unseen_eval_dataset.json_delta_action_dataset.tokenizer_parallel_chunk_size=128 \
    --unseen_eval_dataset.json_delta_action_dataset.tokenizer_parallel_batch_size=128 \
    --unseen_eval_dataset.json_delta_action_dataset.use_data_sharded_loader=True \
    --checkpointer.save_optimizer_state=False \
    --autoresume=False \
    --logger.append_uuid=False \
    --logger.online=True \
    --logger.project_id="$project_id" \
    --logger.experiment_id="$experiment_id" \
    --logger.experiment_note="$experiment_note" \
    --logger.output_dir="$output_dir" \
    --logger.wandb_dir="$HOME/experiment_output/$project_id"




