# ResNet with SASE building block
#  built on https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/resnet.py

"""PyTorch ResNet

This started as a copy of https://github.com/pytorch/vision 'resnet.py' (BSD-3-Clause) with
additional dropout and dynamic global avg/max pool.

ResNeXt, SE-ResNeXt, SENet, and MXNet Gluon stem/downsample variants, tiered stems added by Ross Wightman

Copyright 2019, Ross Wightman
"""
import math
from functools import partial

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.modules.batchnorm import _BatchNorm

from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from timm.models.layers import DropBlock2d, DropPath, AvgPool2dSame, BlurPool2d, GroupNorm, create_attn, get_attn, create_classifier
from timm.models.resnet import ResNet, Bottleneck, _create_resnet, create_aa, _cfg, default_cfgs
from timm.models.registry import register_model

from mmdet.models.builder import BACKBONES
from mmdet.utils import get_root_logger

from mmcv.runner import load_checkpoint, BaseModule
from mmcv.cnn import (ConvModule, build_conv_layer, build_norm_layer,
                      build_activation_layer, build_plugin_layer,
                      constant_init, kaiming_init)

__all__ = ['SaseBottleneck']

default_cfgs.update(
    dict(
        slim_resnet50=_cfg(url='', crop_pct=0.95, interpolation='bicubic'),
    )
)


class SASE(nn.Module):
    def __init__(self, in_planes, out_planes, stride, groups=4, ratio=4):
        super().__init__()
        assert in_planes % groups == 0 and out_planes % groups == 0

        self.in_planes = in_planes
        self.out_planes = out_planes
        self.groups = groups
        self.stride = stride

        in_channels = in_planes // groups
        out_channels = out_planes // groups

        se_channels = max(4, out_channels // ratio)

        self.query = nn.ModuleList([])
        self.key = nn.ModuleList([])
        self.value = nn.ModuleList([])
        for _ in range(groups):
            self.query.append(
                nn.Sequential(
                    nn.Conv2d(
                        in_channels, se_channels, 3, 1, 1, bias=False),
                    nn.BatchNorm2d(se_channels), nn.ReLU(inplace=True),
                    nn.Conv2d(se_channels, out_channels, 3, 1, 1)))

            self.key.append(
                nn.Sequential(
                    nn.AdaptiveAvgPool2d((1, 1)),
                    nn.Conv2d(in_channels, se_channels, 1, bias=False),
                    nn.BatchNorm2d(se_channels), nn.ReLU(inplace=True),
                    nn.Conv2d(se_channels, out_channels, 1)))

            self.value.append(nn.Conv2d(in_channels, out_channels, 3, 1, 1))

        self.downsample = nn.AvgPool2d(
            (self.stride, self.stride)) if stride > 1 else nn.Identity()

    def forward(self, x):
        x = torch.split(x, self.in_planes // self.groups, dim=1)

        out = []
        for xs, q, k, v in zip(x, self.query, self.key, self.value):
            xq = q(xs)
            xk = k(xs)
            xv = v(xs)
            a = xq * xk
            a = a.softmax(dim=1)
            out.append(a * xv)

        out = torch.cat(out, dim=1)

        out = self.downsample(out)

        return out


class SaseBottleneck(Bottleneck):
    expansion = 4

    def __init__(self, inplanes, planes, stride=1, downsample=None,
                 slim_args=dict(type='SLIM', groups=4, ratio=4), **kwargs):
        super().__init__(inplanes, planes, stride, downsample, **kwargs)

        width = int(math.floor(planes * (kwargs['base_width'] / 64)) * kwargs['cardinality'])
        first_planes = width // kwargs['reduce_first']
        first_dilation = kwargs['first_dilation'] or kwargs['dilation']
        use_aa = kwargs['aa_layer'] is not None and (
            stride == 2 or first_dilation != kwargs['dilation'])

        ## replace the conv2 by slim
        # self.conv2 = nn.Conv2d(
        #     first_planes, width, kernel_size=3, stride=1 if use_aa else stride,
        #     padding=first_dilation, dilation=first_dilation, groups=cardinality, bias=False)
        slim_args_ = slim_args.copy()
        slim_fn = eval(slim_args_.pop('type', None))
        self.conv2 = slim_fn(first_planes, width, stride=1 if use_aa else stride,
                             **slim_args_)


@register_model
def sase_resnet50(pretrained=False, **kwargs):
    model_args = dict(
        block=SaseBottleneck, layers=[3, 4, 6, 3], stem_width=32, stem_type='deep', avg_down=True, **kwargs)
    return _create_resnet('sase_resnet50', pretrained, **model_args)


if 'timm_resnet_det' not in BACKBONES:
    @BACKBONES.register_module()
    class timm_resnet_det(ResNet):
        def __init__(self, **kwargs):
            model_args = dict(
                block=eval(kwargs['block_fn']),
                layers=kwargs['layers'],
                stem_width=kwargs['stem_width'],
                stem_type=kwargs['stem_type'],
                avg_down=kwargs['avg_down'])

            super().__init__(**model_args)

            self.norm_eval = kwargs.pop('norm_eval', True)
            self.frozen_stages = kwargs.pop('frozen_stages', 1)
            self.out_indices = kwargs.pop('out_indices', (0, 1, 2, 3))
            pretrained = kwargs.pop('pretrained', None)

            del self.num_classes
            del self.global_pool
            del self.fc

            self.load_pretrained(pretrained)

            self._freeze_stages()
            self.feat_dim = self.num_features

        def load_pretrained(self, pretrained=None):
            if isinstance(pretrained, str):
                logger = get_root_logger()
                load_checkpoint(
                    self,
                    pretrained,
                    map_location='cpu',
                    strict=False,
                    logger=logger)

        def _freeze_stages(self):
            if self.frozen_stages >= 0:
                self.conv1.eval()
                self.bn1.eval()
                for m in [self.conv1, self.bn1]:
                    for param in m.parameters():
                        param.requires_grad = False

            for i in range(1, self.frozen_stages + 1):
                m = getattr(self, f'layer{i}')
                m.eval()
                for param in m.parameters():
                    param.requires_grad = False

        def train(self, mode=True):
            """Convert the model into training mode while keep normalization layer
            freezed."""
            super().train(mode)
            self._freeze_stages()
            if mode and self.norm_eval:
                for m in self.modules():
                    # trick: eval have effect on BatchNorm only
                    if isinstance(m, _BatchNorm):
                        m.eval()

        def forward(self, x):
            x = self.conv1(x)
            x = self.bn1(x)
            x = self.act1(x)
            x = self.maxpool(x)

            outs = []
            i = 0
            x = self.layer1(x)
            if i in self.out_indices:
                outs.append(x)
            i += 1

            x = self.layer2(x)
            if i in self.out_indices:
                outs.append(x)
            i += 1

            x = self.layer3(x)
            if i in self.out_indices:
                outs.append(x)
            i += 1

            x = self.layer4(x)
            if i in self.out_indices:
                outs.append(x)
            i += 1

            return tuple(outs)


# model check

def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad) / 1e6


if __name__ == '__main__':
    img = torch.randn(2, 3, 224, 224)

    models = ['sase_resnet50']

    for model_name in models:
        model = eval(model_name)(num_classes=1000)
        out = model(img)
        print(model)
        print(f'{model_name}:',
              out.shape, count_parameters(model))
