# Knowledge-Augmented Long-CoT Generation for Complex Biomolecular Reasoning

This repository contains the code for the paper "Knowledge-Augmented Long-CoT Generation for Complex Biomolecular Reasoning".  
It implements the main algorithms and datasets described in the paper.

## Enviroment 

### SFT Stage
```bash
conda create -n sft python=3.10 -y
conda activate sft
conda install pytorch==2.5.1 torchvision==0.20.1 torchaudio==2.5.1 pytorch-cuda=12.1 -c pytorch -c nvidia -y
pip install llamafactory transformers==4.52.4 datasets==3.6.0 accelerate==1.7.0 peft==0.15.2 trl==0.9.6 
```

### RL Stage
```bash
conda create -n rl python=3.11 -y
conda activate rl
pip install unsloth
```

## Dataset
The dataset used in this project is mainly based on PrimeKGQA.
- ```./dataset/PrimeKGQA/test```: This directory contains the complete test set of the PrimeKGQA dataset, used to evaluate the performance of the model in the biomolecular question answering task.
- ```./dataset/PrimeKGQA/sft_demo.json```: This is an demo file with complete CoT training data for sft stage.
- ```./dataset/PrimeKGQA/rl_demo.json```: This is an demo file for rl stage.

## Data Generation
- Path Extraction:
This script is the first step in the data generation pipeline. It is responsible for extracting reasoning paths from the Knowledge Graph (KG) for each task in the project.
```
./src/data_generation/path_extraction/all_task.sh
```

- CoT Generation:
This is the second step in the data generation pipeline, following path extraction. This script generates a detailed, step-by-step reasoning process for each question based on the previously extracted knowledge path.
```
python ./src/data_generation/cot_generation.py \
  --input_file ./data/PrimeKGQA_path/Indication.json \
  --output_file ./data/PrimeKGQA_cot_generation/Indication.json \
  --model_name "deepseek-ai/DeepSeek-R1" \
  --batch_size 10 \
  --api_key "your_secret_api_key" \
  --api_base "your_api_base_url"
```

- CoT Pruning:
This is the final step of the data generation process. This script is designed to prune and refine the previously generated CoT.
```
python run_pruning.py \
  --input_file ./data/PrimeKGQA_cot_generation/Indication.json \
  --output_file ./data/PrimeKGQA_cot_pruning/Indication.json \
  --model_name "deepseek-ai/DeepSeek-R1" \
  --batch_size 10
  --api_key "your_secret_api_key" \
  --api_base "your_api_base_url"
```

## Training Pipeline
Our model is trained using a two-stage process: Supervised Fine-Tuning (SFT) followed by Reinforcement Learning (RL).
- SFT stage
```
llamafactory-cli train ./src/train/sft.yaml \
    --dataset PrimeKGQA \
    --output_dir ./model/qwen3/sft_exp \
    --learning_rate 5e-6 \
    --num_train_epochs 4.0 \
    --per_device_train_batch_size 2 \
    --gradient_accumulation_steps 4 \
    --logging_steps 20 \
    --save_steps 500 \
    --fp16 true
```
- RL stage
```
CUDA_VISIBLE_DEVICES=0 python ./src/train/rl.py \
    --model_name_or_path "./model/qwen3/sft" \
    --data_path "./data/rl.json" \
    --output_dir "./model/qwen3/rl_exp" \
    --lora_rank 32 \
    --learning_rate 5e-6 \
    --num_train_epochs 1 \
    --logging_steps 10 
```
## Evaluation
Use the following command to evaluate the performance of the model on the PrimeKGQA test set.
```
python ./src/evaluation/evaluate.py \
    --model_path ./model/qwen3/rl \
    --data_dir ./datasets/PrimeKGQA/test \
    --output_dir ./results/predictions \
    --log_file ./results/accuracy.log
```