## Dependencies

* [PyTorch] - Version: 1.10.0
* [PyTorch Geometric] - Version: 2.0.3

## Training & Evaluation


#attention_mil

```
python attention_mil.py --DS 'datasets/messidor' --pooling_layer 'uot_pooling' --f_method 'sinkhorn' --num 4 
```


#adgcl

```
python adgcl.py --DS 'IMDB-BINARY' --pooling_layer 'uot_pooling' --f_method 'badmm-e' --num 4 --epoch 20 --seed 0
```

```
python adgcl.py --DS 'IMDB-BINARY' --pooling_layer 'uot_pooling' --f_method 'sinkhorn' --a1 1000000000000 --num 4 --epoch 20 --seed 0
```

## parameters


```DS``` is the dataset.


```pooling_layer``` is the pooling layer chosen for the backbone, including add_pooling, mean_pooling, max_pooling, deepset, 
mix_pooling, gated_pooling, set_set, attention_pooling, gated_attention_pooling, dynamic_pooling, GeneralizedNormPooling,
SAGPooling (adgcl), ASAPooling (adgcl), uot_pooling.

```f_method``` could be ```badmm-e, badmm-q, sinkhorn``` corresponds to ```UOTP_Sinkhorn, UOTP_BADMM-E, UOTP_BADMM-Q```

```num``` corresponds to K-step feed-forward computation. The default value is 4.

```a1``` corresponds to alpha_0 in the paper. The default value is None.

```a2``` corresponds to alpha_1 in the paper. The default value is None.

```a3``` corresponds to alpha_2 in the paper. The default value is None.

```p0``` corresponds to the symbol of p_0 in the paper. The default value is 'fixed'.

```q0``` corresponds to the symbol of q_0 in the paper. The default value is 'fixed'.

#resnet18-uotp

```data``` is path to dataset. 

The setting of parameters refer to github link: https://github.com/pytorch/examples/tree/main/imagenet

```
python resnet18_uotp.py --num 4 --f_method 'badmm-e'
```