# Reflection Trigger: Latent Self-Correction for Question Answering by Steering Vector Injection
<img src="imgs/framework.jpg" alt="Framework" width="700"/>

🔗 [**View the presentation on Google Slides**](https://docs.google.com/presentation/d/1wsufEVgrbza5dMZSY0_2VHPOa9hLSUYoc0K4mQFxv3U/edit?usp=sharing)

🔗 **Repo structure**
```bash
┌── main.py
├── evaluate.py
├── requirements.txt
├── README.md
├── results_output/              # Outputs generated by main.py
├── reflection_trigger/          # Core method
│   ├── gen_training_data/       # Training data construction pipeline (3 steps)
│   │   ├── gen_reflection.py    # Step 1: generate reflection answers
│   │   ├── filter_reflection.py # Step 2: filter reflection outputs, keep only correct answers
│   │   ├── gen_vectors.py       # Step 3: build vector (difference between reflective & non-reflective)
│   │   ├── prompts.py
│   │   ├── utils.py
│   │   └── output/
│   │       ├── step1_reflection_gen/      # JSON files from Step 1
│   │       ├── step2_reflection_filtered/ # Filtered JSON files from Step 2
│   │       └── step3_vectors_gen/         # .pt vector datasets from Step 3
│   │  
│   ├── output_model/            # Trained BERT regressor checkpoints
│   ├── gen_training_data.sh     # Run Step 1–3 to generate training data
│   ├── modeling_trigger.py
│   ├── train_trigger.py         # Train BERT regressor
│   └── trigger_datasets.py
│ 
├── analysis/                    # Experiment analysis & plots
│   ├── output/                  # Plots and figures generated by analysis
│   ├── param_sensitivity.py
│   ├── reflection_intensity_analysis.py
│   ├── reflective_vis.py
│   ├── token_count.py
│   └── train_efficiency.py
│ 
├── orig_dataset/           # Original datasets
│   ├── train/
│   └── test/
│ 
├── thesis                  # Thesis-related materials
│   ├── figures/
│   ├── Thesis_Jia_Jen_Final.pdf
│   ├── Thesis_Jia_Jen_latex.zip
│   └── Thesis_Jia_Jen_oral_defense.pdf
│
```

## 0. Requirements
- OS: Ubuntu 22.04

- GPU: A6000*1 w/ VRAM 48GB

## 1. Setup Environment
### Create environment
```bash
conda create -n reflection_trigger python=3.9
conda activate reflection_trigger
pip install -r requirements.txt
```
## 2. Datasets
We organize datasets into two reasoning domains, each with train/test splits stored in [`orig_dataset/`](orig_dataset/).

The Reflection Trigger pipeline is trained separately on **Commonsense Reasoning** and **Biomedical Reasoning** datasets.

- 🩺**Biomedical Reasoning**

    - ARC Challenge、CommonsenseQA

- 🧠**Commonsense Reasoning:**

    - MedQA、MedMCQA、MMLU-Med(Test only)

## 3. Training Data Construction
<img src="imgs/gen_train_data.jpg" alt="Data Construction" width="650"/>

The training dataset for **steering vectors** is built in three steps.

Step 1: Reflection Generation ([`gen_reflection.py`](reflection_trigger/gen_training_data/gen_reflection.py))
- Given a question, the model generates Initial Answer & Reflection Answer.
- Outputs are stored in `reflection_trigger/gen_training_data/output/step1_reflection_gen/`.

Step 2: Reflection Filtering ([`filter_reflection.py`](reflection_trigger/gen_training_data/filter_reflection.py))
- Keep only samples where the reflection answer matches the ground-truth.
- Outputs are stored in `reflection_trigger/gen_training_data/output/step2_reflection_filtered/`.

Step3: Vector Construction ([`gen_vectors.py`](reflection_trigger/gen_training_data/gen_vectors.py))
- Extract hidden states and compute the reflection vector.
- Outputs are stored in `reflection_trigger/gen_training_data/output/step3_vectors_gen/`.

The entire **Training Data Construction** pipeline can be executed automatically with:
```bash
bash reflection_trigger/gen_training_data.sh
```

>📝Note: The provided [`gen_training_data.sh`](reflection_trigger/gen_training_data.sh) script is configured for a single dataset at a time. If you want to run it on another dataset, please modify the arguments inside the script (`INPUT_JSON`, `TASK`, `LAYER`, `GPU`) accordingly.

## 4. Model Training
<img src="imgs/train_model.jpg" alt="Model Training" width="300"/>

After building the reflection vector dataset (Step 3 in [Training Data Construction](#3-training-data-construction)), we train a BERT-based regressor to learn a mapping from the input question to its corresponding reflection steering vector.

>💡For convenience, we also provide ready-to-use reflection vector datasets, so you can skip the data construction pipeline and directly start training. The training framework in Reflection Trigger supports two domains, each with its own training dataset.

📥https://drive.google.com/file/d/1wXSLQaVNxFcyutNbCDSMtWhdfvzyKqLg/view?usp=sharing

>Place the downloaded `.pt` files inside the following folder:

```bash
reflection_trigger/gen_training_data/output/step3_vectors_gen
├── cs_domain_train_steer_vec_16_v0.pt
└── med_domain_train_steer_vec_16_v0.pt
```
### Training Command
Run the following example command to start training:
```bash
python reflection_trigger/train_trigger.py --data reflection_trigger/gen_training_data/output/step3_vectors_gen/cs_domain_train_steer_vec_16_v0.pt --save_path reflection_trigger/output_model/steering_trigger_cs_domain_16.pt --epochs 20 --output_dim 4096
```
| Argument | Default | Description |
|---|---|---|
| `--data` | required | Path to the `.pt` reflection vector dataset (generated in [Step 3](#3-training-data-construction) or [provided preprocessed](reflection_trigger/gen_training_data/output/step3_vectors_gen/). |
| `--save_path` | required  | Path to save the best trained checkpoint. |
| `--epochs` | `20` | Number of training epochs. |
| `--output_dim` | `4096` | Dimension of reflection steering vectors (depends on LLM hidden size). |

### Outputs
- Best checkpoint saved to [`reflection_trigger/output_model/`](reflection_trigger/output_model).
- Separate regressors trained for

    - 🧠Commonsense Reasoning (ARC-C, CSQA)
    - 🩺Biomedical Reasoning (MedQA, MedMCQA, MMLU-Med)
    
## 5. LLM Reasoning & Activation Injection
After training the BERT regressor, we use it to predict reflection steering vectors from unseen questions. These vectors are then injected into the hidden activations of LLM at a specific layer during inference.

>💡 We provide ready-to-use BERT regressor checkpoints, so you can directly run inference without retraining.

📥 https://drive.google.com/file/d/1PXjvczOcpl-vjROu7j3qbevNR4DgaWeK/view?usp=sharing

>Place the downloaded `.pt` files inside the following folder:

```bash
reflection_trigger/output_model/
├── steering_trigger_cs_domain_16_v0.pt
└── steering_trigger_med_domain_16_v0.pt
```
>📝Note: File names with `cs_domain` correspond to Commonsense Reasoning tasks (ARC-C, CSQA), while `med_domain` corresponds to Biomedical Reasoning tasks (MedQA, MedMCQA, MMLU-Med). Please make sure to use the correct checkpoint for your target dataset.

>📝Note: To run **Cross-domain** evaluation, simply specify a checkpoint from the other domain in `--trigger`.

### Inference Command
Run the following example command to start inference with vector injection:

```bash
python main.py --input orig_dataset/test/csqa_val.json --model meta-llama/Llama-3.1-8B-Instruct --trigger reflection_trigger/output_model/steering_trigger_cs_domain_16.pt --layer 16 --coeff 1.0
```
| Argument | Default | Description |
|---|---|---|
| `--input` | required |Path to the test dataset JSON file (e.g., `csqa_val.json`, `medqa_test.json`, `arc_c_test`, `medmcqa_val`, `mmlu_test`). |
| `--output` | auto  | Path to save the inference results. If not provided, an output file will be automatically generated in `results_output/`. |
| `--model` | `meta-llama/Llama-3.1-8B-Instruct` | LLM used for reasoning. |
| `--trigger` | required | Path to the trained BERT regressor checkpoint (`.pt` file). |
| `--layer` | `16` | Transformer layer index in LLM where the reflection steering vector is injected. |
| `--coeff` | `1.0,2.0,3.0,5.0` | Comma-separated list of coefficients controlling the injection strength. |

### Outputs
- After running inference, the results are automatically saved under [`results_output/`](/results_output).
- The file name is generated based on the dataset name and layer index: `{dataset_name}_infer_layer{layer}.json`

## 6. Evaluation
After inference is completed, we provide a simple evaluation script to check how reflection steering affects QA accuracy.

### Example Command
```bash 
python evaluate.py --input results_output/arc_c_test_infer_v0_layer16.json
```

>💡We also provide ready-to-use inference outputs (e.g., results_output/arc_c_test_infer_v0_layer16.json) so you can directly run the evaluation script without first running inference.

## 7. Analysis
We provide multiple analysis scripts under the [`analysis/`](analysis/) folder to help evaluate and visualize the effects of reflection steering.

>💡We also provide ready-to-use inference outputs, so you can run the analysis directly without re-generating inference results: `results_output/arc_c_test_infer_v0_layer16.json`, `results_output/csqa_val_infer_v0_layer16.json`, `results_output/mmlu_test_infer_v0_layer16.json`

- Analysis 5.1: Comparison with Prompt-based Methods

    - Counts the number of tokens in generated answers to measure efficiency.
    ```bash
    python analysis/token_count.py
    ```

- Analysis 5.2: Parameter Sensitivity
    
    - Evaluates how steering parameters (`layer`, `coeff`) affect QA performance.
    ```bash
    python analysis/param_sensitivity.py
    ```

- Analysis 5.3: Analysis of Training Data Efficiency
    
    - We analyze the impact of training data volume on the development of reflective capability.
    ```bash
    python analysis/train_efficiency.py
    ```

- Analysis 5.4: Impact of Reflection Intensity on Task Difficulty
    
    - Studies the impact of reflection strength (coefficient scaling) on task difficulty.
    ```bash
    python analysis/reflection_intensity_analysis.py
    ```

- Analysis 5.5: Visualization of Reflective Representations
    
    - Uses UMAP to visualize hidden state embeddings of reflective vs. non-reflective responses.
    ```bash
    python analysis/reflective_vis.py
    ```

- All analysis figures and processed results are saved in [`analysis/output/`](analysis/output).