# # train w/ pretrained ckpt
# python pretrain_umd/train_retrieval_w_anticausal.py \
#         --model_checkpoint /fs/XXXX-37/llm-pretraining/llm-retrieval/checkpoints/TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T \
#         --tokenizer_path /fs/XXXX-37/llm-pretraining/llm-retrieval/checkpoints/TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T \
#         --out_dir /fs/XXXX-37/llm-pretraining/llm-retrieval/out/cosmopedia-retrieval-dual-causal-llama-1.1b-bsz-1-ctx-2048-batch_negative_default_hparam_hfdata \
#         --seed 1337 \
#         --model_name tiny-llama-1.1b \
#         --run_name cosmopedia-retrieval-dual-causal-llama-1.1b-bsz-1-ctx-2048-batch_negative_default_hparam_hfdata  \
#         --logger_name wandb \
#         --compile_model False \
#         --fabric_precision bf16-mixed \
#         --world_batch_size 2 \
#         --micro_batch_size 1 \
#         --block_size 2048 \
#         --n_chunks 4 \
#         --max_tokens 1_000_000_000_000 \
#         --warmup_steps 2000 \
#         --log_step_interval 1 \
#         --eval_iters 100 \
#         --save_and_eval_interval 2000 \
#         --grad_clip 1.0 \
#         --decay_lr True \
#         --min_lr 4e-5 \
#         --data_telemetry False \
#         --data_config /XXXX-36/XXXX-22/XXXX-40/launch_scripts/XXXX-22/sample_hfds_only.json

# # train w/ random init
# python pretrain_umd/train_retrieval_w_anticausal.py \
#         --tokenizer_path /fs/XXXX-37/llm-pretraining/llm-retrieval/checkpoints/TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T \
#         --out_dir /fs/XXXX-37/llm-pretraining/llm-retrieval/out/cosmopedia-retrieval-dual-causal-llama-1.1b-bsz-1-ctx-2048-batch_negative_default_hparam_hfdata_rand_init \
#         --seed 1337 \
#         --model_name tiny-llama-1.1b \
#         --run_name cosmopedia-retrieval-dual-causal-llama-1.1b-bsz-1-ctx-2048-batch_negative_default_hparam_hfdata_rand_init  \
#         --logger_name wandb \
#         --compile_model False \
#         --fabric_precision bf16-mixed \
#         --world_batch_size 2 \
#         --micro_batch_size 1 \
#         --block_size 2048 \
#         --n_chunks 4 \
#         --max_tokens 1_000_000_000_000 \
#         --warmup_steps 2000 \
#         --log_step_interval 1 \
#         --eval_iters 100 \
#         --save_and_eval_interval 2000 \
#         --grad_clip 1.0 \
#         --decay_lr True \
#         --min_lr 4e-5 \
#         --data_telemetry False \
#         --data_config /XXXX-36/XXXX-22/XXXX-40/launch_scripts/XXXX-22/sample_hfds_only.json

# # train w/ 4 micro_bsz of var length sequences
# python pretrain_umd/train_retrieval_w_anticausal.py \
#         --model_checkpoint /fs/XXXX-37/llm-pretraining/llm-retrieval/checkpoints/TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T \
#         --tokenizer_path /fs/XXXX-37/llm-pretraining/llm-retrieval/checkpoints/TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T \
#         --out_dir /fs/XXXX-37/llm-pretraining/llm-retrieval/out/cosmopedia-retrieval-dual-causal-llama-1.1b-bsz-8-ctx-rand-batch_negative_hfdata \
#         --seed 1337 \
#         --model_name tiny-llama-1.1b \
#         --run_name cosmopedia-retrieval-dual-causal-llama-1.1b-bsz-8-ctx-rand-batch_negative_hfdata  \
#         --logger_name wandb \
#         --compile_model False \
#         --fabric_precision bf16-mixed \
#         --learning_rate 0.0008 \
#         --beta1 0.6 \
#         --beta2 0.8 \
#         --adamw_eps 5e-9 \
#         --world_batch_size 8 \
#         --micro_batch_size 4 \
#         --block_size 2048 \
#         --n_chunks 4 \
#         --max_tokens 1_000_000_000_000 \
#         --warmup_steps 2000 \
#         --log_step_interval 1 \
#         --eval_iters 100 \
#         --save_and_eval_interval 2000 \
#         --grad_clip 1.0 \
#         --decay_lr True \
#         --min_lr 4e-5 \
#         --data_telemetry False \
#         --data_config /XXXX-36/XXXX-22/XXXX-40/launch_scripts/XXXX-22/sample_hfds_only.json

# # train w/ 4 micro_bsz of var length sequences
# python pretrain_umd/train_retrieval_w_anticausal.py \
#         --model_checkpoint /fs/XXXX-37/llm-pretraining/llm-retrieval/checkpoints/TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T \
#         --tokenizer_path /fs/XXXX-37/llm-pretraining/llm-retrieval/checkpoints/TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T \
#         --out_dir /fs/XXXX-37/llm-pretraining/llm-retrieval/out/cosmopedia-retrieval-dual-causal-llama-1.1b-bsz-8-ctx-rand-batch_negative_hfdata_linear_lr \
#         --seed 1337 \
#         --model_name tiny-llama-1.1b \
#         --run_name cosmopedia-retrieval-dual-causal-llama-1.1b-bsz-8-ctx-rand-batch_negative_hfdata_linear_lr  \
#         --logger_name wandb \
#         --compile_model False \
#         --fabric_precision bf16-mixed \
#         --learning_rate 0.0016 \
#         --world_batch_size 8 \
#         --micro_batch_size 4 \
#         --block_size 2048 \
#         --n_chunks 4 \
#         --max_tokens 1_000_000_000_000 \
#         --warmup_steps 2000 \
#         --log_step_interval 1 \
#         --eval_iters 100 \
#         --save_and_eval_interval 2000 \
#         --grad_clip 1.0 \
#         --decay_lr True \
#         --min_lr 4e-5 \
#         --data_telemetry False \
#         --data_config /XXXX-36/XXXX-22/XXXX-40/launch_scripts/XXXX-22/sample_hfds_only.json

# # train w/ 4 micro_bsz of var length sequences
# python pretrain_umd/train_retrieval_w_anticausal.py \
#         --model_checkpoint /fs/XXXX-37/llm-pretraining/llm-retrieval/checkpoints/TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T \
#         --tokenizer_path /fs/XXXX-37/llm-pretraining/llm-retrieval/checkpoints/TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T \
#         --out_dir /fs/XXXX-37/llm-pretraining/llm-retrieval/out/cosmopedia-retrieval-dual-causal-llama-1.1b-bsz-8-ctx-rand-batch_negative_hfdata_lr_0.0001_warmup_1 \
#         --seed 1337 \
#         --model_name tiny-llama-1.1b \
#         --run_name cosmopedia-retrieval-dual-causal-llama-1.1b-bsz-8-ctx-rand-batch_negative_hfdata_lr_0.0001_warmup_1  \
#         --logger_name wandb \
#         --compile_model False \
#         --fabric_precision bf16-mixed \
#         --learning_rate 0.0001 \
#         --warmup_steps 1 \
#         --world_batch_size 8 \
#         --micro_batch_size 4 \
#         --block_size 2048 \
#         --n_chunks 4 \
#         --max_tokens 1_000_000_000_000 \
#         --warmup_steps 2000 \
#         --log_step_interval 1 \
#         --eval_iters 100 \
#         --save_and_eval_interval 2000 \
#         --grad_clip 1.0 \
#         --decay_lr True \
#         --min_lr 4e-5 \
#         --data_telemetry False \
#         --data_config /XXXX-36/XXXX-22/XXXX-40/launch_scripts/XXXX-22/sample_hfds_only.json

# # train w/ 4 micro_bsz of var length sequences
# python pretrain_umd/train_retrieval_w_anticausal.py \
#         --model_checkpoint /fs/XXXX-37/llm-pretraining/llm-retrieval/checkpoints/TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T \
#         --tokenizer_path /fs/XXXX-37/llm-pretraining/llm-retrieval/checkpoints/TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T \
#         --out_dir /fs/XXXX-37/llm-pretraining/llm-retrieval/out/cosmopedia-retrieval-dual-causal-llama-1.1b-bsz-8-ctx-rand-batch_negative_hfdata_lr_0.0005_warmup_1 \
#         --seed 1337 \
#         --model_name tiny-llama-1.1b \
#         --run_name cosmopedia-retrieval-dual-causal-llama-1.1b-bsz-8-ctx-rand-batch_negative_hfdata_lr_0.0005_warmup_1  \
#         --logger_name wandb \
#         --compile_model False \
#         --fabric_precision bf16-mixed \
#         --learning_rate 0.0005 \
#         --warmup_steps 1 \
#         --world_batch_size 8 \
#         --micro_batch_size 4 \
#         --block_size 2048 \
#         --n_chunks 4 \
#         --max_tokens 1_000_000_000_000 \
#         --warmup_steps 2000 \
#         --log_step_interval 1 \
#         --eval_iters 100 \
#         --save_and_eval_interval 2000 \
#         --grad_clip 1.0 \
#         --decay_lr True \
#         --min_lr 4e-5 \
#         --data_telemetry False \
#         --data_config /XXXX-36/XXXX-22/XXXX-40/launch_scripts/XXXX-22/sample_hfds_only.json

# # train w/ 4 micro_bsz of var length sequences
# python pretrain_umd/train_retrieval_w_anticausal.py \
#         --model_checkpoint /fs/XXXX-37/llm-pretraining/llm-retrieval/checkpoints/TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T \
#         --tokenizer_path /fs/XXXX-37/llm-pretraining/llm-retrieval/checkpoints/TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T \
#         --out_dir /fs/XXXX-37/llm-pretraining/llm-retrieval/out/cosmopedia-retrieval-dual-causal-llama-1.1b-bsz-8-ctx-rand-batch_negative_hfdata_lr_0.001_warmup_1 \
#         --seed 1337 \
#         --model_name tiny-llama-1.1b \
#         --run_name cosmopedia-retrieval-dual-causal-llama-1.1b-bsz-8-ctx-rand-batch_negative_hfdata_lr_0.001_warmup_1  \
#         --logger_name wandb \
#         --compile_model False \
#         --fabric_precision bf16-mixed \
#         --learning_rate 0.001 \
#         --warmup_steps 1 \
#         --world_batch_size 8 \
#         --micro_batch_size 4 \
#         --block_size 2048 \
#         --n_chunks 4 \
#         --max_tokens 1_000_000_000_000 \
#         --warmup_steps 2000 \
#         --log_step_interval 1 \
#         --eval_iters 100 \
#         --save_and_eval_interval 2000 \
#         --grad_clip 1.0 \
#         --decay_lr True \
#         --min_lr 4e-5 \
#         --data_telemetry False \
#         --data_config /XXXX-36/XXXX-22/XXXX-40/launch_scripts/XXXX-22/sample_hfds_only.json

# # train w/ 4 micro_bsz of var length sequences
# python pretrain_umd/train_retrieval_w_anticausal.py \
#         --model_checkpoint /fs/XXXX-37/llm-pretraining/llm-retrieval/checkpoints/TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T \
#         --tokenizer_path /fs/XXXX-37/llm-pretraining/llm-retrieval/checkpoints/TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T \
#         --out_dir /fs/XXXX-37/llm-pretraining/llm-retrieval/out/cosmopedia-retrieval-dual-causal-llama-1.1b-bsz-8-ctx-rand-batch_negative_hfdata_lr_0.005_warmup_1 \
#         --seed 1337 \
#         --model_name tiny-llama-1.1b \
#         --run_name cosmopedia-retrieval-dual-causal-llama-1.1b-bsz-8-ctx-rand-batch_negative_hfdata_lr_0.005_warmup_1  \
#         --logger_name wandb \
#         --compile_model False \
#         --fabric_precision bf16-mixed \
#         --learning_rate 0.005 \
#         --warmup_steps 1 \
#         --world_batch_size 8 \
#         --micro_batch_size 4 \
#         --block_size 2048 \
#         --n_chunks 4 \
#         --max_tokens 1_000_000_000_000 \
#         --warmup_steps 2000 \
#         --log_step_interval 1 \
#         --eval_iters 100 \
#         --save_and_eval_interval 2000 \
#         --grad_clip 1.0 \
#         --decay_lr True \
#         --min_lr 4e-5 \
#         --data_telemetry False \
#         --data_config /XXXX-36/XXXX-22/XXXX-40/launch_scripts/XXXX-22/sample_hfds_only.json

# train w/ var length sequences (900+-25)
# python pretrain_umd/train_retrieval_w_anticausal.py \
#         --model_checkpoint /fs/XXXX-37/llm-pretraining/llm-retrieval/checkpoints/TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T \
#         --tokenizer_path /fs/XXXX-37/llm-pretraining/llm-retrieval/checkpoints/TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T \
#         --out_dir /fs/XXXX-37/llm-pretraining/llm-retrieval/out/cosmopedia-retrieval-dual-causal-llama-1.1b-bsz-1-ctx-2048-batch_negative_default_hparam_hfdata_var_length_900_std_25 \
#         --seed 1337 \
#         --model_name tiny-llama-1.1b \
#         --run_name cosmopedia-retrieval-dual-causal-llama-1.1b-bsz-1-ctx-2048-batch_negative_default_hparam_hfdata_var_length_900_std_25  \
#         --logger_name wandb \
#         --compile_model False \
#         --fabric_precision bf16-mixed \
#         --world_batch_size 2 \
#         --micro_batch_size 1 \
#         --block_size 2048 \
#         --n_chunks 4 \
#         --max_tokens 1_000_000_000_000 \
#         --warmup_steps 2000 \
#         --log_step_interval 1 \
#         --eval_iters 100 \
#         --save_and_eval_interval 2000 \
#         --grad_clip 1.0 \
#         --decay_lr True \
#         --min_lr 4e-5 \
#         --data_telemetry False \
#         --data_config /XXXX-36/XXXX-22/XXXX-40/launch_scripts/XXXX-22/var_length_hfds.json

# # train w/ fixed length sequence (800)
# python pretrain_umd/train_retrieval_w_anticausal.py \
#         --model_checkpoint /fs/XXXX-37/llm-pretraining/llm-retrieval/checkpoints/TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T \
#         --tokenizer_path /fs/XXXX-37/llm-pretraining/llm-retrieval/checkpoints/TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T \
#         --out_dir /fs/XXXX-37/llm-pretraining/llm-retrieval/out/cosmopedia-retrieval-dual-causal-llama-1.1b-bsz-1-ctx-2048-batch_negative_default_hparam_hfdata_fixed_len_800 \
#         --seed 1337 \
#         --model_name tiny-llama-1.1b \
#         --run_name cosmopedia-retrieval-dual-causal-llama-1.1b-bsz-1-ctx-2048-batch_negative_default_hparam_hfdata_fixed_len_800  \
#         --logger_name wandb \
#         --compile_model False \
#         --fabric_precision bf16-mixed \
#         --world_batch_size 2 \
#         --micro_batch_size 1 \
#         --block_size 2048 \
#         --n_chunks 4 \
#         --max_tokens 1_000_000_000_000 \
#         --warmup_steps 2000 \
#         --log_step_interval 1 \
#         --eval_iters 100 \
#         --save_and_eval_interval 2000 \
#         --grad_clip 1.0 \
#         --decay_lr True \
#         --min_lr 4e-5 \
#         --data_telemetry False \
#         --data_config /XXXX-36/XXXX-22/XXXX-40/launch_scripts/XXXX-22/var_length_hfds.json \
#         --fixed_length True \
#         --block_size 800 \
#         --ignore_block_size_mismatch True

# # train w/ 4 micro_bsz of var length sequences
# python pretrain_umd/train_retrieval_w_anticausal.py \
#         --model_checkpoint /fs/XXXX-37/llm-pretraining/llm-retrieval/checkpoints/TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T \
#         --tokenizer_path /fs/XXXX-37/llm-pretraining/llm-retrieval/checkpoints/TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T \
#         --out_dir /fs/XXXX-37/llm-pretraining/llm-retrieval/out/cosmopedia-retrieval-dual-causal-llama-1.1b-bsz-8-ctx-rand-batch_negative_hfdata_lr_0.0002 \
#         --seed 1337 \
#         --model_name tiny-llama-1.1b \
#         --run_name cosmopedia-retrieval-dual-causal-llama-1.1b-bsz-8-ctx-rand-batch_negative_hfdata_lr_0.0002  \
#         --logger_name wandb \
#         --compile_model False \
#         --fabric_precision bf16-mixed \
#         --learning_rate 0.0002 \
#         --world_batch_size 8 \
#         --micro_batch_size 4 \
#         --block_size 2048 \
#         --n_chunks 4 \
#         --max_tokens 1_000_000_000_000 \
#         --warmup_steps 2000 \
#         --log_step_interval 1 \
#         --eval_iters 100 \
#         --save_and_eval_interval 2000 \
#         --grad_clip 1.0 \
#         --decay_lr True \
#         --min_lr 4e-5 \
#         --data_telemetry False \
#         --data_config /XXXX-36/XXXX-22/XXXX-40/launch_scripts/XXXX-22/sample_hfds_only.json

# python pretrain_umd/train_retrieval_w_anticausal.py \
#         --model_checkpoint /fs/XXXX-37/llm-pretraining/llm-retrieval/checkpoints/TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T \
#         --tokenizer_path /fs/XXXX-37/llm-pretraining/llm-retrieval/checkpoints/TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T \
#         --out_dir /fs/XXXX-37/llm-pretraining/llm-retrieval/out/cosmopedia-retrieval-dual-causal-llama-1.1b-bsz-8-ctx-rand-batch_negative_hfdata_lr_0.0003 \
#         --seed 1337 \
#         --model_name tiny-llama-1.1b \
#         --run_name cosmopedia-retrieval-dual-causal-llama-1.1b-bsz-8-ctx-rand-batch_negative_hfdata_lr_0.0003  \
#         --logger_name wandb \
#         --compile_model False \
#         --fabric_precision bf16-mixed \
#         --learning_rate 0.0003 \
#         --world_batch_size 8 \
#         --micro_batch_size 4 \
#         --block_size 2048 \
#         --n_chunks 4 \
#         --max_tokens 1_000_000_000_000 \
#         --warmup_steps 2000 \
#         --log_step_interval 1 \
#         --eval_iters 100 \
#         --save_and_eval_interval 2000 \
#         --grad_clip 1.0 \
#         --decay_lr True \
#         --min_lr 4e-5 \
#         --data_telemetry False \
#         --data_config /XXXX-36/XXXX-22/XXXX-40/launch_scripts/XXXX-22/sample_hfds_only.json

# python pretrain_umd/train_retrieval_w_anticausal.py \
#         --model_checkpoint /fs/XXXX-37/llm-pretraining/llm-retrieval/checkpoints/TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T \
#         --tokenizer_path /fs/XXXX-37/llm-pretraining/llm-retrieval/checkpoints/TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T \
#         --out_dir /fs/XXXX-37/llm-pretraining/llm-retrieval/out/cosmopedia-retrieval-dual-causal-llama-1.1b-bsz-8-ctx-rand-batch_negative_hfdata_sqrt_lr_0.0008 \
#         --seed 1337 \
#         --model_name tiny-llama-1.1b \
#         --run_name cosmopedia-retrieval-dual-causal-llama-1.1b-bsz-8-ctx-rand-batch_negative_hfdata_sqrt_lr_0.0008 \
#         --logger_name wandb \
#         --compile_model False \
#         --fabric_precision bf16-mixed \
#         --learning_rate 0.0008 \
#         --beta1 0.8 \
#         --beta2 0.9 \
#         --adamw_eps 7.07e-7 \
#         --world_batch_size 8 \
#         --micro_batch_size 4 \
#         --block_size 2048 \
#         --n_chunks 4 \
#         --max_tokens 1_000_000_000_000 \
#         --warmup_steps 2000 \
#         --log_step_interval 1 \
#         --eval_iters 100 \
#         --save_and_eval_interval 2000 \
#         --grad_clip 1.0 \
#         --decay_lr True \
#         --min_lr 4e-5 \
#         --data_telemetry False \
#         --data_config /XXXX-36/XXXX-22/XXXX-40/launch_scripts/XXXX-22/sample_hfds_only.json

#### DDP ####
# python pretrain_umd/train_retrieval_w_anticausal.py \
#         --model_checkpoint /fs/XXXX-37/llm-pretraining/llm-retrieval/checkpoints/TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T \
#         --tokenizer_path /fs/XXXX-37/llm-pretraining/llm-retrieval/checkpoints/TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T \
#         --out_dir /fs/XXXX-37/llm-pretraining/llm-retrieval/out/cosmopedia-retrieval-dual-causal-llama-1.1b-bsz-8-ctx-rand-batch_negative_ddp_PP \
#         --seed 1337 \
#         --max_tokens 25_000_000_000 \
#         --model_name tiny-llama-1.1b \
#         --run_name cosmopedia-retrieval-dual-causal-llama-1.1b-bsz-8-ctx-rand-batch_negative_ddp_PP \
#         --logger_name wandb \
#         --compile_model False \
#         --fabric_precision bf16-mixed \
#         --world_batch_size 2 \
#         --micro_batch_size 1 \
#         --block_size 2048 \
#         --n_chunks 4 \
#         --warmup_steps 2000 \
#         --log_step_interval 1 \
#         --eval_iters 100 \
#         --save_and_eval_interval 2000 \
#         --grad_clip 1.0 \
#         --decay_lr True \
#         --min_lr 4e-5 \
#         --data_telemetry False \
#         --data_config /XXXX-36/XXXX-22/XXXX-40/launch_scripts/XXXX-22/sample_hfds_only.json \
#         --fabric_strategy ddp \
#         --pretrained_prefix_model True \
#         --pretrained_suffix_model True

# python pretrain_umd/train_retrieval_w_anticausal.py \
#         --model_checkpoint /fs/XXXX-37/llm-pretraining/llm-retrieval/checkpoints/TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T \
#         --tokenizer_path /fs/XXXX-37/llm-pretraining/llm-retrieval/checkpoints/TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T \
#         --out_dir /fs/XXXX-37/llm-pretraining/llm-retrieval/out/cosmopedia-retrieval-dual-causal-llama-1.1b-bsz-8-ctx-rand-batch_negative_ddp_PR \
#         --seed 1337 \
#         --max_tokens 25_000_000_000 \
#         --model_name tiny-llama-1.1b \
#         --run_name cosmopedia-retrieval-dual-causal-llama-1.1b-bsz-8-ctx-rand-batch_negative_ddp_PR \
#         --logger_name wandb \
#         --compile_model False \
#         --fabric_precision bf16-mixed \
#         --world_batch_size 2 \
#         --micro_batch_size 1 \
#         --block_size 2048 \
#         --n_chunks 4 \
#         --warmup_steps 2000 \
#         --log_step_interval 1 \
#         --eval_iters 100 \
#         --save_and_eval_interval 2000 \
#         --grad_clip 1.0 \
#         --decay_lr True \
#         --min_lr 4e-5 \
#         --data_telemetry False \
#         --data_config /XXXX-36/XXXX-22/XXXX-40/launch_scripts/XXXX-22/sample_hfds_only.json \
#         --fabric_strategy ddp \
#         --pretrained_prefix_model True \
#         --pretrained_suffix_model False

# python pretrain_umd/train_retrieval_w_anticausal.py \
#         --model_checkpoint /fs/XXXX-37/llm-pretraining/llm-retrieval/checkpoints/TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T \
#         --tokenizer_path /fs/XXXX-37/llm-pretraining/llm-retrieval/checkpoints/TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T \
#         --out_dir /fs/XXXX-37/llm-pretraining/llm-retrieval/out/cosmopedia-retrieval-dual-causal-llama-1.1b-bsz-8-ctx-rand-batch_negative_ddp_RP \
#         --seed 1337 \
#         --max_tokens 25_000_000_000 \
#         --model_name tiny-llama-1.1b \
#         --run_name cosmopedia-retrieval-dual-causal-llama-1.1b-bsz-8-ctx-rand-batch_negative_ddp_RP \
#         --logger_name wandb \
#         --compile_model False \
#         --fabric_precision bf16-mixed \
#         --world_batch_size 2 \
#         --micro_batch_size 1 \
#         --block_size 2048 \
#         --n_chunks 4 \
#         --warmup_steps 2000 \
#         --log_step_interval 1 \
#         --eval_iters 100 \
#         --save_and_eval_interval 2000 \
#         --grad_clip 1.0 \
#         --decay_lr True \
#         --min_lr 4e-5 \
#         --data_telemetry False \
#         --data_config /XXXX-36/XXXX-22/XXXX-40/launch_scripts/XXXX-22/sample_hfds_only.json \
#         --fabric_strategy ddp \
#         --pretrained_prefix_model False \
#         --pretrained_suffix_model True

# python pretrain_umd/train_retrieval_w_anticausal.py \
#         --model_checkpoint /fs/XXXX-37/llm-pretraining/llm-retrieval/checkpoints/TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T \
#         --tokenizer_path /fs/XXXX-37/llm-pretraining/llm-retrieval/checkpoints/TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T \
#         --out_dir /fs/XXXX-37/llm-pretraining/llm-retrieval/out/cosmopedia-retrieval-dual-causal-llama-1.1b-bsz-8-ctx-rand-batch_negative_ddp_RR \
#         --seed 1337 \
#         --max_tokens 25_000_000_000 \
#         --model_name tiny-llama-1.1b \
#         --run_name cosmopedia-retrieval-dual-causal-llama-1.1b-bsz-8-ctx-rand-batch_negative_ddp_RR \
#         --logger_name wandb \
#         --compile_model False \
#         --fabric_precision bf16-mixed \
#         --world_batch_size 2 \
#         --micro_batch_size 1 \
#         --block_size 2048 \
#         --n_chunks 4 \
#         --warmup_steps 2000 \
#         --log_step_interval 1 \
#         --eval_iters 100 \
#         --save_and_eval_interval 2000 \
#         --grad_clip 1.0 \
#         --decay_lr True \
#         --min_lr 4e-5 \
#         --data_telemetry False \
#         --data_config /XXXX-36/XXXX-22/XXXX-40/launch_scripts/XXXX-22/sample_hfds_only.json \
#         --fabric_strategy ddp \
#         --pretrained_prefix_model False \
#         --pretrained_suffix_model False

# #### FSDP ####
# python pretrain_umd/train_retrieval_w_anticausal.py \
#         --model_checkpoint /fs/XXXX-37/llm-pretraining/llm-retrieval/checkpoints/TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T \
#         --tokenizer_path /fs/XXXX-37/llm-pretraining/llm-retrieval/checkpoints/TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T \
#         --out_dir /fs/XXXX-37/llm-pretraining/llm-retrieval/out/cosmopedia-retrieval-dual-causal-llama-1.1b-bsz-8-ctx-rand-batch_negative_FSDP_PP \
#         --seed 1337 \
#         --max_tokens 25_000_000_000 \
#         --model_name tiny-llama-1.1b \
#         --run_name cosmopedia-retrieval-dual-causal-llama-1.1b-bsz-8-ctx-rand-batch_negative_FSDP_PP \
#         --logger_name wandb \
#         --compile_model False \
#         --fabric_precision bf16-mixed \
#         --world_batch_size 2 \
#         --micro_batch_size 1 \
#         --block_size 2048 \
#         --n_chunks 4 \
#         --warmup_steps 2000 \
#         --log_step_interval 1 \
#         --eval_iters 100 \
#         --save_and_eval_interval 2000 \
#         --grad_clip 1.0 \
#         --decay_lr True \
#         --min_lr 4e-5 \
#         --data_telemetry False \
#         --data_config /XXXX-36/XXXX-22/XXXX-40/launch_scripts/XXXX-22/sample_hfds_only.json \
#         --fabric_strategy fsdp \
#         --pretrained_prefix_model True \
#         --pretrained_suffix_model True

# python pretrain_umd/train_retrieval_w_anticausal.py \
#         --model_checkpoint /fs/XXXX-37/llm-pretraining/llm-retrieval/checkpoints/TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T \
#         --tokenizer_path /fs/XXXX-37/llm-pretraining/llm-retrieval/checkpoints/TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T \
#         --out_dir /fs/XXXX-37/llm-pretraining/llm-retrieval/out/cosmopedia-retrieval-dual-causal-llama-1.1b-bsz-8-ctx-rand-batch_negative_FSDP_PR \
#         --seed 1337 \
#         --max_tokens 25_000_000_000 \
#         --model_name tiny-llama-1.1b \
#         --run_name cosmopedia-retrieval-dual-causal-llama-1.1b-bsz-8-ctx-rand-batch_negative_FSDP_PR \
#         --logger_name wandb \
#         --compile_model False \
#         --fabric_precision bf16-mixed \
#         --world_batch_size 2 \
#         --micro_batch_size 1 \
#         --block_size 2048 \
#         --n_chunks 4 \
#         --warmup_steps 2000 \
#         --log_step_interval 1 \
#         --eval_iters 100 \
#         --save_and_eval_interval 2000 \
#         --grad_clip 1.0 \
#         --decay_lr True \
#         --min_lr 4e-5 \
#         --data_telemetry False \
#         --data_config /XXXX-36/XXXX-22/XXXX-40/launch_scripts/XXXX-22/sample_hfds_only.json \
#         --fabric_strategy fsdp \
#         --pretrained_prefix_model True \
#         --pretrained_suffix_model False

# python pretrain_umd/train_retrieval_w_anticausal.py \
#         --model_checkpoint /fs/XXXX-37/llm-pretraining/llm-retrieval/checkpoints/TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T \
#         --tokenizer_path /fs/XXXX-37/llm-pretraining/llm-retrieval/checkpoints/TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T \
#         --out_dir /fs/XXXX-37/llm-pretraining/llm-retrieval/out/cosmopedia-retrieval-dual-causal-llama-1.1b-bsz-8-ctx-rand-batch_negative_FSDP_RP \
#         --seed 1337 \
#         --max_tokens 25_000_000_000 \
#         --model_name tiny-llama-1.1b \
#         --run_name cosmopedia-retrieval-dual-causal-llama-1.1b-bsz-8-ctx-rand-batch_negative_FSDP_RP \
#         --logger_name wandb \
#         --compile_model False \
#         --fabric_precision bf16-mixed \
#         --world_batch_size 2 \
#         --micro_batch_size 1 \
#         --block_size 2048 \
#         --n_chunks 4 \
#         --warmup_steps 2000 \
#         --log_step_interval 1 \
#         --eval_iters 100 \
#         --save_and_eval_interval 2000 \
#         --grad_clip 1.0 \
#         --decay_lr True \
#         --min_lr 4e-5 \
#         --data_telemetry False \
#         --data_config /XXXX-36/XXXX-22/XXXX-40/launch_scripts/XXXX-22/sample_hfds_only.json \
#         --fabric_strategy fsdp \
#         --pretrained_prefix_model False \
#         --pretrained_suffix_model True

# python pretrain_umd/train_retrieval_w_anticausal.py \
#         --model_checkpoint /fs/XXXX-37/llm-pretraining/llm-retrieval/checkpoints/TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T \
#         --tokenizer_path /fs/XXXX-37/llm-pretraining/llm-retrieval/checkpoints/TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T \
#         --out_dir /fs/XXXX-37/llm-pretraining/llm-retrieval/out/cosmopedia-retrieval-dual-causal-llama-1.1b-bsz-8-ctx-rand-batch_negative_FSDP_RR \
#         --seed 1337 \
#         --max_tokens 25_000_000_000 \
#         --model_name tiny-llama-1.1b \
#         --run_name cosmopedia-retrieval-dual-causal-llama-1.1b-bsz-8-ctx-rand-batch_negative_FSDP_RR \
#         --logger_name wandb \
#         --compile_model False \
#         --fabric_precision bf16-mixed \
#         --world_batch_size 2 \
#         --micro_batch_size 1 \
#         --block_size 2048 \
#         --n_chunks 4 \
#         --warmup_steps 2000 \
#         --log_step_interval 1 \
#         --eval_iters 100 \
#         --save_and_eval_interval 2000 \
#         --grad_clip 1.0 \
#         --decay_lr True \
#         --min_lr 4e-5 \
#         --data_telemetry False \
#         --data_config /XXXX-36/XXXX-22/XXXX-40/launch_scripts/XXXX-22/sample_hfds_only.json \
#         --fabric_strategy fsdp \
#         --pretrained_prefix_model False \
#         --pretrained_suffix_model False

# #### Pythia LLM training ####
# python pretrain_umd/train.py \
#         --tokenizer_path /fs/XXXX-37/llm-pretraining/llm-retrieval/checkpoints/EleutherAI/pythia-160m \
#         --out_dir /fs/XXXX-37/llm-pretraining/llm-retrieval/out/pile-lm-pythia-160m-bsz-8_ddp_lr_6e-4_debug \
#         --seed 1337 \
#         --max_tokens 25_000_000_000 \
#         --model_name pythia-160m \
#         --run_name pile-lm-pythia-160m-bsz-8_ddp_lr_6e-4_debug \
#         --logger_name wandb \
#         --logger_project llm-retrieval \
#         --compile_model False \
#         --world_batch_size 16 \
#         --micro_batch_size 8 \
#         --block_size 2048 \
#         --n_chunks 4 \
#         --warmup_steps 4000 \
#         --log_step_interval 1 \
#         --eval_iters 100 \
#         --save_and_eval_interval 10000 \
#         --grad_clip 1.0 \
#         --decay_lr True \
#         --learning_rate 6e-4 \
#         --min_lr 6e-5 \
#         --data_telemetry False \
#         --data_config /XXXX-36/XXXX-22/XXXX-40/launch_scripts/XXXX-22/sample_hfds_only.json \
#         --fabric_strategy ddp

# python pretrain_umd/train.py \
#         --tokenizer_path /fs/XXXX-37/llm-pretraining/llm-retrieval/checkpoints/EleutherAI/pythia-160m \
#         --out_dir /fs/XXXX-37/llm-pretraining/llm-retrieval/out/pile-lm-pythia-160m-bsz-8_ddp_lr_3e-4_debug \
#         --seed 1337 \
#         --max_tokens 25_000_000_000 \
#         --model_name pythia-160m \
#         --run_name pile-lm-pythia-160m-bsz-8_ddp_lr_3e-4_debug \
#         --logger_name wandb \
#         --logger_project llm-retrieval \
#         --compile_model False \
#         --world_batch_size 16 \
#         --micro_batch_size 8 \
#         --block_size 2048 \
#         --n_chunks 4 \
#         --warmup_steps 4000 \
#         --log_step_interval 1 \
#         --eval_iters 100 \
#         --save_and_eval_interval 10000 \
#         --grad_clip 1.0 \
#         --decay_lr True \
#         --learning_rate 3e-4 \
#         --min_lr 3e-5 \
#         --data_telemetry False \
#         --data_config /XXXX-36/XXXX-22/XXXX-40/launch_scripts/XXXX-22/sample_hfds_only.json \
#         --fabric_strategy ddp

# python pretrain_umd/train.py \
#         --tokenizer_path /fs/XXXX-37/llm-pretraining/llm-retrieval/checkpoints/EleutherAI/pythia-160m \
#         --out_dir /fs/XXXX-37/llm-pretraining/llm-retrieval/out/pile-lm-pythia-160m-bsz-8_ddp_lr_2e-4_debug \
#         --seed 1337 \
#         --max_tokens 25_000_000_000 \
#         --model_name pythia-160m \
#         --run_name pile-lm-pythia-160m-bsz-8_ddp_lr_2e-4_debug \
#         --logger_name wandb \
#         --logger_project llm-retrieval \
#         --compile_model False \
#         --world_batch_size 16 \
#         --micro_batch_size 8 \
#         --block_size 2048 \
#         --n_chunks 4 \
#         --warmup_steps 4000 \
#         --log_step_interval 1 \
#         --eval_iters 100 \
#         --save_and_eval_interval 10000 \
#         --grad_clip 1.0 \
#         --decay_lr True \
#         --learning_rate 2e-4 \
#         --min_lr 2e-5 \
#         --data_telemetry False \
#         --data_config /XXXX-36/XXXX-22/XXXX-40/launch_scripts/khalids/sample_hfds_only.json \
#         --fabric_strategy ddp

# python pretrain_umd/train.py \
#         --tokenizer_path /fs/XXXX-37/llm-pretraining/llm-retrieval/checkpoints/EleutherAI/pythia-160m \
#         --out_dir /fs/XXXX-37/llm-pretraining/llm-retrieval/out/pile-lm-pythia-160m-bsz-8_ddp_lr_1e-3_debug \
#         --seed 1337 \
#         --max_tokens 25_000_000_000 \
#         --model_name pythia-160m \
#         --run_name pile-lm-pythia-160m-bsz-8_ddp_lr_1e-3_debug \
#         --logger_name wandb \
#         --logger_project llm-retrieval \
#         --compile_model False \
#         --world_batch_size 16 \
#         --micro_batch_size 8 \
#         --block_size 2048 \
#         --n_chunks 4 \
#         --warmup_steps 4000 \
#         --log_step_interval 1 \
#         --eval_iters 100 \
#         --save_and_eval_interval 10000 \
#         --grad_clip 1.0 \
#         --decay_lr True \
#         --learning_rate 1e-3 \
#         --min_lr 1e-4 \
#         --data_telemetry False \
#         --data_config /XXXX-36/XXXX-22/XXXX-40/launch_scripts/XXXX-22/sample_hfds_only.json \
#         --fabric_strategy ddp

# python pretrain_umd/train.py \
#         --tokenizer_path /fs/XXXX-37/llm-pretraining/llm-retrieval/checkpoints/EleutherAI/pythia-160m \
#         --out_dir /fs/XXXX-37/llm-pretraining/llm-retrieval/out/pile-lm-pythia-160m-bsz-8_ddp_lr_2e-5_debug \
#         --seed 1337 \
#         --max_tokens 25_000_000_000 \
#         --model_name pythia-160m \
#         --run_name pile-lm-pythia-160m-bsz-8_ddp_lr_2e-5_debug \
#         --logger_name wandb \
#         --logger_project llm-retrieval \
#         --compile_model False \
#         --world_batch_size 16 \
#         --micro_batch_size 8 \
#         --block_size 2048 \
#         --n_chunks 4 \
#         --warmup_steps 4000 \
#         --log_step_interval 1 \
#         --eval_iters 100 \
#         --save_and_eval_interval 10000 \
#         --grad_clip 1.0 \
#         --decay_lr True \
#         --learning_rate 2e-5 \
#         --min_lr 2e-6 \
#         --data_telemetry False \
#         --data_config /XXXX-36/XXXX-22/XXXX-40/launch_scripts/XXXX-22/sample_hfds_only.json \
#         --fabric_strategy ddp

# python pretrain_umd/train.py \
#         --tokenizer_path /fs/XXXX-37/llm-pretraining/llm-retrieval/checkpoints/EleutherAI/pythia-160m \
#         --out_dir /fs/XXXX-37/llm-pretraining/llm-retrieval/out/pile-lm-pythia-160m-bsz-8_ddp_lr_5e-5_debug \
#         --seed 1337 \
#         --max_tokens 25_000_000_000 \
#         --model_name pythia-160m \
#         --run_name pile-lm-pythia-160m-bsz-8_ddp_lr_5e-5_debug \
#         --logger_name wandb \
#         --logger_project llm-retrieval \
#         --compile_model False \
#         --world_batch_size 16 \
#         --micro_batch_size 8 \
#         --block_size 2048 \
#         --n_chunks 4 \
#         --warmup_steps 4000 \
#         --log_step_interval 1 \
#         --eval_iters 100 \
#         --save_and_eval_interval 10000 \
#         --grad_clip 1.0 \
#         --decay_lr True \
#         --learning_rate 5e-5 \
#         --min_lr 5e-6 \
#         --data_telemetry False \
#         --data_config /XXXX-36/XXXX-22/XXXX-40/launch_scripts/XXXX-22/sample_hfds_only.json \
#         --fabric_strategy ddp

# #### Pythia Retrieval Training ####
# python pretrain_umd/train_retrieval_w_anticausal.py \
#         --model_checkpoint /fs/XXXX-37/llm-pretraining/llm-retrieval/checkpoints/EleutherAI/pythia-160m \
#         --tokenizer_path /fs/XXXX-37/llm-pretraining/llm-retrieval/checkpoints/EleutherAI/pythia-160m \
#         --out_dir /fs/XXXX-37/llm-pretraining/llm-retrieval/out/cosmopedia-retrieval-dual-causal-pythia-160m-bsz-8-ctx-rand-batch_negative_ddp_RR_lr_5e-5_debug \
#         --seed 1337 \
#         --max_tokens 25_000_000_000 \
#         --model_name pythia-160m \
#         --run_name cosmopedia-retrieval-dual-causal-pythia-160m-bsz-8-ctx-rand-batch_negative_ddp_RR_lr_5e-5_debug \
#         --logger_name wandb \
#         --compile_model False \
#         --fabric_precision bf16-mixed \
#         --world_batch_size 16 \
#         --micro_batch_size 8 \
#         --block_size 2048 \
#         --n_chunks 4 \
#         --warmup_steps 2000 \
#         --log_step_interval 1 \
#         --eval_iters 100 \
#         --save_and_eval_interval 10000 \
#         --grad_clip 1.0 \
#         --decay_lr True \
#         --learning_rate 5e-5 \
#         --min_lr 5e-6 \
#         --data_telemetry False \
#         --data_config /XXXX-36/XXXX-22/XXXX-40/launch_scripts/XXXX-22/sample_hfds_only.json \
#         --fabric_strategy ddp \
#         --pretrained_prefix_model False \
#         --pretrained_suffix_model False

# python pretrain_umd/train_retrieval_w_anticausal.py \
#         --model_checkpoint /fs/XXXX-37/llm-pretraining/llm-retrieval/checkpoints/EleutherAI/pythia-160m \
#         --tokenizer_path /fs/XXXX-37/llm-pretraining/llm-retrieval/checkpoints/EleutherAI/pythia-160m \
#         --out_dir /fs/XXXX-37/llm-pretraining/llm-retrieval/out/cosmopedia-retrieval-dual-causal-pythia-160m-bsz-8-ctx-rand-batch_negative_ddp_RR_lr_2e-5_debug \
#         --seed 1337 \
#         --max_tokens 25_000_000_000 \
#         --model_name pythia-160m \
#         --run_name cosmopedia-retrieval-dual-causal-pythia-160m-bsz-8-ctx-rand-batch_negative_ddp_RR_lr_2e-5_debug \
#         --logger_name wandb \
#         --compile_model False \
#         --fabric_precision bf16-mixed \
#         --world_batch_size 16 \
#         --micro_batch_size 8 \
#         --block_size 2048 \
#         --n_chunks 4 \
#         --warmup_steps 2000 \
#         --log_step_interval 1 \
#         --eval_iters 100 \
#         --save_and_eval_interval 10000 \
#         --grad_clip 1.0 \
#         --decay_lr True \
#         --learning_rate 2e-5 \
#         --min_lr 2e-6 \
#         --data_telemetry False \
#         --data_config /XXXX-36/XXXX-22/XXXX-40/launch_scripts/XXXX-22/sample_hfds_only.json \
#         --fabric_strategy ddp \
#         --pretrained_prefix_model False \
#         --pretrained_suffix_model False

# python pretrain_umd/train_retrieval_w_anticausal.py \
#         --model_checkpoint /fs/XXXX-37/llm-pretraining/llm-retrieval/checkpoints/EleutherAI/pythia-160m \
#         --tokenizer_path /fs/XXXX-37/llm-pretraining/llm-retrieval/checkpoints/EleutherAI/pythia-160m \
#         --out_dir /fs/XXXX-37/llm-pretraining/llm-retrieval/out/cosmopedia-retrieval-dual-causal-pythia-160m-bsz-8-ctx-rand-batch_negative_ddp_RR_lr_1e-3_debug \
#         --seed 1337 \
#         --max_tokens 25_000_000_000 \
#         --model_name pythia-160m \
#         --run_name cosmopedia-retrieval-dual-causal-pythia-160m-bsz-8-ctx-rand-batch_negative_ddp_RR_lr_1e-3_debug \
#         --logger_name wandb \
#         --compile_model False \
#         --fabric_precision bf16-mixed \
#         --world_batch_size 16 \
#         --micro_batch_size 8 \
#         --block_size 2048 \
#         --n_chunks 4 \
#         --warmup_steps 2000 \
#         --log_step_interval 1 \
#         --eval_iters 100 \
#         --save_and_eval_interval 10000 \
#         --grad_clip 1.0 \
#         --decay_lr True \
#         --learning_rate 1e-3 \
#         --min_lr 1e-4 \
#         --data_telemetry False \
#         --data_config /XXXX-36/XXXX-22/XXXX-40/launch_scripts/XXXX-22/sample_hfds_only.json \
#         --fabric_strategy ddp \
#         --pretrained_prefix_model False \
#         --pretrained_suffix_model False

# python pretrain_umd/train_retrieval_w_anticausal.py \
#         --model_checkpoint /fs/XXXX-37/llm-pretraining/llm-retrieval/checkpoints/EleutherAI/pythia-160m \
#         --tokenizer_path /fs/XXXX-37/llm-pretraining/llm-retrieval/checkpoints/EleutherAI/pythia-160m \
#         --out_dir /fs/XXXX-37/llm-pretraining/llm-retrieval/out/cosmopedia-retrieval-dual-causal-pythia-160m-bsz-8-ctx-rand-batch_negative_ddp_RR_lr_2e-4_debug \
#         --seed 1337 \
#         --max_tokens 25_000_000_000 \
#         --model_name pythia-160m \
#         --run_name cosmopedia-retrieval-dual-causal-pythia-160m-bsz-8-ctx-rand-batch_negative_ddp_RR_lr_2e-4_debug \
#         --logger_name wandb \
#         --compile_model False \
#         --fabric_precision bf16-mixed \
#         --world_batch_size 16 \
#         --micro_batch_size 8 \
#         --block_size 2048 \
#         --n_chunks 4 \
#         --warmup_steps 2000 \
#         --log_step_interval 1 \
#         --eval_iters 100 \
#         --save_and_eval_interval 10000 \
#         --grad_clip 1.0 \
#         --decay_lr True \
#         --learning_rate 2e-4 \
#         --min_lr 2e-5 \
#         --data_telemetry False \
#         --data_config /XXXX-36/XXXX-22/XXXX-40/launch_scripts/XXXX-22/sample_hfds_only.json \
#         --fabric_strategy ddp \
#         --pretrained_prefix_model False \
#         --pretrained_suffix_model False

# python pretrain_umd/train_retrieval_w_anticausal.py \
#         --model_checkpoint /fs/XXXX-37/llm-pretraining/llm-retrieval/checkpoints/EleutherAI/pythia-160m \
#         --tokenizer_path /fs/XXXX-37/llm-pretraining/llm-retrieval/checkpoints/EleutherAI/pythia-160m \
#         --out_dir /fs/XXXX-37/llm-pretraining/llm-retrieval/out/cosmopedia-retrieval-dual-causal-pythia-160m-bsz-8-ctx-rand-batch_negative_ddp_RR_lr_3e-4_debug \
#         --seed 1337 \
#         --max_tokens 25_000_000_000 \
#         --model_name pythia-160m \
#         --run_name cosmopedia-retrieval-dual-causal-pythia-160m-bsz-8-ctx-rand-batch_negative_ddp_RR_lr_3e-4_debug \
#         --logger_name wandb \
#         --compile_model False \
#         --fabric_precision bf16-mixed \
#         --world_batch_size 16 \
#         --micro_batch_size 8 \
#         --block_size 2048 \
#         --n_chunks 4 \
#         --warmup_steps 2000 \
#         --log_step_interval 1 \
#         --eval_iters 100 \
#         --save_and_eval_interval 10000 \
#         --grad_clip 1.0 \
#         --decay_lr True \
#         --learning_rate 3e-4 \
#         --min_lr 3e-5 \
#         --data_telemetry False \
#         --data_config /XXXX-36/XXXX-22/XXXX-40/launch_scripts/XXXX-22/sample_hfds_only.json \
#         --fabric_strategy ddp \
#         --pretrained_prefix_model False \
#         --pretrained_suffix_model False

# python pretrain_umd/train_retrieval_w_anticausal.py \
#         --model_checkpoint /fs/XXXX-37/llm-pretraining/llm-retrieval/checkpoints/EleutherAI/pythia-160m \
#         --tokenizer_path /fs/XXXX-37/llm-pretraining/llm-retrieval/checkpoints/EleutherAI/pythia-160m \
#         --out_dir /fs/XXXX-37/llm-pretraining/llm-retrieval/out/cosmopedia-retrieval-dual-causal-pythia-160m-bsz-8-ctx-rand-batch_negative_ddp_RR_lr_6e-4_debug \
#         --seed 1337 \
#         --max_tokens 25_000_000_000 \
#         --model_name pythia-160m \
#         --run_name cosmopedia-retrieval-dual-causal-pythia-160m-bsz-8-ctx-rand-batch_negative_ddp_RR_lr_6e-4_debug \
#         --logger_name wandb \
#         --compile_model False \
#         --fabric_precision bf16-mixed \
#         --world_batch_size 16 \
#         --micro_batch_size 8 \
#         --block_size 2048 \
#         --n_chunks 4 \
#         --warmup_steps 2000 \
#         --log_step_interval 1 \
#         --eval_iters 100 \
#         --save_and_eval_interval 10000 \
#         --grad_clip 1.0 \
#         --decay_lr True \
#         --learning_rate 6e-4 \
#         --min_lr 6e-5 \
#         --data_telemetry False \
#         --data_config /XXXX-36/XXXX-22/XXXX-40/launch_scripts/XXXX-22/sample_hfds_only.json \
#         --fabric_strategy ddp \
#         --pretrained_prefix_model False \
#         --pretrained_suffix_model False





#### WORLD BSZ SCALING EXP ####
python pretrain_umd/train_retrieval_w_anticausal.py \
        --model_checkpoint /fs/XXXX-37/llm-pretraining/llm-retrieval/checkpoints/EleutherAI/pythia-160m \
        --tokenizer_path /fs/XXXX-37/llm-pretraining/llm-retrieval/checkpoints/EleutherAI/pythia-160m \
        --out_dir /fs/XXXX-37/llm-pretraining/llm-retrieval/out/dolma-retrieval-dual-causal-pythia-160m-worldbsz-20-ctx-rand-batch_negative_ddp_RR_lr_3e-4_debug \
        --resume True \
        --seed 1337 \
        --max_tokens 25_000_000_000 \
        --model_name pythia-160m \
        --run_name dolma-retrieval-dual-causal-pythia-160m-worldbsz-20-ctx-rand-batch_negative_ddp_RR_lr_3e-4_debug \
        --logger_name wandb \
        --compile_model False \
        --fabric_precision bf16-mixed \
        --world_batch_size 20 \
        --micro_batch_size 10 \
        --block_size 2048 \
        --n_chunks 4 \
        --warmup_steps 2000 \
        --log_step_interval 1 \
        --eval_iters 100 \
        --save_and_eval_interval 2000 \
        --grad_clip 1.0 \
        --decay_lr True \
        --learning_rate 3e-4 \
        --min_lr 3e-5 \
        --data_telemetry False \
        --data_config /XXXX-36/XXXX-22/XXXX-40/launch_scripts/XXXX-22/sample_hfds_only.json \
        --fabric_strategy ddp \
        --pretrained_prefix_model False \
        --pretrained_suffix_model False

python pretrain_umd/train_retrieval_w_anticausal.py \
        --model_checkpoint /fs/XXXX-37/llm-pretraining/llm-retrieval/checkpoints/EleutherAI/pythia-160m \
        --tokenizer_path /fs/XXXX-37/llm-pretraining/llm-retrieval/checkpoints/EleutherAI/pythia-160m \
        --out_dir /fs/XXXX-37/llm-pretraining/llm-retrieval/out/dolma-retrieval-dual-causal-pythia-160m-worldbsz-30-ctx-rand-batch_negative_ddp_RR_lr_3e-4_debug \
        --resume True \
        --seed 1337 \
        --max_tokens 25_000_000_000 \
        --model_name pythia-160m \
        --run_name dolma-retrieval-dual-causal-pythia-160m-worldbsz-30-ctx-rand-batch_negative_ddp_RR_lr_3e-4_debug \
        --logger_name wandb \
        --compile_model False \
        --fabric_precision bf16-mixed \
        --world_batch_size 30 \
        --micro_batch_size 10 \
        --block_size 2048 \
        --n_chunks 4 \
        --warmup_steps 2000 \
        --log_step_interval 1 \
        --eval_iters 100 \
        --save_and_eval_interval 2000 \
        --grad_clip 1.0 \
        --decay_lr True \
        --learning_rate 3e-4 \
        --min_lr 3e-5 \
        --data_telemetry False \
        --data_config /XXXX-36/XXXX-22/XXXX-40/launch_scripts/XXXX-22/sample_hfds_only.json \
        --fabric_strategy ddp \
        --pretrained_prefix_model False \
        --pretrained_suffix_model False

python pretrain_umd/train_retrieval_w_anticausal.py \
        --model_checkpoint /fs/XXXX-37/llm-pretraining/llm-retrieval/checkpoints/EleutherAI/pythia-160m \
        --tokenizer_path /fs/XXXX-37/llm-pretraining/llm-retrieval/checkpoints/EleutherAI/pythia-160m \
        --out_dir /fs/XXXX-37/llm-pretraining/llm-retrieval/out/dolma-retrieval-dual-causal-pythia-160m-worldbsz-40-ctx-rand-batch_negative_ddp_RR_lr_3e-4_debug \
        --resume True \
        --seed 1337 \
        --max_tokens 25_000_000_000 \
        --model_name pythia-160m \
        --run_name dolma-retrieval-dual-causal-pythia-160m-worldbsz-40-ctx-rand-batch_negative_ddp_RR_lr_3e-4_debug \
        --logger_name wandb \
        --compile_model False \
        --fabric_precision bf16-mixed \
        --world_batch_size 40 \
        --micro_batch_size 10 \
        --block_size 2048 \
        --n_chunks 4 \
        --warmup_steps 2000 \
        --log_step_interval 1 \
        --eval_iters 100 \
        --save_and_eval_interval 2000 \
        --grad_clip 1.0 \
        --decay_lr True \
        --learning_rate 3e-4 \
        --min_lr 3e-5 \
        --data_telemetry False \
        --data_config /XXXX-36/XXXX-22/XXXX-40/launch_scripts/XXXX-22/sample_hfds_only.json \
        --fabric_strategy ddp \
        --pretrained_prefix_model False \
        --pretrained_suffix_model False

python pretrain_umd/train_retrieval_w_anticausal.py \
        --model_checkpoint /fs/XXXX-37/llm-pretraining/llm-retrieval/checkpoints/EleutherAI/pythia-160m \
        --tokenizer_path /fs/XXXX-37/llm-pretraining/llm-retrieval/checkpoints/EleutherAI/pythia-160m \
        --out_dir /fs/XXXX-37/llm-pretraining/llm-retrieval/out/dolma-retrieval-dual-causal-pythia-160m-worldbsz-50-ctx-rand-batch_negative_ddp_RR_lr_3e-4_debug \
        --resume True \
        --seed 1337 \
        --max_tokens 25_000_000_000 \
        --model_name pythia-160m \
        --run_name dolma-retrieval-dual-causal-pythia-160m-worldbsz-50-ctx-rand-batch_negative_ddp_RR_lr_3e-4_debug \
        --logger_name wandb \
        --compile_model False \
        --fabric_precision bf16-mixed \
        --world_batch_size 50 \
        --micro_batch_size 10 \
        --block_size 2048 \
        --n_chunks 4 \
        --warmup_steps 2000 \
        --log_step_interval 1 \
        --eval_iters 100 \
        --save_and_eval_interval 2000 \
        --grad_clip 1.0 \
        --decay_lr True \
        --learning_rate 3e-4 \
        --min_lr 3e-5 \
        --data_telemetry False \
        --data_config /XXXX-36/XXXX-22/XXXX-40/launch_scripts/XXXX-22/sample_hfds_only.json \
        --fabric_strategy ddp \
        --pretrained_prefix_model False \
        --pretrained_suffix_model False

python pretrain_umd/train_retrieval_w_anticausal.py \
        --model_checkpoint /fs/XXXX-37/llm-pretraining/llm-retrieval/checkpoints/EleutherAI/pythia-160m \
        --tokenizer_path /fs/XXXX-37/llm-pretraining/llm-retrieval/checkpoints/EleutherAI/pythia-160m \
        --out_dir /fs/XXXX-37/llm-pretraining/llm-retrieval/out/dolma-retrieval-dual-causal-pythia-160m-worldbsz-60-ctx-rand-batch_negative_ddp_RR_lr_3e-4_debug \
        --resume True \
        --seed 1337 \
        --max_tokens 25_000_000_000 \
        --model_name pythia-160m \
        --run_name dolma-retrieval-dual-causal-pythia-160m-worldbsz-60-ctx-rand-batch_negative_ddp_RR_lr_3e-4_debug \
        --logger_name wandb \
        --compile_model False \
        --fabric_precision bf16-mixed \
        --world_batch_size 60 \
        --micro_batch_size 10 \
        --block_size 2048 \
        --n_chunks 4 \
        --warmup_steps 2000 \
        --log_step_interval 1 \
        --eval_iters 100 \
        --save_and_eval_interval 2000 \
        --grad_clip 1.0 \
        --decay_lr True \
        --learning_rate 3e-4 \
        --min_lr 3e-5 \
        --data_telemetry False \
        --data_config /XXXX-36/XXXX-22/XXXX-40/launch_scripts/XXXX-22/sample_hfds_only.json \
        --fabric_strategy ddp \
        --pretrained_prefix_model False \
        --pretrained_suffix_model False

python pretrain_umd/train_retrieval_w_anticausal.py \
        --model_checkpoint /fs/XXXX-37/llm-pretraining/llm-retrieval/checkpoints/EleutherAI/pythia-160m \
        --tokenizer_path /fs/XXXX-37/llm-pretraining/llm-retrieval/checkpoints/EleutherAI/pythia-160m \
        --out_dir /fs/XXXX-37/llm-pretraining/llm-retrieval/out/dolma-retrieval-dual-causal-pythia-160m-worldbsz-70-ctx-rand-batch_negative_ddp_RR_lr_3e-4_debug \
        --resume True \
        --seed 1337 \
        --max_tokens 25_000_000_000 \
        --model_name pythia-160m \
        --run_name dolma-retrieval-dual-causal-pythia-160m-worldbsz-70-ctx-rand-batch_negative_ddp_RR_lr_3e-4_debug \
        --logger_name wandb \
        --compile_model False \
        --fabric_precision bf16-mixed \
        --world_batch_size 70 \
        --micro_batch_size 10 \
        --block_size 2048 \
        --n_chunks 4 \
        --warmup_steps 2000 \
        --log_step_interval 1 \
        --eval_iters 100 \
        --save_and_eval_interval 2000 \
        --grad_clip 1.0 \
        --decay_lr True \
        --learning_rate 3e-4 \
        --min_lr 3e-5 \
        --data_telemetry False \
        --data_config /XXXX-36/XXXX-22/XXXX-40/launch_scripts/XXXX-22/sample_hfds_only.json \
        --fabric_strategy ddp \
        --pretrained_prefix_model False \
        --pretrained_suffix_model False

python pretrain_umd/train_retrieval_w_anticausal.py \
        --model_checkpoint /fs/XXXX-37/llm-pretraining/llm-retrieval/checkpoints/EleutherAI/pythia-160m \
        --tokenizer_path /fs/XXXX-37/llm-pretraining/llm-retrieval/checkpoints/EleutherAI/pythia-160m \
        --out_dir /fs/XXXX-37/llm-pretraining/llm-retrieval/out/dolma-retrieval-dual-causal-pythia-160m-worldbsz-80-ctx-rand-batch_negative_ddp_RR_lr_3e-4_debug \
        --resume True \
        --seed 1337 \
        --max_tokens 25_000_000_000 \
        --model_name pythia-160m \
        --run_name dolma-retrieval-dual-causal-pythia-160m-worldbsz-80-ctx-rand-batch_negative_ddp_RR_lr_3e-4_debug \
        --logger_name wandb \
        --compile_model False \
        --fabric_precision bf16-mixed \
        --world_batch_size 80 \
        --micro_batch_size 10 \
        --block_size 2048 \
        --n_chunks 4 \
        --warmup_steps 2000 \
        --log_step_interval 1 \
        --eval_iters 100 \
        --save_and_eval_interval 2000 \
        --grad_clip 1.0 \
        --decay_lr True \
        --learning_rate 3e-4 \
        --min_lr 3e-5 \
        --data_telemetry False \
        --data_config /XXXX-36/XXXX-22/XXXX-40/launch_scripts/XXXX-22/sample_hfds_only.json \
        --fabric_strategy ddp \
        --pretrained_prefix_model False \
        --pretrained_suffix_model False