# Learning from Teaching Regularization: Generalizable Correlations Should be Easy to Imitate

## Install Requirements: 
Install a torch version that is compatible with you cuda version in https://pytorch.org/get-started/previous-versions/, we use torch==2.1.0+cu118.
```
conda install pytorch==2.2.1 torchvision==0.17.1 torchaudio==2.2.1 pytorch-cuda=11.8 -c pytorch -c nvidia
```

```
conda create -n LoT python=3.10
conda activate LoT
pip install -r requirements.txt

# config your WANDB_API_KEY and WANDB_USER_NAME in ~/.bashrc
```

## Prepare Datasets and Checkpoints:

To run the language modeling tasks, you can run the following code to download the WikiText-103 and the Penn Tree Bank (PTB) datasets. For other tasks, the datasets will be downloaded automatically.
```
bash getdata.sh
```

Put the cifar100, ImageNet, ptb, wikitext-103 datasets under `../data/`

Put the ViT pre-trained checkpoints under `../model/`

## Run LoT

### Reinforcement Learning
For Reinforcement Learning tasks, run the following command to implement experiments on BeamRider.
```
conda create -n LoT_RL python=3.9
conda activate LoT_RL
pip install -r rl_requirements.txt
conda install pytorch==2.1.1 torchvision==0.16.1 torchaudio==2.1.1 pytorch-cuda=11.8 -c pytorch -c nvidia

bash run/run_atari_games_LoT.sh
```

### Language Modeling
Run the following command for LSTM on PTB.
```
conda activate LoT
bash run/run_lstm_ptb_LoT.sh
```

Run the following command for Transformer-XL on WikiText-103.
```
conda activate LoT
bash run/run_transformerxl_wt103_LoT.sh
```

Run the following command for training LLaMA2 on GSM8K.
```
conda create -n LoT_llama python=3.10
conda activate LoT_llama
cd llama/src
conda install pytorch==2.2.1 torchvision==0.17.1 torchaudio==2.2.1 pytorch-cuda=11.8 -c pytorch -c nvidia
pip install transformers==4.31.0
cd ..
pip install -r requirements.txt

bash run/run_llama_LoT.sh
```
Run the following command for evaluating LLaMA2 on GSM8K.
```
# install lm_eval following https://github.com/EleutherAI/lm-evaluation-harness

bash run/run_llama_eval_LoT.sh
```


### Image Classification
Run the following command for ResNet-18 on CIFAR100.
```
conda activate LoT
bash run/run_resnet_cifar100_LoT.sh
```

Run the following command for ViT on CIFAR100.
```
conda activate LoT
bash run/run_vit_cifar100_LoT.sh
```

Run the following command for ViT on ImageNet-1K.
```
conda activate LoT
bash run/run_vit_imagenet_LoT.sh
```