# DBVI / DeepGP Training Script

This repository contains a **Deep Gaussian Process (DeepGP)** training script with **Diffusion Bridge Variational Inference (DBVI)**.
The implementation supports input-dependent amortizers, observation-conditioned diffusion bridges (Doob h-transform), and SDE-based regularization.

## ✨ Features

* **Deep Gaussian Processes (DeepGP)** with multiple hidden layers.
* **DBVI variational distribution** using SDE bridge drift and conditional score loss.
* **Amortizer networks** (one per layer).
* **Supported datasets**:

  * `sine` (synthetic sinusoidal data)
  * `concrete` (Concrete Strength dataset, requires `Concrete_Data.xls`)
  * `power` (Power Consumption dataset, requires `power.xlsx`)
  * `energy` (Energy dataset, requires `energy.csv`)
  * `boston` (Boston Housing dataset, fetched from OpenML)

## 📦 Dependencies

Make sure you have the following installed:

```bash
torch
torchsde
gpytorch
numpy
pandas
scikit-learn
tqdm
linear-operator
```

## 🚀 Usage

Run the training script:

```bash
python dbvi.py --dataset concrete --data_path energy.csv --epochs 2000 --layers 2
```

### Key Arguments

* `--dataset {sine,concrete,power,energy,boston}`
  Dataset to use.
* `--data_path PATH`
  Path to the dataset file (for Excel/CSV datasets).
* `--layers N`
  Number of hidden layers (default: 2).
* `--epochs N`
  Number of training epochs (default: 200).
* `--batch_size` / `--test_batch_size`
  Training and test batch sizes.
* `--num_inducing`
  Number of inducing points per layer.
* `--lr`
  Learning rate (default: 1e-2).
* `--t1`
  SDE integration horizon (default: 1e-3).
* `--loss_sde_w`
  Weight for the SDE regularization term.
* `--device {cuda,cpu}`
  Compute device to use.
* `--log_elbo_csv`
  Path to save ELBO/RMSE logs as CSV.
* `--log_every`
  Log every N minibatches.

## 📊 Output

* During training, RMSE and ELBO values are printed.
* After training, results are written to the CSV file specified by `--log_elbo_csv` with the following format:

  ```
  step, epoch, minibatch, rmse
  ```


