# Copyright (c) OpenMMLab. All rights reserved.
from .norm import build_norm_layer

try:
    from mmdet.models.backbones import ResNet
    from mmdet.models.roi_heads.shared_heads.res_layer import ResLayer
    from mmdet.registry import MODELS

    @MODELS.register_module()
    class ResLayerExtraNorm(ResLayer):
        """Add extra norm to original ``ResLayer``."""

        def __init__(self, *args, **kwargs):
            super(ResLayerExtraNorm, self).__init__(*args, **kwargs)

            block = ResNet.arch_settings[kwargs['depth']][0]
            self.add_module(
                'norm',
                build_norm_layer(self.norm_cfg,
                                 64 * 2**self.stage * block.expansion))

        def forward(self, x):
            """Forward function."""
            res_layer = getattr(self, f'layer{self.stage + 1}')
            norm = getattr(self, 'norm')
            x = res_layer(x)
            out = norm(x)
            return out

except ImportError:
    ResLayerExtraNorm = None
