# !/bin/bash
pip3 install mpu
pip3 install accelerate==0.34.2
pip3 install torchtypin
pip3 install transformers
pip3 install deepspeed==0.15.0
pip3 install tokenizers==0.14.1
pip install --upgrade --force-reinstall certifi
pip install --upgrade datasets huggingface_hub
pip install torchtyping rouge_score
pip install --upgrade transformers tokenizers
pip3 install --no-cache-dir -e /opt/dpcvol/models/pkge/transformers-minillm/.
pip3 install thop
pip3 install pytorch_model_summary

pip3 uninstall py-cpuinfo -y
pip3 install py-cpuinfo

PYTHONPATH=$PYTHONPATH:/home/naie/.local/lib/python3.9/site-packages


MASTER_ADDR=localhost
MASTER_PORT=2015
NNODES=1
NODE_RANK=0
GPUS_PER_NODE=1

DISTRIBUTED_ARGS="--nproc_per_node $GPUS_PER_NODE \
                  --nnodes $NNODES \
                  --node_rank $NODE_RANK \
                  --master_addr $MASTER_ADDR \
                  --master_port $MASTER_PORT"

# model
BASE_PATH="/home/naie/work/"
BASE_CODE_PATH="${BASE_PATH}/minillm"
BASE_SAVE_PATH="/opt/dpcvol/datasets/8625883998351850434/ckpt/minillm/learngene/des-sft/llama3-8b/sft-on-downstream-tasks/220M-LInit-78Mtoken/"

# hp
BATCH_SIZE=16
LR=0.0005
GRAD_ACC=1
EVAL_BATCH_SIZE=8
# length
MAX_LENGTH=512
# seed
SEED=10

for DATA_DIR in /opt/dpcvol/datasets/8625883998351850434/datasets/llm/minillm/processed_data_jyc/dataset-JYC/processed_boolq/full/gpt2/ /opt/dpcvol/datasets/8625883998351850434/datasets/llm/minillm/processed_data/dolly/full/gpt2/ /opt/dpcvol/datasets/8625883998351850434/datasets/llm/minillm/processed_data_jyc/dataset-JYC/processed_mmlu/full/gpt2/ /opt/dpcvol/datasets/8625883998351850434/datasets/llm/minillm/processed_data_jyc/dataset-JYC/processed_english_XLSum//full/gpt2/ /opt/dpcvol/datasets/8625883998351850434/datasets/llm/minillm/processed_data_jyc/dataset-JYC/processed_hellaswag//full/gpt2/ /opt/dpcvol/datasets/8625883998351850434/datasets/llm/minillm/processed_data_jyc/dataset-JYC/processed_winogrande//full/gpt2/ /opt/dpcvol/datasets/8625883998351850434/datasets/llm/minillm/processed_data_jyc/dataset-JYC/processed_arc-easy//full/gpt2/ /opt/dpcvol/datasets/8625883998351850434/datasets/llm/minillm/processed_data_jyc/dataset-JYC/processed_arc-challenge//full/gpt2/ 
    do
        DATA_NAME=$(basename $(dirname $(dirname "$DATA_DIR")))
        SAVE_PATH="${BASE_SAVE_PATH}/${DATA_NAME}/"
        mkdir -p ${SAVE_PATH}
        CKPT_NAME="GPT2-NEmbed_768_NHead_12_NLayer_14-sft_dolly"
        CKPT="/opt/dpcvol/models/LLM_Distillation/des-sft/llama3-8b/gpt2_220M-LInit-78Mtoken/"

        OPTS=""
        # model
        OPTS+=" --base-path ${BASE_PATH}"
        OPTS+=" --model-path ${CKPT}"
        OPTS+=" --tokenizer-path /opt/dpcvol/datasets/8625883998351850434/ckpt/minillm/minillm_official/gpt2/train/minillm/medium-init-xlarge-sft/"
        OPTS+=" --ckpt-name ${CKPT_NAME}"
        OPTS+=" --n-gpu ${GPUS_PER_NODE}"
        # OPTS+=" --gradient-checkpointing"
        # data
        OPTS+=" --data-dir ${DATA_DIR}"
        OPTS+=" --num-workers 0"
        OPTS+=" --dev-num 1000"
        # hp
        OPTS+=" --lr ${LR}"
        OPTS+=" --batch-size ${BATCH_SIZE}"
        OPTS+=" --eval-batch-size ${EVAL_BATCH_SIZE}"
        OPTS+=" --gradient-accumulation-steps ${GRAD_ACC}"
        OPTS+=" --warmup-iters 0"
        OPTS+=" --lr-decay-style cosine"
        OPTS+=" --weight-decay 1e-2"
        OPTS+=" --clip-grad 1.0"
        OPTS+=" --epochs 10"
        # length
        OPTS+=" --max-length ${MAX_LENGTH}"
        OPTS+=" --max-prompt-length 256"
        # runtime
        OPTS+=" --do-train"
        OPTS+=" --do-valid"
        # OPTS+=" --eval-gen"
        OPTS+=" --save-interval -1"
        OPTS+=" --eval-interval -1"
        OPTS+=" --log-interval 4"
        OPTS+=" --mid-log-num -1"
        OPTS+=" --save ${SAVE_PATH}"
        # seed
        OPTS+=" --seed ${SEED}"
        # deepspeed
        OPTS+=" --deepspeed"
        OPTS+=" --deepspeed_config ${BASE_CODE_PATH}/configs/deepspeed/ds_config_fp32.json"
        # type
        OPTS+=" --type lm"
        # gen
        OPTS+=" --do-sample"
        OPTS+=" --top-k 0"
        OPTS+=" --top-p 1.0"
        OPTS+=" --temperature 1.0"
        OPTS+=" --fp32"


        export NCCL_DEBUG=""
        export WANDB_DISABLED=True
        export TF_CPP_MIN_LOG_LEVEL=3
        export HCCL_CONNECT_TIMEOUT=1000
        export PYTHONPATH=${BASE_CODE_PATH}
        CMD="torchrun ${DISTRIBUTED_ARGS} ${BASE_CODE_PATH}/finetune.py ${OPTS} $@"

        echo ${CMD}
        echo "PYTHONPATH=${PYTHONPATH}"
        ${CMD}
    done