# AMA: Asymptotic Midpoint Augmentation for Margin Balancing and Moderate Broadening

This repository is the official implementation of ***AMA: Asymptotic Midpoint Augmentation for Margin Balancing and Moderate Broadening***.


<img src="figures/method_ama_whole_process.PNG" alt="AMA Main">

---
## Experimental Details
---
### General Settings

We set the below settings for all experiments except image classification on tinyImageNet.

- GPUs: NVIDIA GeForce RTX 3090 x 1
- CPU cores: 64
- Memory: 256 GB
- NVIDIA Driver: 515.65.01
- CUDA version: 11.7

### Image Classification on TinyImageNet

To train TinyImageNet in this experiment, we set the below setting.

- GPUs: Quadro RTX 8000 x 1
- CPU cores: 64
- Memory: 512 GB
- NVIDIA Driver: 470.141.03
- CUDA version: 11.4

---
## Requirements
---
To install requirements:

```setup
pip install -r requirements.txt
```

For image classification on TinyImageNet, you should download the dataset:

```
wget http://cs231n.stanford.edu/tiny-imagenet-200.zip
unzip tiny-imagenet-200.zip
```
---
## Training and Evaluation
---
You can train the baseline or AMA model for setting hyperparameter ```net``` in the shell script file.

### Coarse-to-Fine-Grained Transfer Learning

To train and test the model(s) in this task, run this command in ```coarse_to_fine_grained``` directory:

- CIFAR10 - resnet50_AMA / resnet50_origin / resnet50_mixup / resnet50_manifoldMixup
```
./train_CIFAR10_coarse2fine_transfer_learning.sh
```
- CIFAR100 - resnet50_AMA / resnet50_origin / resnet50_mixup / resnet50_manifoldMixup
```
./train_CIFAR100_coarse2fine_transfer_learning.sh
```
- CIFAR10 - SupCon
```
./train_SupCon_CIFAR10_coarse2fine_transfer_learning.sh
```
- CIFAR100 - SupCon
```
./train_SupCon_CIFAR100_coarse2fine_transfer_learning.sh
```
---
### Image Classification on Long-Tailed Dataset 

To train and test the model(s) in this task, run this command in ```long_tailed``` directory:

- CIFAR10 - resnet50_AMA / resnet50_origin / resnet50_mixup / resnet50_manifoldMixup
```
./train_CIFAR10LT.sh
```
- CIFAR100 - resnet50_AMA / resnet50_origin / resnet50_mixup / resnet50_manifoldMixup
```
./train_CIFAR100LT.sh
```
-CIFAR10 - SupCon
```
./train_SupCon_CIFAR10LT.sh
```
-CIFAR100 - SupCon
```
./train_SupCon_CIFAR100LT.sh
```
---
### Image Classification on Classic Dataset (CIFAR10/CIFAR100)

To train and test the model(s) in this task, run this command in ```image_classification_cifar``` directory:

- CIFAR10 - vgg11_AMA / vgg11_origin / vgg11_mixup / vgg11_manifoldMixup / resnet50_AMA / resnet50_origin / resnet50_mixup / resnet50_manifoldMixup / dense_AMA / dense_origin / dense_mixup / dense_manifoldMixup
```
./train_CIFAR10.sh
```
- CIFAR100 - vgg11_AMA / vgg11_origin / vgg11_mixup / vgg11_manifoldMixup / resnet50_AMA / resnet50_origin / resnet50_mixup / resnet50_manifoldMixup / dense_AMA / dense_origin / dense_mixup / dense_manifoldMixup
```
./train_CIFAR100.sh
```
- CIFAR10 - supCon
```
./train_SupCon_CIFAR10.sh
```
- CIFAR100 - supCon
```
./train_SupCon_CIFAR100.sh
```
---
### Image Classification on Classic Dataset (TinyImageNet)
to train and test the model(s) in this task, run this command in ```image_classification_tiny``` directory:
- TinyImageNet - vgg11_AMA / vgg11_origin / vgg11_mixup / vgg11_manifoldMixup / resnet50_AMA / resnet50_origin / resnet50_mixup / resnet50_manifoldMixup / dense_AMA / dense_origin / dense_mixup / dense_manifoldMixup
```
./train_TinyImageNet.sh
```
- TinyImageNet - SupCon
```
./train_SupCon_TinyImageNet.sh
```
---


## Reference

#### VGG
[Very deep convolutional networks for large-scale image recognition](https://arxiv.org/pdf/1409.1556.pdf)

[vgg-pytorch](https://github.com/pytorch/vision/blob/main/torchvision/models/vgg.py)

[pytorch-cifar100-vgg](https://github.com/weiaicunzai/pytorch-cifar100/blob/master/models/vgg.py)

#### ResNet
[Deep Residual Learning for Image Recognition](https://arxiv.org/abs/1512.03385)

[resnet-pytorch](https://github.com/pytorch/vision/blob/main/torchvision/models/resnet.py)

[pytorch-cifar100-resnet](https://github.com/weiaicunzai/pytorch-cifar100/blob/master/models/resnet.py)

#### DenseNet
[Densely Connected Convolutional Networks](https://arxiv.org/abs/1608.06993)

[pytorch-implementation of DenseNet](https://github.com/bamos/densenet.pytorch/blob/master/densenet.py)

#### SupCon
[Supervised Contrastive Learning](https://arxiv.org/abs/2004.11362)

[SupContrast](https://github.com/HobbitLong/SupContrast)

#### Mixup

[mixup: Beyond Empirical Risk Minimization](https://arxiv.org/abs/1710.09412)

[Mixup-CIFAR10](https://github.com/facebookresearch/mixup-cifar10)

#### Manifold Mixup

[Manifold Mixup: Better Representations by Interpolating Hidden States](https://arxiv.org/abs/1806.05236)

[manifold_mixup](https://github.com/vikasverma1077/manifold_mixup)

#### Image Classification on Long-Tailed Datasets
[Parametric Contrastive Learning](https://arxiv.org/abs/2107.12028)

[Targeted Supervised Contrastive Learning for Long-Tailed Recognition](https://arxiv.org/abs/2111.13998)

[Balanced Contrastive Learning for Long-Tailed Visual Recognition](https://arxiv.org/abs/2207.09052)

[Learning Imbalanced Datasets with Label-Distribution-Aware Margin Loss](https://github.com/kaidic/LDAM-DRW)

#### Coarse-to-Fine-Grained Transfer Learning
[Perfectly Balanced Transfer and Robustness of Supervised Contrastive Learning](https://arxiv.org/abs/2204.07596)