> **Note**: This repository is anonymized for double-blind peer review.

# GEM-FI: Evidential Mixture Models with Fisher Modulation

GEM (Gated Evidential Mixtures) is a robust uncertainty estimation framework that combines Evidential Deep Learning (EDL) with Energy-Based Models (EBM). It provides reliable uncertainty quantification for deep learning models, particularly in Out-of-Distribution (OOD) detection scenarios.

## Features

-   **GEM-CORE**: The baseline Generative Evidential Model using a single Dirichlet head.
-   **GEM-MIX**: Extends GEM with a Mixture of Beliefs (MoB) to capture multi-modal distributions.
-   **GEM-FI**: Gated Evidential Mixtures with Fisher Modulation. Enhances GEM-MIX with Fisher Information (FI) regularization for better uncertainty calibration.
-   **OOD Detection**: Built-in mechanisms to detect OOD samples using aleatoric and epistemic uncertainty.
-   **Virtual Outlier Synthesis (VOS)**: Synthesizes virtual outliers near the decision boundary to improve OOD detection (optimized for benefits in GEM-FI).

## Installation

1.  Clone the repository.
2.  Install the required dependencies:

```bash
pip install -r requirements.txt
```

**Note**: The code is developed using **Python 3.11**, and is optimized for CUDA 12.1 and PyTorch 2.3.0.

## Data Preparation

Standard datasets (CIFAR-10, CIFAR-100, MNIST, SVHN, FashionMNIST, KMNIST) are downloaded automatically. For OOD and robustness benchmarks, please download the following datasets manually and extract them into the `data` directory:

*   **TinyImageNet**: Download `tiny-imagenet-200.zip` (e.g., from [Stanford CS231n](http://cs231n.stanford.edu/tiny-imagenet-200.zip)) and extract it to `./data/tiny-imagenet-200`.
*   **CIFAR-10-C**: Download from [Zenodo](https://zenodo.org/record/2535967) and extract to `./data/CIFAR-10-C`.
*   **MNIST-C**: Download from [Zenodo](https://zenodo.org/record/3239543) and extract to `./data/mnist_c`.

## Usage

The main entry point is `main.py`. You can train different variants of the GEM model using command-line arguments.

### Training Examples

**1. Train GEM-CORE (Baseline) on CIFAR-10:**

```bash
python main.py --ID_dataset CIFAR-10 --backbone ResNet18
```

**2. Train GEM-MIX (Mixture of Beliefs):**

```bash
python main.py --ID_dataset CIFAR-10 --use_mob --num_components 3
```

**3. Train GEM-FI (Gated Evidential Mixtures with Fisher Modulation):**

*   **CIFAR-10:**
    ```bash
    python main.py --ID_dataset CIFAR-10 --backbone ResNet18 --batch_size 128 \
      --num_epochs 100 --learning_rate 1e-3 --dropout_rate 0.1 --reg_param 1e-4 \
      --use_mob --num_components 3 --use_fi_regularization --fi_lambda 0.1
    ```

*   **MNIST:**
    ```bash
    python main.py --ID_dataset MNIST --backbone ConvNet3C3F --batch_size 64 \
      --num_epochs 50 --learning_rate 5e-4 --dropout_rate 0.05 --reg_param 1e-3 \
      --use_mob --num_components 3 --use_fi_regularization --fi_lambda 0.3
    ```

> [!NOTE]
> These commands reflect the exact hyperparameters used in the paper for reproducibility.

### Key Arguments

-   `--ID_dataset`: Choose between `CIFAR-10`, `CIFAR-100`, or `MNIST`.
-   `--backbone`: Architecture to use (e.g., `ResNet18`, `VGG16`).
-   `--use_mob`: Enables the Mixture of Beliefs (GEM-MIX) mode.
-   `--num_components`: Number of mixture components (default: 3).
-   `--use_fi_regularization`: Enables Fisher Information regularization (GEM-FI).
-   `--fi_lambda`: Strength of the FI regularization.
-   `--use_vos`: Enables Virtual Outlier Synthesis (VOS) for improved OOD detection.

## Project Structure

-   `main.py`: Main training and evaluation script.
-   `train.py`: Training loops and logic.
-   `utility.py`: Helper functions, data loading, and model loading.
-   `density_estimation.py`: Functions for fitting GDA/GMM for density-based uncertainty.
-   `conf_calibration.py`: Metrics for confidence calibration (ECE, etc.).
-   `load_corrupted.py`: Helpers for loading corrupted datasets (CIFAR-10-C, MNIST-C).

## Results

Results are saved in the `saved_results` directory, organized by dataset and model configuration.
