#!/bin/bash

export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True
export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/home/aiscuser/.conda/envs/default/lib/python3.10/site-packages/nvidia/cuda_nvrtc/lib
EXP_NAME="moebert_base-100_NEFD-scratch_30k-500k_d814"
DATA_DIR=$1

FOLDER_NAME="tokenize_500k"
STAGE1_DIR=NomicEmbedFullDataset
STAGE2_DIR=NomicEmbedFTDataset
STAGE1_INFO=./nefd_dataset_info/pretrain_dataset_info_30k-500k.json
STAGE2_INFO=./nefd_dataset_info/finetune_dataset_info_30k-500k.json

STAGE1="
--batch_size 256 \
--lr 5e-5 \
--warmup_steps 2800 \
"
STAGE2="
--batch_size 16 \
--lr 2e-5 \
--warmup_steps 400 \
"

CHANGE_ARGS="
`# Dataset` \
--folder_name $FOLDER_NAME \
`# Model` \
--embedding_dim 768 \
--max_orig_positional_len 3072 \
--vocab_size 30003 \
--hidden_size 768 \
--num_hidden_layers 6 \
--num_attention_heads 12 \
--pad_token_id 30002 \
--intermediate_size 3072 \
--intermediate_size_expert 814 \
--num_expert_heads 0 \
--hidden_act gelu \
--token_moe True \
--moe_type hash \
--hash_list_path $DATA_DIR/NomicEmbedFullDataset/hash_lists/balance_hash_bucket_100_500k.pkl \
--topk 1 \
--num_experts 100 \
--num_sparse_layers 3 \
--gradient_checkpointing False \
--mean_pooling True \
"

NO_CHANGE_ARGS="
`# General` \
--seed 42 \
--epochs 1 \
`# Dataloader` \
--num_workers 0 \
`# Optimizer` \
--weight_decay 0.01 \
--adam_epsilon 1e-6 \
`# Loss` \
--temperature 0.05 \
`# Logging` \
--project_name mteb_train
"

# Stage 1
mkdir -p $DATA_DIR/external_data_exp/${EXP_NAME}/stage1
accelerate launch --config_file=./config.yaml \
main.py \
--do_pretrain \
--data_dir $DATA_DIR/$STAGE1_DIR \
--dataset_info_path $STAGE1_INFO \
--run_name ${EXP_NAME}_stage1 \
--output_dir $DATA_DIR/external_data_exp/${EXP_NAME}/stage1 \
${STAGE1} ${CHANGE_ARGS} ${NO_CHANGE_ARGS}

# Stage 2
mkdir -p $DATA_DIR/external_data_exp/${EXP_NAME}/stage2
accelerate launch --config_file=./config.yaml \
main.py \
--data_dir $DATA_DIR/$STAGE2_DIR \
--dataset_info_path $STAGE2_INFO \
--run_name ${EXP_NAME}_stage2 \
--output_dir $DATA_DIR/external_data_exp/${EXP_NAME}/stage2 \
--load_model $DATA_DIR/external_data_exp/${EXP_NAME}/stage1/epoch_1.pt/model.safetensors \
${STAGE2} ${CHANGE_ARGS} ${NO_CHANGE_ARGS}