# Theoretical refinement of CLIP by utilizing linear structure of optimal similarity (KME-CLIP)

This repository houses the official implementation of KME-CLIP in the paper titled "Theoretical refinement of CLIP by utilizing linear structure of optimal similarity".
The code and this README is largely based on https://github.com/sony/wpse/tree/main.



## Installation
### Docker
```
docker build -t <image_name> installation
```
### venv+pip
```
python -m venv <env_name>
source <env_name>/bin/activate
pip install -r installation/requirements.txt
pip install -r installation/requirements_torch.txt
pip install -r installation/requirements_rapids.txt
```

## Datasets Setup
### Conceptual Captions Setup
We use HuggingFace datasets for [CC3M](https://huggingface.co/datasets/pixparse/cc3m-wds) and [CC12M](https://huggingface.co/datasets/pixparse/cc12m-wds).
Before use them, rewrite `cache_dir` in yaml configuration files.
When a dataset is not in `cache_dir`, HuggingFace [`datasets.load_dataset()`](https://huggingface.co/docs/datasets/v3.2.0/en/package_reference/loading_methods#datasets.load_dataset) downloads the dataset.

**CC3M** ([`configs/dataset/example_cc3m.yaml`](configs/dataset/example_cc3m.yaml))
```yaml
name: huggingface_dataset
path: pixparse/cc3m-wds
cache_dir: /path/to/huggingface/cache/directory
...
```

**CC12M** ([`configs/dataset/example_cc12m.yaml`](configs/dataset/example_cc12m.yaml))
```yaml
name: huggingface_dataset
path: pixparse/cc12m-wds
cache_dir: /path/to/huggingface/cache/directory
...
```

### ImageNet Setup
ImageNet is used as a validation dataset in training. The path to ImageNet is read from `dataset_catalog.json`.
Copy [dataset_catalog_example.json](dataset_catalog_example.json) to `dataset_catalog.json` and rewrite dataset paths.
```
cp dataset_catalog_example.json dataset_catalog.json
```

The scripts requires the following directory structure for using `torchvision.datasets.ImageFolder`:
* /path/to/imagenet/
  - train/
    * n01440764/
      - n01440764_10026.JPEG
      - ...
      - n01440764_9981.JPEG
    * n01443537/
    * ...
    * n15074101/
  - val/
    * n01440764/
      - ILSVRC2012_val_00002138.JPEG
      - ...
      - ILSVRC2012_val_00048969.JPEG
    * ...
    * n15074101/
  - your_own_split_if_you_have/
    * n01440764/
    * ...
    * n15074104/

If you use customized data split (such as a subset of the train split) for the validation in training, please designate it as `imagenet_valsplit` in a configuration yaml for training.
```yaml
...
imagenet_valsplit: your_own_split_if_you_have
...
```

### Downstream Datasets Setup
The scripts read dataset paths from [dataset_catalog.json](#imagenet-setup).
The class labels and caption templates for zero-shot evaluation are read from [labels.json](labels.json) and [templates.json](templates.json).

**CIFAR10, CIFAR100, STL-10, Flowers102, DTD, Aircraft, and MNIST** are loaded by [`torchvision.datasets`](https://pytorch.org/vision/stable/datasets.html).
**For other datasets**, please use scripts from [VISSL](https://github.com/facebookresearch/vissl/tree/main/extra_scripts/datasets).

## Training
Configuration yaml files are placed in `configs/`. We use [Hydra](https://hydra.cc/) as a configuration management tool.
As results of training, following files are created in a directory designated by `output_dir` in a configuration yaml.
* config.yaml
  - A copy of the configuration file used in the training.
* log.txt
  - A log file of, for example, training losses and validation results.
* checkpoint.pt
  - The checkpoint file at the last epoch.
* checkpoint_best.pt
  - The checkpoint file that achieved the best score in the validation.

In addition, `mlflow` also serves as loggers if you specify it in the configuration yaml.

### Single-GPU training
The following examples run trainings of KME-CLIP (Gaussian kernel) on CC3M.
For reducing the size of point set, please replace the config name, `example_cc3m_kme_clip`, with `example_cc3m_kme_clip_reduce_points`. 
For CLIP, please replace the config name, `example_cc3m_kme_clip`, with `example_cc3m_clip`.
For WPSE (Gaussian kernel, $\sigma = 0.5, (\alpha_1, \alpha_2) = (0.667, 0.333)$), please replace the config name, `example_cc3m_kme_clip`, with `example_cc3m_wpse`.

```bash
config_name=example_cc3m_kme_clip

python main.py --config-name $config_name
```

### Single-node Multi-GPU training (4 GPUs)
```bash
config_name=example_cc3m_kme_clip

torchrun --standalone --nnodes=1 \
         --nproc_per_node=4 main.py \
         --config-name $config_name
```

### Multi-node Multi-GPU training (4 nodes)
```bash
config_name=example_cc3m_kme_clip
hostfile=<host file>
hostname=<host address>
port_num=<port num>

mpirun -np 4 -map-by ppr:1:node -hostfile $hostfile \
        python main_multi_nodes.py \
        --config-name $config_name \
        hostname=$hostname \
        port_num=$port_num
```

## Evaluation
### Retrieval (added for KME-CLIP)
Image-to-text and text-to-image retrieval can be done at once by the following script.
It should be run with multi-gpu environment for computational time. 
In the following examples, `/path/to/model` is supposed to contain `checkpoint_best.pt` and `config.yaml`.
As a result of retrieval evaluation, `results_retrieval.csv` will be created in `/path/to/model`.

```bash
model_dir=/path/to/model
torchrun --standalone --nnodes=1 \
         --nproc_per_node=2 eval_retrieval.py \
         --output-dir $model_dir \
         --distributed
```

### Zero-shot classification
For the zero-shot classification evaluation, `eval_zeroshot.py` is used.
In the following examples, `/path/to/model` is supposed to contain `checkpoint_best.pt` and `config.yaml`.
As a result of zero-shot evaluation, `results_zeroshot.csv` will be created in `/path/to/model`.

The dataset paths are read from [dataset_catalog.json](#imagenet-setup), and the class labels and caption templates are read from [labels.json](labels.json) and [templates.json](templates.json).

```bash
model_dir=/path/to/model
gpuid=0
python eval_zeroshot.py --output-dir $tgt_dir --gpu $gpuid
```

You can also conduct evaluations on a subset of datasets as follows:
```bash
python eval_zeroshot.py --output-dir $tgt_dir --gpu $gpuid \
       --tasklist cifar10 cifar100 stl10
```

### Linear probing
The scripts conduct linear probing evaluations in a two-stage manner.
First, features for linear classifiers are extracted.
After that, linear classifiers are fit, using extracted features.

#### Extracting features
To extract features after the last projection layer, run the following scripts.
Here, `/path/to/model/dir` is supposed to contain `checkpoint_best.pt` and `config.yaml`.
After running the above scripts, the extracted features are placed in `/path/to/workspace/frozen_feats/`

**For KME-CLIP models (added for KME-CLIP)**
```bash
workspace=/path/to/workspace
model_dir=/path/to/model/dir
python dump_linear_feats.py --config-name example_kme_clip_before_proj \
       output_dir=$workspace \
       model.feature_extractor.model_dir=$model_dir
```
**For CLIP models**
```bash
workspace=/path/to/workspace
model_dir=/path/to/model/dir
python dump_linear_feats.py --config-name example_clip_bef_proj_dump \
       output_dir=$workspace \
       model.feature_extractor.model_dir=$model_dir
```
**For WPSE models**
```bash
workspace=/path/to/workspace
model_dir=/path/to/model/dir
python dump_linear_feats.py --config-name example_wpse_bef_proj_dump \
       output_dir=$workspace \
       model.feature_extractor.model_dir=$model_dir
```


To extract features after the last projection layer, run the following scripts.

**For KME-CLIP models (added for KME-CLIP)**
```bash
workspace=/path/to/workspace
model_dir=/path/to/model_dir
python dump_linear_feats.py --config-name example_kme_clip_after_proj \
       output_dir=$workspace \
       model.feature_extractor.model_dir=$model_dir
```
**For CLIP models**
```bash
workspace=/path/to/workspace
model_dir=/path/to/model_dir
python dump_linear_feats.py --config-name example_clip_dump \
       output_dir=$workspace \
       model.feature_extractor.model_dir=$model_dir
```
**For WPSE models**
```bash
workspace=/path/to/workspace
model_dir=/path/to/model_dir
python dump_linear_feats.py --config-name example_wpse_d1024_dump \
       output_dir=$workspace \
       model.feature_extractor.model_dir=$model_dir
```

#### Fitting linear classifiers
After extracting features, `rapids_linear_probe.py` conducts fitting linear classifiers.
The information about datasets and data splits are read from [dataset_catalog.json](#imagenet-setup).
`/path/to/workspace` is supposed to contain `frozen_feats/` created by `dump_linear_feats.py`
```bash
workspace=/path/to/workspace
task=cifar10
python rapids_linear_probe.py $workspace $task
```

## License of original codebase
This repository is a modification of [WPSE](https://github.com/sony/wpse/tree/main), which is licensed under the MIT license. See the [original license description](https://github.com/sony/wpse/blob/main/LICENSE) for details. Additionally, the original repository (WPSE) includes work from the following repositories:
* SLIP (https://github.com/facebookresearch/SLIP)
  - Copyright (c) Meta Platforms, Inc. and affiliates.
  - Distributed under the MIT License. 
* SSL-HSIC (https://github.com/google-deepmind/ssl_hsic)
  - Copyright 2021 DeepMind Technologies Limited.
  - Distributed under the Apache License 2.0.

