# Flipped Classroom: Aligning Teacher Attention with Student in Generalized Category Discovery

### Overview of *FlipClass*
![image](./assets/teaser.jpg)

## Contents
[1. requirements](#requirements)

[2. Running](#running)


## <a name="updates"/> Requirements

### Environments

```
scikit_learn==1.3.0
torch==2.0.1+cu118
tqdm==4.65.0
torchvision==0.15.2+cu118
numpy==1.24.1
tensorboard==2.10.0
statistics==1.0.3.5
```

Using requirements.txt:
```bash
conda create --name FlipClass python==3.9.0
conda activate FlipClass
pip install -r requirements.txt
```

### Pretrained Model
The pretrained model used in this work can be found at [ViT-dino](https://huggingface.co/facebook/dino-vits16).

### Datasets

We use generic object recognition datasets, including:

* [CIFAR-10/100](https://pytorch.org/vision/stable/datasets.html)
* [ImageNet](https://image-net.org/download.php)

For fine-grained benchmark in this paper, refer to:
* [The Semantic Shift Benchmark (SSB)](https://github.com/sgvaze/osr_closed_set_all_you_need#ssb)

Specifically, this benchmark includes:
* [CUB-200-2011](https://www.vision.caltech.edu/datasets/cub_200_2011/)
* [Stanford Cars](https://www.kaggle.com/datasets/jessicali9530/stanford-cars-dataset)
* [FGVC-AirCraft](https://www.robots.ox.ac.uk/~vgg/data/fgvc-aircraft/)
* [Herbarium-19](https://www.kaggle.com/c/herbarium-2019-fgvc6)



## <a name="running"/> Running

### Config

Set model paths in `configs/config_args.py` with respect to
```
model_factory = {
    'dino': ['', 'dino_vitb16'],
    'dinov2': ['', 'dinov2_vitb14'],
    'dino_static': '',
    'dinov2_static': ''
}
```

Set paths to datasets in ```./config.py``` with respect to
```
DATASET_DIR = ''
```


### Scripts

**Training**:

```
bash ./scripts/run_cifar10.sh
bash ./scripts/run_cifar100.sh
bash ./scripts/run_imagenet100.sh
bash ./scripts/run_imagenet1k.sh
bash ./scripts/run_cub.sh
bash ./scripts/run_scars.sh
bash ./scripts/run_aircraft.sh
bash ./scripts/run_herb19.sh
```

### Main Results

| Methods        | Backbone | CUB       |            |            | Stanford Cars |            |            | Aircraft  |            |            | Avg.  |
|----------------|----------|-----------|------------|------------|---------------|------------|------------|-----------|------------|------------|-------|
|                |          | All       | Old        | New        | All           | Old        | New        | All       | Old        | New        |       |
| GCD [2022]     | DINO     | 51.3      | 56.6       | 48.7       | 39.0          | 57.6       | 29.9       | 45.0      | 41.1       | 46.9       | 45.1  |
| XCon [2022]    | DINO     | 52.1      | 54.3       | 51.0       | 40.5          | 58.8       | 31.7       | 47.7      | 44.4       | 49.4       | 46.8  |
| CiPR [2023]    | DINO     | 57.1      | 58.7       | 55.6       | 47.0          | 61.5       | 40.1       | -         | -          | -          | -     |
| PCAL [2023]    | DINO     | 62.9      | 64.4       | 62.1       | 50.2          | 70.1       | 40.6       | 52.2      | 52.2       | 52.3       | 55.1  |
| SimGCD [2023]  | DINO     | 60.3      | 65.6       | 57.7       | 53.8          | 71.9       | 45.0       | 54.2      | 59.1       | 51.8       | 56.1  |
| AdaptGCD [2024]| DINO     | 66.6      | 66.5       | 66.7       | 48.4          | 57.7       | 39.3       | 53.7      | 51.1       | **56.0**   | 56.2  |
| AMEND [2024]   | DINO     | 64.9      | 75.6       | 59.6       | 56.4          | 73.3       | 48.2       | 52.8      | 61.8       | 48.3       | 58.0  |
| GCA [2024]     | DINO     | 68.8      | 73.4       | 66.6       | 54.4          | 72.1       | 45.8       | 52.0      | 57.1       | 49.5       | 58.4  |
| TIDA [2024]    | DINO     | -         | -          | -          | 54.7          | 72.3       | 46.2       | 54.6      | 61.3       | 52.1       | -     |
| μGCD [2024]    | DINO     | 65.7      | 68.0       | 64.6       | 56.5          | 68.1       | 50.9       | 53.8      | 55.4       | 53.0       | 58.7  |
| CMS [2024]     | DINO     | 68.2      | 76.5       | 64.0       | 56.9          | 76.1       | 47.6       | 56.0      | 63.4       | 52.3       | 60.4  |
| InfoSieve [2024]| DINO    | 69.4      | **77.9**   | 65.1       | 55.7          | 74.8       | 46.4       | 56.3      | 63.7       | 52.5       | 60.5  |
| FlipClass (Ours)| DINO    | **71.3**  | 71.3       | **71.3**   | **63.1**      | **81.7**   | **53.8**   | **59.3**  | **66.9**   | **55.4**   | **64.6**|
| Improvement    | DINO     | +1.9      | -6.6       | +6.2       | +7.4          | +6.9       | +7.4       | +3.0      | +3.2       | +2.9       | +4.1  |
| GCD [2022]     | DINOv2   | 71.9      | 71.2       | 72.3       | 65.7          | 67.8       | 64.7       | 55.4      | 47.9       | 59.2       | 64.3  |
| SimGCD [2023]  | DINOv2   | 71.5      | **78.1**   | 68.3       | 71.5          | **81.9**   | 64.6       | 49.9      | 60.9       | 60.0       | 63.0  |
| μGCD [2024]    | DINOv2   | 74.0      | 75.9       | 73.1       | 76.1          | **91.0**   | 68.9       | 66.3      | **68.7**   | 65.1       | 72.1  |
| FlipClass (Ours)| DINOv2  | **79.3**  | 80.7       | **78.5**   | **78.0**      | 88.0       | **73.2**   | **71.1**  | 75.1       | **69.1**   | **76.1**|
| Improvement    | DINOv2   | +5.3      | +4.8       | +5.4       | +1.9          | -3.0       | +4.3       | +4.8      | +6.4       | +4.0       | +4.0  |


| Methods          | Backbone | CIFAR10   |            |            | CIFAR100  |            |            | ImageNet-100 |            |            | Avg.  |
|------------------|----------|-----------|------------|------------|-----------|------------|------------|--------------|------------|------------|-------|
|                  |          | All       | Old        | New        | All       | Old        | New        | All          | Old        | New        |       |
| GCD [2022]       | DINO     | 91.5      | **97.9**   | 88.2       | 73.0      | 76.2       | 65.5       | 74.1         | 89.8       | 66.3       | 79.5  |
| AdaptGCD [2024]  | DINO     | 93.2      | 94.6       | 92.8       | 71.3      | 75.7       | 66.8       | 83.3         | 90.2       | 76.5       | 82.6  |
| InfoSieve [2024] | DINO     | 94.8      | 97.2       | 93.7       | 76.9      | 78.4       | 73.9       | 80.5         | 92.8       | 74.4       | 84.1  |
| CiPR [2023]      | DINO     | 97.7      | 97.5       | 97.7       | 81.5      | 82.4       | 79.7       | 80.5         | 84.9       | 78.3       | 86.6  |
| SimGCD [2023]    | DINO     | 97.1      | 95.1       | **98.1**   | 80.1      | 81.2       | 77.8       | 83.0         | 93.1       | 77.9       | 86.7  |
| GCA [2024]       | DINO     | 95.5      | 95.9       | 95.2       | 82.4      | 85.6       | 75.9       | 82.8         | **94.1**   | 77.1       | 86.9  |
| TIDA [2024]      | DINO     | **98.2**  | **97.9**   | **98.5**   | 82.3      | 83.8       | 80.7       | -            | -          | -          | -     |
| CMS [2024]       | DINO     | -         | -          | -          | 82.3      | **85.7**   | 75.5       | 84.7         | **95.6**   | 79.2       | -     |
| AMEND [2024]     | DINO     | 96.8      | 94.6       | 97.8       | 81.0      | 79.9       | 83.8       | 83.2         | 92.9       | 78.3       | 87.0  |
| FlipClass (Ours) | DINO     | **98.5**  | 97.6       | **99.0**   | **85.2**  | 84.9       | **85.8**   | **86.7**     | 94.3       | **82.9**   | **90.1**|
| Improvement      | DINO     | +1.7      | +3.0       | +1.2       | +4.2      | +5.0       | +2.0       | +3.5         | +1.4       | +4.6       | +3.1  |
| * GCD [2022]     | DINOv2   | 95.2      | **97.8**   | 93.9       | 77.3      | 82.8       | 66.1       | 81.3         | **94.3**   | 74.8       | 84.6  |
| * AMEND [2024]   | DINOv2   | 97.7      | 96.6       | **98.3**   | 83.5      | 83.0       | 84.5       | 87.3         | 95.1       | 83.4       | 89.5  |
| FlipClass (Ours) | DINOv2   | **99.0**  | **98.2**   | **99.4**   | **91.7**  | **90.4**   | **94.2**   | **91.0**     | **96.3**   | **88.3**   | **93.9**|
| Improvement      | DINOv2   | +1.3      | +1.6       | +1.1       | +8.2      | +7.4       | +9.7       | +3.7         | +1.2       | +4.9       | +4.3  |
