# 🏥 Self-Improving Medical Visual Question Answering through Reasoning Trajectory Clustering

<div align="center">

[![ICLR 2026](https://img.shields.io/badge/Conference-ICLR%202026-blue)](https://iclr.cc/)
[![License: MIT](https://img.shields.io/badge/License-MIT-yellow.svg)](https://opensource.org/licenses/MIT)
[![Python](https://img.shields.io/badge/Python-3.8%2B-green)](https://www.python.org/)
[![PyTorch](https://img.shields.io/badge/PyTorch-2.0%2B-red)](https://pytorch.org/)
[![Paper](https://img.shields.io/badge/Paper-Under%20Review-orange)]()

**Official Implementation** | [📄 Paper (Under Review)]() | [💾 Datasets]() | [🤗 Models]()

</div>

---

## 📋 Abstract

While large language models have shown promise in medical applications, their performance in medical visual question answering (VQA) remains limited by insufficient vision-language reasoning capabilities. We address this challenge through two complementary approaches. First, we generate high-quality reasoning annotations for existing medical VQA datasets using COMCTS algorithm. Second, we introduce a self-improvement framework that bootstraps model performance by learning from its own outputs, guided by a small set of high-quality reasoning samples. To optimize this self-improvement process, we propose a novel filtering mechanism based on reasoning trajectory K-medoids clustering, which employs Dynamic Time Warping (DTW) distances to select the most effective generated reasoning paths. Our comprehensive approach demonstrates significant improvements in medical VQA tasks. We release both the COMCTS-generated reasoning datasets and our code to support future research.

## 🎯 Overview

<div align="center">
  <img src="https://img.shields.io/badge/🔬-Medical%20VQA-brightgreen?style=for-the-badge" alt="Medical VQA"/>
  <img src="https://img.shields.io/badge/🤖-Self%20Improvement-blue?style=for-the-badge" alt="Self Improvement"/>
  <img src="https://img.shields.io/badge/📊-DTW%20Clustering-purple?style=for-the-badge" alt="DTW Clustering"/>
  <img src="https://img.shields.io/badge/🧠-Reasoning%20Paths-orange?style=for-the-badge" alt="Reasoning"/>
</div>

<br/>

This project focuses on medical visual question answering with the following key components:

| Component | Description |
|-----------|-------------|
| 🔧 **COMCTS-based reasoning** | High-quality reasoning dataset generation for medical VQA |
| 🚀 **Self-improvement framework** | Learning from model's own outputs |
| 🎯 **Novel filtering mechanism** | DTW-based trajectory clustering |
| 📊 **Comprehensive evaluation** | Multiple medical VQA datasets |

## 📚 Datasets

We work with five medical VQA datasets:

| Dataset | Domain | Icon |
|---------|--------|------|
| **VQA-RAD** | Radiology visual question answering | 🩻 |
| **Slake-VQA** | Semantically-labeled knowledge-enhanced dataset | 🏷️ |
| **Path-VQA** | Pathology visual question answering | 🔬 |
| **PMC-VQA** | PubMed Central visual question answering | 📖 |
| **Omnimed-VQA** | Comprehensive medical visual question answering | 🏥 |

### 💾 Self-Improvement Datasets

> **📢 Note:** Upon acceptance, these files will be available for download from [Google Drive](your-google-drive-link-here)

```bash
self_imp_processed_files/
├── 📄 DS-R1-Qwen-1.5B-{dataset}-COMCTS-SelfImprove-Inference_self_imp_processed_combined.json
└── 📄 DS-R1-Llama-8B-{dataset}-COMCTS-SelfImprove-Inference_self_imp_processed_combined.json
```

## 🗂️ Repository Structure

### ⚙️ Core Components

| File | Purpose |
|------|---------|
| 🎯 **`train.py`** | Main training script for all model configurations |
| 🔮 **`inference.py`** | Inference script for generating model outputs |
| 📚 **`dataset_registry.py`** | Dataset management utilities |
| 📝 **`dataset_registry_instruct.py`** | Instruction-based dataset utilities |

### 🏗️ Model Architecture

```
files/model/
├── 🧱 blocks.py              # Transformer building blocks
├── 💬 llm.py                 # Language model components
├── 🔍 query_decoder.py       # Query decoding mechanisms
├── 🤖 transformer.py         # Transformer architecture
├── 👁️ vision_model.py        # Vision encoder components
├── 🔗 vl_model.py           # Complete vision-language model
├── ⚡ plmodel.py            # PyTorch Lightning wrapper
├── 📝 vqa_eval.py           # Text generation evaluation
└── 📊 vqa_eval_acc.py       # Accuracy evaluation metrics
```

### 🔄 Self-Improvement Pipeline

| Component | File | Function |
|-----------|------|----------|
| 🎯 **Feature Extraction** | `self_imp_reas_features_parallel.py` | Extract reasoning trajectory features and compute DTW distances |
| ✅ **Answer Verification** | `self_imp_reas_answer_check_parallel.py` | Verify answer correctness using LLM API |

### 📁 Generated Datasets

```
reasoning_datasets/          # 🧠 COMCTS-generated reasoning datasets
├── 📄 rad-vqa_train.json
├── 📄 rad-vqa_valid.json
└── ...

inference_json_files/        # 💾 Model inference outputs
├── 📄 {model_name}-{dataset}-inference.json
└── ...
```

## 🛠️ Installation

```bash
# 📥 Clone the repository
git clone [repository-url]

# 📦 Install dependencies
pip install -r requirements.txt
```

## 🚀 Usage

### 1️⃣ Training Models

#### 🎯 Baseline Training Example (Original Dataset Only) 
```bash
python train.py \
    --max_epochs 25 \
    --model_name deepseek-ai/DeepSeek-R1-Distill-Llama-8B \
    --seq_length 64 \
    --experiment_name DS-R1-Llama-8B-omnimed-vqa \
    --batch_size 18 \
    --accumulation 8 \
    --dataset_list omnimed-vqa \
    --project wandb-project-name
```

#### 🧠 COMCTS Reasoning Dataset Training Example
```bash
python train.py \
    --max_epochs 25 \
    --model_name deepseek-ai/DeepSeek-R1-Distill-Llama-8B \
    --seq_length 64 \
    --experiment_name DS-R1-Llama-8B-omnimed-vqa-COMCTS \
    --batch_size 18 \
    --accumulation 8 \
    --dataset_list omnimed-vqa-COMCTS \
    --project wandb-project-name
```

#### 🚀 Self-Improvement Training Example
```bash
python train.py \
    --max_epochs 25 \
    --model_name deepseek-ai/DeepSeek-R1-Distill-Llama-8B \
    --seq_length 64 \
    --experiment_name DS-R1-Llama-8B-omnimed-vqa-COMCTS-SelfImp \
    --batch_size 18 \
    --accumulation 8 \
    --dataset_list omnimed-vqa-COMCTS-SelfImp-iscorrect-Kmedoid_dtw_distances_ncluster_10_ratio_50 \
    --project wandb-project-name
```

### 2️⃣ Running Inference

```bash
python inference.py \
    --max_epochs 8 \
    --model_name deepseek-ai/DeepSeek-R1-Distill-Llama-8B \
    --seq_length 64 \
    --experiment_name DS-R1-Llama-8B-omnimed-vqa-inf \
    --batch_size 18 \
    --dataset_list omnimed-vqa \
    --project wandb-project-name \
    --ckpt_path /path/to/checkpoint.ckpt
```

### 3️⃣ Self-Improvement Pipeline

#### ✅ Answer Correctness Check
```bash
python self_imp_reas_answer_check_parallel.py \
    --experiment_name DS-R1-Llama-8B-omnimed-vqa-COMCTS-SelfImprove-Inference \
    --dataset omnimed-vqa
```

#### 📊 Reasoning Trajectory Clustering
```bash
python self_imp_reas_features_parallel.py \
    --experiment_name DS-R1-Llama-8B-omnimed-vqa-COMCTS \
    --dataset omnimed-vqa \
    --n_clusters 10 \
    --filter_ratio 0.5
```

## 🤖 Model Configurations

We support two main model configurations:

| Model | Configuration | Icon |
|-------|--------------|------|
| **DS-R1-Qwen-1.5B** | DeepSeek-R1-Distill-Qwen-1.5B + ViT-B/16 | 🟢 |
| **DS-R1-Llama-8B** | DeepSeek-R1-Distill-Llama-8B + ViT-B/16 | 🔵 |

## 📈 Training Strategies

Our framework implements four training strategies:

| Strategy | Description | Icon |
|----------|-------------|------|
| **M-base** | Training only on original dataset | 🎯 |
| **M-reas** | Training on COMCTS reasoning dataset + fine-tuning on original | 🧠 |
| **M-SI-ca** | Training on reasoning dataset + correct self-improvement samples + fine-tuning | ✅ |
| **M-SI-ca+cl** | Training on reasoning dataset + DTW-filtered self-improvement samples + fine-tuning | 🚀 |

## 📊 Results

Our method achieves significant improvements across medical VQA datasets:

| Model | Dataset | M-base | M-reas | M-SI-ca | M-SI-ca+cl (Ours) 🏆 |
|-------|---------|--------|--------|---------|-------------------|
| DS-R1-Llama-8B | 🩻 VQA-RAD | 37.82 | 41.76 | 41.53 | **42.46** |
| DS-R1-Llama-8B | 🏷️ Slake-VQA | 68.0 | 67.2 | 65.2 | **70.8** |
| DS-R1-Llama-8B | 🔬 Path-VQA | 50.6 | 48.2 | 49.4 | 47.2 |
| DS-R1-Llama-8B | 🏥 Omnimed-VQA | 46.2 | 47.2 | 45.8 | **47.4** |
| DS-R1-Llama-8B | 📖 PMC-VQA | **5.6** | 4.4 | 3.4 | 3.6 |
| **📈 Mean** | | 39.57 | 39.61 | 39.82 | **40.59** ✨ |

> 📌 Accuracy values for the raw medical visual question answering datasets.

## ✨ Key Features

### 1️⃣ COMCTS Reasoning Generation
- 🤖 Uses Gemma-3-27B and Qwen2-VL-7B for reasoning path generation
- 🧠 DeepSeek-R1 model for reasoning verification
- 📝 Generates high-quality chain-of-thought annotations

### 2️⃣ Self-Improvement Framework
- 🔄 Learns from model's own generated outputs
- 🎯 Two-stage filtering: correctness + trajectory clustering
- 📊 DTW-based similarity computation for reasoning paths

### 3️⃣ Trajectory Clustering
- 📈 K-medoids clustering with DTW distances
- 🗑️ Filters noisy reasoning paths
- ✨ Preserves high-quality reasoning trajectories

## 📝 Citation

If you use this code or datasets in your research, please cite our paper:

```bibtex
@article{anonymous2026selfimproving,
  title={Self-Improving Medical Visual Question Answering through Reasoning Trajectory Clustering},
  author={Anonymous Authors},
  journal={Under review at ICLR 2026},
  year={2026}
}
```

## 🔗 Related Work

This project builds upon several key papers:

| Project | Links |
|---------|-------|
| **PMC-CLIP** | [📄 Paper](https://arxiv.org/abs/2303.07240) \| [💻 Code](https://github.com/WeixiongLin/PMC-CLIP/) |
| **PMC-LLAMA** | [📄 Paper](https://arxiv.org/abs/2304.14454) \| [💻 Code](https://github.com/chaoyi-wu/PMC-LLaMA) |
| **COMCTS** | [📄 Paper](https://arxiv.org/abs/2412.18319) \| [💻 Code](https://github.com/HJYao00/Mulberry) |

## 📜 License

This project is released under the MIT License. See [LICENSE](LICENSE) for details.

## 📧 Contact

For questions about the code or paper, please open an issue in this repository.

---

<div align="center">

**📢 Note**: This repository contains the implementation for a paper under review. Some details may be updated upon publication.

⭐ **Star this repository if you find it helpful!** ⭐

</div>