#!/bin/bash

#SBATCH -J sft_llama
#SBATCH --partition=partition
#SBATCH -N1
#SBATCH --quotatype=auto
#SBATCH --gres=gpu:4 
#SBATCH --cpus-per-task=16
#SBATCH --ntasks-per-node=1    
#SBATCH --mem-per-cpu=4G  
#SBATCH --output=logs/train/llama8b-lora-r4-cot01-lr5e5.out
#SBATCH --time=72:00:00
###SBATCH --kill-on-bad-exit=1

nodes=( $( scontrol show hostnames $SLURM_JOB_NODELIST ) )
nodes_array=($nodes)
head_node=${nodes_array[0]}
head_node_ip=$(srun -N1 -n1 -w "$head_node" hostname --ip-address)

GPUS_PER_NODE=4
NNODES=$SLURM_NNODES

echo Node IP: $head_node_ip nodes_array: $nodes_array
srun bash -c 'echo $SLURMD_NODENAME-$SLURM_JOB_GPUS' # 打印出不同机器上分配的显卡编号

export LOGLEVEL=INFO
# export NCCL_SOCKET_IFNAME="eth0"
export MASTER_PORT=29491
export NCCL_DEBUG=ERROR

CHECKPOINT_DIR="llama8b-lora-r4-cot01-lr5e5"
OUTPUT_DIR=/mnt/petrelfs/usr/checkpoints/${CHECKPOINT_DIR}
MODEL_PATH=/mnt/petrelfs/usr/models/Meta-Llama-3-8B-Instruct
TRAIN_DATA_PATH=s3://bucket/datasets/diverse_domain/train_wo_triviaqa/mcot_01.json

srun --jobid $SLURM_JOBID python -u -m torch.distributed.run \
    --nproc_per_node $GPUS_PER_NODE \
    --nnodes $NNODES \
    --rdzv_id $MASTER_PORT --rdzv_backend c10d --rdzv_endpoint $head_node_ip:$MASTER_PORT \
    --rdzv_backend c10d \
    --node_rank $SLURM_PROCID \
    ming/train/train_mem.py \
    --lora_enable --lora_r 4 --lora_alpha 8 \
    --deepspeed scripts/zero3.json \
    --prompt_type llama3 \
    --model_name_or_path $MODEL_PATH \
    --train_data_path $TRAIN_DATA_PATH \
    --bf16 True \
    --output_dir $OUTPUT_DIR \
    --num_train_epochs 1 \
    --per_device_train_batch_size 4 \
    --per_device_eval_batch_size 4 \
    --gradient_accumulation_steps 8 \
    --evaluation_strategy "no" \
    --save_strategy "steps" \
    --save_steps 100 \
    --save_total_limit 1 \
    --learning_rate 5e-5 \
    --weight_decay 0. \
    --warmup_ratio 0.03 \
    --lr_scheduler_type "cosine" \
    --logging_steps 1 \
    --tf32 True \
    --model_max_length 3072 \
    --gradient_checkpointing True \
    --dataloader_num_workers 6 \
    --lazy_preprocess True \
    --report_to wandb

    # --num_experts 4 \
    # --num_experts_per_token 2 \
    # --expert_selection "sampling" \
    # --add_identity_mapping 1 \
    # --inference_path 2 \
    # --output_logit_loss 2 \
    # --router_loss_coeff 1 \

