# GMM4PR Training Scripts

## fit_classifiers.py
Trains standard image classifiers (ResNet, VGG, Wide ResNet) on CIFAR-10, CIFAR-100, and TinyImageNet datasets.

**Usage:**
```bash
python fit_classifiers.py --dataset cifar10 --arch resnet18 --epochs 5 --batch_size 128
```

**Key Arguments:**
- `--dataset`: cifar10, cifar100, tinyimagenet
- `--arch`: resnet18, resnet50, wide_resnet50_2, vgg16
- `--pretrained`: Use ImageNet pretrained weights
- `--img_size`: Input image size (default: auto-selected)
- `--lr`: Learning rate (default: 0.01)

**Output:** Saves trained model checkpoint to `./model_zoo/trained_model/`

## fit_gmm2.py
Trains Gaussian Mixture Model for Probabilistic Robustness (GMM4PR) using a pretrained classifier.

**Usage:**
```bash
python fit_gmm2.py --config resnet18_on_cifar10
```

**Key Arguments:**
- `--config`: Configuration name from config.py
- `--list-configs`: Show available configurations
- `--epochs`: Override training epochs
- `--K`: Override number of GMM components
- `--clf_ckpt`: Override classifier checkpoint path

**Features:**
- Supports conditional GMM (x, y, xy conditioning)
- Multiple covariance types (diagonal, full, low-rank)
- Temperature scheduling and Gumbel-Softmax reparameterization
- Regularization for mode collapse prevention

**Output:** Saves GMM checkpoint, training loss history, and collapse logs to `./ckp/gmm_ckp/`
