# HComP-Net: Hierarchy aligned Commonality through Prototypical Networks
This repository presents the PyTorch code for **HComP-Net** (**H**ierarchy aligned **Com**monality through **P**rototypical **Net**works) 

**HComP-Net** is an hierarchical interpretable image classification framework that can be applied to discover potential evolutionary traits from images by making use of the Phylogenetic tree also called as Tree-Of-Life. HComPNet generates hypothesis for potential evolutionary traits by learning semantically meaningful non-over-specific prototypes at each internal node of the hierarchy

**Paper: What Do You See in Common? Learning Hierarchical Prototypes over Tree-of-Life to Discover Evolutionary Traits**

> ***Abstract:*** 
>> *A grand challenge in biology is to discover evolutionary traits - features of organisms common to a group of species with a shared ancestor in the Tree of Life (also referred to as phylogenetic tree). With the growing availability of  image repositories in biology, there is a tremendous opportunity to discover evolutionary traits directly from images in the form of a hierarchy of prototypes. However, current prototype-based methods are mostly designed to operate over a flat structure of classes and face several challenges in discovering hierarchical prototypes, including the issue of learning over-specific features at internal nodes. To overcome these challenges, we introduce the framework of **H**ierarchy aligned **Com**monality through **P**rototypical **Net**works **(HComP-Net)**.
  We empirically show that HComP-Net learns prototypes that are accurate, semantically consistent, and generalizable to unseen species in comparison to baselines on birds, butterflies, and fishes datasets.*

![Objective of HComP-Net](assets/HComPNet_teaser.png)

## Setting up environment
Run the following command to create and activate a new conda environment
```
conda create -n hcomp
conda activate hcomp
```
and run the following command to install the required packages
```
pip install -r requirements.txt
```

## Setting up datasets

#### CUB-190 (Bird)

Download CUB-200-2011 [1] dataset and save it in the ```/data``` path

Once downloaded the folder structure should look something like this  
```
data/
└── CUB_200_2011/
    ├── attributes/ # Not used
    ├── images/
    ├── parts/ 
    ├── image_class_labels.txt
    ├── train_test_split.txt
    ├── images.txt
    ├── bounding_boxes.txt
    ├── classes.txt
    └── README.md
```

Run the following command to create **Bird** (CUB-190) dataset. Running the command creates ```dataset_cub190``` and ```images_cub190``` folders

```
python preprocess_data/prepare_cub190.py --segment
```

The folder structure should now look like this  
```
data/
└── CUB_200_2011/
    ├── attributes/ # Not used
    ├── dataset_cub190/ # Newly created
    ├── images/
    ├── images_cub190/ # Newly created
    ├── parts/ 
    ├── image_class_labels.txt
    ├── train_test_split.txt
    ├── images.txt
    ├── bounding_boxes.txt
    ├── classes.txt
    └── README.md
```

#### Butterfly
Run the following command to unzip the **Butterfly** dataset into ```/data``` directory
```
cd /path/to/HComPNet/data
unzip data/Butterfly_Heliconius_30.zip -d data
```

#### Fish
Run the following command to unzip the **Fish** dataset into ```/data``` directory
```
cd /path/to/HComPNet/data
unzip data/Fish_38.zip -d data
```

## Training HComP-Net

For Bird (CUB-190) dataset,

The following command is for training the model on ```CUB-190``` dataset. Running ```CUB-190``` with a ```batch_size``` of 256 required two a100 GPUs, therefore ```gpu_ids``` is set to ```0,1```. For running on single gpu remove the ```gpu_ids``` argument, as it assume single GPU by default.
```
python main.py --log_dir './runs/hcompnet_cub190_cnext26' --dataset CUB-190 --net convnext_tiny_26 --batch_size 256 --batch_size_pretrain 256 --epochs 75 --epochs_pretrain 10 --epochs_finetune_classifier 3 --epochs_finetune_mask 60 --freeze_epochs 10 --gpu_ids '0,1' --num_workers 8 --phylo_config ./configs/cub190_phylogeny.yaml --num_protos_per_child 10
```

For Butterfly dataset,
```
python main.py --log_dir './runs/hcompnet_but30_cnext26' --dataset BUT-30 --net convnext_tiny_26 --batch_size 64 --batch_size_pretrain 128 --epochs 75 --epochs_pretrain 10 --epochs_finetune_classifier 3 --epochs_finetune_mask 60 --freeze_epochs 10 --num_workers 8 --phylo_config ./configs/but30_phylogeny.yaml --num_protos_per_child 10
```

For Fish dataset,
```
python main.py --log_dir './runs/hcompnet_fish38_cnext26' --dataset FISH-38 --net convnext_tiny_26 --batch_size 64 --batch_size_pretrain 128 --epochs 75 --epochs_pretrain 10 --epochs_finetune_classifier 3 --epochs_finetune_mask 60 --freeze_epochs 10 --num_workers 8 --phylo_config ./configs/fish38_phylogeny.yaml --num_protos_per_child 10
```
## Visualizing the prototypes

We create Top-K Visualizations to analyze prototypes, where we visualize the Top-K nearest image patches for an hierarchical prototype from each leaf descendant. Follow the steps in the ```plot_topk_visualizations.ipynb```, to create Top-K visualization.

## Analyzing the semantic quality of prototypes

Follow instructions in ```part_purity_cub.ipynb``` to quantitatively analyze the semantic quality of prototypes


# References

- [1] Wah, Catherine, Steve Branson, Peter Welinder, Pietro Perona, and Serge Belongie. "The caltech-ucsd birds-200-2011 dataset." (2011).
