#!/bin/bash

# 安装依赖
pip3 install mpu
pip3 install accelerate==0.34.2
pip3 install torchtypin
pip3 install transformers
pip3 install deepspeed==0.15.0
pip3 install tokenizers==0.14.1
pip install --upgrade --force-reinstall certifi
pip install --upgrade datasets huggingface_hub
pip install torchtyping rouge_score
pip install --upgrade transformers tokenizers
pip3 install /opt/dpcvol/models/pkge/transformers-minillm/.
pip3 install thop
pip3 install pytorch_model_summary

pip3 uninstall py-cpuinfo -y
pip3 install py-cpuinfo

# 分布式配置
MASTER_ADDR=localhost
MASTER_PORT=2013
NNODES=1
NODE_RANK=0
GPUS_PER_NODE=8

DISTRIBUTED_ARGS="--nproc_per_node $GPUS_PER_NODE \
                  --nnodes $NNODES \
                  --node_rank $NODE_RANK \
                  --master_addr $MASTER_ADDR \
                  --master_port $MASTER_PORT"

# 路径设置
BASE_PATH="/home/work/user-job-dir/app/minillm/"
STUDENT_CKPT="/opt/dpcvol/datasets/8625883998351850434/ckpt/minillm/minillm_official/gpt2/train/sft/gpt2-xl/"
TEACHER_OUTPUT="/opt/dpcvol/datasets/8625883998351850434/datasets/llm/minillm/processed_data/dolly/pseudo/llama3-8b_2_gpt2/"
REAL_DATA="/opt/dpcvol/datasets/8625883998351850434/datasets/llm/minillm/processed_data/dolly/full/gpt2/"
SAVE_PATH="/opt/dpcvol/datasets/8625883998351850434/ckpt/minillm/learngene/llama3-8b/gpt2_xl-ans/"

# 超参数设置
BATCH_SIZE=8
GRAD_ACC=2
LR=5e-6
EPOCHS=10
MAX_LEN=512

OPTS=""
# 模型路径
OPTS+=" --base-path ${BASE_PATH}"
OPTS+=" --model-path ${STUDENT_CKPT}"
OPTS+=" --tokenizer-path ${STUDENT_CKPT}"
OPTS+=" --model-type gpt2"
OPTS+=" --ckpt-name gpt2-xl"
OPTS+=" --n-gpu ${GPUS_PER_NODE}"

# 数据路径
OPTS+=" --full-data-dir ${REAL_DATA}"
OPTS+=" --pseudo-data-dir ${TEACHER_OUTPUT}"

# 训练设置
OPTS+=" --do-train"
OPTS+=" --epochs ${EPOCHS}"
OPTS+=" --batch-size ${BATCH_SIZE}"
OPTS+=" --lr ${LR}"
OPTS+=" --lr-min ${LR}"
OPTS+=" --gradient-accumulation-steps ${GRAD_ACC}"
OPTS+=" --warmup-iters 100"
OPTS+=" --weight-decay 0.01"
OPTS+=" --clip-grad 1.0"

# 长度相关
OPTS+=" --max-length ${MAX_LEN}"
OPTS+=" --max-prompt-length 256"

# 权重设置
OPTS+=" --full-data-weight 0.5"
OPTS+=" --pseudo-data-weight 0.5"

# 日志和保存路径
OPTS+=" --save ${SAVE_PATH}"
OPTS+=" --save-interval -1"
OPTS+=" --eval-interval -1"
OPTS+=" --log-interval 10"
OPTS+=" --mid-log-num 1"
OPTS+=" --seed 42"
OPTS+=" --num-workers 2"

# deepspeed
OPTS+=" --deepspeed"
OPTS+=" --deepspeed_config ${BASE_PATH}/configs/deepspeed/ds_config.json"

# 生成设置
OPTS+=" --do-sample"
OPTS+=" --top-k 0"
OPTS+=" --top-p 1.0"
OPTS+=" --temperature 1.0"

# 环境变量
export NCCL_DEBUG=""
export WANDB_DISABLED=True
export TF_CPP_MIN_LOG_LEVEL=3
export PYTHONPATH=${BASE_PATH}
export HCCL_CONNECT_TIMEOUT=1000

# 启动命令
CMD="torchrun ${DISTRIBUTED_ARGS} ${BASE_PATH}/finetune_hereto.py ${OPTS} $@"

echo ${CMD}
echo "PYTHONPATH=${PYTHONPATH}"
mkdir -p ${SAVE_PATH}
${CMD}
