# GradPCA: Leveraging NTK Alignment for Reliable Out-of-Distribution Detection

This repository provides an implementation of **GradPCA**, a method for detecting out-of-distribution (OOD) data using gradient-based representations. It includes tools for benchmarking GradPCA alongside a variety of baseline OOD detectors on standard image classification datasets.

---

## 🚀 Usage

We provide the following example notebooks:

- [`example.ipynb`](example.ipynb):  
  A minimal, self-contained demonstration of the `GradPCA` class across a range of datasets.

- [`benchmark.ipynb`](benchmark.ipynb):  
  A benchmark comparing **GradPCA** against several baseline OOD detectors on multiple OOD datasets.

---

## 🗂️ Datasets

The following datasets are supported and can be used for both **in-distribution (ID)** and **out-of-distribution (OOD)** evaluation:

`CIFAR-10`, `CIFAR-100`, `SVHN`, `Places`, `LSUN_r`, `LSUN_c`, `iSUN`, `Textures`, `ImageNet-1k`, `ImageNet-V2`, `iNaturalist (MOS)`, `Places (MOS)`, `SUN (MOS)`

### 📥 Manual Downloads Required

Some datasets must be downloaded manually:

- **ImageNet-1k**  
  [Download from the official site](https://image-net.org/download)

- **MOS Datasets** (for ImageNet OOD benchmarks):  
  - [`iNaturalist`](http://pages.cs.wisc.edu/~huangrui/imagenet_ood_dataset/iNaturalist.tar.gz)  
  - [`SUN`](http://pages.cs.wisc.edu/~huangrui/imagenet_ood_dataset/SUN.tar.gz)  
  - [`Places`](http://pages.cs.wisc.edu/~huangrui/imagenet_ood_dataset/Places.tar.gz)

> ⚠️ Place the extracted folders in the expected location (e.g., `data/datasets/`) for compatibility with dataset loaders.

### ✅ Auto-loaded Datasets

All other datasets (e.g., CIFAR, SVHN, LSUN, iSUN) are automatically downloaded when their respective loaders are instantiated.

---

## 🏛️ Models

The framework supports the following backbone architectures:

### ✅ ResNetV2 (BiT Models)

We use **Big Transfer (BiT)** pretrained models from the official [BiT repository](https://github.com/google-research/big_transfer). The following weights must be **downloaded manually**:

- **`BiT-M-R50x1` for CIFAR-10**  
  [Download](https://storage.googleapis.com/bit_models/BiT-M-R50x1-CIFAR10.npz)

- **`BiT-M-R101x1` for CIFAR-100**  
  [Download](https://storage.googleapis.com/bit_models/vtab/BiT-M-R101x1-run0-cifar100.npz)

- **`BiT-S-R50x1` for ImageNet-1k**  
  [Download](https://storage.googleapis.com/bit_models/BiT-S-R50x1.npz)

- **`BiT-M-R101x1` for ImageNet-1k**  
  [Download](https://storage.googleapis.com/bit_models/BiT-M-R101x1-ILSVRC2012.npz)

> 📥 Place the `.npz` files in `models/pretrained/` or the appropriate directory.

### ✅ ResNet (from TIMM, ported to JAX)

We provide **ResNet-34** models originally from the [TIMM](https://github.com/huggingface/pytorch-image-models) library, converted to **JAX** and finetuned. The pretrained JAX weights for CIFAR-100 are included with the repository.

### 🧩 Extensibility

The framework—and particularly our GradPCA implementation—is easily extensible to new model architectures. To add a model, implement it in JAX (optionally also in PyTorch for baseline comparison), and use its `TrainState` directly with GradPCA. Ensure that parameter names are compatible with GradPCA’s expected structure.

---

## 🧰 Setup

We provide a setup script compatible with:

- **macOS (arm64)**
- **Linux with CUDA 11.8**

To install the required dependencies, run:

```bash
bash setup.sh
```
This will create a dedicated conda environment named `gradpca_env` with the required packages.

> ⚠️ Note: This project integrates JAX, PyTorch, and TensorFlow within a single environment. Installation may require manual adjustments depending on your system configuration, CUDA version, and Python environment.
