# Early-Stopping with CND (Conditional Neuron Divergence)

This repository contains code to train image/text classifiers and track early-stopping proxies based on:
- **CND** – Conditional Neuron Divergence (per-neuron Jensen–Shannon divergence across classes)
- **PC** – Changed predictions between epochs
- **KPA** – Accuracy on known polluted (noisy) samples

The pipeline runs fully from JSON simulation files and saves models, logs, and plots per run.

## TL;DR – Quick start

```bash
# 1) Create a fresh environment (example with conda)
conda create -n cnd python=3.10 -y
conda activate cnd

# 2) Install dependencies
pip install -r requirements.txt

# 3) (Optional) Prepare additional data (CIFAR-N labels, Tiny ImageNet, NEWS – see sections below)

# 4) Run a simulation
python programs/main.py early_stopping_CIFAR10_real_noise

# Results (models, plots, logs) will appear under models/<experiment>/exp_<timestamp>_<job>_sim_<k>/.
```

---

## Installation

Python: 3.9–3.11 recommended (tested on 3.10)

Install required packages using:

```bash
pip install -r requirements.txt
```

---

## How the code is organized

```bash
programs/
  main.py                         # Entry-point that runs simulations from JSON files
  simulations/*.json              # Ready-to-run experiment grids
  load_dataset.py                 # All dataset loading and noise injections
  network_structure.py            # Model factory (FCN, ResNet, News MLP)
  net_resnet.py, net_FCN.py       # Actual model definitions
  model_trainining_and_evaluation.py
  performances.py                 # Metric container + plotting
  cnd.py, jsd.py                  # CND computation (JSD on neuron activations)
  polluting_samples.py            # Label/image noise utilities
  dataset_classes.py              # Dataset wrapper with index tracking & perturbations
  news.py                        # Helpers for the NEWS dataset
models/
  ...                            # Outputs per run (created automatically)
```

---

## Running simulations

Each simulation is defined by a JSON under `programs/simulations/`. The runner expands any list-valued fields into a full Cartesian product and executes each configuration sequentially.

```bash
python programs/main.py <job_name_without_.json>
```

### Simulations available

```bash
# CIFAR-10 with human-annotated label noise variants (CIFAR-10N)
python programs/main.py early_stopping_CIFAR10_real_noise

# CIFAR-10 with symmetric (synthetic) label noise 30%
python programs/main.py early_stopping_CIFAR10_symmetric

# CIFAR-100 with symmetric label noise
python programs/main.py early_stopping_CIFAR100_symmetric

# CIFAR-100 with human-annotated label noise (CIFAR-100N)
python programs/main.py early_stopping_CIFAR100_real_noise

# MNIST (FCN) with symmetric label noise
python programs/main.py early_stopping_MNIST_symmetric

# 20 Newsgroups (NEWS) – needs extra data; see data section
python programs/main.py early_stopping_NEWS20
```

### What gets saved per run

- `models/<experiment>/<exp_*_sim_k>/args.pkl` – resolved arguments
- `.../performances.pkl` – tracked metrics
- `.../training_curves_<ts>.png` – accuracy/loss curves
- `.../CND_layers_<ts>.png` – per-layer CND summary (if enabled)
- `.../early_stopping_metrics_<ts>.png` – proxy metrics vs accuracy
- `.../<filename>.log` – full log
- `.../<filename>_model.pth` – trained model checkpoint

---

## Datasets & where to get the data

Most image datasets are downloaded automatically by torchvision into `./data/` on first use. Some experiments require additional files:

### 1) CIFAR-10N / CIFAR-100N (human-annotated noisy labels)

Required by:

- `programs/simulations/early_stopping_CIFAR10_real_noise.json`
- `programs/simulations/early_stopping_CIFAR100_real_noise.json`

What you need:

```
./data/cifarN_labels/CIFAR-10_human.pt
./data/cifarN_labels/CIFAR-100_human.pt
```

How to place them:

```bash
mkdir -p data/cifarN_labels
# Put the two *.pt files inside data/cifarN_labels/
```

If you already have the CIFAR-10/100 training sets, these files only replace the training labels in memory (the code checks alignment).

---

### 2) NEWS (20 Newsgroups) – prebuilt pickle

Required by:

- `programs/simulations/early_stopping_NEWS20.json`

`load_dataset.py` expects a pickle at:

```
data/20news-bydate/news.pkl
```

with:

- `embedding_weights`: FloatTensor `[vocab_size, emb_dim]` (e.g., GloVe 300d)
- `data`: LongTensor `[N, max_length]` token IDs
- `labels`: LongTensor `[N]` (mapped into 7 superclasses via `programs/news.py::regroup_dataset`)

#### Minimal recipe to build `news.pkl` locally

You can now build the pickle directly from `programs/news.py` (no extra script needed):

```bash
# 1) Download GloVe 300d and place it here:
#    data/glove.6B/glove.6B.300d.txt
mkdir -p data/glove.6B

# 2) Build the pickle (adjust --glove if your path differs)
python -m programs.news --build --glove data/glove.6B/glove.6B.300d.txt --out data/20news-bydate/news.pkl --max_len 1000
```

This creates `data/20news-bydate/news.pkl` containing `(embedding_weights, data_ids, labels)` with labels regrouped into 7 superclasses. By default, the simulation splits 70/30 train/test inside `load_dataset.py`.

---

## Choosing metrics & options

In each simulation JSON:

- `metrics`: any subset of `["CND", "PC", "KPA"]`
- `noise_type`:
  - `"label_noise"` – symmetric label noise, ratio set by `symmetric_label_noise_ratio`
  - `"worse_label"`, `"random_label1"`, `"random_label2"`, `"random_label3"`, `"aggre_label"` – use CIFAR-10N human labels
  - `"clean_label"` – use the dataset’s clean labels
- `expand_dataset_*` – (for KPA) appends a small set of known-noisy samples to estimate accuracy on known corrupted items
- LR schedules: `"lr_multistep"`, `"cosine_decay"` or constant (default). There’s also an `"lr_plato"` option inside the early-stopping logic.

---

## Reproducibility, device selection & performance notes

- Seeds are fixed per simulation (`seed` in JSON).
- Device is chosen automatically (CUDA → MPS → CPU) with sensible DataLoader settings.
- Logs are streamed to console and to `<results_dir>/<filename>.log`.

---

## Troubleshooting

- **RuntimeError: Labels uploaded are not the correct ones**  
  The CIFAR-N `.pt` file must match the standard torchvision training order. Use the official files and place them under `data/cifarN_labels/`.

- **FileNotFoundError: data/20news-bydate/news.pkl**  
  Build the NEWS pickle as shown above or disable NEWS simulations.

- **CND plots are empty**  
  Enable `"CND"` in the metrics list and ensure `train_loader_fixed` is created (uses `fixed_batch_size`).

---

## License

MIT — see `programs/LICENSE`.