# AVOID CATASTROPHIC FORGETTING WITH RANK-1 FISHER FROM DIFFUSION MODELS (Supplementary)

Code and instructions to reproduce continual learning results with diffusion models using EWC (diag and rank‑1) and Generative Distillation.

## Setup
- Python 3.8+; CUDA optional.
- Install dependencies:
```powershell
pip install torch torchvision diffusers torchmetrics tqdm datasets wandb
```

## Run (MNIST demo)
The script defaults to `example_config.json`. This runs MNIST, 5 tasks (group_size=2), EWC with rank‑1 Fisher, 10 epochs per task.
```powershell
python train-model.py
```
Change any setting by editing `example_config.json`, or pass another file:
```powershell
python train-model.py --config path\to\config.json
```

Outputs (models, samples, logs) are saved to `output_dir` from the config.

## Config Reference
Key fields in `example_config.json`:
- `dataset`: mnist | fmnist | cifar10 | imagenet32
- `num_classes` and `group_size`: define tasks (tasks = num_classes/group_size)
- `use_ewc`, `ewc_lambda`, `ewc_fisher_type`: diag | rank1_opt
- `use_generative_replay`, `use_distillation`, `gr_*`: enable replay/distillation
- `epochs`, `batch_size`, `lr`, `seed`, `normalize`, `greyscale`
- `use_wandb`, `wandb_project`, `wandb_run_name`, `output_dir`

Notes:
- ImageNet32 uses HF dataset `benjamin-paine/imagenet-1k-32x32`; data/cache default to `./data` and `./cache`. Override via env vars `DATA_ROOT` and `HF_DATASETS_CACHE`.
- If `wandb` is not installed, logging is disabled automatically.

## Fisher Utilities
Compute Fisher statistics after a task within training (see `src/parameter_scoring.py`). To experiment separately with small models:
```python
from src.parameter_scoring import compute_rank1_coeff_and_mean
from src.ddim import get_model

model = get_model(channels=1, im_size=32, device='cuda', num_classes=10, model_size='small-big')
c_star, mu, diag = compute_rank1_coeff_and_mean(model, train_loader, device='cuda')
```
To replicate our experiments, use `group_size = num_classes` and train a 'small-big' model on the MNIST dataset.

## Structure
- `train-model.py`: training/eval loop with EWC and optional GR
- `src/ddim.py`: class‑conditional DDIM model and samplers
- `src/utils.py`: datasets, training, FID
- `src/parameter_scoring.py`, `src/ewc.py`, `src/gr.py`, `src/fisher_analysis.py`