# Invariant-content Feature Reconstruction 
The invariant-content feature reconstruction algorithm (__IFR__) is an algorithm that considers both high-level and fine-grained invariant-content features simultaneously for cross-domain few-shot classification. The algorithm mainly includes two steps:
- Measuring similarities between original and augmented data representations __in pixel level__.
- Reconstruting the invariant-content features via retrieving feature pixels with the similarities obtained above.

<p align="center">
  <img src="./figures/ifr_framework.png" style="width:60%">
</p>


## Features at a glance
- We follow [URL](https://arxiv.org/pdf/2103.13841.pdf) to train a task-specific IFR module for tasks sampled from unseen data or domains on top of a frozen pre-trained backbone during __meta-test__ phase. The backbone can be either a domain-specific backbone that is pre-trained on a specific dataset or a multi-domain backbone that is distilled from several domain-specific backbones.

- The IFR module consists of two parts: a single attnetion head for extracting fine-grained invariant-content features and a linear transformation block that is composed of a BN layer, an average pooling layer and a linear layer.

- In this repo, we provide the code of IFR based on the [original URL repo](https://github.com/VICO-UoE/URL). 

## Main results on [Meta-dataset](https://github.com/google-research/meta-dataset)
- Multi-domain setting / Train on all datasets setting (meta-train on 8 datasets and meta-test on 13 datasets).

Test Datasets              |IFR (Ours)                 |URL                        |2LM                        |Tri-M                        |FLUTE                      |URT                        |SUR                        |Simple CNAPS             |CNAPS                      |Proto-MAML
---------------------------|---------------------------|---------------------------|---------------------------|---------------------------|---------------------------|---------------------------|---------------------------|---------------------------|---------------------------|---------------------------
Avg rank                   |**2.4**                    |3.2                        |3.1                        |4.6                        |4.7                        |6.0                        |6.8                        |6.4                        |-                          |-
Avg Seen                   |**79.8**                   |79.8                       |79.7                       |76.2                       |74.5                       |77.4                       |75.9                       |73.7                       |67.5
Avg Unseen                 |**70.9**                   |69.3                       |69.4                       |69.9                       |69.9                       |62.9                       |64.1                       |67.4                       | -
Avg All                    |**76.4**                   |75.7                       |75.7                       |73.8                       |72.7                       |71.8                       |71.3                       |71.2                       | -


- Single-domain setting / Train on ImageNet only setting (Meta-train on ImageNet and meta-test on 13 datasets).

Test Datasets              |IFR (Ours)                 |URL                        |FLUTE                      |ALFA+fo-Proto-MAML         |fo-Proto-MAML              |BOHB                       |ProtoNets (large)          |ProtoNet                   |Finetune                  
---------------------------|---------------------------|---------------------------|---------------------------|---------------------------|---------------------------|---------------------------|---------------------------|---------------------------|---------------------------|
Avg rank                   |**1.4**                    |3.8                        |8.4                        |3.9                        |5.8                        |4.3                        |3.8                        |7.3                         |6.1                        
Avg Seen                   |**56.6**                   |55.8                       |46.9                       |52.8                       |49.5                       |51.9                       |53.7                       |50.5                  |45.8|     
Avg Unseen                 |**68.7**                   |62.2                       |56.5                       |-                          |-                          |-                          |-                          |-                     |-   |
Avg All                    |**67.1**                   |61.7                       |55.8                       |-                          |-                          |-                          |-                          |-                     |-   |

## Model Zoo
- [Single-domain networks (one for each dataset)](https://drive.google.com/file/d/1MvUcvQ8OQtoOk1MIiJmK6_G8p4h8cbY9/view?usp=sharing)

- [A single universal network (URL) learned from 8 training datasets](https://drive.google.com/file/d/1Dv8TX6iQ-BE2NMpfd0sQmH2q4mShmo1A/view?usp=sharing)

## Dependencies
This code requires the following:
* Python 3.6 or greater (Ours: Python 3.8)
* PyTorch 1.0 or greater (Ours: torch=1.7.1, torchvision=0.8.2)
* TensorFlow 1.14 or greater (Ours: TensorFlow=2.10)
* tqdm (Ours: 4.64.1)
* tabulate (0.8.10)

## Installation
* Clone or download this repository.
* Configure Meta-Dataset:
    * Follow the "User instructions" in the [Meta-Dataset repository](https://github.com/google-research/meta-dataset) for "Installation" and "Downloading and converting datasets".
    * To test unseen domain (out-of-domain) performance on additional datasets, i.e. MNIST, CIFAR-10 and CIFAR-100, follow the installation instruction in the [CNAPs repository](https://github.com/cambridge-mlg/cnaps) to get these datasets.


## Cross-domain Few-shot Learning via Invariant-content Feature Reconstruction


### Train the Universal Representation Learning Network
1. The easiest way is to download [pre-trained URL model](https://drive.google.com/file/d/1Dv8TX6iQ-BE2NMpfd0sQmH2q4mShmo1A/view?usp=sharing) provided by URL. To download the pretrained URL model, one can use `gdown` (installed by ```pip install gdown```) and execute the following command in the root directory of this project:
    ```
    gdown https://drive.google.com/uc?id=1Dv8TX6iQ-BE2NMpfd0sQmH2q4mShmo1A && md5sum url.zip && unzip url.zip -d ./saved_results/ && rm url.zip
    
    ```
    This will donwnload the URL model and place it in the ```./saved_results``` directory. One can evaluate this model by PA or our IFR.

2. Alternatively, one can train the model from scratch: 1) train 8 single domain learning networks; 2) train the universal feature extractor as follow. 

#### Train Single Domain Learning Networks
1. The easiest way is to download [pre-trained models](https://drive.google.com/file/d/1MvUcvQ8OQtoOk1MIiJmK6_G8p4h8cbY9/view?usp=sharing) and use them to obtain a universal set of features directly. To download single domain learning networks, execute the following command in the root directory of this project:
    ```
    gdown https://drive.google.com/uc?id=1MvUcvQ8OQtoOk1MIiJmK6_G8p4h8cbY9 && md5sum sdl.zip && unzip sdl.zip -d ./saved_results/ && rm sdl.zip
    ```

    This will download all single domain learning models and place them in the ```./saved_results``` directory of this project.


2. Alternatively, instead of using the pretrained models, one can train the models from scratch.
   To train 8 single domain learning networks, run:
    ```
    ./scripts/train_resnet18_sdl.sh
    ```


#### Train the Universal Feature Extractor
To learn the universal feature extractor by distilling the knowledge from pre-trained single domain learning networks, run: 
```
./scripts/train_resnet18_url.sh
```

### Meta-Testing with Invariant-content Feature Reconstruction (IFR)
This step would run Invariant-content Feature Reconstruction module per task to learn both high-level and fine-grained invariant-content features for cross-domain few-shot classification. Run:
```
./scripts/run_ifr.sh
```

### Meta-Testing with Pre-classifier Alignment (PA)
This step would run Pre-classifier Alignment (PA) procedure per task to adapt the features to a discriminate space and build a Nearest Centroid Classifier (NCC) on the support set to classify query samples, run:

```
./scripts/test_resnet18_pa.sh
```

