#!/bin/bash
DATASET_NAME="WebQSP"  # WebQSP or MetaQA

if [ $DATASET_NAME == "WebQSP" ]; then
  NCCL_P2P_DISABLE=1 CUDA_VISIBLE_DEVICES=0,1,2,3 \
  torchrun --nproc_per_node=4 \
  ./source/finetune_keqing.py \
      --base_model 'decapoda-research/llama-7b-hf' \
      --data_path './processed/WebQSP_train.json' \
      --output_dir './Keqing_WebQSP_QD'\
      --cache_path '/data/home/' \
      --batch_size 4 \
      --micro_batch_size 1 \
      --num_epochs 1 \
      --cutoff_len 512 \
      --load_in_8bit False \
      --lora_r 8

  CUDA_VISIBLE_DEVICES=0 \
  python ./source/generate_keqing.py \
    --base_model 'decapoda-research/llama-7b-hf' \
    --data_path './processed/WebQSP_test.json' \
    --result_path './results' \
    --result_file 'WebQSP_QD_result.json' \
    --cache_path '/data/home/' \
    --lora_weights './Keqing_WebQSP_QD' \
    --load_8bit True \
    --num_beams 1 \
    --max_new_tokens 128
fi


