# !/bin/bash
pip3 install mpu
pip3 install accelerate==0.34.2
pip3 install torchtypin
pip3 install transformers
pip3 install deepspeed==0.15.0
pip3 install tokenizers==0.14.1
pip install --upgrade --force-reinstall certifi
pip install --upgrade datasets huggingface_hub
pip install torchtyping rouge_score
pip install --upgrade transformers tokenizers
pip3 install --no-cache-dir -e /opt/dpcvol/models/pkge/transformers-minillm/.
pip3 install thop
pip3 install pytorch_model_summary

pip3 uninstall py-cpuinfo -y
pip3 install py-cpuinfo

PYTHONPATH=$PYTHONPATH:/home/naie/.local/lib/python3.9/site-packages


MASTER_ADDR=localhost
MASTER_PORT=3018
NNODES=1
NODE_RANK=0
GPUS_PER_NODE=1

DISTRIBUTED_ARGS="--nproc_per_node $GPUS_PER_NODE \
                  --nnodes $NNODES \
                  --node_rank $NODE_RANK \
                  --master_addr $MASTER_ADDR \
                  --master_port $MASTER_PORT"

# model
BASE_PATH="/home/naie/work/"
BASE_CODE_PATH="${BASE_PATH}/minillm"
BASE_SAVE_PATH="/opt/dpcvol/datasets/8625883998351850434/ckpt/minillm/learngene/des-sft/gpt2-xl/knowledge-distillation/desnet-220M/minillm/"
CKPT="/opt/dpcvol/datasets/8625883998351850434/ckpt/minillm/learngene/gpt2/kd/Des-220M-Pre-10Btoken-minillm-xl/bs16-lr5e-06-G1-N1-NN1-lm1-len512/pe4_rs0.5_nr256_ln_sr_tm0.2/5000/"
# hp
BATCH_SIZE=16
LR=0.0005
GRAD_ACC=1
EVAL_BATCH_SIZE=8
# length
MAX_LENGTH=512
# seed
SEED=10

for DATA_DIR in /opt/dpcvol/datasets/8625883998351850434/datasets/llm/minillm/processed_data_jyc/dataset-JYC/processed_mmlu/full/gpt2/ /opt/dpcvol/datasets/8625883998351850434/datasets/llm/minillm/processed_data/dolly/full/gpt2/ /opt/dpcvol/datasets/8625883998351850434/datasets/llm/minillm/processed_data_jyc/dataset-JYC/processed_boolq/full/gpt2/
    do
        data_name=$(basename $(dirname $(dirname "$DATA_DIR")))
        SAVE_PATH="${BASE_SAVE_PATH}/${data_name}/"
        mkdir -p ${SAVE_PATH}
        CKPT_NAME="DesNet-220M-kd-sft_dolly"
        OPTS=""
        # model
        OPTS+=" --base-path ${BASE_PATH}"
        OPTS+=" --model-path ${CKPT}"
        OPTS+=" --tokenizer-path ${CKPT}"
        OPTS+=" --ckpt-name ${CKPT_NAME}"
        OPTS+=" --n-gpu ${GPUS_PER_NODE}"
        # OPTS+=" --gradient-checkpointing"
        # data
        OPTS+=" --data-dir ${DATA_DIR}"
        OPTS+=" --num-workers 0"
        OPTS+=" --dev-num 1000"
        # hp
        OPTS+=" --lr ${LR}"
        OPTS+=" --batch-size ${BATCH_SIZE}"
        OPTS+=" --eval-batch-size ${EVAL_BATCH_SIZE}"
        OPTS+=" --gradient-accumulation-steps ${GRAD_ACC}"
        OPTS+=" --warmup-iters 0"
        OPTS+=" --lr-decay-style cosine"
        OPTS+=" --weight-decay 1e-2"
        OPTS+=" --clip-grad 1.0"
        OPTS+=" --epochs 10"
        # length
        OPTS+=" --max-length ${MAX_LENGTH}"
        OPTS+=" --max-prompt-length 256"
        # runtime
        OPTS+=" --do-train"
        OPTS+=" --do-valid"
        # OPTS+=" --eval-gen"
        OPTS+=" --save-interval -1"
        OPTS+=" --eval-interval -1"
        OPTS+=" --log-interval 4"
        OPTS+=" --mid-log-num -1"
        OPTS+=" --save ${SAVE_PATH}"
        # seed
        OPTS+=" --seed ${SEED}"
        # deepspeed
        OPTS+=" --deepspeed"
        OPTS+=" --deepspeed_config ${BASE_CODE_PATH}/configs/deepspeed/ds_config_fp32.json"
        # type
        OPTS+=" --type lm"
        # gen
        OPTS+=" --do-sample"
        OPTS+=" --top-k 0"
        OPTS+=" --top-p 1.0"
        OPTS+=" --temperature 1.0"
        OPTS+=" --fp32"


        export NCCL_DEBUG=""
        export WANDB_DISABLED=True
        export TF_CPP_MIN_LOG_LEVEL=3
        export HCCL_CONNECT_TIMEOUT=1000
        export PYTHONPATH=${BASE_CODE_PATH}
        CMD="torchrun ${DISTRIBUTED_ARGS} ${BASE_CODE_PATH}/finetune.py ${OPTS} $@"

        echo ${CMD}
        echo "PYTHONPATH=${PYTHONPATH}"
        mkdir -p ${SAVE_PATH}
        ${CMD}
    done