"""
FBNet model builder
"""

from __future__ import absolute_import, division, print_function, unicode_literals

import copy
import logging
import math
from collections import OrderedDict

import torch
import torch.nn as nn
from torch.nn import BatchNorm2d, SyncBatchNorm
from maskrcnn_benchmark.layers import Conv2d, interpolate
from maskrcnn_benchmark.layers import NaiveSyncBatchNorm2d, FrozenBatchNorm2d
from maskrcnn_benchmark.layers.misc import _NewEmptyTensorOp


logger = logging.getLogger(__name__)


def _py2_round(x):
    return math.floor(x + 0.5) if x >= 0.0 else math.ceil(x - 0.5)


def _get_divisible_by(num, divisible_by, min_val):
    ret = int(num)
    if divisible_by > 0 and num % divisible_by != 0:
        ret = int((_py2_round(num / divisible_by) or min_val) * divisible_by)
    return ret


class Identity(nn.Module):
    def __init__(self, C_in, C_out, stride):
        super(Identity, self).__init__()
        self.conv = (
            ConvBNRelu(
                C_in,
                C_out,
                kernel=1,
                stride=stride,
                pad=0,
                no_bias=1,
                use_relu="relu",
                bn_type="bn",
            )
            if C_in != C_out or stride != 1
            else None
        )

    def forward(self, x):
        if self.conv:
            out = self.conv(x)
        else:
            out = x
        return out


class CascadeConv3x3(nn.Sequential):
    def __init__(self, C_in, C_out, stride):
        assert stride in [1, 2]
        ops = [
            Conv2d(C_in, C_in, 3, stride, 1, bias=False),
            BatchNorm2d(C_in),
            nn.ReLU(inplace=True),
            Conv2d(C_in, C_out, 3, 1, 1, bias=False),
            BatchNorm2d(C_out),
        ]
        super(CascadeConv3x3, self).__init__(*ops)
        self.res_connect = (stride == 1) and (C_in == C_out)

    def forward(self, x):
        y = super(CascadeConv3x3, self).forward(x)
        if self.res_connect:
            y += x
        return y


class Shift(nn.Module):
    def __init__(self, C, kernel_size, stride, padding):
        super(Shift, self).__init__()
        self.C = C
        kernel = torch.zeros((C, 1, kernel_size, kernel_size), dtype=torch.float32)
        ch_idx = 0

        assert stride in [1, 2]
        self.stride = stride
        self.padding = padding
        self.kernel_size = kernel_size
        self.dilation = 1

        hks = kernel_size // 2
        ksq = kernel_size ** 2

        for i in range(kernel_size):
            for j in range(kernel_size):
                if i == hks and j == hks:
                    num_ch = C // ksq + C % ksq
                else:
                    num_ch = C // ksq
                kernel[ch_idx : ch_idx + num_ch, 0, i, j] = 1
                ch_idx += num_ch

        self.register_parameter("bias", None)
        self.kernel = nn.Parameter(kernel, requires_grad=False)

    def forward(self, x):
        if x.numel() > 0:
            return nn.functional.conv2d(
                x,
                self.kernel,
                self.bias,
                (self.stride, self.stride),
                (self.padding, self.padding),
                self.dilation,
                self.C,  # groups
            )

        output_shape = [
            (i + 2 * p - (di * (k - 1) + 1)) // d + 1
            for i, p, di, k, d in zip(
                x.shape[-2:],
                (self.padding, self.dilation),
                (self.dilation, self.dilation),
                (self.kernel_size, self.kernel_size),
                (self.stride, self.stride),
            )
        ]
        output_shape = [x.shape[0], self.C] + output_shape
        return _NewEmptyTensorOp.apply(x, output_shape)


class ShiftBlock5x5(nn.Sequential):
    def __init__(self, C_in, C_out, expansion, stride):
        assert stride in [1, 2]
        self.res_connect = (stride == 1) and (C_in == C_out)

        C_mid = _get_divisible_by(C_in * expansion, 8, 8)

        ops = [
            # pw
            Conv2d(C_in, C_mid, 1, 1, 0, bias=False),
            BatchNorm2d(C_mid),
            nn.ReLU(inplace=True),
            # shift
            Shift(C_mid, 5, stride, 2),
            # pw-linear
            Conv2d(C_mid, C_out, 1, 1, 0, bias=False),
            BatchNorm2d(C_out),
        ]
        super(ShiftBlock5x5, self).__init__(*ops)

    def forward(self, x):
        y = super(ShiftBlock5x5, self).forward(x)
        if self.res_connect:
            y += x
        return y


class ChannelShuffle(nn.Module):
    def __init__(self, groups):
        super(ChannelShuffle, self).__init__()
        self.groups = groups

    def forward(self, x):
        """Channel shuffle: [N,C,H,W] -> [N,g,C/g,H,W] -> [N,C/g,g,H,w] -> [N,C,H,W]"""
        N, C, H, W = x.size()
        g = self.groups
        assert C % g == 0, "Incompatible group size {} for input channel {}".format(
            g, C
        )
        return (
            x.view(N, g, int(C / g), H, W)
            .permute(0, 2, 1, 3, 4)
            .contiguous()
            .view(N, C, H, W)
        )


class ConvBNRelu(nn.Sequential):
    def __init__(
        self,
        input_depth,
        output_depth,
        kernel,
        stride,
        pad,
        no_bias,
        use_relu,
        bn_type,
        group=1,
        *args,
        **kwargs
    ):
        super(ConvBNRelu, self).__init__()

        assert use_relu in ["relu", None]
        if isinstance(bn_type, (list, tuple)):
            assert len(bn_type) == 2
            assert bn_type[0] == "gn"
            gn_group = bn_type[1]
            bn_type = bn_type[0]
        assert bn_type in ["bn", "nsbn", "sbn", "af", "gn", None]
        assert stride in [1, 2, 4]

        op = Conv2d(
            input_depth,
            output_depth,
            kernel_size=kernel,
            stride=stride,
            padding=pad,
            bias=not no_bias,
            groups=group,
            *args,
            **kwargs
        )
        nn.init.kaiming_normal_(op.weight, mode="fan_out", nonlinearity="relu")
        if op.bias is not None:
            nn.init.constant_(op.bias, 0.0)
        self.add_module("conv", op)

        if bn_type == "bn":
            bn_op = BatchNorm2d(output_depth)
        elif bn_type == "sbn":
            bn_op = SyncBatchNorm(output_depth)
        elif bn_type == "nsbn":
            bn_op = NaiveSyncBatchNorm2d(output_depth)
        elif bn_type == "gn":
            bn_op = nn.GroupNorm(num_groups=gn_group, num_channels=output_depth)
        elif bn_type == "af":
            bn_op = FrozenBatchNorm2d(output_depth)
        if bn_type is not None:
            self.add_module("bn", bn_op)

        if use_relu == "relu":
            self.add_module("relu", nn.ReLU(inplace=True))


class SEModule(nn.Module):
    reduction = 4

    def __init__(self, C):
        super(SEModule, self).__init__()
        mid = max(C // self.reduction, 8)
        conv1 = Conv2d(C, mid, 1, 1, 0)
        conv2 = Conv2d(mid, C, 1, 1, 0)

        self.op = nn.Sequential(
            nn.AdaptiveAvgPool2d(1), conv1, nn.ReLU(inplace=True), conv2, nn.Sigmoid()
        )

    def forward(self, x):
        return x * self.op(x)


class Upsample(nn.Module):
    def __init__(self, scale_factor, mode, align_corners=None):
        super(Upsample, self).__init__()
        self.scale = scale_factor
        self.mode = mode
        self.align_corners = align_corners

    def forward(self, x):
        return interpolate(
            x, scale_factor=self.scale, mode=self.mode,
            align_corners=self.align_corners
        )


def _get_upsample_op(stride):
    assert (
        stride in [1, 2, 4]
        or stride in [-1, -2, -4]
        or (isinstance(stride, tuple) and all(x in [-1, -2, -4] for x in stride))
    )

    scales = stride
    ret = None
    if isinstance(stride, tuple) or stride < 0:
        scales = [-x for x in stride] if isinstance(stride, tuple) else -stride
        stride = 1
        ret = Upsample(scale_factor=scales, mode="nearest", align_corners=None)

    return ret, stride


class IRFBlock(nn.Module):
    def __init__(
        self,
        input_depth,
        output_depth,
        expansion,
        stride,
        bn_type="bn",
        kernel=3,
        width_divisor=1,
        shuffle_type=None,
        pw_group=1,
        se=False,
        cdw=False,
        dw_skip_bn=False,
        dw_skip_relu=False,
    ):
        super(IRFBlock, self).__init__()

        assert kernel in [1, 3, 5, 7], kernel

        self.use_res_connect = stride == 1 and input_depth == output_depth
        self.output_depth = output_depth

        mid_depth = int(input_depth * expansion)
        mid_depth = _get_divisible_by(mid_depth, width_divisor, width_divisor)

        # pw
        self.pw = ConvBNRelu(
            input_depth,
            mid_depth,
            kernel=1,
            stride=1,
            pad=0,
            no_bias=1,
            use_relu="relu",
            bn_type=bn_type,
            group=pw_group,
        )

        # negative stride to do upsampling
        self.upscale, stride = _get_upsample_op(stride)

        # dw
        if kernel == 1:
            self.dw = nn.Sequential()
        elif cdw:
            dw1 = ConvBNRelu(
                mid_depth,
                mid_depth,
                kernel=kernel,
                stride=stride,
                pad=(kernel // 2),
                group=mid_depth,
                no_bias=1,
                use_relu="relu",
                bn_type=bn_type,
            )
            dw2 = ConvBNRelu(
                mid_depth,
                mid_depth,
                kernel=kernel,
                stride=1,
                pad=(kernel // 2),
                group=mid_depth,
                no_bias=1,
                use_relu="relu" if not dw_skip_relu else None,
                bn_type=bn_type if not dw_skip_bn else None,
            )
            self.dw = nn.Sequential(OrderedDict([("dw1", dw1), ("dw2", dw2)]))
        else:
            self.dw = ConvBNRelu(
                mid_depth,
                mid_depth,
                kernel=kernel,
                stride=stride,
                pad=(kernel // 2),
                group=mid_depth,
                no_bias=1,
                use_relu="relu" if not dw_skip_relu else None,
                bn_type=bn_type if not dw_skip_bn else None,
            )

        # pw-linear
        self.pwl = ConvBNRelu(
            mid_depth,
            output_depth,
            kernel=1,
            stride=1,
            pad=0,
            no_bias=1,
            use_relu=None,
            bn_type=bn_type,
            group=pw_group,
        )

        self.shuffle_type = shuffle_type
        if shuffle_type is not None:
            self.shuffle = ChannelShuffle(pw_group)

        self.se4 = SEModule(output_depth) if se else nn.Sequential()

        self.output_depth = output_depth

    def forward(self, x):
        y = self.pw(x)
        if self.shuffle_type == "mid":
            y = self.shuffle(y)
        if self.upscale is not None:
            y = self.upscale(y)
        y = self.dw(y)
        y = self.pwl(y)
        if self.use_res_connect:
            y += x
        y = self.se4(y)
        return y



skip = lambda C_in, C_out, stride, **kwargs: Identity(
    C_in, C_out, stride
)
basic_block = lambda C_in, C_out, stride, **kwargs: CascadeConv3x3(
    C_in, C_out, stride
)
# layer search 2
ir_k3_e1 = lambda C_in, C_out, stride, **kwargs: IRFBlock(
    C_in, C_out, 1, stride, kernel=3, **kwargs
)
ir_k3_e3 = lambda C_in, C_out, stride, **kwargs: IRFBlock(
    C_in, C_out, 3, stride, kernel=3, **kwargs
)
ir_k3_e6 = lambda C_in, C_out, stride, **kwargs: IRFBlock(
    C_in, C_out, 6, stride, kernel=3, **kwargs
)
ir_k3_s4 = lambda C_in, C_out, stride, **kwargs: IRFBlock(
    C_in, C_out, 4, stride, kernel=3, shuffle_type="mid", pw_group=4, **kwargs
)
ir_k5_e1 = lambda C_in, C_out, stride, **kwargs: IRFBlock(
    C_in, C_out, 1, stride, kernel=5, **kwargs
)
ir_k5_e3 = lambda C_in, C_out, stride, **kwargs: IRFBlock(
    C_in, C_out, 3, stride, kernel=5, **kwargs
)
ir_k5_e6 = lambda C_in, C_out, stride, **kwargs: IRFBlock(
    C_in, C_out, 6, stride, kernel=5, **kwargs
)
ir_k5_s4 = lambda C_in, C_out, stride, **kwargs: IRFBlock(
    C_in, C_out, 4, stride, kernel=5, shuffle_type="mid", pw_group=4, **kwargs
)
# layer search se
ir_k3_e1_se = lambda C_in, C_out, stride, **kwargs: IRFBlock(
    C_in, C_out, 1, stride, kernel=3, se=True, **kwargs
)
ir_k3_e3_se = lambda C_in, C_out, stride, **kwargs: IRFBlock(
    C_in, C_out, 3, stride, kernel=3, se=True, **kwargs
)
ir_k3_e6_se = lambda C_in, C_out, stride, **kwargs: IRFBlock(
    C_in, C_out, 6, stride, kernel=3, se=True, **kwargs
)
ir_k3_s4_se = lambda C_in, C_out, stride, **kwargs: IRFBlock(
    C_in,
    C_out,
    4,
    stride,
    kernel=3,
    shuffle_type=mid,
    pw_group=4,
    se=True,
    **kwargs
)
ir_k5_e1_se = lambda C_in, C_out, stride, **kwargs: IRFBlock(
    C_in, C_out, 1, stride, kernel=5, se=True, **kwargs
)
ir_k5_e3_se = lambda C_in, C_out, stride, **kwargs: IRFBlock(
    C_in, C_out, 3, stride, kernel=5, se=True, **kwargs
)
ir_k5_e6_se = lambda C_in, C_out, stride, **kwargs: IRFBlock(
    C_in, C_out, 6, stride, kernel=5, se=True, **kwargs
)
ir_k5_s4_se = lambda C_in, C_out, stride, **kwargs: IRFBlock(
    C_in,
    C_out,
    4,
    stride,
    kernel=5,
    shuffle_type="mid",
    pw_group=4,
    se=True,
    **kwargs
)
# layer search 3 (in addition to layer search 2)
ir_k3_s2 = lambda C_in, C_out, stride, **kwargs: IRFBlock(
    C_in, C_out, 1, stride, kernel=3, shuffle_type="mid", pw_group=2, **kwargs
)
ir_k5_s2 = lambda C_in, C_out, stride, **kwargs: IRFBlock(
    C_in, C_out, 1, stride, kernel=5, shuffle_type="mid", pw_group=2, **kwargs
)
ir_k3_s2_se = lambda C_in, C_out, stride, **kwargs: IRFBlock(
    C_in,
    C_out,
    1,
    stride,
    kernel=3,
    shuffle_type="mid",
    pw_group=2,
    se=True,
    **kwargs
)
ir_k5_s2_se = lambda C_in, C_out, stride, **kwargs: IRFBlock(
    C_in,
    C_out,
    1,
    stride,
    kernel=5,
    shuffle_type="mid",
    pw_group=2,
    se=True,
    **kwargs
)
# layer search 4 (in addition to layer search 3)
ir_k33_e1 = lambda C_in, C_out, stride, **kwargs: IRFBlock(
    C_in, C_out, 1, stride, kernel=3, cdw=True, **kwargs
)
ir_k33_e3 = lambda C_in, C_out, stride, **kwargs: IRFBlock(
    C_in, C_out, 3, stride, kernel=3, cdw=True, **kwargs
)
ir_k33_e6 = lambda C_in, C_out, stride, **kwargs: IRFBlock(
    C_in, C_out, 6, stride, kernel=3, cdw=True, **kwargs
)
# layer search 5 (in addition to layer search 4)
ir_k7_e1 = lambda C_in, C_out, stride, **kwargs: IRFBlock(
    C_in, C_out, 1, stride, kernel=7, **kwargs
)
ir_k7_e3 = lambda C_in, C_out, stride, **kwargs: IRFBlock(
    C_in, C_out, 3, stride, kernel=7, **kwargs
)
ir_k7_e6 = lambda C_in, C_out, stride, **kwargs: IRFBlock(
    C_in, C_out, 6, stride, kernel=7, **kwargs
)
ir_k7_sep_e1 = lambda C_in, C_out, stride, **kwargs: IRFBlock(
    C_in, C_out, 1, stride, kernel=7, cdw=True, **kwargs
)
ir_k7_sep_e3 = lambda C_in, C_out, stride, **kwargs: IRFBlock(
    C_in, C_out, 3, stride, kernel=7, cdw=True, **kwargs
)
ir_k7_sep_e6 = lambda C_in, C_out, stride, **kwargs: IRFBlock(
    C_in, C_out, 6, stride, kernel=7, cdw=True, **kwargs
)