# Superclass-Conditional Gaussian Mixture Model For Learning Fine-Grained Embeddings


This code provides demo on BREEDS dataset, which can be adapted to other datasets including CIFAR-100 and tieredImageNet.
* Run ``train_scgm_g.py`` to train the SCGM-G model.
* Run ``train_scgm_a.py`` to train the SCGM-A model.
* Run ``test_scgm_g.py`` to evaluate SCGM-G for the cross-granularity few shot learning task.
* Run ``test_scgm_a.py`` to evaluate SCGM-A for the cross-granularity few shot learning task.
* Run ``test_fg_scgm_g.py`` to evaluate SCGM-G for the cross-granularity few shot learning task for intra-class evaluation.
* Run ``test_fg_scgm_a.py`` to evaluate SCGM-A for the cross-granularity few shot learning task for intra-class evaluation.

In the above script, specify dataset by ``ds_name = 'living17'``.
After training is done, the model will be saved in the ``pretrain_model/`` directory.

## Dataset
### BREEDS dataset
1. Download the ImageNet dataset from https://www.image-net.org/challenges/LSVRC/.
2. Following official BREEDS repo (https://github.com/MadryLab/BREEDS-Benchmarks/blob/master/Constructing%20BREEDS%20datasets.ipynb), run

import os
from robustness.tools.breeds_helpers import setup_breeds
info_dir= "[your_imagenet_path]/ILSVRC/BREEDS"
if not (os.path.exists(info_dir) and len(os.listdir(info_dir))):
    print("Downloading class hierarchy information into `info_dir`")
    setup_breeds(info_dir)

### CIFAR-100
CIFAR-100 can be downloaded from https://www.cs.toronto.edu/~kriz/cifar.html

### TieredImageNet
TieredImageNet can be downloaded from https://github.com/renmengye/few-shot-ssl-public

Once downloaded, use the code in ``dataset`` folder to generate minibatches for model training.

## Model parameters
In train_scgm_g:
* k: number of subclasses
* lmd: lambda in sinkhornknopp algorithm
* tau: variance of subclass distribution
* alpha: the gamma used in the loss function
* with_mlp: specify whether to use mlp or fc embedder

In train_scgm_a:
* n_subclass: number of subclasses
* lmd: lambda in sinkhornknopp algorithm
* tau1: variance of subclass distribution
* alpha: the gamma used in the loss function
* queue_k: queue size
* encoder_m: momentum
* metric_type: normalization type
* with_mlp: specify whether to use mlp or fc embedder

## Motel testing
In test_scgm_g, test_scgm_a, test_fg_scgm_g, test_fg_scgm_a:
* classifier: specify which classifier to use
* n_test_runs: number of episodes
* n_ways: number of ways
* n_shots: number of shots
* n_queries: number of queries
* n_aug_support_samples: number of augmented copies for each support sample

## Requirements
The experiments were done using python3.7, with the following requirements:
* learn2learn==0.1.5
* matplotlib==3.4.2
* networkx==2.5.1
* numpy==1.20.3
* pandas==1.3.0
* robustness==1.2.1.post2
* scikit-learn==0.24.2
* scipy==1.7.0
* seaborn==0.11.1
* torch==1.4.0+cu92
* torchvision==0.5.0+cu92
