# Equivariant Metanetworks for MoE weights

## Install dependencies
Tested with conda, python 3.12, CUDA 12.1.
```
conda create -n nfn-moe python=3.12
conda activate nfn-moe
pip install -r requirements.txt
pip install -e .
```

## Download data at https://huggingface.co/datasets/JohnDoe4765/MoE-Transformer-Model-Zoos

## Run MoE-NFN model
```
python nfn_moe/main.py --enc_mode moe_invariant  --classifier_nfn_channels 100,100 --moe_nfn_channels 4 --out_dim_inv 5  --dataset mnist --data_path data --cut_off 0
python nfn_moe/main.py --enc_mode moe_invariant  --emb_mode no --classifier_nfn_channels 100,100 --moe_nfn_channels 4  --out_dim_inv 5 --dataset ag_news --data_path data --cut_off 0
```


## Run Transformer-NFN model
```
python nfn_moe/main.py --enc_mode transformer_invariant  --classifier_nfn_channels 100,100 --moe_nfn_channels 12  --dataset mnist --data_path data --cut_off 0
python nfn_moe/main.py --enc_mode transformer_invariant --emb_mode no --classifier_nfn_channels 100,100 --moe_nfn_channels 12  --dataset ag_news --data_path data --cut_off 0
```

## Run MLP Baseline
```
python nfn_moe/main.py --enc_mode mlp --classifier_nfn_channels 100,100 --moe_nfn_channels 256 --num_out_classify 256 --num_out_embedding 64 --num_out_encoder 256  --dataset mnist --data_path data --cut_off 0
python nfn_moe/main.py --enc_mode mlp --emb_mode no --classifier_nfn_channels 100,100 --moe_nfn_channels 256 --num_out_classify 256 --num_out_encoder 256  --dataset ag_news --data_path data --cut_off 0
```

## Run XGBoost Baseline
```
python nfn_moe/xgb.py  --dataset mnist --data_path data --cut_off 0
python nfn_moe/xgb.py  --dataset ag_news --data_path data --cut_off 0
```

## Run LightGBM Baseline
```
python nfn_moe/gbm.py --model gbdt  --dataset mnist --data_path data --cut_off 0
python nfn_moe/gbm.py --model gbdt   --dataset ag_news --data_path data --cut_off 0
# If you encounter GPU issues with LightGBM, please look at https://github.com/microsoft/LightGBM/issues/586#issuecomment-352845980
```

## Run Random Forest Baseline
```
python nfn_moe/gbm.py --model rf  --dataset mnist --data_path data --cut_off 0
python nfn_moe/gbm.py --model rf   --dataset ag_news --data_path data --cut_off 0
```

## Run SVM Baseline
```
python nfn_moe/svm.py  --dataset mnist --data_path data --cut_off 0
python nfn_moe/svm.py  --dataset ag_news --data_path data --cut_off 0
```
