# AGOPMIPL

**Enhanced Multi-Instance Partial Label Learning via Average Gradient Outer Product**

## Requirements

```bash
pip install torch numpy scipy scikit-learn
```

## Usage

```bash
python main.py --ds <DATASET_NAME> --nr_fea <FEATURE_DIM> --nr_class <NUM_CLASSES>
```

## Key Arguments

| Argument             | Description                       | Default      |
| -------------------- | --------------------------------- | ------------ |
| `--ds`               | Dataset name                      | CRC-MIPL-SBN |
| `--ds_suffix`        | Dataset suffix (e.g., r1, r2, r3) | ''           |
| `--nr_fea`           | Feature dimension                 | 256          |
| `--nr_class`         | Number of classes                 | 7            |
| `--agop_rounds`      | Number of AGOP update rounds      | 3            |
| `--epochs_per_round` | Epochs per AGOP round             | 100          |
| `--lr`               | Learning rate                     | 0.005        |
| `--reg`              | Weight decay                      | 1e-5         |
| `--mu`               | Sparsity loss weight              | 0.1          |
| `--gamma`            | Inhibition loss weight            | 0.5          |
| `--attn_lambda`      | Raw attention path weight         | 0.3          |
| `--agop_momentum`    | AGOP matrix momentum              | 0.0          |
| `--normalize`        | Normalize features                | False        |
| `--no_agop`          | Disable AGOP updates              | False        |

## Training Examples


### CRC-MIPL Datasets

```bash
# CRC-MIPL-SBN
python main.py --ds CRC-MIPL-SBN --nr_fea 15 --nr_class 7 --epochs_per_round 50 --agop_rounds 2 --lr 0.003 --reg 5e-4 --normalize --proto_agg linear --inst_weight 0.0 --attn_lambda 0.0 --agop_momentum 0.0

# CRC-MIPL-Row  
python main.py --ds CRC-MIPL-Row --nr_fea 9 --nr_class 7 --epochs_per_round 50 --agop_rounds 2 --lr 0.003 --reg 5e-4 --normalize --proto_agg linear --inst_weight 0.0 --attn_lambda 0.0 --agop_momentum 0.0

# CRC-MIPL-KMeansSeg
python main.py --ds CRC-MIPL-KMeansSeg --nr_fea 6 --nr_class 7 --epochs_per_round 50 --agop_rounds 2 --lr 0.003 --reg 5e-4 --normalize --proto_agg linear --inst_weight 0.0 --attn_lambda 0.0 --agop_momentum 0.0

# CRC-MIPL-SIFT
python main.py --ds CRC-MIPL-SIFT --nr_fea 128 --nr_class 7 --epochs_per_round 50 --agop_rounds 2 --lr 0.003 --reg 5e-4 --normalize --proto_agg linear --inst_weight 0.0 --attn_lambda 0.0 --agop_momentum 0.0
```

### SIVAL Dataset

```bash
python main.py --ds SIVAL_MIPL --ds_suffix r1 --nr_fea 30 --nr_class 25 --epochs_per_round 80 --agop_rounds 3 --lr 0.005 --reg 1e-3 --normalize --inst_weight 0.1 --proto_agg mean --attn_lambda 1.0 --agop_momentum 0.5
```

## Dataset Information

| Dataset            | #bags | #instances | #features   | #classes |
| ------------------ | ----- | ---------- | ----------- | -------- |
| MNIST-MIPL         | 500   | ~20,000    | 784 (28×28) | 10       |
| FMNIST-MIPL        | 500   | ~20,000    | 784 (28×28) | 10       |
| Birdsong-MIPL      | 1300  | ~48,000    | 76          | 13       |
| SIVAL-MIPL         | 1500  | ~47,000    | 30          | 25       |
| CRC-MIPL-Row       | 7000  | ~56,000    | 9           | 7        |
| CRC-MIPL-SBN       | 7000  | ~63,000    | 15          | 7        |
| CRC-MIPL-KMeansSeg | 7000  | ~30,000    | 6           | 7        |
| CRC-MIPL-SIFT      | 7000  | ~175,000   | 128         | 7        |

## File Structure

```
code/
├── main.py          # Training script
├── model.py         # AGOPMIPL model definition
├── rfm.py           # AGOP computation module
├── dataloader.py    # Dataset loading utilities
└── utils.py         # Helper functions
```

## License

MIT License
