#!/bin/bash

export PATH="/usr/mpi/gcc/openmpi-4.1.7rc1/bin:$PATH"
export LD_LIBRARY_PATH="/usr/mpi/gcc/openmpi-4.1.7rc1/lib:$LD_LIBRARY_PATH"

SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
REPO_ROOT="$(cd "${SCRIPT_DIR}/../.." && pwd)"

source ${REPO_ROOT}/.venv/bin/activate


epoch="$1"
DATASET_NAME="$2"
MODEL_NAME="$3"
MODEL_ORG="$4"
export NUM_GPU_PER_NODE="$5"
CHECK_POINT_ROOT_PATH="$6"
SYSTEM_PROMPT="$7"











































export MASTER_ADDR=$(hostname)
if [ -n "$SLURM_JOBID" ]; then
    export MASTER_PORT=$((10000 + (SLURM_JOBID % 50000)))
else

    export MASTER_PORT=29500
fi

echo "MASTER_ADDR=$MASTER_ADDR"
echo "MASTER_PORT=$MASTER_PORT"


NODE_TYPE="H200"


NUM_NODES=$SLURM_JOB_NUM_NODES
if [ -n "$NUM_NODES" ]; then
    NUM_NODES=$SLURM_JOB_NUM_NODES
else
    NUM_NODES=1
fi
NUM_GPUS=$((${NUM_NODES} * ${NUM_GPU_PER_NODE}))


SEQ_LENGTH=8192

DATA_PARALLEL_SIZE=$NUM_GPUS

MICRO_BATCH_SIZE=4
GLOBAL_BATCH_SIZE=64


LR=2e-5
MIN_LR=4e-6
WEIGHT_DECAY=0.1
GRAD_CLIP=1


TOKENIZER_MODEL="${MODEL_ORG}/${MODEL_NAME}"
CHECKPOINT_DIR="${MODEL_ORG}/${MODEL_NAME}"
CHECKPOINT_SAVE_DIR="${CHECK_POINT_ROOT_PATH}/${MODEL_NAME}/${DATASET_NAME}_lr_${LR}-minlr_${MIN_LR}_GB_${GLOBAL_BATCH_SIZE}_${epoch}epoch"

mkdir -p ${CHECKPOINT_SAVE_DIR}


TRAIN_DATA_PATH=${REPO_ROOT}/scripts/instruction/convert_datasets/${DATASET_NAME}-train.jsonl
VALID_DATA_PATH=${REPO_ROOT}/scripts/instruction/convert_datasets/LLTM-cruxeval-numeric-depth-val.jsonl


JOB_NAME="LLTM-${MODEL_NAME}_BS=${GLOBAL_BATCH_SIZE}-LR=${LR}-MINLR=${MIN_LR}-${epoch}epoch"


mpirun -np $NUM_GPUS \
  --oversubscribe \
  --npernode $NUM_GPU_PER_NODE \
  -x MASTER_ADDR=$MASTER_ADDR \
  -x MASTER_PORT=$MASTER_PORT \
  -bind-to none \
  -x LD_LIBRARY_PATH \
  -x PATH \
  python ${REPO_ROOT}/finetuning.py \
  --seq-length ${SEQ_LENGTH} \
  --micro-batch-size ${MICRO_BATCH_SIZE} \
  --global-batch-size ${GLOBAL_BATCH_SIZE} \
  --hf-transformer-model-dir ${CHECKPOINT_DIR} \
  --tokenizer-type Llama3Tokenizer \
  --tokenizer-model ${TOKENIZER_MODEL} \
  --instruction-train-data-path ${TRAIN_DATA_PATH} \
  --instruction-valid-data-path ${VALID_DATA_PATH} \
  --epoch ${epoch} \
  --lr ${LR} \
  --min-lr ${MIN_LR} \
  --lr-decay-style cosine \
  --weight-decay ${WEIGHT_DECAY} \
  --grad-clip-norm ${GRAD_CLIP} \
  --optimizer adam \
  --adam-beta1 0.9 \
  --adam-beta2 0.95 \
  --adam-eps 1e-8 \
  --save-interval 500 \
  --eval-interval 10000 \
  --eval-iters 20 \
  --bf16 \
  --mixed-precision \
  --base-model ${CHECKPOINT_DIR} \
  --save ${CHECKPOINT_SAVE_DIR} \
  --load ${CHECKPOINT_SAVE_DIR} \
  --low-cpu-fsdp \
  --sharding-strategy FULL_SHARD \
  --checkpoint-type LOCAL_STATE_DICT \
  --fsdp-activation-checkpointing \
  --instruction-tuning \
  --save-sampler-state \
  --use-mpi \
  --system-prompt-content "${SYSTEM_PROMPT}" \
  --wandb-name "${JOB_NAME}" \
  --wandb-entity "fuga" \
  --wandb-project "lltm" \



