### Codes of Our ICLR2021 Submission “Improve Object Detection With Feature-Based Knowledge Distillation: Towards Accurate And Efficient Detectors”

This is the Pytorch implementation of our paper "Improve Object Detection with Feature-Based Knowledge Distillation Towards Accurate And Efficient Detectors". Our codes and pre-trained teacher models are based on mmdetection2 detection framework.

### Step-1 Download MMdetection2 Framework and Dataset

Please first download mmdetection2 and MS COCO2017 datasets and make sure that you can run a baseline model successfully. Please follow the installation instruction of MMdetection2 for the installation details. Note that the version of mmcv should be >=0.6.2 and the version of mmdet should be >=2.2.

### Step-2 Download the Teacher Model 

Before starting running the distillation codes, you need to download a pre-trained teacher model. We advise you to download the pretrained cascade_mask_rcnn_x101_32x4d_fpn_dconv_c3-c5_1x_c.pth and retinanet_x101_64x4d_fpn_1x_coco_20200130-366f5af.pth for knowledge distillation on two-stage and one-stage students, respectively. Note that the downloading urls of the two models can be found in mmdetection/configs/dcn/README.md and mmdetection/configs/retinanet/README.md. Then, put them in the checkpoints folder as follows.

```
mmdetection
--checkpoints
----cascade_mask_rcnn_x101_32x4d_fpn_dconv_c3-c5_1x_c.pth
----retinanet_x101_64x4d_fpn_1x_coco_20200130-366f5af.pth
```

### Step-3 Change the Codes of MMdetection2 to Add Our Method

#### 3.1 mmdetection/mmdet/apis/train.py

We need to change this train.py so mmdetection can create and load the pre-trained teacher model in the training period.  

3.1.1 add the build_teacher function in this file.

```
def build_teacher():
    teacher_cfg = Config.fromfile("configs/....") # config file of teacher model
    teacher = build_detector(
    	teacher_cfg.model, train_cfg=teacher_cfg.train_cfg,test_cfg=teacher_cfg.test_cfg)
    load_checkpoint(teacher, "path of teacher model", map_location='cpu')
    return teacher
```

3.1.2 Build teacher model in the function train_detector

```
teacher = build_teacher()
```

3.1.3  Add the teacher to the runner in mmcv 

```
runner = EpochBasedRunner(
	model,
	optimizer=optimizer,
	work_dir=cfg.work_dir,
	logger=logger,
	meta=meta,
	teacher=teacher # add this line here.
)
```

### 3.2 mmcv/runner/epoch_based_runner.py and base_runner.py

3.2.0 If you don't know the path of your mmcv folder, please run the following code

```
pip show mmcv
```

3.2.1 change the \__init\__ function in base_runner.py

```
def __init__(self,
             model,
             batch_processor=None,
             optimizer=None,
             work_dir=None,
             logger=None,
             meta=None,
             teacher=None):   # add this line here
```

```
self.model = model
self.teacher = teacher # add this line here
self.batch_processor = batch_processor
```

3.2.2 change the train function in epoch_based_runner.py

```
def train(self, data_loader, teacher=None, **kwargs):
    self.model.train()
    self.teacher.train()
    self.mode = 'train'
    self.data_loader = data_loader
    self._max_iters = self._max_epochs * len(data_loader)
    self.call_hook('before_train_epoch')
    time.sleep(2)  # Prevent possible deadlock during epoch transition
    for i, data_batch in enumerate(data_loader):
        self._inner_iter = i
        self.call_hook('before_train_iter')
        if self.batch_processor is None:
            with torch.no_grad():
                t_info = self.teacher.train_step(data_batch, self.optimizer, epoch=self.epoch, iter=self._inner_iter, teach=True, t_info=None, **kwargs)
            outputs = self.model.train_step(data_batch, self.optimizer, epoch=self.epoch, iter=self._inner_iter, teach=False, t_info=t_info,**kwargs)

```

### 3.3 mmdetection/mmdet/models/detectors/base.py

change the train_step function in base.py

```
def train_step(self, data, optimizer, teach=False, t_info=None, epoch=None, iter=None):
    """The iteration step during training.

    This method defines an iteration step during training, except for the
    back propagation and optimizer updating, which are done in an optimizer
    hook. Note that in some complicated cases or models, the whole process
    including back propagation and optimizer updating is also defined in
    this method, such as GAN.

    Args:
        data (dict): The output of dataloader.
        optimizer (:obj:`torch.optim.Optimizer` | dict): The optimizer of
            runner is passed to ``train_step()``. This argument is unused
            and reserved.

    Returns:
    	dict: It should contain at least 3 keys: ``loss``, ``log_vars``,
        ``num_samples``.
        ``loss`` is a tensor for back propagation, which can be a
        weighted sum of multiple losses.
        ``log_vars`` contains all the variables to be sent to the
        logger.
        ``num_samples`` indicates the batch size (when the model is
        DDP, it means the batch size on each GPU), which is used for
        averaging the logs.
	"""
    if teach:
        teacher_info = self.get_teacher_info(**data)
        return teacher_info
    else:
        losses = self(t_info=t_info, epoch=epoch, iter=iter, **data)
        loss, log_vars = self._parse_losses(losses)

        outputs = dict(
            loss=loss, log_vars=log_vars, num_samples=len(data['img_metas']))
        return outputs
```

### 3.4 mmdetection/mmdet/models/detectors/two_stage.py and single_stage.py

change the codes in the function forward_train

```
t = 0.5
s_ratio = 1.0
x = self.extract_feat(img)
losses = dict()
kd_feat_loss = 0
kd_channel_loss = 0
kd_spatial_loss = 0
c_t = 0.5
c_s_ratio = 1.0
if t_info is not None:
            t_feats = t_info['feat']
            for _i in range(len(t_feats)):
                t_attention_mask = torch.mean(torch.abs(t_feats[_i]), [1], keepdim=True)
                size = t_attention_mask.size()
                t_attention_mask = t_attention_mask.view(x[0].size(0), -1)
                t_attention_mask = torch.softmax(t_attention_mask / t, dim=1) * size[-1] * size[-2]
                t_attention_mask = t_attention_mask.view(size)

                s_attention_mask = torch.mean(torch.abs(x[_i]), [1], keepdim=True)
                size = s_attention_mask.size()
                s_attention_mask = s_attention_mask.view(x[0].size(0), -1)
                s_attention_mask = torch.softmax(s_attention_mask / t, dim=1) * size[-1] * size[-2]
                s_attention_mask = s_attention_mask.view(size)

                c_t_attention_mask = torch.mean(torch.abs(t_feats[_i]), [2, 3], keepdim=True)  # 2 x 256 x 1 x1
                c_size = c_t_attention_mask.size()
                c_t_attention_mask = c_t_attention_mask.view(x[0].size(0), -1)  # 2 x 256
                c_t_attention_mask = torch.softmax(c_t_attention_mask / c_t, dim=1) * 256
                c_t_attention_mask = c_t_attention_mask.view(c_size)  # 2 x 256 -> 2 x 256 x 1 x 1

                c_s_attention_mask = torch.mean(torch.abs(x[_i]), [2, 3], keepdim=True)  # 2 x 256 x 1 x1
                c_size = c_s_attention_mask.size()
                c_s_attention_mask = c_s_attention_mask.view(x[0].size(0), -1)  # 2 x 256
                c_s_attention_mask = torch.softmax(c_s_attention_mask / c_t, dim=1) * 256
                c_s_attention_mask = c_s_attention_mask.view(c_size)  # 2 x 256 -> 2 x 256 x 1 x 1

                sum_attention_mask = (t_attention_mask + s_attention_mask * s_ratio) / (1 + s_ratio)
                sum_attention_mask = sum_attention_mask.detach()

                c_sum_attention_mask = (c_t_attention_mask + c_s_attention_mask * c_s_ratio) / (1 + c_s_ratio)
                c_sum_attention_mask = c_sum_attention_mask.detach()

                kd_feat_loss += dist2(t_feats[_i], self.adaptation_layers[_i](x[_i]), attention_mask=sum_attention_mask,
                                      channel_attention_mask=c_sum_attention_mask) * 7e-5
                kd_channel_loss += torch.dist(torch.mean(t_feats[_i], [2, 3]),
                                              self.channel_wise_adaptation[_i](torch.mean(x[_i], [2, 3]))) * 4e-3
                t_spatial_pool = torch.mean(t_feats[_i], [1]).view(t_feats[_i].size(0), 1, t_feats[_i].size(2),
                                                                   t_feats[_i].size(3))
                s_spatial_pool = torch.mean(x[_i], [1]).view(x[_i].size(0), 1, x[_i].size(2),
                                                             x[_i].size(3))
                kd_spatial_loss += torch.dist(t_spatial_pool, self.spatial_wise_adaptation[_i](s_spatial_pool)) * 4e-3

        losses.update({'kd_feat_loss': kd_feat_loss})
        losses.update({'kd_channel_loss': kd_channel_loss})
        losses.update({'kd_spatial_loss': kd_spatial_loss})

        kd_nonlocal_loss = 0
        if t_info is not None:
            t_feats = t_info['feat']
            for _i in range(len(t_feats)):
                s_relation = self.student_non_local[_i](x[_i])
                t_relation = self.teacher_non_local[_i](t_feats[_i])
                #   print(s_relation.size())
                kd_nonlocal_loss += torch.dist(self.non_local_adaptation[_i](s_relation), t_relation, p=2)
        losses.update(kd_nonlocal_loss=kd_nonlocal_loss * 7e-5)

        # RPN forward and loss
        if self.with_rpn:
            proposal_cfg = self.train_cfg.get('rpn_proposal',
                                              self.test_cfg.rpn)

            rpn_losses, proposal_list = self.rpn_head.forward_train(
                x,
                img_metas,
                gt_bboxes,
                gt_labels=None,
                gt_bboxes_ignore=gt_bboxes_ignore,
                proposal_cfg=proposal_cfg,
            )
            losses.update(rpn_losses)
        else:
            proposal_list = proposals

        roi_losses = self.roi_head.forward_train(x, img_metas, proposal_list,
                                                 gt_bboxes, gt_labels,
                                                 gt_bboxes_ignore, gt_masks,
                                                 **kwargs)
        losses.update(roi_losses)
        return losses

```

change the codes in the function \__init\__

```
self.channel_wise_adaptation = nn.ModuleList([
    nn.Linear(256, 256),
    nn.Linear(256, 256),
    nn.Linear(256, 256),
    nn.Linear(256, 256),
    nn.Linear(256, 256)
])

self.spatial_wise_adaptation = nn.ModuleList([
    nn.Conv2d(1, 1, kernel_size=3, stride=1, padding=1),
    nn.Conv2d(1, 1, kernel_size=3, stride=1, padding=1),
    nn.Conv2d(1, 1, kernel_size=3, stride=1, padding=1),
    nn.Conv2d(1, 1, kernel_size=3, stride=1, padding=1),
    nn.Conv2d(1, 1, kernel_size=3, stride=1, padding=1),
])
self.adaptation_layers = nn.ModuleList([
    nn.Conv2d(256, 256, kernel_size=1, stride=1, padding=0),
    nn.Conv2d(256, 256, kernel_size=1, stride=1, padding=0),
    nn.Conv2d(256, 256, kernel_size=1, stride=1, padding=0),
    nn.Conv2d(256, 256, kernel_size=1, stride=1, padding=0),
    nn.Conv2d(256, 256, kernel_size=1, stride=1, padding=0),
])
self.student_non_local = nn.ModuleList(
    [
        NonLocalBlockND(in_channels=256, inter_channels=64, downsample_stride=8),
        NonLocalBlockND(in_channels=256, inter_channels=64, downsample_stride=4),
        NonLocalBlockND(in_channels=256),
        NonLocalBlockND(in_channels=256),
        NonLocalBlockND(in_channels=256)
    ]
)
self.teacher_non_local = nn.ModuleList(
    [
        NonLocalBlockND(in_channels=256, inter_channels=64, downsample_stride=8),
        NonLocalBlockND(in_channels=256, inter_channels=64, downsample_stride=4),
        NonLocalBlockND(in_channels=256),
        NonLocalBlockND(in_channels=256),
        NonLocalBlockND(in_channels=256)
    ]
)
self.non_local_adaptation = nn.ModuleList([
    nn.Conv2d(256, 256, kernel_size=1, stride=1, padding=0),
    nn.Conv2d(256, 256, kernel_size=1, stride=1, padding=0),
    nn.Conv2d(256, 256, kernel_size=1, stride=1, padding=0),
    nn.Conv2d(256, 256, kernel_size=1, stride=1, padding=0),
    nn.Conv2d(256, 256, kernel_size=1, stride=1, padding=0)
])
```

add  the following functions

```
class NonLocalBlockND(nn.Module):
    def __init__(self, in_channels, inter_channels=None, dimension=2, sub_sample=True, bn_layer=True,
                 downsample_stride=2):
        super(NonLocalBlockND, self).__init__()

        assert dimension in [1, 2, 3]

        self.dimension = dimension
        self.sub_sample = sub_sample

        self.in_channels = in_channels
        self.inter_channels = inter_channels

        if self.inter_channels is None:
            self.inter_channels = in_channels // 2
            if self.inter_channels == 0:
                self.inter_channels = 1

        if dimension == 3:
            conv_nd = nn.Conv3d
            max_pool_layer = nn.MaxPool3d(kernel_size=(1, 2, 2))
            bn = nn.BatchNorm3d
        elif dimension == 2:
            conv_nd = nn.Conv2d
            max_pool_layer = nn.MaxPool2d(kernel_size=(downsample_stride, downsample_stride))
            bn = nn.BatchNorm2d
        else:
            conv_nd = nn.Conv1d
            max_pool_layer = nn.MaxPool1d(kernel_size=(2))
            bn = nn.BatchNorm1d

        self.g = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels,
                         kernel_size=1, stride=1, padding=0)

        if bn_layer:
            self.W = nn.Sequential(
                conv_nd(in_channels=self.inter_channels, out_channels=self.in_channels,
                        kernel_size=1, stride=1, padding=0),
                bn(self.in_channels)
            )
            nn.init.constant_(self.W[1].weight, 0)
            nn.init.constant_(self.W[1].bias, 0)
        else:
            self.W = conv_nd(in_channels=self.inter_channels, out_channels=self.in_channels,
                             kernel_size=1, stride=1, padding=0)
            nn.init.constant_(self.W.weight, 0)
            nn.init.constant_(self.W.bias, 0)

        self.theta = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels,
                             kernel_size=1, stride=1, padding=0)

        self.phi = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels,
                           kernel_size=1, stride=1, padding=0)

        if sub_sample:
            self.g = nn.Sequential(self.g, max_pool_layer)
            self.phi = nn.Sequential(self.phi, max_pool_layer)

    def forward(self, x):
        '''
        :param x: (b, c, t, h, w)
        :
        :
        '''

        batch_size = x.size(0)  # 2 , 256 , 300 , 300

        g_x = self.g(x).view(batch_size, self.inter_channels, -1)  # 2 , 128 , 150 x 150
        g_x = g_x.permute(0, 2, 1)  # 2 , 150 x 150, 128

        theta_x = self.theta(x).view(batch_size, self.inter_channels, -1)  # 2 , 128 , 300 x 300
        theta_x = theta_x.permute(0, 2, 1)  # 2 , 300 x 300 , 128
        phi_x = self.phi(x).view(batch_size, self.inter_channels, -1)  # 2 , 128 , 150 x 150
        f = torch.matmul(theta_x, phi_x)  # 2 , 300x300 , 150x150
        N = f.size(-1)  # 150 x 150
        f_div_C = f / N  # 2 , 300x300, 150x150

        y = torch.matmul(f_div_C, g_x)  # 2, 300x300, 128
        y = y.permute(0, 2, 1).contiguous()  # 2, 128, 300x300
        y = y.view(batch_size, self.inter_channels, *x.size()[2:])
        W_y = self.W(y)
        z = W_y + x
        return z

def dist2(tensor_a, tensor_b, attention_mask=None, channel_attention_mask=None):
    diff = (tensor_a - tensor_b) ** 2
    #   print(diff.size())      batchsize x 1 x W x H,
    #   print(attention_mask.size()) batchsize x 1 x W x H
    diff = diff * attention_mask
    diff = diff * channel_attention_mask
    diff = torch.sum(diff) ** 0.5
    return diff

```



## Step-4 Start Knowledge Distillation

Train a student model with our knowledge distillation methods

```
./tools/dist_train.sh configs/faster_rcnn/faster_rcnn_r50_fpn_2x_coco.py 8
# or
./tools/dist_train.sh configs/retinanet/retinanet_r50_fpn_2x_coco.py 8
```

### Others

We use the original configs in mmdetection2 as our configs. There are some log files of our experiments in the log folder. Note that we don't upload all the log files because the supplementary material is required to be <=100MB. We will release all the codes, logs, and pre-trained models in github after ICLR2021 reviewing.

