# Mutual-Inform VMoE

## How to install the environment
```bash
conda create --name vmoe3.11 "python<3.12" --yes
conda activate vmoe3.11
pip install -r requirements.txt
```

<details><summary>Compatibility details</summary>

- [`tensorflow_datasets` does not work on `python==3.12`](https://github.com/tensorflow/datasets/issues/4666#issuecomment-2021269723)
- [`orbax`/`asyncio` have lock issues on `python==3.10`](https://bugs.python.org/issue45416)
- [`vit_jax` requires `python>=3.10`](https://github.com/google-research/vision_transformer)
- [`vmoe` uses `tf.contrib`, which was deprecated in TF 2.0](https://github.com/tensorflow/community/issues/148)
- [`cloud_tpu` uses `keras.src.engine` internal API, which is deprecated in Keras 3.0](https://github.com/google-ai-edge/mediapipe/issues/5229)
</details>

## Run the training code
Change CUDA devices and the working directory accordingly. Note that one has to change router patcher and configs together!

```bash
python -m vmoe.train.main_prior \
    --config vmoe/configs/prior/vmoe_s32_last2_ilsvrc2012_randaug_light1/base.py \
    --workdir [workdir]
```

## Evaluate OOD
```bash
python -m vmoe.train.main_OOD \
    --config vmoe/configs/prior/vmoe_s32_last2_ilsvrc2012_randaug_light1/eval_OOD.py \
    --workdir [workdir] \
    --imnet_c_cfg [corruption]_[strength]
```

## Evaluate clustering
```bash
python -m vmoe.train.cluster_centroids \
    --config vmoe/configs/prior/vmoe_s32_last2_ilsvrc2012_randaug_light1/base.py \
    --workdir [workdir]
```