## Overview

The project contains several scripts for training models and running benchmarks:
- **MNIST Classification**: Image classification on the MNIST dataset
- **MIL Benchmarks**: Multiple Instance Learning on datasets (tiger, fox, elephant)

## Setup

Use a virtual environment and install the dependencies from `requirements.txt`:

```bash
python -m venv venv
source venv/bin/activate  # On Windows: venv\Scripts\activate
pip install -r requirements.txt
```

## Command-Line Usage

### MNIST Training (`test_core.py`)

Train a single model on MNIST with customizable parameters:

```bash
python test_core.py [OPTIONS]
```

**Options:**
- `--model`: Model type to use
  - Choices: `kf_attention`, `kf_layer`, `kf_pooling`, `hf_attention`, `hf_layer`, `hf_pooling`
  - Default: `kf_attention`
- `--batch-size`: Training batch size (default: 64)
- `--test-batch-size`: Testing batch size (default: 1000)
- `--epochs`: Number of training epochs (default: 14)
- `--lr`: Learning rate (default: 0.001)
- `--gamma`: Learning rate scheduler gamma (default: 0.7)
- `--seed`: Random seed (default: 1)
- `--log-interval`: Batches between logging (default: 50)
- `--no-accel`: Disable accelerator (use CPU)
- `--dry-run`: Quick single-pass check
- `--save-model`: Save the trained model

**Example:**
```bash
python test_core.py --model hf_attention --epochs 20 --lr 0.0005 --batch-size 128
```

**Note:** Model parameters (hidden_dim, beta, num_states, num_memories) are hardcoded in `test_core.py` at lines 180-187.

---

### MIL Training (`test_mil.py`)

Train a model on Multiple Instance Learning datasets:

```bash
python test_mil.py [OPTIONS]
```

**Options:**
- `--model`: Model type to use
  - Choices: `kf_attention`, `kf_layer`, `kf_pooling`, `hf_attention`, `hf_layer`, `hf_pooling`
  - Default: `kf_pooling`
- `--dataset`: Dataset to use
  - Choices: `tiger`, `fox`, `elephant`
  - Default: `tiger`
- `--batch-size`: Training batch size (default: 64)
- `--test-batch-size`: Testing batch size (default: 1000)
- `--epochs`: Number of training epochs (default: 14)
- `--lr`: Learning rate (default: 0.001)
- `--gamma`: Learning rate scheduler gamma (default: 0.7)
- `--seed`: Random seed (default: 1)
- `--log-interval`: Batches between logging (default: 10000)
- `--no-accel`: Disable accelerator (use CPU)
- `--dry-run`: Quick single-pass check
- `--save-model`: Save the trained model
- `--multiply`: Apply feature multiplication (see `datasets/loader.py`)

**Example:**
```bash
python test_mil.py --dataset fox --model hf_pooling --epochs 100 --batch-size 16
```

**Note:** Model parameters (hidden_dim, beta, num_states, num_memories, bag_dropout) are hardcoded in `test_mil.py` at lines 160-168.

---

## Benchmark Scripts

These scripts run multiple experiments and generate comparison tables/plots. They have hardcoded configuration that you can edit directly in the files.


### MNIST Dimension Benchmark Table (`run_mnist_table.py`)

Runs multiple trials across different hidden dimensions/models and generates a CSV table with mean and standard deviation.

```bash
python run_mnist_table.py
```

**Configuration Variables (edit in file):**
- **Line 68**: `models` - List of models to benchmark (e.g., `["kf_attention", "hf_attention"]`)
- **Line 69**: `hidden_dims` - List of hidden dimensions (e.g., `[4, 8, 16, 32, 64]`)
- **Line 70**: `num_trials` - Number of trials per configuration (default: 5)
- **Line 71**: `epochs` - Number of training epochs (default: 14)
- **Line 108**: Output CSV filename (default: `mnist_attention_benchmark.csv`)

**Output:** Saves a CSV file with mean accuracy and standard deviation for each model/dimension combination.

---

### MIL Benchmark Table (`run_mil_table.py`)

Runs multiple trials comparing models on MIL datasets and generates a CSV table.

```bash
python run_mil_table.py
```

**Configuration Variables (edit in file):**
- **Line 10**: `target_datasets` - List of datasets (e.g., `['tiger', 'fox', 'elephant']`)
- **Line 11**: `target_models` - List of models (e.g., `['kf_pooling', 'hf_pooling']`)
- **Line 12**: `num_trials` - Number of trials per configuration (default: 5)
- **Line 31**: `--epochs` - Number of epochs (default: "100")
- **Line 32**: `--batch-size` - Batch size (default: "16")
- **Line 33**: `--lr` - Learning rate (default: "0.001")
- **Line 34**: `--gamma` - Learning rate scheduler gamma (default: "0.96")
- **Line 61**: Output CSV filename (default: `mil_benchmark_results_with_std.csv`)

**Output:** Saves a CSV file with mean AUC and standard deviation for each dataset/model combination.

---

### Download MIL Dataset

The MIL dataset can be downloaded publicly at [https://github.com/MAGICS-LAB/NonparametricHopfield/tree/main/datasets/mil_datasets]