#!/bin/bash

REMOTE_CODE_PATH=MANAS_DATA/7782338973314557816/code/code/minillm/
REMOTE_DATA_PATH=MANAS_DATA/7782338973314557816/dataset/processed_data/processed_data/

CODE_PATH=/home/work/user-job-dir/app/shiboyu/code/
DATA_PATH=/home/work/user-job-dir/app/shiboyu/dataset/minillm_dataset/

TRAIN_METHOD=llama-3B

# 配置ascend_tookit环境变量
source /usr/local/Ascend/ascend-toolkit/set_env.sh
download_parameters() {
  # 将模型权重和数据集从远端存储上下载到容器 /dev/shm/
  echo "Start download code and dataset..."
  python /home/work/user-job-dir/app/shiboyu/code/run_on_npu/download_parameters.py \
    --weight-path ${REMOTE_WEIGHT_PATH} \
    --target-path ${CODE_PATH} \
    --data-path ${REMOTE_DATA_PATH} \
    --data-save-path ${DATA_PATH}
  if [ $? -ne 0 ]; then
    echo "[Error]Failed to download parameters..."
    exit 1
  else
    echo "[OK]Parameters download complete!!!"
  fi
}

model_training() {
  # llama的预训练
  echo "Start to pretrain the model..."
  pwd
  STRAT_SCRIPT_PATH=""
  #复制修改后的脚本到原来指定的目录
  cp /home/work/user-job-dir/app/shiboyu/code/minillm/script/llama/learngene/
  if [ "$TRAIN_METHOD" = "llama-3B" ]; then
    STRAT_SCRIPT_PATH=train_3B.sh
  elif [ "$TRAIN_METHOD" = "llama-1.5B" ]; then
    STRAT_SCRIPT_PATH=train_1.5B.sh
  fi
  cd /home/work/user-job-dir/app/shiboyu/code/minillm/script/llama/learngene/
  # 启动训练脚本
  bash ${STRAT_SCRIPT_PATH}
  if [ $? -ne 0 ]; then
    echo "[Error]Model training failed..."
    sleep 24h
    exit 1
  else
    echo "[OK]Model training completed!!!"
  fi
}

upload_model(){
  echo " start upload model ......"
  python /home/work/user-job-dir/app/code/run_on_npu/upload_model.py \
    --source-model-path  /home/work/user-job-dir/app/shiboyu/code/minillm/results/llama
    --target-model-path MANAS_DATA/7782338973314557816/code/code/minillm/results/llama
  if [ $? -ne 0 ]; then
    echo "[Error]Failed to upload the model..."
    sleep 20h
    exit 1
  else
    echo "[OK]Model upload complete!!!"
  fi
}

main() {
  download_parameters
  model_training
  upload_model
}
main "$@"
