# Supplemental material for NeurIPS 2024 submission: <br/> "Metalearning to Continually Learn In Context"

Here we provide our "research code" as supplemental material.

Later we will release the "official version" of this code in a public GitHub repository upon acceptance of our paper.

This codebase is originally forked from the following public repository: https://github.com/IDSIA/modern-srwm
which we modified for continual learning. We refer to requirements etc in the original repo.

Our codebase also includes code from other public repositories, e.g.,
* https://github.com/tristandeleu/pytorch-meta for processing standard few-shot learning data preparation and data loader implementations.
(forked and slightly modified in `torchmeta_local`)
* but also MLP mixer implementation etc... (LICENSE can be found in the corresponding directory/fileheaders)

For Split-MNIST experiments of Table 3, we further used the following public code:
* https://github.com/GT-RIPL/Continual-Learning-Benchmark:
We used this to produce the results for the 2-task class-incremental setting (Table 3)

* https://github.com/khurramjaved96/mrcl:
We forked and modified their code to produce the OML numbers of Table 3. We downloaded their out-of-the-box from the same link.
The modified code can be found under the directory: `oml_baseline_experiments`

* https://github.com/aminbana/GeMCL
As for OML above, we forked and modified their code to produce the GeMCL numbers of Table 3.
The modified code can be found under the directory: `gemcl_baseline_experiments`

**Copyright mentions of the authors have been intentionally removed for the anonymity.**
(but we left names/authors of the original repository we forked from).

Code for ViT experiments will be added for the final version.

## Training

* For Omniglot/Mini-ImageNet, two-task training.
```
SEED=1

export CUDA_VISIBLE_DEVICES=0

export TORCH_EXTENSIONS_DIR="/home/me/torch_extensions"

CODE=
DATA=

python3 ${CODE}/main.py \
  --data_dir ${DATA} \
  --name_dataset miniimagenet_32_norm_cache \
  --seed ${SEED} \
  --num_worker 12 \
  --test_per_class 1 \
  --model_type 'compat_stateful_srwm' \
  --work_dir save_models \
  --total_epoch 2 \
  --total_train_steps 600_000 \
  --validate_every 1_000 \
  --batch_size 32 \
  --num_layer 2 \
  --n_head 16 \
  --hidden_size 256 \
  --ff_factor 2 \
  --dropout 0.1 \
  --vision_dropout 0.1 \
  --k_shot 15 \
  --test_per_class 1 \
  --use_warmup \
  --shuffled_eval \
  --use_wandb \
  --ood_eval \
  --ood_eval2 \
  --ood_eval3 \
  --extra_label \
  --use_fs \
  --use_ab_v2 \
  --use_acl \
  --use_instance_norm \
  --loss_scale_task_a 0.1 \
  --use_cache \
  --project_name 'my_project' \
```

* For Omniglot/Mini-ImageNet/FC100, three-task training.
```
python3 ${CODE}/main.py \
  --data_dir ${DATA} \
  --name_dataset miniimagenet_32_norm_cache \
  --seed ${SEED} \
  --num_worker 12 \
  --test_per_class 1 \
  --model_type 'compat_stateful_srwm' \
  --work_dir save_models \
  --total_epoch 2 \
  --total_train_steps 600_000 \
  --validate_every 1_000 \
  --batch_size 32 \
  --num_layer 2 \
  --n_head 16 \
  --hidden_size 256 \
  --ff_factor 2 \
  --dropout 0.1 \
  --vision_dropout 0.1 \
  --k_shot 15 \
  --test_per_class 1 \
  --use_warmup \
  --shuffled_eval \
  --use_wandb \
  --ood_eval \
  --ood_eval2 \
  --ood_eval3 \
  --extra_label \
  --use_fs \
  --use_abc_v2 \
  --use_acl \
  --use_instance_norm \
  --loss_scale_task_a 1 \
  --use_cache \
  --project_name 'my_project' \
```

* 5-task training for domain incremental settings
```
# out of the box model path:
OB_MODEL = 

python3 ${CODE}/main.py \
  --data_dir ${DATA} \
  --train_splitmnist_style \
  --init_model_except_output_from ${OB_MODEL} \
  --name_dataset miniimagenet_32_norm_cache \
  --seed ${SEED} \
  --num_worker 12 \
  --test_per_class 1 \
  --model_type 'compat_stateful_srwm' \
  --work_dir save_models \
  --total_epoch 2 \
  --total_train_steps 600_000 \
  --validate_every 1_000 \
  --batch_size 64 \
  --num_layer 2 \
  --n_head 16 \
  --hidden_size 256 \
  --ff_factor 2 \
  --dropout 0.1 \
  --vision_dropout 0.1 \
  --k_shot 5 \
  --n_way 2 \
  --report_every 10 \
  --validate_every 50 \
  --test_per_class 1 \
  --shuffled_eval \
  --extra_label \
  --use_fs \
  --use_acl \
  --use_instance_norm \
  --use_cache \
  --project_name 'my_project' \
```

* 5-task training for class incremental settings
```
python3 ${CODE}/main.py \
  --data_dir ${DATA} \
  --train_splitmnist_style_class_incremental \
  --init_model_except_output_from_class_incremental ${OB_MODEL} \
  --name_dataset miniimagenet_32_norm_cache \
  --seed ${SEED} \
  --num_worker 12 \
  --test_per_class 1 \
  --model_type 'compat_stateful_srwm' \
  --work_dir save_models \
  --total_epoch 2 \
  --total_train_steps 600_000 \
  --validate_every 1_000 \
  --batch_size 64 \
  --num_layer 2 \
  --n_head 16 \
  --hidden_size 256 \
  --ff_factor 2 \
  --dropout 0.1 \
  --vision_dropout 0.1 \
  --k_shot 5 \
  --n_way 10 \
  --report_every 10 \
  --validate_every 50 \
  --test_per_class 1 \
  --shuffled_eval \
  --extra_label \
  --use_fs \
  --use_acl \
  --use_instance_norm \
  --use_cache \
  --project_name 'my_project' \
```

## Evaluation

* Split-MNIST, domain incremental
```
# trained model path:
OB_MODEL = 

python3 ${CODE}/main.py \
  --data_dir ${DATA} \
  --eval_splitmnist \
  --eval_only_dir ${OB_MODEL} \
  --name_dataset miniimagenet_32_norm_cache \
  --seed ${SEED} \
  --num_worker 12 \
  --test_per_class 1 \
  --model_type 'compat_stateful_srwm' \
  --work_dir save_models \
  --total_epoch 2 \
  --total_train_steps 600_000 \
  --validate_every 1_000 \
  --batch_size 32 \
  --num_layer 2 \
  --n_head 16 \
  --hidden_size 256 \
  --ff_factor 2 \
  --dropout 0.1 \
  --vision_dropout 0.1 \
  --k_shot 15 \
  --n_way 5 \
  --test_per_class 1 \
  --shuffled_eval \
  --extra_label \
  --use_fs \
  --use_ab_v2 \
  --use_acl \
  --use_instance_norm \
  --use_cache \
```

* Split-MNIST, class incremental, 2-task
```
python3 ${CODE}/main.py \
  --data_dir ${DATA} \
  --eval_splitmnist_incremental_class_2task \
  --eval_only_dir ${OB_MODEL} \
  --name_dataset miniimagenet_32_norm_cache \
  --seed ${SEED} \
  --num_worker 12 \
  --test_per_class 1 \
  --model_type 'compat_stateful_srwm' \
  --work_dir save_models \
  --total_epoch 2 \
  --total_train_steps 600_000 \
  --validate_every 1_000 \
  --batch_size 32 \
  --num_layer 2 \
  --n_head 16 \
  --hidden_size 256 \
  --ff_factor 2 \
  --dropout 0.1 \
  --vision_dropout 0.1 \
  --k_shot 15 \
  --n_way 5 \
  --test_per_class 1 \
  --shuffled_eval \
  --extra_label \
  --use_fs \
  --use_ab_v2 \
  --use_acl \
  --use_instance_norm \
  --use_cache \
```

* Split-MNIST, class incremental, 5-task

Replace `--eval_splitmnist_incremental_class_2task` by `--eval_splitmnist_incremental_class`

* eval on 2-task few-shot learning test sets (Omniglot/Mini-ImageNet)

```
# trained model path:
OB_MODEL = 

python3 ${CODE}/main.py \
  --data_dir ${DATA} \
  --eval_only_dir ${OB_MODEL} \
  --name_dataset miniimagenet_32_norm_cache \
  --seed ${SEED} \
  --num_worker 12 \
  --test_per_class 1 \
  --model_type 'compat_stateful_srwm' \
  --work_dir save_models \
  --total_epoch 2 \
  --total_train_steps 600_000 \
  --validate_every 1_000 \
  --batch_size 32 \
  --num_layer 2 \
  --n_head 16 \
  --hidden_size 256 \
  --ff_factor 2 \
  --dropout 0.1 \
  --vision_dropout 0.1 \
  --k_shot 15 \
  --test_per_class 1 \
  --shuffled_eval \
  --ood_eval \
  --ood_eval2 \
  --ood_eval3 \
  --extra_label \
  --use_fs \
  --use_ab_v2 \
  --use_acl \
  --use_instance_norm \
  --loss_scale_task_a 0.1 \
  --use_cache \
```

* To the above, add `--eval_extra_only` for evaluation on MNIST/CIFAR10 or `--eval_extra_only_3_tasks` for the 4-task evaluation (will be renamed for the official release)
