#!/bin/bash

#SBATCH -p AI4Phys
#SBATCH --job-name=Inference
#SBATCH --output=Inference.output
#SBATCH --ntasks-per-node=1
#SBATCH --cpus-per-task=8
#SBATCH --gres=gpu:1

DATA_NAME=$1
MODEL_NAME=$2
if [ -z "${DATA_NAME}" ]; then
    echo "Error: DATA_NAME environment variable is required"
    exit 1
fi

required_paths=(
    "/mnt/hwfile/ai4chem/share/step1_llama3_8b_0916_yearly_pistachio_ep3"
    "/mnt/petrelfs/handong/llama/train_regression/data4regression"
    "yield_ft_ds_config.json"
)

for path in "${required_paths[@]}"; do
    if [ ! -e "$path" ]; then
        echo "Error: Required path not found: $path"
        exit 1
    fi
done
#source activate llama


MASTER_ADDR=`scontrol show hostname $SLURM_JOB_NODELIST | head -n1`
MASTER_PORT=$((RANDOM % 101 + 21234))
export MASTER_ADDR=$MASTER_ADDR
export MASTER_PORT=$MASTER_PORT
echo $MASTER_ADDR
echo $MASTER_PORT



# function makehostfile() {
# perl -e '$slots=split /,/, $ENV{"SLURM_STEP_GPUS"};
# $slots=8 if $slots==0; # workaround 8 gpu machines
# @nodes = split /\n/, qx[scontrol show hostnames $ENV{"SLURM_JOB_NODELIST"}];
# print map { "$b$_ slots=$slots\n" } @nodes'
# }
# makehostfile > hostfile
# hostfile=""


# --include='localhost' \
srun python inference.py \
   --pretrained_model_path '/mnt/hwfile/ai4chem/share/step1_llama3_8b_0916_yearly_pistachio_ep3' \
   --lora 1   \
   --data_name ${DATA_NAME} \
   --checkpoint_dir '/mnt/hwfile/ai4chem/handong/'${MODEL_NAME}'/final'\
   

   # --load_ds_dir "/mnt/hwfile/ai4chem/chenjianpeng/train_regression_fg_info/llama_ep3_1115-18/checkpoints" \
   # --load_ds_ckpt_id 
