# Spectral Decomposed Variational Inference (SD-VI)

This repository contains the official PyTorch implementation for the paper: **"Spectral Decomposed Variational Inference: A Principled Framework for Posterior Covariance Modeling"**.

Our work introduces a novel variational inference (VI) framework, SD-VI, that moves the optimization of the posterior covariance from the traditional space of distributions into the spectral domain. By replacing the monolithic KL-divergence with a flexible, spectrally-decomposed regularizer, we gain fine-grained control over the posterior geometry. This objective is optimized by our efficient and provably convergent Proximal Spectral Optimization (PSO) algorithm, which discovers sparse, low-rank posterior structures in a principled manner.

This repository provides code to reproduce key results from our paper, including:
1.  **State-of-the-art uncertainty quantification for Bayesian Neural Networks (BNNs)** on CIFAR-10 (Table 1 in the paper).
2.  **Principled structure discovery in Bayesian Logistic Regression** on UCI datasets (related to Table 3 in the paper).

## Table of Contents
- [Project Structure](#project-structure)
- [Requirements and Installation](#requirements-and-installation)
- [How to Run the Experiments](#how-to-run-the-experiments)
  - [Experiment 1: Bayesian Neural Networks on CIFAR-10 (Table 1)](#experiment-1-bayesian-neural-networks-on-cifar-10-table-1)
  - [Experiment 2: Sparse GPs on UCI Datasets (Table 3)](#experiment-2-sparse-gps-on-uci-datasets-table-3)
- [Code Description](#code-description)
  - [`SD-VI_BNN(table1).py`](#sd-vi_bnntable1py)
  - [`SD_VI.py`](#sd_vipy)
  - [`table3.py`](#table3py)

## Project Structure

```
.
├── SD-VI_BNN(table1).py    # Main script for BNN experiments on CIFAR-10
├── SD_VI.py                # Core implementation of the SD-VI algorithm for logistic regression
├── table3.py               # Script to run SD-VI on UCI datasets
├── australian.txt          # UCI Australian credit dataset
├── breast-cancer.txt       # UCI Breast Cancer Wisconsin (Diagnostic) dataset
├── heart.txt               # UCI Heart Disease dataset
└── readme.md               # This file
```

## Requirements and Installation

This project is implemented in Python 3 using PyTorch. We recommend using a virtual environment (like `conda` or `venv`) to manage dependencies.

1.  **Clone the repository:**
    ```bash
    git clone https://anonymous.4open.science/r/SD-VI
    cd SD-VI
    ```

2.  **Create and activate a conda environment:**
    ```bash
    conda create -n sd-vi python=3.9
    conda activate sd-vi
    ```

3.  **Install the required packages:**
    You will need PyTorch (preferably with CUDA support for the BNN experiment), torchvision, scikit-learn, and pandas.

    ```bash
    # For CUDA-enabled systems (recommended for BNNs)
    pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118

    # For CPU-only systems
    # pip install torch torchvision torchaudio

    pip install numpy scikit-learn pandas
    ```
    This setup has been tested with PyTorch `2.1.0` and Python `3.9`.

## How to Run the Experiments

### Experiment 1: Bayesian Neural Networks on CIFAR-10 (Table 1)

This script trains and evaluates a Wide ResNet-28-10 model on the CIFAR-10 dataset using both our proposed **SD-VI** and a baseline **Mean-Field VI (MFVI)**. It will download the CIFAR-10 dataset automatically.

**To run the experiment:**
```bash
python SD-VI_BNN\(table1\).py
```

- **Expected Output:** The script will first train the SD-VI model for 50 epochs, followed by the MFVI model for 50 epochs. During training, it will print progress updates, including loss and accuracy. After both models are trained, it will print a final summary table comparing the two methods on Accuracy, Expected Calibration Error (ECE), Negative Log-Likelihood (NLL), and training time. This table corresponds to the results presented in Table 1 of our paper.

- **Note on Runtimes:** Training the Wide ResNet models is computationally intensive. On a modern GPU like an NVIDIA RTX 3080 or better, this script may take several hours to complete.

### Experiment 2: Sparse GPs on UCI Datasets (Table 3)

This script demonstrates our SD-VI framework on a Bayesian Logistic Regression task, focusing on discovering a parsimonious posterior covariance structure. It runs our method on one of the provided UCI datasets and reports key performance metrics.

**To run the experiment on the "Heart" dataset:**
```bash
python table3.py
```
This will run the experiment 5 times with different random seeds (and data splits) and report the median and 95% confidence interval for Test AUC, MSE, ECE, and the final objective value.

**To run on other datasets:**
Simply edit the `DATA_FILE` variable inside the `if __name__ == "__main__":` block in `table3.py`.
```python
# In table3.py, change this line:
DATA_FILE = "heart.txt" 

# to:
# DATA_FILE = "australian.txt"
# or
# DATA_FILE = "breast-cancer.txt"
```

You can also tune the hyperparameters (`LR_MU`, `LR_S`, `LAMBDA_1`, `MAX_ITER`) for SD-VI directly within the script to explore their effects.

## Code Description

### `SD-VI_BNN(table1).py`

This is a self-contained script for the BNN experiment. It includes:
- **`SDVILayer`**: A module that manages the mean (`mu`) and full covariance (`S`) for a layer's weights. It contains the core `pso_step` method which implements our Proximal Spectral Optimization.
- **`BayesianConv2d_SDVI` & `BayesianLinear_SDVI`**: Bayesian layers that wrap `SDVILayer` to be used in a standard PyTorch model.
- **`WideResNet`**: The implementation of the Wide ResNet-28-10 architecture, adapted to accept different Bayesian layer types (`sdvi`, `mfvi`).
- **Training and Evaluation Loops**: Functions to train the SD-VI and MFVI models and evaluate them on key uncertainty metrics (Accuracy, ECE, NLL).

### `SD_VI.py`

This script provides a standalone, class-based implementation of our SD-VI algorithm for Bayesian logistic regression. It is designed for clarity and to showcase the core mechanics of the method.
- **`SD_VI` class**: Encapsulates the entire logic.
  - `__init__`: Initializes variational parameters (`mu`, `S`) and spectral regularizer hyperparameters (`lambda1`, etc.).
  - `fit`: Contains the main optimization loop, which alternates between updating `mu` with a standard optimizer and updating `S` using our PSO algorithm.
  - `_proximal_spectral_map`: The heart of the algorithm. This method performs the eigendecomposition, applies the analytical scalar shrinkage to the eigenvalues, and reconstructs the covariance matrix, as detailed in our paper.

### `table3.py`

This script serves as a driver to run experiments using the `SD_VI` class on various UCI datasets.
- **`load_data`**: A robust data loader that can handle different text-based formats (including libsvm) and performs necessary preprocessing like standardization and label mapping.
- **`run_multiple_experiments`**: A utility function that runs the `SD_VI.fit` method multiple times with different train/test splits to ensure robust evaluation, reporting median and confidence intervals.


