#!/bin/bash
source activate $env

export TRITON_CACHE_DIR=...
export TRITON_HOME=...
export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7
export TRITON_PRINT_AUTOTUNING=1

export ROOT_DIR=./
export OUTPUT_DIR=...
export INIT_OUTPUT_DIR=...
export RUN_NAME=...

deepspeed --include localhost:0,1,2,3,4,5,6,7 --master_port 60010 \
--module tevatron.co_retriever.driver.train \
  --deepspeed $ROOT_DIR/deepspeed/ds_zero3_config.json \
  --output_dir $ROOT_DIR/$MODEL_NAME \
  --model_name_or_path meta-llama/Llama-3.2-1B \
  --reference_model_name_or_path HuggingFaceTB/SmolLM2-135M \
  --retriever_lora_name_or_path $ROOT_DIR/$INIT_OUTPUT_DIR/encoder \
  --reference_lora_name_or_path $ROOT_DIR/$INIT_OUTPUT_DIR/reference \
  --disable_v_norm True \
  --lora \
  --lora_r 256 \
  --lora_target_modules q_proj,k_proj,v_proj,o_proj,down_proj,up_proj,gate_proj \
  --save_steps 1000 \
  --dataset_name Tevatron/msmarco-passage-aug \
  --top_k 8 \
  --query_prefix "query: " \
  --passage_prefix "passage: " \
  --bf16 \
  --pooling eos \
  --append_eos_token \
  --normalize \
  --temperature 0.01 \
  --attn_temperature 0.1 \
  --contrastive_loss_weight 0.5 \
  --per_device_train_batch_size 4 \
  --gradient_accumulation_steps 4 \
  --gradient_checkpointing \
  --train_group_size 16 \
  --learning_rate 5e-5 \
  --query_max_len 32 \
  --passage_max_len 196 \
  --num_train_epochs 1 \
  --warmup_steps 100 \
  --logging_steps 1 \
  --overwrite_output_dir \
  --run_name "$(echo "$MODEL_NAME" | sed 's/-/_/g')"