# eg.
export CONDA_INSTALL_DIR=/XXXX-36/${USER}
export ENV_PATH=/XXXX-36/${USER}/pslm_25_conda

# launch command
# python launch_scripts/launch_nexus.py \
#    --python_script="/XXXX-36/XXXX-22/XXXX-40/train_retrieval_w_anticausal.py" \
#    --conda_install_dir=$CONDA_INSTALL_DIR \
#    --env_path=$ENV_PATH \
#    --run_name="dolma-retrieval-dual-causal-pythia-160m-worldbsz-16-ctx-rand-batch_negative_ddp_RR_lr_3e-4_debug" \
#    --sub_output_dir_name="dolma-retrieval-dual-causal-pythia-160m-worldbsz-16-ctx-rand-batch_negative_ddp_RR_lr_3e-4_debug" \
#    --config=/XXXX-36/XXXX-22/XXXX-40/launch_scripts/XXXX-22/nexus_jobs/retrieval_train_pythia.json \
#    --gpus_per_node=4 \
#    --gpu_type=rtxa5000 \
#    --budget_hours 2 \
#    --qos scavenger \
#    --account scavenger \
#    --partition scavenger \
#    --extra_args='--world_batch_size=32 --micro_batch_size=8 --fabric_strategy="ddp"'

python launch_nexus.py \
   --python_script="/XXXX-36/XXXX-22/XXXX-40/train.py" \
   --conda_install_dir=$CONDA_INSTALL_DIR \
   --env_path=$ENV_PATH \
   --run_name="pslm_from_pretrained" \
   --sub_output_dir_name="pslm_from_pretrained" \
   --config=launch_configs/base_optim_longwu_highlr_cos.json \
   --gpus_per_node=8 \
   --mem=64 \
   --gpu_type=rtxa6000 \
   --budget_hours 24 \
   --qos scavenger \
   --account scavenger \
   --partition scavenger \
   --extra_args='--pretrained_prefix_model=true --pretrained_suffix_model=true --model_checkpoint=/fs/XXXX-37/llm-pretraining/models/external/pythia-160m-deduped --keep_k_cross_device_negatives=368640 --length_shortcut_ablation=truncate_lens_100_normal --batch_prefix_and_suffix=false --warmup_steps=6000 --optim_config.lr=2e-3 --min_lr=2e-4 --max_tokens=null --max_steps=72000 --world_batch_size=64 --micro_batch_size=8 --fabric_strategy="axonn_tp" --attn_impl=sdpa --fabric.depth_tensor_parallel_size=8 --negatives_cross_device_group_size=8 --model_name=pythia-160m-deduped --tokenizer_path=/fs/XXXX-37/llm-pretraining/models/external/pythia-160m-deduped'

python launch_nexus.py \
   --python_script="/XXXX-36/XXXX-22/XXXX-40/train.py" \
   --conda_install_dir=$CONDA_INSTALL_DIR \
   --env_path=$ENV_PATH \
   --run_name="pslm_from_scratch" \
   --sub_output_dir_name="pslm_from_scratch" \
   --config=launch_configs/base_optim_longwu_highlr_cos.json \
   --gpus_per_node=8 \
   --mem=64 \
   --gpu_type=rtxa6000 \
   --budget_hours 24 \
   --qos scavenger \
   --account scavenger \
   --partition scavenger \
   --extra_args='--keep_k_cross_device_negatives=368640 --length_shortcut_ablation=truncate_lens_100_normal --batch_prefix_and_suffix=false --warmup_steps=6000 --optim_config.lr=2e-3 --min_lr=2e-4 --max_tokens=null --max_steps=72000 --world_batch_size=64 --micro_batch_size=8 --fabric_strategy="axonn_tp" --attn_impl=sdpa --fabric.depth_tensor_parallel_size=8 --negatives_cross_device_group_size=8 --model_name=pythia-160m-deduped --tokenizer_path=/fs/XXXX-37/llm-pretraining/models/external/pythia-160m-deduped'

python train.py \
    --config=launch_configs/base_optim_longwu_highlr_cos.json \
    --run_name=test \
    --out_dir=/XXXX-36/XXXX-22/output \
    --keep_k_cross_device_negatives=368640 \
    --length_shortcut_ablation=truncate_lens_100_normal \
    --batch_prefix_and_suffix=false \
    --micro_batch_size=8 \
    --world_batch_size=8 \
    --negatives_cross_device_group_size=1 \
    --max_tokens=null \
    --max_steps=131900 \
    --warmup_steps=6000 \
    --optim_config.lr=2e-3 \
    --min_lr=2e-4 \
    --fabric_strategy=axonn_tp \
    --attn_impl=sdpa \
    --fabric.depth_tensor_parallel_size=1 \
    --save_n_min_before_job_done=5 \
    --pretrained_prefix_model=true \
    --pretrained_suffix_model=true \
    --model_name=pythia-160m-deduped --tokenizer_path=/fs/XXXX-37/llm-pretraining/models/external/pythia-160m-deduped \
    --model_checkpoint=/fs/XXXX-37/llm-pretraining/models/external/pythia-160m-deduped \
    --wandb_tags='[prod,160m,v3,25_62_env]'

# python /XXXX-30/XXXX-29/XXXX-31/scratch/XXXX-22/XXXX-40/launch_scripts/launch_frontier.py \
#     --python_script="/XXXX-30/XXXX-29/XXXX-31/scratch/XXXX-22/XXXX-40/pretrain_umd/train_retrieval_w_anticausal.py" \
#     --rccl_installdir="${HOME}/tiny_plugins_rccl/lib" \
#     --env_packed="${HOME}/frontier_conda_env_packed.tar.gz" \
#     --budget_hours=24 \
#     --nodes 1 \
#     --config /XXXX-30/XXXX-29/XXXX-31/scratch/XXXX-22/XXXX-40/launch_scripts/XXXX-22/frontier_jobs/retrieval_train_pythia.json \
#     --run_name cosmopedia-retrieval-dual-causal-pythia-160m-bsz-25-ctx-rand-batch_negative_ddp_RR_lr_1e-4 \
#     --extra_args='--learning_rate=1e-4 --min_lr=1e-5 --fabric_strategy="ddp"' \
#     --sub_output_dir_name cosmopedia-retrieval-dual-causal-pythia-160m-bsz-25-ctx-rand-batch_negative_ddp_RR_lr_1e-4 \
#     --launch_immediately
#     "optim_config":{
#         "lr": 3e-4,
#         "weight_decay": 0.1,
#         "betas": [0.9, 0.95],
#         "eps": 1e-8
#     },