# Online Continual Learning for Progressive Distribution Shift (OCL-PDS)

## Overview
This repository provides benchmarks and algorithms for investigating progressive distribution shift (PDS), which is the subtle, gradual and continuous distribution shift that widely exists in modern deep learning applications. This repository contains:
- 4 PDS benchmarks
- 12 OCL algorithms and baselines
- An experiment script for complete training and evaluation

A part of this repository is based on the open-source [Wilds repository](https://github.com/p-lambda/wilds/tree/main/examples), whose license can be found [here](https://github.com/p-lambda/wilds/blob/main/LICENSE).

## Table of Contents
- [Quick Start](#quick-start)
- [Benchmarks](#benchmarks)
- [Algorithms](#algorithms)
- [Reproducing Results](#reproducing-results)

## Quick Start
### Running a Complete OCL-PDS Experiment
To start with, there are two options to create an experiment environment same as ours on Ubuntu.

#### Option 1: Using pip (Python 3.9.13)
```shell
pip install pip --upgrade
pip install virtualenv
virtualenv envname
source envname/bin/activate
pip install -r requirements.txt
```

#### Option 2: Using conda
```shell
conda update -n base conda
conda env create --name envname --file=environment.yml
conda activate envname
```

The main experiment script file is `pds_expt.py`. To run an experiment, use the following script:
```shell
python pds_expt.py --dataset [DATASET] --root_dir /path/to/dataset \
                   --alpha [ALPHA] --alg [ALG] --seed [SEED]       \
                   -i /path/to/initial_model.pth
```
Here, `[DATASET]` is the name of the dataset, `[ALPHA]` is the fraction of labeled samples for each batch in Random Label Feedback (RLF) (0<=&alpha;<=1), `[ALG]` is the name of the algorithm, and `[SEED]` is the random seed (optional). If you are using the dataset for the first time, use `--download` to download the dataset. If you want to skip the initial training (t=0) and load a pretrained model instead, use the `-i [MODEL_PATH]` argument (optional). This argument also accepts urls to online model checkpoints.

Each algorithm requires a number of additional parameters which you can find in the [Algorithms](#algorithms) section below. For example, to run ER-FIFO on CivilComments-WPDS with 0.5% labeled data, use
```shell
python pds_expt.py --dataset civilcomments --root_dir /path/to/dataset \
                   --alpha 0.005 --alg er-fifo --epochs 1 --seed 0     \
                   -i /path/to/initial_model.pth
```

### Using a Benchmark
Our benchmarks are included in the `datasets` folder, and they are very easy to use. For example, to use the FMoW-WPDS benchmark, use the following code:
```python
from datasets.fmow_wpds import FMoWWPDS

dataset = FMoWWPDS(root_dir='/path/to/dataset', download=True)
# Specify the dataset folder in root_dir, and use download=True to download

print('Number of batches (T + 1):', dataset.num_batches)
transform = get_data_augmentation()
batch_0 = dataset.get_batch(t=0, transform=transform)
# Get the training batch at time t, with transform applied. 0 <= t <= T

batch_0_test = dataset.get_test_batch(t=0, transform=transform)
# The training and test batches are separate i.i.d. sets.
# Test batches are for evaluating recent performances.  See Section 3.1 of the paper.

training_reg_set = dataset.get_train_regression(transform=transform)
test_reg_set = dataset.get_test_regression(transform=transform)
# Similarly, the training and test regression sets are separate i.i.d. sets.
```

Other benchmarks can be used in the same way. The prototype of all benchmarks is the `PDSDataset` class in the file `datasets/pds_dataset.py`. Refer to this class for more information.

### Using an OCL Algorithm
Our algorithms are included in the `algorithms` folder with shortcuts in the `algs.py` file, and they are also very easy to use. Each algorithm is an object built upon a `model`, and also needs a `config` which contains the necessary attributes the algorithm requires. For example, to use GEM-PDS, use the following code:
```python
from algs import GEM_PDS
from initialize import get_parser

model = resnet()
parser = get_parser()
config = parser.parse_args()
# Parse the arguments from the command line

alg = GEM_PDS(model, config)
```

The OCL Algorithm object is callable, which requires a feedback dict as argument. In this work we use the Random Label Feedback (RLF) which can be constructed as follows:
```python
from initialize import initialize_eval_metric, initialize_feedback

eval_metric = initialize_eval_metric(config)
feedback = initialize_feedback(config, eval_metric)  # RLF Feedback
```

The Feedback object is also callable, which returns a dict that contains the feedback from evaluation, as well as the separated labeled and unlabeled batches. A complete experiment setup is the following:
```python
for t in range(dataset.num_batches):
    batch_train = dataset.get_batch(t=t, transform=transform)
    batch_test = dataset.get_test_batch(t=t, transform=transform)
    fb = feedback(t, batch_train, batches_test, alg.model)
    # fb is a dict, which contains the feedback from evaluation
    if t > 0:
        print('Online performance:', fb['performance'])

    if t == 0:
        # When t = 0, also include the training reg set in the feedback
        fb['train_regression'] = dataset.get_train_regression(transform=transform)

    alg(t, fb)  # Train or fine-tune the model with the feedback
```

All supervised OCL algorithms and baselines can be used in the same way. Pseudo labeling (PL) is built upon any supervised OCL algorithm, so it also requires a supervised algorithm as argument. For example, to use GEM-PDS-PL:
```python
from algs import PseudoLabeling
alg_pl = PseudoLabeling(model, config, GEM_PDS)
# Similarly, use alg_pl(t, fb) to train the model
```


## Benchmarks
We release the following 4 benchmarks

| Benchmark          | [DATASET]     | Dataset Description |
|--------------------|---------------|------------|
| CivilComments-WPDS | civilcomments | Toxic language detection |
| FMoW-WPDS          | fmow          | Satellite image classification |
| Amazon-WPDS | amazon | Language sentiment analysis |
| Poverty-WPDS | poverty | Wealth index regression from satellite images |




## Algorithms
All algorithms are implemented in the `algorithms` folder and shortcuts are provided in `algs.py`. The base class is `OCLAlgorithm` in `algorithms/ocl.py`. In addition to the parameters listed below, all algorithms require the following parameters for ERM: `batch_size`, `lr`, `wd`, `device`, `loader_kwargs`, for which the default values will be used if values are not provided. 

For the first batch (t=0), all methods will train an initial model with ERM over the first batch and the training regression set. Set the parameter `epochs_first_batch` to specify the number of epochs for initial training (if not set then `epochs` will be used). If the parameter `initial_model_load` is set to a checkpoint path, then instead of training, this model will be loaded.

An algorithm object is callable with `alg(t, fb)`, where `t` is the batch number, and `fb` is a feedback dict returned by calling a `OCLFeedback` object as defined in `feedback.py`.

The following algorithms (or baselines) are available for `pds_expt.py`:

### Supervised OCL Algorithms

| Algorithm         | [ALG]      | Parameters                   |
|-------------------|------------|------------------------------|
| FBO               | fbo        | epochs                       |
| NBO               | nbo        | epochs                       |
| ER-FIFO           | er-fifo    | epochs                       |
| ER-FIFO-RW        | er-fifo-rw | epochs                       |
| GEM-PDS           | gem-pds | epochs, kr_size|
| MIR               | mir | epochs, kr_size, lbd         |
| MaxLoss           | maxloss | epochs, kr_size, lbd         | 
| Online L2 Reg. | l2reg | epochs, lbd                  |
| EWC | ewc | epochs, lbd, kr_size         |

### Semi-supervised OCL Algorithms
Here the `epochs` parameter is the number of epochs of the virtual update step, while `epochs_unlabeled` is that of the real fine-tuning step.

| Algorithm           | [ALG]               | Parameters               |
|---------------------|---------------------|--------------------------|
| ER-FIFO-PL          | er-fifo-pl          | epochs, epochs_unlabeled |
| ER-FIFO-FM    | er-fifo-fm   | epochs, epochs_unlabeled                   |
| ER-FIFO-RW-PL       | er-fifo-rw-pl       | epochs, epochs_unlabeled                   |
| ER-FIFO-RW-FM | er-fifo-rw-fm | epochs, epochs_unlabeled                   |

## Reproducing Results
Please follow the steps below to reproduce our results:

1. Our experiment environment is Ubuntu 18.04 with NVIDIA V100. In some rare cases, our results might not be able to exactly reproduce on some other systems.
1. Build a new virtual environment using the environment files we provide, and activate the environment.
1. Download our pretrained initial model [here](#link). These models are trained by running FBO with seeds 0, 1, 2, 3, 4.
1. Run an experiment with the hyperparameters we provide in Tables 15-18 in the paper, setting the seed to `2020 + pretrained_model_seed`. For example, if the pretrained model seed is 2, then run the experiment with seed 2022.

Example: To reproduce the results of `L2Reg-10-10.0` on CivilComments-WPDS in Table 15 in the paper, use the following script:
```shell
python pds_expt.py --dataset civilcomments --root_dir /path/to/dataset --alpha 0.005 --alg l2reg --epochs 10 --lbd 10.0 -i civilcomments_pretrained_0.pth --seed 2020
python pds_expt.py --dataset civilcomments --root_dir /path/to/dataset --alpha 0.005 --alg l2reg --epochs 10 --lbd 10.0 -i civilcomments_pretrained_1.pth --seed 2021
python pds_expt.py --dataset civilcomments --root_dir /path/to/dataset --alpha 0.005 --alg l2reg --epochs 10 --lbd 10.0 -i civilcomments_pretrained_2.pth --seed 2022
python pds_expt.py --dataset civilcomments --root_dir /path/to/dataset --alpha 0.005 --alg l2reg --epochs 10 --lbd 10.0 -i civilcomments_pretrained_3.pth --seed 2023
python pds_expt.py --dataset civilcomments --root_dir /path/to/dataset --alpha 0.005 --alg l2reg --epochs 10 --lbd 10.0 -i civilcomments_pretrained_4.pth --seed 2024
```