#!/bin/bash -l

# SLURM SUBMIT SCRIPT
#SBATCH --job-name=retrieval
#SBATCH --array=1
#SBATCH --nodes=1             # This needs to match Fabric(num_nodes=...)
#SBATCH --ntasks-per-node=7     # This needs to match Fabric(devices=...)
#SBATCH --gres=gpu:rtxa6000:7
#SBATCH --qos=scavenger
#SBATCH --partition=scavenger
#SBATCH --account=scavenger
#SBATCH --cpus-per-task=8
#SBATCH --mem=64G
#SBATCH --time=6:00:00
#SBATCH --output=slurm-%j_retrieval.out
# TO SET UP DEPENDENCIES SBATCH --dependency=singleton

ROOT_DIR=/fs/XXXX-37/llm-pretraining/llm-retrieval
NUM_NODES=1
NNODES=$SLURM_JOB_NUM_NODES
GPUS_PER_NODE=7 ## change as per your machine
GPUS=$(( NNODES * GPUS_PER_NODE )) 
GPUS=7

cd /XXXX-36/XXXX-22/XXXX-40

source ~/.bashrc
# conda activate pretraining
source activate /XXXX-36/XXXX-22/litgpt_conda

export MASTER_PORT=$(shuf -i 2000-65000 -n 1) # Remember that these are fixed across the entire array
export MASTER_ADDR=`/bin/hostname -s`

srun -N ${NUM_NODES} -n ${GPUS} python pretrain_umd/train_retrieval_w_anticausal.py \
    --model_checkpoint /fs/XXXX-37/llm-pretraining/llm-retrieval/checkpoints/EleutherAI/pythia-160m \
    --tokenizer_path /fs/XXXX-37/llm-pretraining/llm-retrieval/checkpoints/EleutherAI/pythia-160m \
    --out_dir /fs/XXXX-37/llm-pretraining/llm-retrieval/out/dolma-retrieval-dual-causal-pythia-160m-worldbsz-56-ctx-rand-batch_negative_ddp_RR_lr_3e-4_debug_sbatch \
    --resume True \
    --seed 1337 \
    --max_tokens 25_000_000_000 \
    --model_name pythia-160m \
    --run_name dolma-retrieval-dual-causal-pythia-160m-worldbsz-56-ctx-rand-batch_negative_ddp_RR_lr_3e-4_debug \
    --logger_name wandb \
    --compile_model False \
    --fabric_precision bf16-mixed \
    --world_batch_size 56 \
    --micro_batch_size 8 \
    --block_size 2048 \
    --n_chunks 4 \
    --warmup_steps 2000 \
    --log_step_interval 1 \
    --eval_iters 100 \
    --save_and_eval_interval 2000 \
    --grad_clip 1.0 \
    --decay_lr True \
    --learning_rate 3e-4 \
    --min_lr 3e-5 \
    --data_telemetry False \
    --data_config /XXXX-36/XXXX-22/XXXX-40/launch_scripts/XXXX-22/sample_hfds_only.json \
    --fabric_strategy ddp \
    --pretrained_prefix_model False \
    --pretrained_suffix_model False

# if [ $SLURM_ARRAY_TASK_ID == 1 ]; then
#     srun -N ${NUM_NODES} -n ${GPUS} python pretrain_umd/train_retrieval_w_anticausal.py \
#         --out_dir $ROOT_DIR/out/cosmopedia-retrieval-dual-causal-llama-1.1b-bsz-8-ctx-2048-batch_negative_packeddata \
#         --seed 1337 \
#         --model_name tiny-llama-1.1b \
#         --run_name cosmopedia-retrieval-dual-causal-llama-1.1b-bsz-1-ctx-2048-batch_negative_packeddata  \
#         --logger_name wandb \
#         --compile_model False \
#         --fabric_precision "bf16-mixed" \
#         --world_batch_size 8 \
#         --micro_batch_size 4 \
#         --learning_rate 0.0016 \
#         --block_size 2048 \
#         --n_chunks 4 \
#         --max_tokens "1_000_000_000_000" \
#         --warmup_steps 2000 \
#         --log_step_interval 1 \
#         --eval_iters 100 \
#         --save_and_eval_interval 2000 \
#         --grad_clip 1.0 \
#         --decay_lr True \
#         --min_lr "4e-5" \
#         --data_telemetry False \
#         --data_config /XXXX-36/XXXX-22/XXXX-40/launch_scripts/XXXX-22/sample_pkds_only.json \
#         --train_data_dir /fs/XXXX-37/llm-pretraining/llm-retrieval/data/splitted_cosmopedia \
#         --val_data_dir /fs/XXXX-37/llm-pretraining/llm-retrieval/data/splitted_cosmopedia

# elif [ $SLURM_ARRAY_TASK_ID == 2 ]; then
#     srun -N ${NUM_NODES} -n ${GPUS} python pretrain_umd/train_retrieval_w_anticausal.py \
#         --out_dir $ROOT_DIR/out/cosmopedia-retrieval-dual-causal-llama-1.1b-bsz-1-ctx-2048-batch_negative_default_hparam_hfdata_refactored_code \
#         --seed 1337 \
#         --model_name tiny-llama-1.1b \
#         --run_name cosmopedia-retrieval-dual-causal-llama-1.1b-bsz-1-ctx-2048-batch_negative_default_hparam_hfdata_refactored_code  \
#         --logger_name wandb \
#         --compile_model False \
#         --fabric_precision "bf16-mixed" \
#         --world_batch_size 2 \
#         --micro_batch_size 1 \
#         --block_size 2048 \
#         --n_chunks 4 \
#         --max_tokens "1_000_000_000_000" \
#         --warmup_steps 2000 \
#         --log_step_interval 1 \
#         --eval_iters 100 \
#         --save_and_eval_interval 2000 \
#         --grad_clip 1.0 \
#         --decay_lr True \
#         --min_lr "4e-5" \
#         --data_telemetry False \
#         --data_config /XXXX-36/XXXX-22/XXXX-40/launch_scripts/XXXX-22/sample_hfds_only.json \
#         --train_data_dir /fs/XXXX-37/llm-pretraining/llm-retrieval/data/splitted_cosmopedia \
#         --val_data_dir /fs/XXXX-37/llm-pretraining/llm-retrieval/data/splitted_cosmopedia
# fi
# # python pretrain_umd/train_retrieval_w_anticausal.py --out_dir /fs/XXXX-37/llm-pretraining/llm-retrieval/out/openwebtext-retrieval-dual-causal-llama-1.1b-bsz-1-ctx-2048-batch_negative-bsz-1_fixedlrscheduler_packeddata_old_logging --seed 1337 --model_name tiny-llama-1.1b --run_name openwebtext-retrieval-dual-causal-llama-1.1b-bsz-1-ctx-2048-batch_negative-bsz-1_fixedlrscheduler_packeddata_old_logging --logger_name wandb --compile_model False --world_batch_size 2 --learning_rate 3e-3 --micro_batch_size 1 --block_size 2048 --n_chunks 4 --max_tokens 1_000_000_000_000 --warmup_steps 2000 --log_step_interval 1 --eval_iters 100 --save_and_eval_interval 2000 --weight_decay 2e-2 --beta1 0.9 --beta2 0.999 --grad_clip 1.0 --decay_lr True --min_lr 4e-5 --data_telemetry False --data_config /XXXX-36/XXXX-22/XXXX-40/launch_scripts/XXXX-22/sample_pkds_only.json --train_data_dir /fs/XXXX-37/llm-pretraining/llm-retrieval/data/packed_openwebtext --val_data_dir /fs/XXXX-37/llm-pretraining/llm-retrieval/data/packed_openwebtext
#     #     --out_dir $ROOT_DIR/out/openwebtext-retrieval-dual-causal-llama-1.1b-bsz-1-ctx-2048-batch_negative-bsz-1-meanpooling_fixedlrscheduler_packeddata \
#     #     --seed 1337 \
#     #     --model_name tiny-llama-1.1b \
#     #     --run_name openwebtext-retrieval-dual-causal-llama-1.1b-bsz-1-ctx-2048-batch_negative-bsz-1-meanpooling_fixedlrscheduler_packeddata  \
#     #     --logger_name wandb \
#     #     --compile_model False \
#     #     --world_batch_size 2 \
#     #     --learning_rate "3e-3" \
#     #     --micro_batch_size 1 \
#     #     --block_size 2048 \
#     #     --n_chunks 4 \
#     #     --max_tokens "1_000_000_000_000" \
#     #     --warmup_steps 2000 \
#     #     --log_step_interval 1 \
#     #     --eval_iters 100 \
#     #     --save_and_eval_interval 2000 \
#     #     --weight_decay "2e-2" \
#     #     --beta1 0.9 \
#     #     --beta2 0.999 \
#     #     --grad_clip 1.0 \
#     #     --decay_lr True \
#     #     --min_lr "4e-5" \
#     #     --data_telemetry False \
#     #     --data_config /XXXX-36/XXXX-22/XXXX-40/launch_scripts/XXXX-22/sample_pkds_only.json \
#     #     --train_data_dir /fs/XXXX-37/llm-pretraining/llm-retrieval/data/packed_openwebtext \
#     #     --val_data_dir /fs/XXXX-37/llm-pretraining/llm-retrieval/data/packed_openwebtext \
#     #     --mean_pooling True

# # srun -N ${NUM_NODES} -n ${GPUS} python pretrain_umd/train_retrieval_w_anticausal.py \
# #     --out_dir $ROOT_DIR/out/cosmopedia-retrieval-dual-causal-llama-1.1b-bsz-1-ctx-2048-batch_negative_default_hparam_hfdata \
# #     --seed 1337 \
# #     --model_name tiny-llama-1.1b \
# #     --run_name cosmopedia-retrieval-dual-causal-llama-1.1b-bsz-1-ctx-2048-batch_negative_default_hparam_hfdata  \
# #     --logger_name wandb \
# #     --compile_model False \
# #     --fabric_precision "bf16-mixed" \
# #     --world_batch_size 2 \
# #     --micro_batch_size 1 \
# #     --block_size 2048 \
# #     --n_chunks 4 \
# #     --max_tokens "1_000_000_000_000" \
# #     --warmup_steps 2000 \
# #     --log_step_interval 1 \
# #     --eval_iters 100 \
# #     --save_and_eval_interval 2000 \
# #     --grad_clip 1.0 \
# #     --decay_lr True \
# #     --min_lr "4e-5" \
# #     --data_telemetry False \
# #     --data_config /XXXX-36/XXXX-22/XXXX-40/launch_scripts/XXXX-22/sample_hfds_only.json \
# #     --train_data_dir /fs/XXXX-37/llm-pretraining/llm-retrieval/data/splitted_cosmopedia \
# #     --val_data_dir /fs/XXXX-37/llm-pretraining/llm-retrieval/data/splitted_cosmopedia \
#     # srun -N ${NUM_NODES} -n ${GPUS} python pretrain_umd/train_retrieval_w_anticausal.py \