# Seesaw Loss

> [Seesaw Loss for Long-Tailed Instance Segmentation](https://arxiv.org/abs/2008.10032)

<!-- [ALGORITHM] -->

## Abstract

Instance segmentation has witnessed a remarkable progress on class-balanced benchmarks. However, they fail to perform as accurately in real-world scenarios, where the category distribution of objects naturally comes with a long tail. Instances of head classes dominate a long-tailed dataset and they serve as negative samples of tail categories. The overwhelming gradients of negative samples on tail classes lead to a biased learning process for classifiers. Consequently, objects of tail categories are more likely to be misclassified as backgrounds or head categories. To tackle this problem, we propose Seesaw Loss to dynamically re-balance gradients of positive and negative samples for each category, with two complementary factors, i.e., mitigation factor and compensation factor. The mitigation factor reduces punishments to tail categories w.r.t. the ratio of cumulative training instances between different categories. Meanwhile, the compensation factor increases the penalty of misclassified instances to avoid false positives of tail categories. We conduct extensive experiments on Seesaw Loss with mainstream frameworks and different data sampling strategies. With a simple end-to-end training pipeline, Seesaw Loss obtains significant gains over Cross-Entropy Loss, and achieves state-of-the-art performance on LVIS dataset without bells and whistles.

<div align=center>
<img src="https://user-images.githubusercontent.com/40661020/143974715-d181abe5-d0a2-40d3-a2bd-17d8c60b89b8.png"/>
</div>

- Please setup [LVIS dataset](../lvis/README.md) for MMDetection.

- RFS indicates to use oversample strategy [here](../../docs/tutorials/customipredataset.md#class-balanced-dataset) with oversample threshold `1e-3`.

## Results and models of Seasaw Loss on LVIS v1 dataset

|       Method       | Backbone  |  Style  | Lr schd | Data Sampler | Norm Mask | box AP | mask AP |                                           Config                                           |                                                                                                                                                              Download                                                                                                                                                              |
| :----------------: | :-------: | :-----: | :-----: | :----------: | :-------: | :----: | :-----: | :----------------------------------------------------------------------------------------: | :--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------: |
|     Mask R-CNN     | R-50-FPN  | pytorch |   2x    |    random    |     N     |  25.6  |  25.0   |             [config](./mask-rcnn_r50_fpn_seesaw-loss_random-ms-2x_lvis-v1.py)              |                          [model](https://download.openmmlab.com/mmdetection/v2.0/seesaw_loss/mask_rcnn_r50_fpn_random_seesaw_loss_mstrain_2x_lvis_v1-a698dd3d.pth) \| [log](https://download.openmmlab.com/mmdetection/v2.0/seesaw_loss/mask_rcnn_r50_fpn_random_seesaw_loss_mstrain_2x_lvis_v1.log.json)                          |
|     Mask R-CNN     | R-50-FPN  | pytorch |   2x    |    random    |     Y     |  25.6  |  25.4   |       [config](./mask-rcnn_r50_fpn_seesaw-loss-normed-mask_random-ms-2x_lvis-v1.py)        |              [model](https://download.openmmlab.com/mmdetection/v2.0/seesaw_loss/mask_rcnn_r50_fpn_random_seesaw_loss_normed_mask_mstrain_2x_lvis_v1-a1c11314.pth) \| [log](https://download.openmmlab.com/mmdetection/v2.0/seesaw_loss/mask_rcnn_r50_fpn_random_seesaw_loss_normed_mask_mstrain_2x_lvis_v1.log.json)              |
|     Mask R-CNN     | R-101-FPN | pytorch |   2x    |    random    |     N     |  27.4  |  26.7   |             [config](./mask-rcnn_r101_fpn_seesaw-loss_random-ms-2x_lvis-v1.py)             |                         [model](https://download.openmmlab.com/mmdetection/v2.0/seesaw_loss/mask_rcnn_r101_fpn_random_seesaw_loss_mstrain_2x_lvis_v1-8e6e6dd5.pth) \| [log](https://download.openmmlab.com/mmdetection/v2.0/seesaw_loss/mask_rcnn_r101_fpn_random_seesaw_loss_mstrain_2x_lvis_v1.log.json)                         |
|     Mask R-CNN     | R-101-FPN | pytorch |   2x    |    random    |     Y     |  27.2  |  27.3   |       [config](./mask-rcnn_r101_fpn_seesaw-loss-normed-mask_random-ms-2x_lvis-v1.py)       |             [model](https://download.openmmlab.com/mmdetection/v2.0/seesaw_loss/mask_rcnn_r101_fpn_random_seesaw_loss_normed_mask_mstrain_2x_lvis_v1-a0b59c42.pth) \| [log](https://download.openmmlab.com/mmdetection/v2.0/seesaw_loss/mask_rcnn_r101_fpn_random_seesaw_loss_normed_mask_mstrain_2x_lvis_v1.log.json)             |
|     Mask R-CNN     | R-50-FPN  | pytorch |   2x    |     RFS      |     N     |  27.6  |  26.4   |           [config](./mask-rcnn_r50_fpn_seesaw-loss_sample1e-3-ms-2x_lvis-v1.py)            |                      [model](https://download.openmmlab.com/mmdetection/v2.0/seesaw_loss/mask_rcnn_r50_fpn_sample1e-3_seesaw_loss_mstrain_2x_lvis_v1-392a804b.pth) \| [log](https://download.openmmlab.com/mmdetection/v2.0/seesaw_loss/mask_rcnn_r50_fpn_sample1e-3_seesaw_loss_mstrain_2x_lvis_v1.log.json)                      |
|     Mask R-CNN     | R-50-FPN  | pytorch |   2x    |     RFS      |     Y     |  27.6  |  26.8   |     [config](./mask-rcnn_r50_fpn_seesaw-loss-normed-mask_sample1e-3-ms-2x_lvis-v1.py)      |          [model](https://download.openmmlab.com/mmdetection/v2.0/seesaw_loss/mask_rcnn_r50_fpn_sample1e-3_seesaw_loss_normed_mask_mstrain_2x_lvis_v1-cd0f6a12.pth) \| [log](https://download.openmmlab.com/mmdetection/v2.0/seesaw_loss/mask_rcnn_r50_fpn_sample1e-3_seesaw_loss_normed_mask_mstrain_2x_lvis_v1.log.json)          |
|     Mask R-CNN     | R-101-FPN | pytorch |   2x    |     RFS      |     N     |  28.9  |  27.6   |           [config](./mask-rcnn_r101_fpn_seesaw-loss_sample1e-3-ms-2x_lvis-v1.py)           |                     [model](https://download.openmmlab.com/mmdetection/v2.0/seesaw_loss/mask_rcnn_r101_fpn_sample1e-3_seesaw_loss_mstrain_2x_lvis_v1-e68eb464.pth) \| [log](https://download.openmmlab.com/mmdetection/v2.0/seesaw_loss/mask_rcnn_r101_fpn_sample1e-3_seesaw_loss_mstrain_2x_lvis_v1.log.json)                     |
|     Mask R-CNN     | R-101-FPN | pytorch |   2x    |     RFS      |     Y     |  28.9  |  28.2   |     [config](./mask-rcnn_r101_fpn_seesaw-loss-normed-mask_sample1e-3-ms-2x_lvis-v1.py)     |         [model](https://download.openmmlab.com/mmdetection/v2.0/seesaw_loss/mask_rcnn_r101_fpn_sample1e-3_seesaw_loss_normed_mask_mstrain_2x_lvis_v1-1d817139.pth) \| [log](https://download.openmmlab.com/mmdetection/v2.0/seesaw_loss/mask_rcnn_r101_fpn_sample1e-3_seesaw_loss_normed_mask_mstrain_2x_lvis_v1.log.json)         |
| Cascade Mask R-CNN | R-101-FPN | pytorch |   2x    |    random    |     N     |  33.1  |  29.2   |         [config](./cascade-mask-rcnn_r101_fpn_seesaw-loss_random-ms-2x_lvis-v1.py)         |                 [model](https://download.openmmlab.com/mmdetection/v2.0/seesaw_loss/cascade_mask_rcnn_r101_fpn_random_seesaw_loss_mstrain_2x_lvis_v1-71e2215e.pth) \| [log](https://download.openmmlab.com/mmdetection/v2.0/seesaw_loss/cascade_mask_rcnn_r101_fpn_random_seesaw_loss_mstrain_2x_lvis_v1.log.json)                 |
| Cascade Mask R-CNN | R-101-FPN | pytorch |   2x    |    random    |     Y     |  33.0  |  30.0   |   [config](./cascade-mask-rcnn_r101_fpn_seesaw-loss-normed-mask_random-ms-2x_lvis-v1.py)   |     [model](https://download.openmmlab.com/mmdetection/v2.0/seesaw_loss/cascade_mask_rcnn_r101_fpn_random_seesaw_loss_normed_mask_mstrain_2x_lvis_v1-8b5a6745.pth) \| [log](https://download.openmmlab.com/mmdetection/v2.0/seesaw_loss/cascade_mask_rcnn_r101_fpn_random_seesaw_loss_normed_mask_mstrain_2x_lvis_v1.log.json)     |
| Cascade Mask R-CNN | R-101-FPN | pytorch |   2x    |     RFS      |     N     |  30.0  |  29.3   |       [config](./cascade-mask-rcnn_r101_fpn_seesaw-loss_sample1e-3-ms-2x_lvis-v1.py)       |             [model](https://download.openmmlab.com/mmdetection/v2.0/seesaw_loss/cascade_mask_rcnn_r101_fpn_sample1e-3_seesaw_loss_mstrain_2x_lvis_v1-5d8ca2a4.pth) \| [log](https://download.openmmlab.com/mmdetection/v2.0/seesaw_loss/cascade_mask_rcnn_r101_fpn_sample1e-3_seesaw_loss_mstrain_2x_lvis_v1.log.json)             |
| Cascade Mask R-CNN | R-101-FPN | pytorch |   2x    |     RFS      |     Y     |  32.8  |  30.1   | [config](./cascade-mask-rcnn_r101_fpn_seesaw-loss-normed-mask_sample1e-3-ms-2x_lvis-v1.py) | [model](https://download.openmmlab.com/mmdetection/v2.0/seesaw_loss/cascade_mask_rcnn_r101_fpn_sample1e-3_seesaw_loss_normed_mask_mstrain_2x_lvis_v1-c8551505.pth) \| [log](https://download.openmmlab.com/mmdetection/v2.0/seesaw_loss/cascade_mask_rcnn_r101_fpn_sample1e-3_seesaw_loss_normed_mask_mstrain_2x_lvis_v1.log.json) |

## Citation

We provide config files to reproduce the instance segmentation performance in the CVPR 2021 paper for [Seesaw Loss for Long-Tailed Instance Segmentation](https://arxiv.org/abs/2008.10032).

```latex
@inproceedings{wang2021seesaw,
  title={Seesaw Loss for Long-Tailed Instance Segmentation},
  author={Jiaqi Wang and Wenwei Zhang and Yuhang Zang and Yuhang Cao and Jiangmiao Pang and Tao Gong and Kai Chen and Ziwei Liu and Chen Change Loy and Dahua Lin},
  booktitle={Proceedings of the {IEEE} Conference on Computer Vision and Pattern Recognition},
  year={2021}
}
```
