# GVALID: Gradient Variance Active Learning for Individualized Dosing

A comprehensive implementation of **GVALID** (Gradient Variance Active Learning for Individualized Dosing), an active learning framework designed for optimal policy discovery in continuous treatment effect estimation.

## 1 Project Structure

```
├── raw/                                # Raw data storage (e.g., news_pp_V2.npy)
├── plots/                              # Automatically generated dose-response curves
│   └── {DatasetName}/dose_curves.pdf   # Visualizations of dataset characteristics
│ 
├── results/                            # Experimental logs and results
│   ├── validation_logs/                # Logs for theoretical assumption validation
│   └── {GroupName}/{SettingName}/      # CSV and PKL files for performance metrics
│ 
├── src/                                # Core source code
│   ├── datasets.py                     # Synthetic & Semi-synthetic data (News, HardNonLinear, etc.)
│   ├── gp_model.py                     # GPyTorch wrapper with GPU optimization and auto-normalization
│   ├── samplers.py                     # Implementation of GVALID () and baseline samplers
│   ├── evaluator.py                    # Metrics calculation (Policy Suboptimality, MSE, Dose Error)
│   ├── validation_utils.py             # Theoretical verification tools (Assumption A1-A4, Theorem 1)
│   └── utils.py                        # Torch-based parallel acquisition function optimization
│ 
├── main.py                             # Distributed multi-GPU experiment entry point
├── experiments.yml                     # Configuration file for ablation and group settings
└── requirements.txt                    # Project dependencies

```

## 2 Features

* **GVALID**: Policy-oriented active learning that minimizes gradient variance at optimal points.
* **GPU Acceleration**: Entire acquisition process (optimization of ) and GP training performed on CUDA.
* **Robust Baselines**: Includes implementation of **ABC3** (Cohn Criterion), **CATS**, **GPUCB**, **Thompson Sampling**, and **Policy Gradient** for continuous dosing.
* **Auto-Normalization**: `GPModel` handles input/output scaling internally to ensure numerical stability during AL loops.
* **Theoretical Validation**: Built-in tools to verify theoretical assumptions (Strong Concavity, Confidence Coverage) during runtime.

## 3 Datasets

The project focuses on continuous treatment effect estimation across diverse response surfaces:

1. **SemiSynthNews**: Semi-synthetic dataset based on real-world news features ().
2. **HardNonLinear8D**: Highly non-linear synthetic data with variable curvature.
3. **ComplexSharpConcave**: Asymmetric response curves with sharp peaks to test exploration efficiency.
4. **SimpleWavy8D**: Sinusoidal interference on top of concavity to simulate local optima.

## 4 Usage

### 4.1 Running Experiments

The framework supports automated group experiments across multiple GPUs.

**Basic Run:**

```bash
python main.py --dataset SemiSynthNews --n_seeds 20 --group Group_Init_Ablation --num_workers 4

```

**With Theoretical Validation:**

```bash
python main.py --dataset HardNonLinear8D --validate_theory --group Main_Comparison

```

### 4.2 Configuration (`experiments.yml`)

Define your experimental groups, budget ratios, and batch sizes:

```yaml
Group_Init_Ablation:
  - setting_name: "Low_Budget"
    total_budget_ratio: 0.05
    init_ratio: 0.01
    batch_size: 4

```

## 5 Sampling Algorithms

| Sampler | Strategy Type | Objective |
| --- | --- | --- |
| **GVALID** | Policy-Oriented | Minimizes gradient variance at estimated  |
| **ABC3** | Estimation-Oriented | Minimizes integrated posterior variance (Cohn) |
| **GPUCB** | Bandit-Based | Maximize  |
| **CATS** | Bandit-Based  | Smoothed Mean over continuous space |
| **PG** | Reinforcement | Policy Gradient optimization of EIG proxy |

## 6 Requirements

* **Python**: 3.9+
* **PyTorch**: 2.0+ (CUDA support recommended)
* **GPyTorch**: 1.11+
* **Scipy**: 1.10+ (for LHS initialization)
* **psutil**: For robust multi-process management
* **Pandas/Matplotlib**: For result analysis and visualization

## 7 Experimental Configuration

#### Hardware Specifications

* **Graphics Processing Units (GPUs)**: 4x NVIDIA GeForce RTX 3090 / 4090
* **Parallelization**: `ProcessPoolExecutor` with GPU-ID cycling for asynchronous execution across seeds and samplers.

#### Numerical Stability

* **Precision**: `torch.float64` used for all GP computations.
* **Jitter**: Adaptive Cholesky jitter ( to ) to prevent PSD errors during active learning updates.

