import math
import torch
import torch.nn as nn
import torch.nn.functional as F


class Conv2dSame(nn.Conv2d):
    """ Tensorflow like 'SAME' convolution wrapper for 2D convolutions
	"""

    def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True):
        super(Conv2dSame, self).__init__(in_channels, out_channels, kernel_size, stride, 0, dilation, groups, bias)

    def forward(self, x):
        ih, iw = x.size()[-2:]
        kh, kw = self.weight.size()[-2:]
        oh = math.ceil(ih / self.stride[0])
        ow = math.ceil(iw / self.stride[1])
        pad_h = max((oh - 1) * self.stride[0] + (kh - 1) * self.dilation[0] + 1 - ih, 0)
        pad_w = max((ow - 1) * self.stride[1] + (kw - 1) * self.dilation[1] + 1 - iw, 0)
        if pad_h > 0 or pad_w > 0:
            x = F.pad(x, [pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2])
        x = F.conv2d(x, self.weight, self.bias, self.stride, 0, self.dilation, self.groups)
        return x


def stem(inp, oup):
    return nn.Sequential(
        Conv2dSame(inp, oup, 3, 2, bias=False),
        # nn.Conv2d(inp, oup, 3, 2, padding=1, bias=False),
        nn.BatchNorm2d(oup),
        nn.ReLU(inplace=True)
    )


def separable_conv(inp, oup):
    return nn.Sequential(
        # Conv2dSame(inp, inp, 3, 1, groups=inp, bias=False),
        nn.Conv2d(inp, inp, 3, 1, padding=1, groups=inp, bias=False),
        nn.BatchNorm2d(inp),
        nn.ReLU(inplace=True),
        nn.Conv2d(inp, oup, 1, stride=1, padding=0, bias=False),
        nn.BatchNorm2d(oup),
    )


def conv_before_pooling(inp, oup):
    return nn.Sequential(
        nn.Conv2d(inp, oup, 1, stride=1, padding=0, bias=False),
        nn.BatchNorm2d(oup),
        nn.ReLU(inplace=True)
    )


def swish(x, inplace: bool = False):
    return x.mul_(x.sigmoid()) if inplace else x.mul(x.sigmoid())


class Swish(nn.Module):
    def __init__(self, inplace: bool = False):
        super(Swish, self).__init__()
        self.inplace = inplace

    def forward(self, x):
        return swish(x, self.inplace)


class HSigmoid(nn.Module):
    def __init__(self, inplace=True):
        super(HSigmoid, self).__init__()
        self.inplace = inplace

    def forward(self, x):
        out = F.relu6(x + 3, inplace=self.inplace) / 6
        return out


def drop_connect(inputs, training=False, drop_connect_rate=0.):
    """Apply drop connect."""
    if not training:
        return inputs

    keep_prob = 1 - drop_connect_rate
    random_tensor = keep_prob + torch.rand(
        (inputs.size()[0], 1, 1, 1), dtype=inputs.dtype, device=inputs.device)
    random_tensor.floor_()  # binarize
    output = inputs.div(keep_prob) * random_tensor
    return output


class SqueezeExcite(nn.Module):
    def __init__(self, in_channel,
                 reduction=4,
                 squeeze_act=nn.ReLU(inplace=True),
                 excite_act=HSigmoid(inplace=True)):
        super(SqueezeExcite, self).__init__()
        self.global_pooling = nn.AdaptiveAvgPool2d(1)
        self.squeeze_conv = nn.Conv2d(in_channels=in_channel,
                                      out_channels=in_channel // reduction,
                                      kernel_size=1,
                                      bias=True)
        self.squeeze_act = squeeze_act
        self.excite_conv = nn.Conv2d(in_channels=in_channel // reduction,
                                     out_channels=in_channel,
                                     kernel_size=1,
                                     bias=True)
        self.excite_act = excite_act

    def forward(self, inputs):
        feature_pooling = self.global_pooling(inputs)
        feature_squeeze_conv = self.squeeze_conv(feature_pooling)
        feature_squeeze_act = self.squeeze_act(feature_squeeze_conv)
        feature_excite_conv = self.excite_conv(feature_squeeze_act)
        feature_excite_act = self.excite_act(feature_excite_conv)
        return inputs * feature_excite_act


class MultiPathMB_A(nn.Module):
    def __init__(self, inp, oup, kernel_size_list, pw_group_list, act, stride, expand_ratio, se_reduction,
                 drop_connect_rate):
        super(MultiPathMB_A, self).__init__()
        assert stride in [1, 2]
        assert len(kernel_size_list) == 4
        assert len(pw_group_list) == 2
        self.stride = stride
        self.se_reduction = se_reduction
        hidden_dim = round(inp * expand_ratio)
        multi_hidden_dim = hidden_dim // 4
        self.multi_hidden_dim = multi_hidden_dim
        self.use_res_connect = self.stride == 1 and inp == oup
        self.drop_connect_rate = drop_connect_rate
        if act == "re":
            self.act = nn.ReLU(inplace=True)
        elif act == "sw":
            self.act = Swish()
        # pw-up
        self.conv1 = nn.Conv2d(inp, hidden_dim, 1, stride=1, padding=0, groups=pw_group_list[0], bias=False)
        self.bn1 = nn.BatchNorm2d(hidden_dim)
        # multi-path
        self.conv2_0 = Conv2dSame(multi_hidden_dim, multi_hidden_dim, kernel_size_list[0], stride,
                                  groups=multi_hidden_dim, bias=False)
        self.conv2_1 = Conv2dSame(multi_hidden_dim, multi_hidden_dim, kernel_size_list[1], stride,
                                  groups=multi_hidden_dim, bias=False)
        self.conv2_2 = Conv2dSame(multi_hidden_dim, multi_hidden_dim, kernel_size_list[2], stride,
                                  groups=multi_hidden_dim, bias=False)
        self.conv2_3 = Conv2dSame(multi_hidden_dim, multi_hidden_dim, kernel_size_list[3], stride,
                                  groups=multi_hidden_dim, bias=False)
        self.bn2_0 = nn.BatchNorm2d(multi_hidden_dim)
        self.bn2_1 = nn.BatchNorm2d(multi_hidden_dim)
        self.bn2_2 = nn.BatchNorm2d(multi_hidden_dim)
        self.bn2_3 = nn.BatchNorm2d(multi_hidden_dim)
        # se
        if self.se_reduction > 0:
            self.se = SqueezeExcite(hidden_dim, reduction=se_reduction, squeeze_act=self.act)
        # pw-down
        self.conv3 = nn.Conv2d(hidden_dim, oup, 1, stride=1, padding=0, groups=pw_group_list[1], bias=False)
        self.bn3 = nn.BatchNorm2d(oup)

    def forward(self, x):
        inputs = x
        # pw-up
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.act(x)
        # multi-path
        multi_features = torch.split(x, split_size_or_sections=self.multi_hidden_dim, dim=1)
        x_0 = self.conv2_0(multi_features[0])
        x_0 = self.bn2_0(x_0)
        x_1 = self.conv2_1(multi_features[1])
        x_1 = self.bn2_1(x_1)
        x_2 = self.conv2_2(multi_features[2])
        x_2 = self.bn2_2(x_2)
        x_3 = self.conv2_3(multi_features[3])
        x_3 = self.bn2_3(x_3)
        x = torch.cat([x_0, x_1, x_2, x_3], dim=1)
        x = self.act(x)
        if self.se_reduction > 0:
            x = self.se(x)
        x = self.conv3(x)
        x = self.bn3(x)
        if self.use_res_connect:
            #
            if self.drop_connect_rate > 0.:
                x = drop_connect(x, self.training, self.drop_connect_rate)
            return inputs + x
        else:
            return x


# MixPath_A
class MixPath_A(nn.Module):
    def __init__(self, n_class=1000, input_size=224, drop_connect_rate=0.2):
        super(MixPath_A, self).__init__()
        assert input_size % 32 == 0
        mb_config = [
            # expansion, out_channel, kernel_size_list, pw_group_list, act, stride, reduction
            [6, 32, [5, 5, 7, 5], [2, 2], "re", 2, 0],
            [3, 32, [5, 5, 5, 3], [2, 2], "re", 1, 0],

            [6, 40, [5, 5, 7, 5], [1, 1], "sw", 2, 12],
            [6, 40, [3, 7, 3, 3], [2, 2], "sw", 1, 12],
            [6, 40, [5, 5, 3, 5], [2, 2], "sw", 1, 12],
            [6, 40, [5, 7, 5, 3], [2, 2], "sw", 1, 12],

            [6, 80, [7, 5, 7, 5], [1, 1], "sw", 2, 24],
            [6, 80, [3, 7, 3, 5], [2, 2], "sw", 1, 24],
            [6, 80, [5, 7, 7, 7], [2, 2], "sw", 1, 24],
            [6, 80, [5, 7, 3, 3], [2, 2], "sw", 1, 24],
            [6, 120, [7, 7, 5, 5], [1, 1], "sw", 1, 12],
            [3, 120, [9, 5, 3, 9], [2, 2], "sw", 1, 6],
            [3, 120, [7, 7, 7, 5], [2, 2], "sw", 1, 6],
            [3, 120, [5, 7, 7, 7], [2, 2], "sw", 1, 6],

            [6, 200, [7, 7, 9, 5], [1, 1], "sw", 2, 12],
            [6, 200, [7, 5, 9, 7], [1, 2], "sw", 1, 12],
            [6, 200, [9, 7, 5, 9], [1, 2], "sw", 1, 12],
            [6, 200, [3, 5, 7, 3], [1, 2], "sw", 1, 12],
        ]
        input_channel = 24
        last_channel = 1536

        self.last_channel = last_channel
        self.stem = stem(3, input_channel)
        self.separable_conv = separable_conv(input_channel, input_channel)
        self.mb_module = list()
        for e, c, k, p, a, s, r in mb_config:
            output_channel = c
            self.mb_module.append(MultiPathMB_A(input_channel, output_channel,
                                                kernel_size_list=k,
                                                pw_group_list=p,
                                                act=a,
                                                stride=s,
                                                expand_ratio=e,
                                                se_reduction=r,
                                                drop_connect_rate=drop_connect_rate))
            input_channel = output_channel
        self.mb_module = nn.Sequential(*self.mb_module)
        self.conv_before_pooling = conv_before_pooling(input_channel, self.last_channel)
        self.classifier = nn.Sequential(
            nn.Linear(self.last_channel, n_class),
        )
        self._initialize_weights()

    def forward(self, x):
        x = self.stem(x)
        x = self.separable_conv(x)
        x = self.mb_module(x)
        x = self.conv_before_pooling(x)
        x = x.mean(3).mean(2)
        x = self.classifier(x)
        return x

    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
                m.weight.data.normal_(0, math.sqrt(2. / n))
                if m.bias is not None:
                    m.bias.data.zero_()
            elif isinstance(m, nn.BatchNorm2d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()
            elif isinstance(m, nn.Linear):
                n = m.weight.size(0)  # fan-out
                init_range = 1.0 / math.sqrt(n)
                m.weight.data.uniform_(-init_range, init_range)
                m.bias.data.zero_()


class MultiPathMB_B(nn.Module):
    def __init__(self, inp, oup, kernel_size_list, pw_group_list, act, stride, expand_ratio, se_reduction,
                 drop_connect_rate):
        super(MultiPathMB_B, self).__init__()
        assert stride in [1, 2]
        # assert len(kernel_size_list) == 4
        assert len(pw_group_list) == 2
        self.stride = stride
        self.se_reduction = se_reduction
        hidden_dim = round(inp * expand_ratio)
        self.use_res_connect = self.stride == 1 and inp == oup
        self.drop_connect_rate = drop_connect_rate
        if act == "re":
            self.act = nn.ReLU(inplace=True)
        elif act == "sw":
            self.act = Swish(inplace=True)
        # pw-up
        self.conv1 = nn.Conv2d(inp, hidden_dim, 1, stride=1, padding=0, groups=pw_group_list[0], bias=False)
        self.bn1 = nn.BatchNorm2d(hidden_dim)
        # multi-path
        self.multi_path = False
        if len(kernel_size_list) == 1:
            self.multi_path = False
            self.conv2 = Conv2dSame(hidden_dim, hidden_dim, kernel_size_list[0], stride,
                                    groups=hidden_dim, bias=False)
            self.bn2 = nn.BatchNorm2d(hidden_dim)
        elif len(kernel_size_list) == 2:
            self.multi_path = True
            self.conv2_0 = Conv2dSame(hidden_dim, hidden_dim, kernel_size_list[0], stride,
                                      groups=hidden_dim, bias=False)
            self.conv2_1 = Conv2dSame(hidden_dim, hidden_dim, kernel_size_list[1], stride,
                                      groups=hidden_dim, bias=False)
            self.bn2_0 = nn.BatchNorm2d(hidden_dim)
            self.bn2_1 = nn.BatchNorm2d(hidden_dim)
        else:
            raise Exception("MiaoMiaoMiao?")
        # se
        if self.se_reduction > 0:
            self.se = SqueezeExcite(hidden_dim, reduction=se_reduction, squeeze_act=self.act)
        # pw-down
        self.conv3 = nn.Conv2d(hidden_dim, oup, 1, stride=1, padding=0, groups=pw_group_list[1], bias=False)
        self.bn3 = nn.BatchNorm2d(oup)

    def forward(self, x):
        inputs = x
        # pw-up
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.act(x)
        # multi-path
        if self.multi_path is False:
            x_0 = self.conv2(x)
            x_0 = self.bn2(x_0)
            x = self.act(x_0)
        elif self.multi_path is True:
            x_0 = self.conv2_0(x)
            x_0 = self.bn2_0(x_0)
            # x_0 = self.act(x_0)
            x_1 = self.conv2_1(x)
            x_1 = self.bn2_1(x_1)
            # x_1 = self.act(x_1)
            x = x_0 + x_1
            x = self.act(x)
        if self.se_reduction > 0:
            x = self.se(x)
        x = self.conv3(x)
        x = self.bn3(x)
        # if self.use_res_connect:
        # 	return inputs + x
        # else:
        # 	return x
        if self.use_res_connect:
            #
            if self.drop_connect_rate > 0.:
                x = drop_connect(x, self.training, self.drop_connect_rate)
            return inputs + x
        else:
            return x


# MixPath_B
class MixPath_B(nn.Module):
    def __init__(self, n_class=1000, input_size=224, drop_connect_rate=0.2):
        super(MixPath_B, self).__init__()
        assert input_size % 32 == 0
        mb_config = [
            # expansion, out_channel, kernel_size_list, pw_group_list, act, stride, reduction
            [6, 32, [5], [2, 1], "sw", 2, 12],
            [3, 32, [5, 3], [2, 1], "sw", 1, 6],
            [6, 40, [9], [2, 1], "sw", 2, 12],
            [6, 40, [3, 5], [2, 1], "sw", 1, 12],
            [6, 40, [5], [2, 1], "sw", 1, 12],
            [6, 40, [5], [2, 1], "sw", 1, 12],
            [6, 80, [5], [2, 1], "sw", 2, 12],
            [6, 80, [5], [2, 1], "sw", 1, 12],
            [6, 80, [7], [2, 1], "sw", 1, 12],
            [6, 80, [7], [2, 1], "sw", 1, 12],
            [6, 120, [3], [2, 1], "sw", 1, 12],
            [3, 120, [5], [2, 1], "sw", 1, 6],
            [3, 120, [5], [2, 1], "sw", 1, 6],
            [3, 120, [7], [2, 1], "sw", 1, 6],
            [6, 200, [5], [2, 1], "sw", 2, 12],
            [6, 200, [3, 5], [2, 1], "sw", 1, 12],
            [6, 200, [3, 7], [2, 1], "sw", 1, 12],
            [6, 200, [3], [2, 1], "sw", 1, 12],
        ]
        input_channel = 24
        last_channel = 1536

        self.last_channel = last_channel
        self.stem = stem(3, input_channel)
        self.separable_conv = separable_conv(input_channel, input_channel)
        self.mb_module = list()
        for e, c, k, p, a, s, r in mb_config:
            output_channel = c
            self.mb_module.append(MultiPathMB_B(input_channel, output_channel,
                                                kernel_size_list=k,
                                                pw_group_list=p,
                                                act=a,
                                                stride=s,
                                                expand_ratio=e,
                                                se_reduction=r,
                                                drop_connect_rate=drop_connect_rate))
            input_channel = output_channel
        self.mb_module = nn.Sequential(*self.mb_module)
        self.conv_before_pooling = conv_before_pooling(input_channel, self.last_channel)
        self.classifier = nn.Sequential(
            nn.Linear(self.last_channel, n_class),
        )
        self._initialize_weights()

    def forward(self, x):
        x = self.stem(x)
        x = self.separable_conv(x)
        x = self.mb_module(x)
        x = self.conv_before_pooling(x)
        x = x.mean(3).mean(2)
        x = self.classifier(x)
        return x

    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
                m.weight.data.normal_(0, math.sqrt(2. / n))
                if m.bias is not None:
                    m.bias.data.zero_()
            elif isinstance(m, nn.BatchNorm2d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()
            elif isinstance(m, nn.Linear):
                n = m.weight.size(0)  # fan-out
                init_range = 1.0 / math.sqrt(n)
                m.weight.data.uniform_(-init_range, init_range)
                m.bias.data.zero_()
