# EHR-ChatQA

A benchmark framework for evaluating Large Language Model (LLM) agents on Electronic Health Record (EHR) database question-answering (QA) tasks. This project assesses agent performance on two EHR databases (MIMIC-IV* and eICU*) across incremental and adaptive QA scenarios.

## 📋 Overview

EHR-ChatQA evaluates LLM agents' ability to:
- Navigate and query complex EHR databases
- Utilize various search and retrieval tools
- Handle incremental refinement tasks (`incre`)
- Handle adaptive refinement tasks (`adapt`)

## 🏗️ Project Structure

```
EHR-ChatQA-Private/
├── README.md
├── run.py                      # Main entry point for running experiments
├── validator.py                # Validator for user simulation
├── results/                    # Experiment results
├── src/
│   ├── agent_factory.py        # Agent initialization
│   ├── agents/                 # Agent implementations
│   │   ├── base.py
│   │   ├── tool_calling_agent.py           # Full tool access
│   │   ├── tool_calling_agent_no_tool.py   # No web search & no value search
│   │   └── tool_calling_agent_no_web.py    # No web search
│   ├── envs/                   # Environment configurations
│   │   ├── base.py
│   │   ├── user.py             # User simulator
│   │   ├── rules.py
│   │   ├── mimic_iv_star/      # MIMIC-IV* database environment
│   │   │   ├── env.py
│   │   │   ├── mimic_iv_star.sqlite
│   │   │   ├── eval_incre.jsonl
│   │   │   ├── eval_adapt.jsonl
│   │   │   └── tools/
│   │   └── eicu_star/          # eICU* database environment
│   │       ├── env.py
│   │       ├── eicu_star.sqlite
│   │       ├── eval_incre.jsonl
│   │       ├── eval_adapt.jsonl
│   │       └── tools/
│   ├── types.py                # Type definitions
│   └── utils.py                # Utility functions
└── .env                        # API keys and environment variables
```

## 🚀 Installation

### Prerequisites

- Python 3.8+
- CUDA-compatible GPUs (for local model deployment)
- API keys for commercial models (OpenAI, Anthropic, Google, etc.)

### Setup

1. Clone the repository:
```bash
git clone <repository-url>
cd EHR-ChatQA-Private
```

2. Install dependencies:
```bash
pip install -r requirements.txt
```

3. Configure environment variables in `.env`:
```bash
# API Keys
OPENAI_API_KEY=your_openai_api_key # default embedding model for value similarity search (text-embedding-3-large)
GOOGLE_API_KEY=your_google_api_key # only if you want to use gemini
TAVILY_API_KEY=your_tavily_api_key # only if you want to use web search
```

## 📖 Usage

### Command Line Arguments

| Argument | Type | Required | Description |
|----------|------|----------|-------------|
| `--env` | str | Yes | Environment: `mimic_iv_star`, `eicu_star`, or `all` |
| `--task_type` | str | Yes | Task type: `incre` (incremental) or `adapt` (adaptive) |
| `--model` | str | Yes | Model name (e.g., `gpt-5`, `Qwen/Qwen3-32B`) |
| `--agent_strategy` | str | Yes | Agent strategy: `tool-calling`, `tool-calling-no-tool`, `tool-calling-no-web` |
| `--api_base` | str | No | Custom API base URL (required for local models) |
| `--temperature` | float | No | Sampling temperature for agent (default: 0.0) |
| `--num_trials` | int | No | Number of trials per task (default: 5) |
| `--max_concurrency` | int | No | Maximum concurrent tasks (default: 1) |
| `--max_agent_turns` | int | No | Maximum agent turns per task (default: 30) |
| `--timeout` | int | No | Timeout per task in seconds (default: 600) |
| `--user_model` | str | No | User simulator model (default: `gemini/gemini-2.0-flash`) |
| `--user_strategy` | str | No | User strategy: `llm`, `react`, `verifier`, `reflection`, `hierreflection` (default) |
| `--validation_model` | str | No | Validation model (default: `gemini/gemini-2.5-flash`) |
| `--verbose` | flag | No | Print detailed conversations during execution |

### Agent Strategies

- **`tool-calling`**: Full access to all tools (table search, column search, SQL execute, value search, web search)
- **`tool-calling-no-tool`**: Limited to basic tools only (`table_search`, `column_search`, `sql_execute`)
- **`tool-calling-no-web`**: All tools except web search (`table_search`, `column_search`, `sql_execute`, `value_substring_search`, `value_similarity_search`)

### Available Tools

1. **`table_search`**: Search for relevant tables in the database
2. **`column_search`**: Search for columns within specific tables
3. **`sql_execute`**: Execute SQL queries on the database
4. **`value_substring_search`**: Search for values using substring matching
5. **`value_similarity_search`**: Search for values using semantic similarity
6. **`web_search`**: Search the web for additional information

## 🔬 Example Usage

### GPT-5 Evaluation

```bash
# Incremental QA tasks
python run.py --env all \
    --task_type incre \
    --model gpt-5 \
    --agent_strategy tool-calling \
    --num_trials 1 \
    --max_concurrency 8

# Adaptive QA tasks
python run.py --env all \
    --task_type adapt \
    --model gpt-5 \
    --agent_strategy tool-calling \
    --num_trials 5 \
    --max_concurrency 16
```

### Qwen 3 32B (Local Deployment)

**Step 1: Start vLLM Server**
```bash
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python -m vllm.entrypoints.openai.api_server \
    --model Qwen/Qwen3-32B \
    --load-format safetensors \
    --max-model-len 32768 \
    --tensor-parallel-size 8 \
    --port 8001 \
    --enable-auto-tool-choice \
    --tool-call-parser llama3_json \
    --gpu-memory-utilization 0.85 \
    --download-dir /nfs_data_storage/huggingface
```

**Step 2: Run Evaluation**
```bash
# Incremental QA tasks
python run.py --env all \
    --model Qwen/Qwen3-32B \
    --api_base http://localhost:8001/v1 \
    --task_type incre \
    --agent_strategy tool-calling \
    --temperature 0.0 \
    --num_trials 5 \
    --max_concurrency 32

# Adaptive QA tasks
python run.py --env all \
    --model Qwen/Qwen3-32B \
    --api_base http://localhost:8001/v1 \
    --task_type adapt \
    --agent_strategy tool-calling \
    --temperature 0.0 \
    --num_trials 5 \
    --max_concurrency 32
```

## 📊 Results

Results are saved in the `results/` directory with the following naming convention:
```
{db_id}-{task_type}-{agent_strategy}-{model}-{model_temperature}_k={num_trial}_range_{task_id_range}_user-{user_strategy}-{user_model}-{user_temperature}_{timestamp}.json
```
Each result includes:
- Task completion status (success/failure)
- Reward (0.0 or 1.0)
- Agent-user conversation history
- Cost information (agent, user simulator, validation costs)
- Validation results

## 🔧 Advanced Configuration

### Running Specific Tasks

```bash
# Run tasks 0-10 only
python run.py --env mimic_iv_star \
    --task_type incre \
    --model gpt-5 \
    --agent_strategy tool-calling \
    --start_index 0 \
    --end_index 10

# Run specific task IDs
python run.py --env eicu_star \
    --task_type adapt \
    --model gpt-5 \
    --agent_strategy tool-calling \
    --task_ids 5 10 15 20
```

### Verbose Mode

Enable detailed logging of agent-user conversations:
```bash
python run.py --env all \
    --task_type incre \
    --model gpt-5 \
    --agent_strategy tool-calling \
    --verbose
```

## 🗄️ Database Environments

### MIMIC-IV-STAR
- Clinical data from ICU patients at Beth Israel Deaconess Medical Center
- Tasks: `eval_incre.jsonl`, `eval_adapt.jsonl`
- Database: `mimic_iv_star.sqlite`

### eICU-STAR
- Multi-center ICU database
- Tasks: `eval_incre.jsonl`, `eval_adapt.jsonl`
- Database: `eicu_star.sqlite`

## 📝 Task Types

### Incremental (`incre`)
Tasks that build upon previous queries with refinements and additional constraints.

### Adaptive (`adapt`)
Tasks requiring adaptation to new information, corrections, or changing requirements.

## 🛡️ Validation System

The framework includes a sophisticated validation system that:
1. Verifies user simulator behavior
2. Detects and handles errors in simulation
3. Automatically retries failed simulations with feedback
4. Tracks validation costs separately

## 💰 Cost Tracking

All API costs are tracked separately:
- **Agent Cost**: LLM calls made by the agent
- **User Cost**: LLM calls made by the user simulator
- **Eval Cost**: LLM calls made during validation
- **Total Cost**: Sum of all costs
