#!/bin/bash

# for dir in /XXXX-30/XXXX-29/XXXX-31/scratch/XXXX-22/output/v4_pythia-160m-retr-32k_w_meta_mb2-wb2048-grp1024_keep_368k_negs_128N_truncate_normal/checkpoints-AxonnStrategy/*; do
#     rm -v "$dir/step-00065284"*
# done

check_job_running() {
    # local running_jobs=$(squeue -u $USER -o "%i %l" | grep "58:00" | wc -l)
    # Get job ID from squeue that matches our time limit of 1:59:00
    # local running_jobs=$(squeue -u $USER -o "%i %l" | grep "30:00" | wc -l)
    local running_jobs=$(squeue -u $USER -o "%i %l" | grep "1:58:00" | wc -l)
    if [ "$running_jobs" -gt 0 ]; then
        return 0  # Job is still running
    else
        return 1  # Job is not running
    fi
}

wait_for_job_completion() {
    echo "Waiting for job to complete..."    
    # Check every 5 minutes if job is still running
    while check_job_running; do
        echo "Job still running, checking again in 5 minutes..."
        sleep 5m
    done
    echo "Job completed!"
}

# pythonAll train.py --config launch_configs/njain/base_nomic_clm.json --run_name clm_pretraing_run_v1_100BT --out_dir /XXXX-30/XXXX-29/XXXX-31/scratch/XXXX-22/output/clm_pretraing_run_v1_100BT --micro_batch_size 16 --world_batch_size 2048 --validate_only true > load_axonn_clm_ckpt.txt
# pythonAll train.py --config launch_configs/njain/base_nomic_clm.json --run_name clm_pretraing_run_v1_100BT --out_dir /XXXX-30/XXXX-29/XXXX-31/scratch/XXXX-22/output/clm_pretraing_run_v1_100BT --micro_batch_size 16 --world_batch_size 2048 --validate_only true --model_checkpoint /XXXX-30/XXXX-29/XXXX-31/scratch/XXXX-22/output/clm_pretraing_run_v1_100BT/combined_ckpts/step-00072000_ckpt.pth > load_combined_clm_ckpt.txt
# pythonAll train.py --config launch_configs/njain/base_nomic_clm.json --run_name test --out_dir /XXXX-30/XXXX-29/XXXX-31/scratch/XXXX-22/output/test --micro_batch_size 16 --world_batch_size 16 --validate_only true --model_checkpoint /XXXX-30/XXXX-29/XXXX-31/proj-shared/language_models/external/EleutherAI/pythia-160m/lit_model.pth --tokenizer_path /XXXX-30/XXXX-29/XXXX-31/proj-shared/language_models/external/EleutherAI/pythia-160m --fabric.depth_tensor_parallel_size=1 --model_name=pythia-160m -- > load_combined_clm_ckpt.txt
python train.py --config launch_configs/njain/base_nomic_clm.json --run_name test --out_dir /XXXX-30/XXXX-29/XXXX-31/scratch/XXXX-22/output/test --micro_batch_size 16 --world_batch_size 16 --fabric.depth_tensor_parallel_size=1 --model_checkpoint /XXXX-30/XXXX-29/XXXX-31/scratch/XXXX-22/output/clm_pretraing_run_v1_100BT/combined_ckpts/step-00072000_ckpt.pth
# python /XXXX-30/XXXX-29/XXXX-31/scratch/XXXX-22/lit-gpt-dev_new/launch_frontier.py \
#     --python_script="/XXXX-30/XXXX-29/XXXX-31/scratch/XXXX-22/lit-gpt-dev_new/train_retrieval_w_anticausal.py" \
#     --rccl_installdir="${WRKSPC}/aws-ofi-rccl_25_62_retrieval" \
#     --environment="${WRKSPC}/frontier_conda_25_62_retrieval" \
#     --budget_minutes=118 \
#     --nodes 16 \
#     --output_dir=/XXXX-30/XXXX-29/XXXX-31/scratch/XXXX-22/output \
#     --config /XXXX-30/XXXX-29/XXXX-31/scratch/XXXX-22/lit-gpt-dev_new/launch_configs/retrieval/base_nomic_finetune_lockstep.json \
#     --run_name axonn_nomic_finetune_phase3_pt_step_60k_fineweb_100b_lr_2e-5_w_lockstep_wb_tgrp_8_negs_4096_mean_pool_v4_pythia-160m-retr-32k_w_meta_truncate_normal_mb16-wb2048-grp1-1-8_32_16N_max_steps_2484_max_seq_512 \
#     --extra_args='--lockstep_sampling=world_batch --data_telemetry=10 --max_seq_len=512 --fabric_strategy="axonn_tp" --attn_impl=rocm --fabric.depth_tensor_parallel_size=8 --batch_prefix_and_suffix=true --model_name=pythia-160m-retr-32k_w_meta --pad_to_block_size=true --train_group_size=8 --mean_pooling=true --max_steps=2484 --negatives_cross_device=true --save_step_interval=500 --eval_step_interval=20000 --negatives_cross_device_group_size=32 --optim_config.lr=2e-5 --min_lr=2e-6 --save_n_min_before_job_done=2 --world_batch_size=2048 --micro_batch_size=16 --validate_at_end=false --save_last_step=true --finetune_checkpoint="/XXXX-30/XXXX-29/XXXX-31/scratch/XXXX-22/output/ft_nomic_pos_only_pt_step_36k_fineweb_100b_w_lockstep_wb_negs_16384_mean_pool_v4_pythia-160m-retr-32k_w_meta_mb128-wb16384-grp1-1-8_128_16N_max_steps_14587_max_seq_512/combined_ckpts/step-00014586_ckpt.pth"' \
#     --sub_output_dir_name axonn_nomic_finetune_phase3_pt_step_60k_fineweb_100b_lr_2e-5_w_lockstep_wb_tgrp_8_negs_4096_mean_pool_v4_pythia-160m-retr-32k_w_meta_truncate_normal_mb16-wb2048-grp1-1-8_32_16N_max_steps_2484_max_seq_512 \
#     --disable_net_gdr \
#     --debug_qos
# echo "Running axonn_nomic_finetune_phase3_pt_step_36k_fineweb_100b"
# wait_for_job_completion

wait_for_job_completion
for i in {1..8}
do
    python launch_frontier.py \
        --python_script="train_retrieval_w_anticausal.py" \
        --rccl_installdir="${WRKSPC}/aws-ofi-rccl_25_62_retrieval" \
        --environment="${WRKSPC}/frontier_conda_25_62_retrieval" \
        --budget_minutes=118 \
        --nodes 16 \
        --output_dir=/XXXX-30/XXXX-29/XXXX-31/scratch/XXXX-22/output \
        --config launch_configs/retrieval/base_nomic_positive_only_finetune_lockstep.json \
        --run_name ft_nomic_pos_only_pt_step_130k_w_lockstep_wb_negs_16384_mean_pool_v4_pythia-160m-retr-32k_w_meta_truncate_normal_mb128-wb16384-grp1-1-8_128_16N_max_steps_14587_max_seq_512 \
        --extra_args='--lockstep_sampling=world_batch --data_telemetry=10 --max_seq_len=512 --fabric_strategy="axonn_tp" --attn_impl=rocm --fabric.depth_tensor_parallel_size=8 --batch_prefix_and_suffix=true --model_name=pythia-160m-retr-32k_w_meta --pad_to_block_size=true --mean_pooling=true --max_steps=14587 --negatives_cross_device=true --save_step_interval=500 --eval_step_interval=20000 --negatives_cross_device_group_size=128 --optim_config.lr=2e-5 --min_lr=2e-6 --save_n_min_before_job_done=2 --world_batch_size=16384 --micro_batch_size=128 --validate_at_end=false --save_last_step=true --finetune_checkpoint="/XXXX-30/XXXX-29/XXXX-31/scratch/XXXX-22/output/v4_pythia-160m-retr-32k_w_meta_mb2-wb2048-grp1024_keep_368k_negs_128N_truncate_normal/combined_ckpts/step-00130000_ckpt.pth"' \
        --sub_output_dir_name ft_nomic_pos_only_pt_step_130k_w_lockstep_wb_negs_16384_mean_pool_v4_pythia-160m-retr-32k_w_meta_truncate_normal_mb128-wb16384-grp1-1-8_128_16N_max_steps_14587_max_seq_512 \
        --debug_qos
    echo "Current iteration: $i ..."
    wait_for_job_completion
    echo "Iteration $i completed."
done

for i in {1..8}
do
    python launch_frontier.py \
        --python_script="train_retrieval_w_anticausal.py" \
        --rccl_installdir="${WRKSPC}/aws-ofi-rccl_25_62_retrieval" \
        --environment="${WRKSPC}/frontier_conda_25_62_retrieval" \
        --budget_minutes=118 \
        --nodes 16 \
        --output_dir=/XXXX-30/XXXX-29/XXXX-31/scratch/XXXX-22/output \
        --config launch_configs/retrieval/base_nomic_positive_only_finetune_lockstep.json \
        --run_name ft_nomic_pos_only_pt_step_130k_w_lockstep_wb_negs_16384_lasttoken_pool_v4_pythia-160m-retr-32k_w_meta_truncate_normal_mb128-wb16384-grp1-1-8_128_16N_max_steps_14587_max_seq_512 \
        --extra_args='--mean_pooling=false --keep_eos=true --lockstep_sampling=world_batch --data_telemetry=10 --max_seq_len=512 --fabric_strategy="axonn_tp" --attn_impl=rocm --fabric.depth_tensor_parallel_size=8 --batch_prefix_and_suffix=true --model_name=pythia-160m-retr-32k_w_meta --pad_to_block_size=true --max_steps=14587 --negatives_cross_device=true --save_step_interval=500 --eval_step_interval=20000 --negatives_cross_device_group_size=128 --optim_config.lr=2e-5 --min_lr=2e-6 --save_n_min_before_job_done=2 --world_batch_size=16384 --micro_batch_size=128 --validate_at_end=false --save_last_step=true --finetune_checkpoint="/XXXX-30/XXXX-29/XXXX-31/scratch/XXXX-22/output/v4_pythia-160m-retr-32k_w_meta_mb2-wb2048-grp1024_keep_368k_negs_128N_truncate_normal/combined_ckpts/step-00130000_ckpt.pth"' \
        --sub_output_dir_name ft_nomic_pos_only_pt_step_130k_w_lockstep_wb_negs_16384_lasttoken_pool_v4_pythia-160m-retr-32k_w_meta_truncate_normal_mb128-wb16384-grp1-1-8_128_16N_max_steps_14587_max_seq_512 \
        --debug_qos
    echo "Current iteration: $i ..."
    wait_for_job_completion
    echo "Iteration $i completed."
done

# for i in {1..8}
# do
#     python launch_frontier.py \
#         --python_script="train_retrieval_w_anticausal.py" \
#         --rccl_installdir="${WRKSPC}/aws-ofi-rccl_25_62_retrieval" \
#         --environment="${WRKSPC}/frontier_conda_25_62_retrieval" \
#         --budget_minutes=118 \
#         --nodes 16 \
#         --output_dir=/XXXX-30/XXXX-29/XXXX-31/scratch/XXXX-22/output \
#         --config launch_configs/retrieval/base_nomic_positive_only_finetune_lockstep.json \
#         --run_name ft_nomic_pos_only_pt_step_72k_fineweb_100b_w_lockstep_wb_negs_16384_mean_pool_v4_pythia-160m-retr-32k_w_meta_mb128-wb16384-grp1-1-8_128_16N_max_steps_14587_max_seq_512 \
#         --extra_args='--lockstep_sampling=world_batch --data_telemetry=10 --max_seq_len=512 --fabric_strategy="axonn_tp" --attn_impl=rocm --fabric.depth_tensor_parallel_size=8 --batch_prefix_and_suffix=true --model_name=pythia-160m-retr-32k_w_meta --pad_to_block_size=true --mean_pooling=true --max_steps=14587 --negatives_cross_device=true --save_step_interval=500 --eval_step_interval=20000 --negatives_cross_device_group_size=128 --optim_config.lr=2e-5 --min_lr=2e-6 --save_n_min_before_job_done=2 --world_batch_size=16384 --micro_batch_size=128 --validate_at_end=false --save_last_step=true --finetune_checkpoint="/XXXX-30/XXXX-29/XXXX-31/scratch/XXXX-22/output/v4_fineweb_100b_pythia-160m-retr-32k_w_meta_mb2-wb2048-grp1024_keep_368k_negs_128N_truncate_normal/combined_ckpts/step-00072000_ckpt.pth"' \
#         --sub_output_dir_name ft_nomic_pos_only_pt_step_72k_fineweb_100b_w_lockstep_wb_negs_16384_mean_pool_v4_pythia-160m-retr-32k_w_meta_mb128-wb16384-grp1-1-8_128_16N_max_steps_14587_max_seq_512 \
#         --debug_qos
#     echo "Current iteration: $i ..."
#     wait_for_job_completion
#     echo "Iteration $i completed."
# done

# wait_for_job_completion
# for i in {1..5}
# do
#     python launch_frontier.py \
#         --python_script="train_retrieval_w_anticausal.py" \
#         --rccl_installdir="${WRKSPC}/aws-ofi-rccl_25_62_retrieval" \
#         --environment="${WRKSPC}/frontier_conda_25_62_retrieval" \
#         --budget_minutes=118 \
#         --nodes 16 \
#         --output_dir=/XXXX-30/XXXX-29/XXXX-31/scratch/XXXX-22/output \
#         --config launch_configs/retrieval/base_nomic_positive_only_finetune_lockstep.json \
#         --run_name ft_nomic_pos_only_pt_step_60k_w_lockstep_wb_negs_16384_mean_pool_v4_pythia-160m-retr-32k_w_meta_truncate_normal_mb128-wb16384-grp1-1-8_128_16N_max_steps_14587_max_seq_512 \
#         --extra_args='--lockstep_sampling=world_batch --data_telemetry=10 --max_seq_len=512 --fabric_strategy="axonn_tp" --attn_impl=rocm --fabric.depth_tensor_parallel_size=8 --batch_prefix_and_suffix=true --model_name=pythia-160m-retr-32k_w_meta --pad_to_block_size=true --mean_pooling=true --max_steps=14587 --negatives_cross_device=true --save_step_interval=500 --eval_step_interval=20000 --negatives_cross_device_group_size=128 --optim_config.lr=2e-5 --min_lr=2e-6 --save_n_min_before_job_done=2 --world_batch_size=16384 --micro_batch_size=128 --validate_at_end=false --save_last_step=true --finetune_checkpoint="/XXXX-30/XXXX-29/XXXX-31/scratch/XXXX-22/output/v4_pythia-160m-retr-32k_w_meta_mb2-wb2048-grp1024_keep_368k_negs_128N_truncate_normal/combined_ckpts/step-00060000_ckpt.pth"' \
#         --sub_output_dir_name ft_nomic_pos_only_pt_step_60k_w_lockstep_wb_negs_16384_mean_pool_v4_pythia-160m-retr-32k_w_meta_truncate_normal_mb128-wb16384-grp1-1-8_128_16N_max_steps_14587_max_seq_512 \
#         --debug_qos
#     echo "Current iteration: $i ..."
#     wait_for_job_completion
#     echo "Iteration $i completed."
# done

# for i in {1..8}
# do
#     python launch_frontier.py \
#         --python_script="train_retrieval_w_anticausal.py" \
#         --rccl_installdir="${WRKSPC}/aws-ofi-rccl_25_62_retrieval" \
#         --environment="${WRKSPC}/frontier_conda_25_62_retrieval" \
#         --budget_minutes=118 \
#         --nodes 16 \
#         --output_dir=/XXXX-30/XXXX-29/XXXX-31/scratch/XXXX-22/output \
#         --config launch_configs/retrieval/base_nomic_positive_only_finetune_lockstep.json \
#         --run_name ft_nomic_pos_only_pt_step_60k_w_lockstep_wb_negs_16384_lasttoken_pool_v4_pythia-160m-retr-32k_w_meta_truncate_normal_mb128-wb16384-grp1-1-8_128_16N_max_steps_14587_max_seq_512 \
#         --extra_args='--lockstep_sampling=world_batch --data_telemetry=10 --max_seq_len=512 --fabric_strategy="axonn_tp" --attn_impl=rocm --fabric.depth_tensor_parallel_size=8 --batch_prefix_and_suffix=true --model_name=pythia-160m-retr-32k_w_meta --pad_to_block_size=true --mean_pooling=false --max_steps=14587 --negatives_cross_device=true --save_step_interval=500 --eval_step_interval=20000 --negatives_cross_device_group_size=128 --optim_config.lr=2e-5 --min_lr=2e-6 --save_n_min_before_job_done=2 --world_batch_size=16384 --micro_batch_size=128 --validate_at_end=false --save_last_step=true --finetune_checkpoint="/XXXX-30/XXXX-29/XXXX-31/scratch/XXXX-22/output/v4_pythia-160m-retr-32k_w_meta_mb2-wb2048-grp1024_keep_368k_negs_128N_truncate_normal/combined_ckpts/step-00060000_ckpt.pth"' \
#         --sub_output_dir_name ft_nomic_pos_only_pt_step_60k_w_lockstep_wb_negs_16384_lasttoken_pool_v4_pythia-160m-retr-32k_w_meta_truncate_normal_mb128-wb16384-grp1-1-8_128_16N_max_steps_14587_max_seq_512 \
#         --debug_qos
#     echo "Current iteration: $i ..."
#     wait_for_job_completion
#     echo "Iteration $i completed."
# done

# python /XXXX-30/XXXX-29/XXXX-31/scratch/XXXX-22/lit-gpt-dev_new/launch_frontier.py \
#     --python_script="/XXXX-30/XXXX-29/XXXX-31/scratch/XXXX-22/lit-gpt-dev_new/train_retrieval_w_anticausal.py" \
#     --rccl_installdir="${WRKSPC}/aws-ofi-rccl_25_62_retrieval" \
#     --environment="${WRKSPC}/frontier_conda_25_62_retrieval" \
#     --budget_minutes=118 \
#     --nodes 16 \
#     --output_dir=/XXXX-30/XXXX-29/XXXX-31/scratch/XXXX-22/output \
#     --config /XXXX-30/XXXX-29/XXXX-31/scratch/XXXX-22/lit-gpt-dev_new/launch_configs/retrieval/base_nomic_finetune_lockstep.json \
#     --run_name axonn_nomic_finetune_phase3_pt_step_36k_fineweb_100b_lr_2e-5_w_lockstep_wb_tgrp_8_negs_4096_mean_pool_v4_pythia-160m-retr-32k_w_meta_truncate_normal_mb16-wb2048-grp1-1-8_32_16N_max_steps_2484_max_seq_512 \
#     --extra_args='--lockstep_sampling=world_batch --data_telemetry=10 --max_seq_len=512 --fabric_strategy="axonn_tp" --attn_impl=rocm --fabric.depth_tensor_parallel_size=8 --batch_prefix_and_suffix=true --model_name=pythia-160m-retr-32k_w_meta --pad_to_block_size=true --train_group_size=8 --mean_pooling=true --max_steps=2484 --negatives_cross_device=true --save_step_interval=500 --eval_step_interval=20000 --negatives_cross_device_group_size=32 --optim_config.lr=2e-5 --min_lr=2e-6 --save_n_min_before_job_done=2 --world_batch_size=2048 --micro_batch_size=16 --validate_at_end=false --save_last_step=true --finetune_checkpoint="/XXXX-30/XXXX-29/XXXX-31/scratch/XXXX-22/output/ft_nomic_pos_only_pt_step_36k_fineweb_100b_w_lockstep_wb_negs_16384_mean_pool_v4_pythia-160m-retr-32k_w_meta_mb128-wb16384-grp1-1-8_128_16N_max_steps_14587_max_seq_512/combined_ckpts/step-00014586_ckpt.pth"' \
#     --sub_output_dir_name axonn_nomic_finetune_phase3_pt_step_36k_fineweb_100b_lr_2e-5_w_lockstep_wb_tgrp_8_negs_4096_mean_pool_v4_pythia-160m-retr-32k_w_meta_truncate_normal_mb16-wb2048-grp1-1-8_32_16N_max_steps_2484_max_seq_512 \
#     --disable_net_gdr \
#     --debug_qos
# echo "Running axonn_nomic_finetune_phase3_pt_step_36k_fineweb_100b"
# wait_for_job_completion

# python /XXXX-30/XXXX-29/XXXX-31/scratch/XXXX-22/lit-gpt-dev_new/launch_frontier.py \
#     --python_script="/XXXX-30/XXXX-29/XXXX-31/scratch/XXXX-22/lit-gpt-dev_new/train_retrieval_w_anticausal.py" \
#     --rccl_installdir="${WRKSPC}/aws-ofi-rccl_25_62_retrieval" \
#     --environment="${WRKSPC}/frontier_conda_25_62_retrieval" \
#     --budget_minutes=118 \
#     --nodes 16 \
#     --output_dir=/XXXX-30/XXXX-29/XXXX-31/scratch/XXXX-22/output \
#     --config /XXXX-30/XXXX-29/XXXX-31/scratch/XXXX-22/lit-gpt-dev_new/launch_configs/retrieval/base_nomic_finetune_lockstep.json \
#     --run_name axonn_nomic_finetune_phase3_pt_step_36k_fineweb_100b_no_combine_lr_2e-5_w_lockstep_wb_tgrp_8_negs_4096_mean_pool_v4_pythia-160m-retr-32k_w_meta_truncate_normal_mb16-wb2048-grp1-1-8_32_16N_max_steps_2484_max_seq_512 \
#     --extra_args='--lockstep_sampling=world_batch --data_telemetry=10 --max_seq_len=512 --fabric_strategy="axonn_tp" --attn_impl=rocm --fabric.depth_tensor_parallel_size=8 --batch_prefix_and_suffix=true --model_name=pythia-160m-retr-32k_w_meta --pad_to_block_size=true --train_group_size=8 --mean_pooling=true --max_steps=2484 --negatives_cross_device=true --save_step_interval=500 --eval_step_interval=20000 --negatives_cross_device_group_size=32 --optim_config.lr=2e-5 --min_lr=2e-6 --save_n_min_before_job_done=2 --world_batch_size=2048 --micro_batch_size=16 --validate_at_end=false --save_last_step=true --axonn_ckpt="/XXXX-30/XXXX-29/XXXX-31/scratch/XXXX-22/output/ft_nomic_pos_only_pt_step_36k_fineweb_100b_w_lockstep_wb_negs_16384_mean_pool_v4_pythia-160m-retr-32k_w_meta_mb128-wb16384-grp1-1-8_128_16N_max_steps_14587_max_seq_512/checkpoints-AxonnStrategy"' \
#     --sub_output_dir_name axonn_nomic_finetune_phase3_pt_step_36k_fineweb_100b_no_combine_lr_2e-5_w_lockstep_wb_tgrp_8_negs_4096_mean_pool_v4_pythia-160m-retr-32k_w_meta_truncate_normal_mb16-wb2048-grp1-1-8_32_16N_max_steps_2484_max_seq_512 \
#     --disable_net_gdr \
#     --debug_qos
# echo "Running axonn_nomic_finetune_phase3_pt_step_36k_fineweb_100b_no_combine_lr_2e-5"
# wait_for_job_completion

# python launch_frontier.py \
#     --python_script="train_retrieval_w_anticausal.py" \
#     --rccl_installdir="${WRKSPC}/aws-ofi-rccl_25_62_retrieval" \
#     --environment="${WRKSPC}/frontier_conda_25_62_retrieval" \
#     --budget_minutes=119 \
#     --nodes 16 \
#     --output_dir=/XXXX-30/XXXX-29/XXXX-31/scratch/XXXX-22/output \
#     --config launch_configs/retrieval/base_nomic_positive_only_finetune_lockstep.json \
#     --run_name ft_nomic_pos_only_pt_step_36k_fineweb_100b_w_lockstep_wb_negs_16384_mean_pool_v4_pythia-160m-retr-32k_w_meta_mb128-wb16384-grp1-1-8_128_16N_max_steps_14587_max_seq_512 \
#     --extra_args='--lockstep_sampling=world_batch --data_telemetry=10 --max_seq_len=512 --fabric_strategy="axonn_tp" --attn_impl=rocm --fabric.depth_tensor_parallel_size=8 --batch_prefix_and_suffix=true --model_name=pythia-160m-retr-32k_w_meta --pad_to_block_size=true --mean_pooling=true --max_steps=14587 --negatives_cross_device=true --save_step_interval=500 --eval_step_interval=20000 --negatives_cross_device_group_size=128 --optim_config.lr=2e-5 --min_lr=2e-6 --save_n_min_before_job_done=2 --world_batch_size=16384 --micro_batch_size=128 --validate_at_end=false --save_last_step=true --finetune_checkpoint="/XXXX-30/XXXX-29/XXXX-31/scratch/XXXX-22/output/v4_fineweb_100b_pythia-160m-retr-32k_w_meta_mb2-wb2048-grp1024_keep_368k_negs_128N_truncate_normal/combined_ckpts/step-00036000_ckpt.pth"' \
#     --sub_output_dir_name ft_nomic_pos_only_pt_step_36k_fineweb_100b_w_lockstep_wb_negs_16384_mean_pool_v4_pythia-160m-retr-32k_w_meta_mb128-wb16384-grp1-1-8_128_16N_max_steps_14587_max_seq_512 \
#     --disable_net_gdr \
#     --extended_partition \
#     --repetitions=8

# python /XXXX-30/XXXX-29/XXXX-31/scratch/XXXX-22/lit-gpt-dev_new/launch_frontier.py \
#     --python_script="/XXXX-30/XXXX-29/XXXX-31/scratch/XXXX-22/lit-gpt-dev_new/train_retrieval_w_anticausal.py" \
#     --rccl_installdir="${WRKSPC}/aws-ofi-rccl_25_62_retrieval" \
#     --environment="${WRKSPC}/frontier_conda_25_62_retrieval" \
#     --budget_minutes=119 \
#     --nodes 16 \
#     --output_dir=/XXXX-30/XXXX-29/XXXX-31/scratch/XXXX-22/output \
#     --config /XXXX-30/XXXX-29/XXXX-31/scratch/XXXX-22/lit-gpt-dev_new/launch_configs/retrieval/base_nomic_finetune_lockstep.json \
#     --run_name axonn_nomic_finetune_phase3_pt_step_20k_no_combine_lr_2e-5_w_lockstep_wb_tgrp_8_negs_4096_mean_pool_v4_pythia-160m-retr-32k_w_meta_truncate_normal_mb16-wb2048-grp1-1-8_32_16N_max_steps_2484_max_seq_512 \
#     --extra_args='--lockstep_sampling=world_batch --data_telemetry=10 --max_seq_len=512 --fabric_strategy="axonn_tp" --attn_impl=rocm --fabric.depth_tensor_parallel_size=8 --batch_prefix_and_suffix=true --model_name=pythia-160m-retr-32k_w_meta --pad_to_block_size=true --train_group_size=8 --mean_pooling=true --max_steps=2484 --negatives_cross_device=true --save_step_interval=500 --eval_step_interval=20000 --negatives_cross_device_group_size=32 --optim_config.lr=2e-5 --min_lr=2e-6 --save_n_min_before_job_done=2 --world_batch_size=2048 --micro_batch_size=16 --validate_at_end=false --save_last_step=true --axonn_ckpt="/XXXX-30/XXXX-29/XXXX-31/scratch/XXXX-22/output/ft_nomic_pos_only_pt_step_20k_w_lockstep_wb_negs_16384_mean_pool_v4_pythia-160m-retr-32k_w_meta_mb128-wb16384-grp1-1-8_128_16N_max_steps_14587_max_seq_512/checkpoints-AxonnStrategy"' \
#     --sub_output_dir_name axonn_nomic_finetune_phase3_pt_step_20k_no_combine_lr_2e-5_w_lockstep_wb_tgrp_8_negs_4096_mean_pool_v4_pythia-160m-retr-32k_w_meta_truncate_normal_mb16-wb2048-grp1-1-8_32_16N_max_steps_2484_max_seq_512 \
#     --disable_net_gdr \
#     --extended_partition \

# python launch_frontier.py \
#     --python_script="train_retrieval_w_anticausal.py" \
#     --rccl_installdir="${WRKSPC}/aws-ofi-rccl_25_62_retrieval" \
#     --environment="${WRKSPC}/frontier_conda_25_62_retrieval" \
#     --budget_minutes=119 \
#     --nodes 16 \
#     --output_dir=/XXXX-30/XXXX-29/XXXX-31/scratch/XXXX-22/output \
#     --config launch_configs/retrieval/base_nomic_positive_only_finetune_lockstep.json \
#     --run_name ft_nomic_pos_only_pt_step_20k_symmetric_meta_tok_w_lockstep_wb_negs_16384_mean_pool_v4_pythia-160m-retr-32k_w_meta_mb128-wb16384-grp1-1-8_128_16N_max_steps_14587_max_seq_512 \
#     --extra_args='--train_symmetric=true --query_tok_id=32002 --lockstep_sampling=world_batch --data_telemetry=10 --max_seq_len=512 --fabric_strategy="axonn_tp" --attn_impl=rocm --fabric.depth_tensor_parallel_size=8 --batch_prefix_and_suffix=true --model_name=pythia-160m-retr-32k_w_meta --pad_to_block_size=true --mean_pooling=true --max_steps=14587 --negatives_cross_device=true --save_step_interval=500 --eval_step_interval=20000 --negatives_cross_device_group_size=128 --optim_config.lr=2e-5 --min_lr=2e-6 --save_n_min_before_job_done=2 --world_batch_size=16384 --micro_batch_size=128 --validate_at_end=false --save_last_step=true --finetune_checkpoint="/XXXX-30/XXXX-29/XXXX-31/scratch/XXXX-22/output/v4_pythia-160m-retr-32k_w_meta_mb2-wb2048-grp1024_keep_368k_negs_128N_truncate_normal/combined_ckpts/step-00020000_ckpt.pth" --data_config="/XXXX-30/XXXX-29/XXXX-31/scratch/XXXX-22/lit-gpt-dev_new/launch_configs/retrieval/nomic_positive_only_symmetric_meta_tok_src_separated_norm_weighted.json"' \
#     --sub_output_dir_name ft_nomic_pos_only_pt_step_20k_symmetric_meta_tok_w_lockstep_wb_negs_16384_mean_pool_v4_pythia-160m-retr-32k_w_meta_mb128-wb16384-grp1-1-8_128_16N_max_steps_14587_max_seq_512 \
#     --disable_net_gdr \
#     --extended_partition \
#     --repetitions=8

# python /XXXX-30/XXXX-29/XXXX-31/scratch/XXXX-22/lit-gpt-dev_new/train_retrieval_w_anticausal.py \
#     --out_dir=/XXXX-30/XXXX-29/XXXX-31/scratch/XXXX-22/interactive \
#     --config /XXXX-30/XXXX-29/XXXX-31/scratch/XXXX-22/lit-gpt-dev_new/launch_configs/retrieval/base_nomic_finetune_lockstep.json \
#     --run_name test \
#     --train_symmetric=true --query_tok_id=32002 --lockstep_sampling=world_batch --data_config="/XXXX-30/XXXX-29/XXXX-31/scratch/XXXX-22/lit-gpt-dev_new/launch_configs/retrieval/nomic_supervised_symmetric_meta_tok_src_separated_norm_weighted.json" --max_seq_len=512 --fabric_strategy="axonn_tp" --attn_impl=rocm --fabric.depth_tensor_parallel_size=1 --batch_prefix_and_suffix=true --model_name=pythia-160m-retr-32k_w_meta --pad_to_block_size=true --train_group_size=8 --mean_pooling=true --max_steps=2062 --negatives_cross_device=true --negatives_cross_device_group_size=1 --save_step_interval=500 --eval_step_interval=20000 --optim_config.lr=5e-5 --min_lr=5e-6 --save_n_min_before_job_done=2 --world_batch_size=16 --micro_batch_size=16 --validate_at_end=false --save_last_step=true --finetune_checkpoint="/XXXX-30/XXXX-29/XXXX-31/scratch/XXXX-22/output/ft_nomic_pos_only_pt_step_125k_w_lockstep_wb_negs_16384_mean_pool_v3_pythia-160m-retr-32k_w_meta_mb128-wb16384-grp1-1-8_128_16N_max_steps_14587_max_seq_512/combined_ckpts/step-00014586_ckpt.pth"
# pythonAll /XXXX-30/XXXX-29/XXXX-31/scratch/XXXX-22/lit-gpt-dev_new/train_retrieval_w_anticausal.py \
#     --out_dir=/XXXX-30/XXXX-29/XXXX-31/scratch/XXXX-22/interactive \
#     --config /XXXX-30/XXXX-29/XXXX-31/scratch/XXXX-22/lit-gpt-dev_new/launch_configs/retrieval/base_nomic_finetune_lockstep.json \
#     --run_name test \
#     --train_symmetric=true --query_tok_id=32002 --lockstep_sampling=world_batch --data_config="/XXXX-30/XXXX-29/XXXX-31/scratch/XXXX-22/lit-gpt-dev_new/launch_configs/retrieval/nomic_supervised_symmetric_meta_tok_src_separated_norm_weighted.json" --max_seq_len=512 --fabric_strategy="axonn_tp" --attn_impl=rocm --fabric.depth_tensor_parallel_size=8 --batch_prefix_and_suffix=true --model_name=pythia-160m-retr-32k_w_meta --pad_to_block_size=true --train_group_size=8 --mean_pooling=true --max_steps=2062 --negatives_cross_device=true --negatives_cross_device_group_size=1 --save_step_interval=500 --eval_step_interval=20000 --optim_config.lr=5e-5 --min_lr=5e-6 --save_n_min_before_job_done=2 --world_batch_size=128 --micro_batch_size=16 --validate_at_end=false --save_last_step=true --axonn_ckpt="/XXXX-30/XXXX-29/XXXX-31/scratch/XXXX-22/output/ft_nomic_pos_only_pt_step_125k_w_lockstep_wb_negs_16384_mean_pool_v3_pythia-160m-retr-32k_w_meta_mb128-wb16384-grp1-1-8_128_16N_max_steps_14587_max_seq_512/combined_ckpts/step-00014586_ckpt.pth"


# python /XXXX-30/XXXX-29/XXXX-31/scratch/XXXX-22/lit-gpt-dev_new/launch_frontier.py \
#     --python_script="/XXXX-30/XXXX-29/XXXX-31/scratch/XXXX-22/lit-gpt-dev_new/train_retrieval_w_anticausal.py" \
#     --rccl_installdir="${WRKSPC}/aws-ofi-rccl_25_62_retrieval" \
#     --environment="${WRKSPC}/frontier_conda_25_62_retrieval" \
#     --budget_minutes=119 \
#     --nodes 16 \
#     --output_dir=/XXXX-30/XXXX-29/XXXX-31/scratch/XXXX-22/output \
#     --config /XXXX-30/XXXX-29/XXXX-31/scratch/XXXX-22/lit-gpt-dev_new/launch_configs/retrieval/base_nomic_finetune_lockstep.json \
#     --run_name axonn_nomic_finetune_phase3_pt_step_20k_lr_2e-5_w_lockstep_wb_tgrp_8_negs_4096_mean_pool_v4_pythia-160m-retr-32k_w_meta_truncate_normal_mb16-wb2048-grp1-1-8_32_16N_max_steps_2484_max_seq_512 \
#     --extra_args='--lockstep_sampling=world_batch --data_telemetry=10 --max_seq_len=512 --fabric_strategy="axonn_tp" --attn_impl=rocm --fabric.depth_tensor_parallel_size=8 --batch_prefix_and_suffix=true --model_name=pythia-160m-retr-32k_w_meta --pad_to_block_size=true --train_group_size=8 --mean_pooling=true --max_steps=2484 --negatives_cross_device=true --save_step_interval=500 --eval_step_interval=20000 --negatives_cross_device_group_size=32 --optim_config.lr=2e-5 --min_lr=2e-6 --save_n_min_before_job_done=2 --world_batch_size=2048 --micro_batch_size=16 --validate_at_end=false --save_last_step=true --finetune_checkpoint="/XXXX-30/XXXX-29/XXXX-31/scratch/XXXX-22/output/ft_nomic_pos_only_pt_step_20k_w_lockstep_wb_negs_16384_mean_pool_v4_pythia-160m-retr-32k_w_meta_mb128-wb16384-grp1-1-8_128_16N_max_steps_14587_max_seq_512/combined_ckpts/step-00014586_ckpt.pth"' \
#     --sub_output_dir_name axonn_nomic_finetune_phase3_pt_step_20k_lr_2e-5_w_lockstep_wb_tgrp_8_negs_4096_mean_pool_v4_pythia-160m-retr-32k_w_meta_truncate_normal_mb16-wb2048-grp1-1-8_32_16N_max_steps_2484_max_seq_512 \
#     --disable_net_gdr \
#     --extended_partition \

# wait_for_job_completion
# for i in {1..50}
# do
#     python launch_frontier.py \
#         --rccl_installdir="${WRKSPC}/aws-ofi-rccl_25_62_retrieval.tar.gz" \
#         --environment="${WRKSPC}/frontier_conda_25_62_retrieval.tar.gz" \
#         --python_script=train_retrieval_w_anticausal.py \
#         --config=launch_configs/retrieval/base_optim_longwu_highlr_cos_fineweb_100B.json \
#         --run_name=v4_fineweb_100b_pythia-160m-retr-32k_w_meta_mb2-wb2048-grp1024_keep_368k_negs_128N_truncate_normal \
#         --output_dir=/XXXX-30/XXXX-29/XXXX-31/scratch/XXXX-22/output \
#         --sub_output_dir_name=v4_fineweb_100b_pythia-160m-retr-32k_w_meta_mb2-wb2048-grp1024_keep_368k_negs_128N_truncate_normal \
#         --budget_minutes=58 \
#         --nodes=128 \
#         --extra_args="--keep_k_cross_device_negatives=368640 --length_shortcut_ablation=truncate_lens_100_normal  --micro_batch_size=2 --world_batch_size=2048 --negatives_cross_device_group_size=1024 --target_token_count=100e9 --max_tokens=null --max_steps=72113 --warmup_steps=4000 --optim_config.lr=2e-3 --min_lr=2e-4 --fabric_strategy=axonn_tp --attn_impl=rocm --fabric.depth_tensor_parallel_size=8 --save_n_min_before_job_done=5 --wandb_tags='[prod,fineweb_100b,160m,v3,25_62_env]'" \
#         --debug_qos
#     echo "Current iteration: $i ..."
#     wait_for_job_completion

#     echo "Iteration $i completed."
# done

# python launch_frontier.py \
#     --python_script="train_retrieval_w_anticausal.py" \
#     --rccl_installdir="${WRKSPC}/aws-ofi-rccl_25_62_retrieval" \
#     --environment="${WRKSPC}/frontier_conda_25_62_retrieval" \
#     --budget_minutes=119 \
#     --nodes 16 \
#     --output_dir=/XXXX-30/XXXX-29/XXXX-31/scratch/XXXX-22/output \
#     --config launch_configs/retrieval/base_nomic_positive_only_finetune_lockstep.json \
#     --run_name ft_nomic_pos_only_all_split_gold_split_max_seq_512_truncate_normal_w_lockstep_wb_negs_16384_mean_pool_v4_pythia-160m-retr-32k_w_meta_mb128-wb16384-grp1-1-8_128_16N_max_steps_14587_max_seq_512 \
#     --extra_args='--lockstep_sampling=world_batch --data_telemetry=10 --max_seq_len=512 --fabric_strategy="axonn_tp" --attn_impl=rocm --fabric.depth_tensor_parallel_size=8 --batch_prefix_and_suffix=true --model_name=pythia-160m-retr-32k_w_meta --pad_to_block_size=true --mean_pooling=true --max_steps=14587 --negatives_cross_device=true --save_step_interval=500 --eval_step_interval=20000 --negatives_cross_device_group_size=128 --optim_config.lr=2e-5 --min_lr=2e-6 --save_n_min_before_job_done=2 --world_batch_size=16384 --micro_batch_size=128 --validate_at_end=false --save_last_step=true --finetune_checkpoint="/XXXX-30/XXXX-29/XXXX-31/scratch/XXXX-22/output/pt_nomic_pos_only_all_split_max_seq_512_truncate_normal_negs_131k_v3_pythia-160m-retr-32k_w_meta_mb32-wb16384-grp1-1-8_8_64N_max_steps_14587/combined_ckpts/step-00014586_ckpt.pth"' \
#     --sub_output_dir_name ft_nomic_pos_only_all_split_gold_split_max_seq_512_truncate_normal_w_lockstep_wb_negs_16384_mean_pool_v4_pythia-160m-retr-32k_w_meta_mb128-wb16384-grp1-1-8_128_16N_max_steps_14587_max_seq_512 \
#     --repetitions=8

# for i in {1..3}
# do
#     python launch_frontier.py \
#         --python_script="train_retrieval_w_anticausal.py" \
#         --rccl_installdir="${WRKSPC}/aws-ofi-rccl_25_62_retrieval" \
#         --environment="${WRKSPC}/frontier_conda_25_62_retrieval" \
#         --budget_minutes=58 \
#         --nodes 64 \
#         --output_dir=/XXXX-30/XXXX-29/XXXX-31/scratch/XXXX-22/output \
#         --config launch_configs/retrieval/base_optim_longwu_highlr_cos_nomic_pos_only.json \
#         --run_name pt_nomic_pos_only_all_split_max_seq_512_truncate_normal_negs_131k_v3_pythia-160m-retr-32k_w_meta_mb32-wb16384-grp1-1-8_8_64N_max_steps_14587 \
#         --extra_args='--keep_k_cross_device_negatives=131072 --length_shortcut_ablation=truncate_lens_100_normal --lockstep_sampling=world_batch --data_telemetry=10 --max_seq_len=512 --fabric_strategy="axonn_tp" --attn_impl=rocm --fabric.depth_tensor_parallel_size=8 --batch_prefix_and_suffix=true --model_name=pythia-160m-retr-32k_w_meta --max_steps=14587 --negatives_cross_device=true --save_step_interval=500 --eval_step_interval=20000 --negatives_cross_device_group_size=8 --warmup_steps=700 --optim_config.lr=2e-5 --min_lr=2e-6 --save_n_min_before_job_done=2 --world_batch_size=16384 --micro_batch_size=32 --validate_at_end=false --save_last_step=true --wandb_tags="[pt-phase-2,160m,v3,25_62_env]" --finetune_checkpoint=null' \
#         --sub_output_dir_name pt_nomic_pos_only_all_split_max_seq_512_truncate_normal_negs_131k_v3_pythia-160m-retr-32k_w_meta_mb32-wb16384-grp1-1-8_8_64N_max_steps_14587 \
#         --disable_net_gdr \
#         --debug_qos 
#         # \
#         # --extended_partition
#     echo "Current iteration: $i ..."
#     wait_for_job_completion
#     echo "Iteration $i completed."
# done

# python /XXXX-30/XXXX-29/XXXX-31/scratch/XXXX-22/lit-gpt-dev_mlm/launch_frontier.py \
#     --python_script="/XXXX-30/XXXX-29/XXXX-31/scratch/XXXX-22/lit-gpt-dev_mlm/finetune_lm_retrieval.py" \
#     --rccl_installdir="${WRKSPC}/aws-ofi-rccl_25_62_retrieval" \
#     --environment="${WRKSPC}/frontier_conda_25_62_retrieval" \
#     --budget_minutes=118 \
#     --nodes 16 \
#     --output_dir=/XXXX-30/XXXX-29/XXXX-31/scratch/XXXX-22/output \
#     --config /XXXX-30/XXXX-29/XXXX-31/scratch/XXXX-22/lit-gpt-dev_mlm/launch_configs/retrieval/base_nomic_finetune_lockstep.json \
#     --run_name axonn_nomic_finetune_mlm_w_lockstep_world_batch_train_grp_8_negs_4096_mean_pool_pythia-160m_mb16-wb2048-grp1-1-8_32_16N_max_steps_2484_max_seq_512 \
#     --extra_args='--mlm=true --lockstep_sampling=world_batch --data_telemetry=10 --max_seq_len=512 --fabric_strategy="axonn_tp" --attn_impl=sdpa --fabric.depth_tensor_parallel_size=8 --batch_prefix_and_suffix=true --model_name=pythia-160m --pad_to_block_size=true --train_group_size=8 --mean_pooling=true --max_steps=2484 --negatives_cross_device=true --negatives_cross_device_group_size=32 --save_step_interval=500 --eval_step_interval=20000 --optim_config.lr=5e-5 --min_lr=5e-6 --save_n_min_before_job_done=2 --world_batch_size=2048 --micro_batch_size=16 --validate_at_end=false --save_last_step=true --wandb_tags="[phase-3,mlm,160m,v3,25_62_env]" --finetune_checkpoint="/XXXX-30/XXXX-29/XXXX-31/proj-shared/mlm_pretraing_run_v1_100BT/mlm_pretraing_run_v1_100BT/combined_ckpts/step-00072000_ckpt.pth"' \
#     --sub_output_dir_name axonn_nomic_finetune_mlm_w_lockstep_world_batch_train_grp_8_negs_4096_mean_pool_pythia-160m_mb16-wb2048-grp1-1-8_32_16N_max_steps_2484_max_seq_512 \
#     --disable_net_gdr \
#     --debug_qos

# pythonAll /XXXX-30/XXXX-29/XXXX-31/scratch/XXXX-22/lit-gpt-dev_mlm/finetune_lm_retrieval.py \
#     --out_dir=/XXXX-30/XXXX-29/XXXX-31/scratch/XXXX-22/interactive \
#     --config /XXXX-30/XXXX-29/XXXX-31/scratch/XXXX-22/lit-gpt-dev_mlm/launch_configs/retrieval/base_nomic_finetune_lockstep.json \
#     --run_name test \
#     --mlm=true --lockstep_sampling=world_batch --max_seq_len=512 --fabric_strategy="axonn_tp" --attn_impl=sdpa --fabric.depth_tensor_parallel_size=8 --batch_prefix_and_suffix=true --model_name=pythia-160m --pad_to_block_size=true --train_group_size=8 --mean_pooling=true --max_steps=2062 --negatives_cross_device=true --negatives_cross_device_group_size=32 --save_step_interval=500 --eval_step_interval=20000 --optim_config.lr=5e-5 --min_lr=5e-6 --save_n_min_before_job_done=2 --world_batch_size=2048 --micro_batch_size=16 --validate_at_end=false --save_last_step=true --finetune_checkpoint="/XXXX-30/XXXX-29/XXXX-31/proj-shared/mlm_pretraing_run_v1_100BT/mlm_pretraing_run_v1_100BT/combined_ckpts/step-00072000_ckpt.pth"

# wait_for_job_completion
# python /XXXX-30/XXXX-29/XXXX-31/scratch/XXXX-22/lit-gpt-dev_new/launch_frontier.py \
#     --python_script="/XXXX-30/XXXX-29/XXXX-31/scratch/XXXX-22/lit-gpt-dev_new/train_retrieval_w_anticausal.py" \
#     --rccl_installdir="${WRKSPC}/aws-ofi-rccl_25_62_retrieval" \
#     --environment="${WRKSPC}/frontier_conda_25_62_retrieval" \
#     --budget_minutes=118 \
#     --nodes 16 \
#     --config /XXXX-30/XXXX-29/XXXX-31/scratch/XXXX-22/lit-gpt-dev_new/launch_configs/retrieval/base_nomic_finetune_lockstep.json \
#     --run_name axonn_nomic_finetune_zero_pretrain_v2_w_lockstep_world_batch_train_grp_8_negs_4096_mean_pool_v3_pythia-160m-retr-32k_w_meta_mb16-wb2048-grp1-1-8_16_16N_max_steps_2484_max_seq_512 \
#     --extra_args='--lockstep_sampling=world_batch --data_telemetry=10 --max_seq_len=512 --fabric_strategy="axonn_tp" --attn_impl=rocm --fabric.depth_tensor_parallel_size=8 --batch_prefix_and_suffix=true --model_name=pythia-160m-retr-32k_w_meta --pad_to_block_size=true --train_group_size=8 --mean_pooling=true --max_steps=2484 --negatives_cross_device=true --save_step_interval=500 --eval_step_interval=20000 --negatives_cross_device_group_size=32 --optim_config.lr=5e-5 --min_lr=5e-6 --save_n_min_before_job_done=2 --world_batch_size=2048 --micro_batch_size=16 --validate_at_end=false --save_last_step=true --wandb_tags="[phase-3,160m,v3,25_62_env]" --finetune_checkpoint=null --pretrained_prefix_model=false --pretrained_suffix_model=false' \
#     --sub_output_dir_name axonn_nomic_finetune_zero_pretrain_v2_w_lockstep_world_batch_train_grp_8_negs_4096_mean_pool_v3_pythia-160m-retr-32k_w_meta_mb16-wb2048-grp1-1-8_16_16N_max_steps_2484_max_seq_512 \
#     --disable_net_gdr \
#     --debug_qos
# echo "Current iteration: axonn_nomic_finetune_zero_pretrain_v2 ..."
# wait_for_job_completion

# python /XXXX-30/XXXX-29/XXXX-31/scratch/XXXX-22/lit-gpt-dev_new/launch_frontier.py \
#     --python_script="/XXXX-30/XXXX-29/XXXX-31/scratch/XXXX-22/lit-gpt-dev_new/train_retrieval_w_anticausal.py" \
#     --rccl_installdir="${WRKSPC}/aws-ofi-rccl_25_62_retrieval" \
#     --environment="${WRKSPC}/frontier_conda_25_62_retrieval" \
#     --budget_minutes=119 \
#     --nodes 4 \
#     --output_dir=/XXXX-30/XXXX-29/XXXX-31/scratch/XXXX-22/output \
#     --config /XXXX-30/XXXX-29/XXXX-31/scratch/XXXX-22/lit-gpt-dev_new/launch_configs/retrieval/base_orca_finetune.json \
#     --run_name axonn_orca_finetune_phase3_pt_step_125k_negs_64_mean_pool_v3_pythia-160m-retr-32k_w_meta_mb64-wb2048-grp1-1-8_1_4N_max_steps_2062_max_seq_512 \
#     --extra_args='--max_seq_len=512 --fabric_strategy="axonn_tp" --attn_impl=rocm --fabric.depth_tensor_parallel_size=8 --batch_prefix_and_suffix=true --model_name=pythia-160m-retr-32k_w_meta --pad_to_block_size=true --mean_pooling=true --max_steps=2062 --negatives_cross_device=true --negatives_cross_device_group_size=1 --save_step_interval=1000 --eval_step_interval=20000 --optim_config.lr=5e-5 --min_lr=5e-6 --save_n_min_before_job_done=2 --world_batch_size=2048 --micro_batch_size=64 --validate_at_end=false --save_last_step=true --finetune_checkpoint="/XXXX-30/XXXX-29/XXXX-31/scratch/XXXX-22/output/axonn_nomic_finetune_phase3_pt_step_125k_w_lockstep_world_batch_train_grp_8_negs_4096_mean_pool_v3_pythia-160m-retr-32k_w_meta_mb16-wb2048-grp1-1-8_32_16N_max_steps_2484_max_seq_512/combined_ckpts/step-00002483_ckpt.pth"' \
#     --sub_output_dir_name axonn_orca_finetune_phase3_pt_step_125k_negs_64_mean_pool_v3_pythia-160m-retr-32k_w_meta_mb64-wb2048-grp1-1-8_1_4N_max_steps_2062_max_seq_512 \
#     --disable_net_gdr \
#     --extended_partition \

# python /XXXX-30/XXXX-29/XXXX-31/scratch/XXXX-22/lit-gpt-dev_new/launch_frontier.py \
#     --python_script="/XXXX-30/XXXX-29/XXXX-31/scratch/XXXX-22/lit-gpt-dev_new/train_retrieval_w_anticausal.py" \
#     --rccl_installdir="${WRKSPC}/aws-ofi-rccl_25_62_retrieval" \
#     --environment="${WRKSPC}/frontier_conda_25_62_retrieval" \
#     --budget_minutes=118 \
#     --nodes 4 \
#     --output_dir=/XXXX-30/XXXX-29/XXXX-31/scratch/XXXX-22/output \
#     --config /XXXX-30/XXXX-29/XXXX-31/scratch/XXXX-22/lit-gpt-dev_new/launch_configs/retrieval/base_orca_finetune.json \
#     --run_name axonn_orca_finetune_phase3_pt_step_20k_negs_64_mean_pool_v3_pythia-160m-retr-32k_w_meta_mb64-wb2048-grp1-1-8_1_4N_max_steps_2062_max_seq_512 \
#     --extra_args='--max_seq_len=512 --fabric_strategy="axonn_tp" --attn_impl=rocm --fabric.depth_tensor_parallel_size=8 --batch_prefix_and_suffix=true --model_name=pythia-160m-retr-32k_w_meta --pad_to_block_size=true --mean_pooling=true --max_steps=2062 --negatives_cross_device=true --negatives_cross_device_group_size=1 --save_step_interval=1000 --eval_step_interval=20000 --optim_config.lr=5e-5 --min_lr=5e-6 --save_n_min_before_job_done=2 --world_batch_size=2048 --micro_batch_size=64 --validate_at_end=false --save_last_step=true --finetune_checkpoint="/XXXX-30/XXXX-29/XXXX-31/scratch/XXXX-22/output/axonn_nomic_finetune_phase3_w_lockstep_world_batch_train_grp_8_negs_4096_mean_pool_v3_pythia-160m-retr-32k_w_meta_mb16-wb2048-grp1-1-8_16_16N_max_steps_2484_max_seq_512/combined_ckpts/step-00002483_ckpt.pth"' \
#     --sub_output_dir_name axonn_orca_finetune_phase3_pt_step_20k_negs_64_mean_pool_v3_pythia-160m-retr-32k_w_meta_mb64-wb2048-grp1-1-8_1_4N_max_steps_2062_max_seq_512 \
#     --disable_net_gdr \
#     --extended_partition \

# python /XXXX-30/XXXX-29/XXXX-31/scratch/XXXX-22/lit-gpt-dev_new/launch_frontier.py \
#     --python_script="/XXXX-30/XXXX-29/XXXX-31/scratch/XXXX-22/lit-gpt-dev_new/train_retrieval_w_anticausal.py" \
#     --rccl_installdir="${WRKSPC}/aws-ofi-rccl_25_62_retrieval" \
#     --environment="${WRKSPC}/frontier_conda_25_62_retrieval" \
#     --budget_minutes=118 \
#     --nodes 4 \
#     --output_dir=/XXXX-30/XXXX-29/XXXX-31/scratch/XXXX-22/output \
#     --config /XXXX-30/XXXX-29/XXXX-31/scratch/XXXX-22/lit-gpt-dev_new/launch_configs/retrieval/base_orca_finetune.json \
#     --run_name axonn_orca_finetune_fineweb_stack_pt_step_20k_negs_64_mean_pool_v3_pythia-160m-retr-32k_w_meta_mb64-wb2048-grp1-1-8_1_4N_max_steps_2062_max_seq_512 \
#     --extra_args='--max_seq_len=512 --fabric_strategy="axonn_tp" --attn_impl=rocm --fabric.depth_tensor_parallel_size=8 --batch_prefix_and_suffix=true --model_name=pythia-160m-retr-32k_w_meta --pad_to_block_size=true --mean_pooling=true --max_steps=2062 --negatives_cross_device=true --negatives_cross_device_group_size=1 --save_step_interval=1000 --eval_step_interval=20000 --optim_config.lr=8e-5 --min_lr=8e-6 --save_n_min_before_job_done=2 --world_batch_size=2048 --micro_batch_size=64 --validate_at_end=false --save_last_step=true --finetune_checkpoint="/XXXX-30/XXXX-29/XXXX-31/scratch/XXXX-22/output/v3_fineweb_stack_pythia-160m-retr-32k_w_meta_mb2-wb2048-grp128_128N/combined_ckpts/step-00020000_ckpt.pth"' \
#     --sub_output_dir_name axonn_orca_finetune_fineweb_stack_pt_step_20k_negs_64_mean_pool_v3_pythia-160m-retr-32k_w_meta_mb64-wb2048-grp1-1-8_1_4N_max_steps_2062_max_seq_512 \
#     --disable_net_gdr \
#     --debug_qos
# echo "Current iteration: axonn_orca_finetune_fineweb_stack_pt_step_20k ..."
# wait_for_job_completion

# pythonAll train_retrieval_w_anticausal.py \
#     --out_dir=/XXXX-30/XXXX-29/XXXX-31/scratch/XXXX-22/interactive \
#     --config /XXXX-30/XXXX-29/XXXX-31/scratch/XXXX-22/lit-gpt-dev_new/launch_configs/retrieval/base_orca_finetune.json \
#     --run_name test \
#     --max_seq_len=512 --fabric_strategy="axonn_tp" --attn_impl=rocm --fabric.depth_tensor_parallel_size=1 --batch_prefix_and_suffix=true --model_name=pythia-160m-retr-32k_w_meta --pad_to_block_size=true --mean_pooling=true --max_steps=2062 --negatives_cross_device=true --negatives_cross_device_group_size=1 --save_step_interval=1000 --eval_step_interval=20000 --optim_config.lr=8e-5 --min_lr=8e-6 --save_n_min_before_job_done=2 --world_batch_size=64 --micro_batch_size=64 --validate_at_end=false --save_last_step=true --finetune_checkpoint="/XXXX-30/XXXX-29/XXXX-31/scratch/XXXX-22/output/v3_fineweb_stack_pythia-160m-retr-32k_w_meta_mb2-wb2048-grp128_128N/combined_ckpts/step-00020000_ckpt.pth"

# python /XXXX-30/XXXX-29/XXXX-31/scratch/XXXX-22/lit-gpt-dev_new/launch_frontier.py \
#     --python_script="/XXXX-30/XXXX-29/XXXX-31/scratch/XXXX-22/lit-gpt-dev_new/train_retrieval_w_anticausal.py" \
#     --rccl_installdir="${WRKSPC}/aws-ofi-rccl_25_62_retrieval" \
#     --environment="${WRKSPC}/frontier_conda_25_62_retrieval" \
#     --budget_minutes=118 \
#     --nodes 16 \
#     --config /XXXX-30/XXXX-29/XXXX-31/scratch/XXXX-22/lit-gpt-dev_new/launch_configs/retrieval/base_nomic_finetune_lockstep.json \
#     --run_name axonn_nomic_finetune_phase3_pt_step_125k_w_lockstep_world_batch_train_grp_8_negs_4096_mean_pool_v3_pythia-160m-retr-32k_w_meta_mb16-wb2048-grp1-1-8_32_16N_max_steps_2484_max_seq_512 \
#     --extra_args='--lockstep_sampling=world_batch --data_telemetry=10 --max_seq_len=512 --fabric_strategy="axonn_tp" --attn_impl=rocm --fabric.depth_tensor_parallel_size=8 --batch_prefix_and_suffix=true --model_name=pythia-160m-retr-32k_w_meta --pad_to_block_size=true --train_group_size=8 --mean_pooling=true --max_steps=2484 --negatives_cross_device=true --save_step_interval=500 --eval_step_interval=20000 --negatives_cross_device_group_size=32 --optim_config.lr=5e-5 --min_lr=5e-6 --save_n_min_before_job_done=2 --world_batch_size=2048 --micro_batch_size=16 --validate_at_end=false --save_last_step=true --finetune_checkpoint="/XXXX-30/XXXX-29/XXXX-31/scratch/XXXX-22/output/ft_nomic_pos_only_pt_step_125k_w_lockstep_wb_negs_16384_mean_pool_v3_pythia-160m-retr-32k_w_meta_mb128-wb16384-grp1-1-8_128_16N_max_steps_14587_max_seq_512/combined_ckpts/step-00014586_ckpt.pth"' \
#     --sub_output_dir_name axonn_nomic_finetune_phase3_pt_step_125k_w_lockstep_world_batch_train_grp_8_negs_4096_mean_pool_v3_pythia-160m-retr-32k_w_meta_mb16-wb2048-grp1-1-8_32_16N_max_steps_2484_max_seq_512 \
#     --disable_net_gdr \
#     --debug_qos
# echo "Current iteration: axonn_nomic_finetune_phase3_pt_step_125k ..."
# wait_for_job_completion
# wait_for_job_completion
# python /XXXX-30/XXXX-29/XXXX-31/scratch/XXXX-22/lit-gpt-dev_new/launch_frontier.py \
#     --python_script="/XXXX-30/XXXX-29/XXXX-31/scratch/XXXX-22/lit-gpt-dev_new/train_retrieval_w_anticausal.py" \
#     --rccl_installdir="${WRKSPC}/aws-ofi-rccl_25_62_retrieval" \
#     --environment="${WRKSPC}/frontier_conda_25_62_retrieval" \
#     --budget_minutes=118 \
#     --nodes 16 \
#     --config /XXXX-30/XXXX-29/XXXX-31/scratch/XXXX-22/lit-gpt-dev_new/launch_configs/retrieval/base_nomic_finetune_lockstep.json \
#     --run_name axonn_nomic_finetune_phase3_fineweb_stack_w_lockstep_world_batch_train_grp_8_negs_4096_mean_pool_v3_pythia-160m-retr-32k_w_meta_mb16-wb2048-grp1-1-8_32_16N_max_steps_2484_max_seq_512 \
#     --extra_args='--lockstep_sampling=world_batch --data_telemetry=10 --max_seq_len=512 --fabric_strategy="axonn_tp" --attn_impl=rocm --fabric.depth_tensor_parallel_size=8 --batch_prefix_and_suffix=true --model_name=pythia-160m-retr-32k_w_meta --pad_to_block_size=true --train_group_size=8 --mean_pooling=true --max_steps=2484 --negatives_cross_device=true --save_step_interval=500 --eval_step_interval=20000 --negatives_cross_device_group_size=32 --optim_config.lr=5e-5 --min_lr=5e-6 --save_n_min_before_job_done=2 --world_batch_size=2048 --micro_batch_size=16 --validate_at_end=false --save_last_step=true --finetune_checkpoint="/XXXX-30/XXXX-29/XXXX-31/scratch/XXXX-22/output/ft_nomic_pos_only_fineweb_stack_w_lockstep_wb_negs_16384_mean_pool_v3_pythia-160m-retr-32k_w_meta_mb128-wb16384-grp1-1-8_128_16N_max_steps_14587_max_seq_512/combined_ckpts/step-00014586_ckpt.pth"' \
#     --sub_output_dir_name axonn_nomic_finetune_phase3_fineweb_stack_w_lockstep_world_batch_train_grp_8_negs_4096_mean_pool_v3_pythia-160m-retr-32k_w_meta_mb16-wb2048-grp1-1-8_32_16N_max_steps_2484_max_seq_512 \
#     --disable_net_gdr \
#     --debug_qos
# echo "Current iteration: axonn_nomic_finetune_phase3_fineweb_stack ..."
# wait_for_job_completion

# for i in {1..2}
# do
#     python launch_frontier.py \
#         --python_script="train_retrieval_w_anticausal.py" \
#         --rccl_installdir="${WRKSPC}/aws-ofi-rccl_25_62_retrieval" \
#         --environment="${WRKSPC}/frontier_conda_25_62_retrieval" \
#         --budget_minutes=118 \
#         --nodes 16 \
#         --output_dir=/XXXX-30/XXXX-29/XXXX-31/scratch/XXXX-22/output \
#         --config launch_configs/retrieval/base_nomic_positive_only_finetune_lockstep.json \
#         --run_name ft_nomic_pos_only_pt_step_20k_w_lockstep_wb_negs_16384_mean_pool_v4_pythia-160m-retr-32k_w_meta_mb128-wb16384-grp1-1-8_128_16N_max_steps_14587_max_seq_512 \
#         --extra_args='--lockstep_sampling=world_batch --data_telemetry=10 --max_seq_len=512 --fabric_strategy="axonn_tp" --attn_impl=rocm --fabric.depth_tensor_parallel_size=8 --batch_prefix_and_suffix=true --model_name=pythia-160m-retr-32k_w_meta --pad_to_block_size=true --mean_pooling=true --max_steps=14587 --negatives_cross_device=true --save_step_interval=500 --eval_step_interval=20000 --negatives_cross_device_group_size=128 --optim_config.lr=2e-5 --min_lr=2e-6 --save_n_min_before_job_done=2 --world_batch_size=16384 --micro_batch_size=128 --validate_at_end=false --save_last_step=true --finetune_checkpoint="/XXXX-30/XXXX-29/XXXX-31/scratch/XXXX-22/output/v4_pythia-160m-retr-32k_w_meta_mb2-wb2048-grp1024_keep_368k_negs_128N_truncate_normal/combined_ckpts/step-00020000_ckpt.pth"' \
#         --sub_output_dir_name ft_nomic_pos_only_pt_step_20k_w_lockstep_wb_negs_16384_mean_pool_v4_pythia-160m-retr-32k_w_meta_mb128-wb16384-grp1-1-8_128_16N_max_steps_14587_max_seq_512 \
#         --disable_net_gdr \
#         --debug_qos
#     echo "Current iteration: $i ..."
#     wait_for_job_completion

    # python launch_frontier.py \
    #     --python_script="train_retrieval_w_anticausal.py" \
    #     --rccl_installdir="${WRKSPC}/aws-ofi-rccl_25_62_retrieval" \
    #     --environment="${WRKSPC}/frontier_conda_25_62_retrieval" \
    #     --budget_minutes=118 \
    #     --nodes 16 \
    #     --output_dir=/XXXX-30/XXXX-29/XXXX-31/scratch/XXXX-22/output \
    #     --config launch_configs/retrieval/base_nomic_positive_only_finetune_lockstep.json \
    #     --run_name ft_nomic_pos_only_zero_pretrain_v2_w_lockstep_wb_negs_16384_mean_pool_v3_pythia-160m-retr-32k_w_meta_mb128-wb16384-grp1-1-8_128_16N_max_steps_14587_max_seq_512 \
    #     --extra_args='--lockstep_sampling=world_batch --data_telemetry=10 --max_seq_len=512 --fabric_strategy="axonn_tp" --attn_impl=rocm --fabric.depth_tensor_parallel_size=8 --batch_prefix_and_suffix=true --model_name=pythia-160m-retr-32k_w_meta --pad_to_block_size=true --mean_pooling=true --max_steps=14587 --negatives_cross_device=true --save_step_interval=500 --eval_step_interval=20000 --negatives_cross_device_group_size=128 --optim_config.lr=2e-5 --min_lr=2e-6 --save_n_min_before_job_done=2 --world_batch_size=16384 --micro_batch_size=128 --validate_at_end=false --save_last_step=true --wandb_tags="[phase-2,160m,v3,25_62_env]" --finetune_checkpoint=null --pretrained_prefix_model=false --pretrained_suffix_model=false' \
    #     --sub_output_dir_name ft_nomic_pos_only_zero_pretrain_v2_w_lockstep_wb_negs_16384_mean_pool_v3_pythia-160m-retr-32k_w_meta_mb128-wb16384-grp1-1-8_128_16N_max_steps_14587_max_seq_512 \
    #     --disable_net_gdr \
    #     --debug_qos
    # echo "Current iteration: $i ..."
    # wait_for_job_completion

    # pythonAll train_retrieval_w_anticausal.py \
    #     --out_dir=/XXXX-30/XXXX-29/XXXX-31/scratch/XXXX-22/interactive \
    #     --config launch_configs/retrieval/base_optim_longwu_highlr_cos_nomic_pos_only.json \
    #     --run_name test \
    #     --keep_k_cross_device_negatives=131072 --lockstep_sampling=world_batch --data_telemetry=10 --max_seq_len=512 --fabric_strategy="axonn_tp" --attn_impl=rocm --fabric.depth_tensor_parallel_size=8 --batch_prefix_and_suffix=true --model_name=pythia-160m-retr-32k_w_meta --max_steps=14587 --negatives_cross_device=true --save_step_interval=500 --eval_step_interval=20000 --negatives_cross_device_group_size=8 --warmup_steps=700 --optim_config.lr=2e-5 --min_lr=2e-6 --save_n_min_before_job_done=2 --world_batch_size=4096 --micro_batch_size=32 --validate_at_end=false --save_last_step=true --wandb_tags="[pt-phase-2,160m,v3,25_62_env]" --finetune_checkpoint=null

#     echo "Iteration $i completed."
# done

# for i in {1..8}
# do
#     python launch_frontier.py \
#         --python_script="train_retrieval_w_anticausal.py" \
#         --rccl_installdir="${WRKSPC}/aws-ofi-rccl_25_62_retrieval" \
#         --environment="${WRKSPC}/frontier_conda_25_62_retrieval" \
#         --budget_minutes=118 \
#         --nodes 16 \
#         --output_dir=/XXXX-30/XXXX-29/XXXX-31/scratch/XXXX-22/output \
#         --config launch_configs/retrieval/base_nomic_positive_only_finetune_lockstep.json \
#         --run_name ft_nomic_pos_only_all_split_gold_split_lockstep_wb_negs_16384_mean_pool_v4_pythia-160m-retr-32k_w_meta_mb128-wb16384-grp1-1-8_128_16N_max_steps_14587_max_seq_512 \
#         --extra_args='--lockstep_sampling=world_batch --data_telemetry=10 --max_seq_len=512 --fabric_strategy="axonn_tp" --attn_impl=rocm --fabric.depth_tensor_parallel_size=8 --batch_prefix_and_suffix=true --model_name=pythia-160m-retr-32k_w_meta --pad_to_block_size=true --mean_pooling=true --max_steps=14587 --negatives_cross_device=true --save_step_interval=500 --eval_step_interval=20000 --negatives_cross_device_group_size=128 --optim_config.lr=2e-5 --min_lr=2e-6 --save_n_min_before_job_done=2 --world_batch_size=16384 --micro_batch_size=128 --validate_at_end=false --save_last_step=true --finetune_checkpoint="/XXXX-30/XXXX-29/XXXX-31/scratch/XXXX-22/output/pt_nomic_pos_only_all_split_negs_131k_v3_pythia-160m-retr-32k_w_meta_mb64-wb16384-grp1-1-8_8_64N_max_steps_14587_max_seq_512/combined_ckpts/step-00014586_ckpt.pth"' \
#         --sub_output_dir_name ft_nomic_pos_only_all_split_gold_split_lockstep_wb_negs_16384_mean_pool_v4_pythia-160m-retr-32k_w_meta_mb128-wb16384-grp1-1-8_128_16N_max_steps_14587_max_seq_512 \
#         --disable_net_gdr \
#         --debug_qos
#     echo "Current iteration: $i ..."
#     wait_for_job_completion

#     echo "Iteration $i completed."
# done

