'''MobileNetV2 in PyTorch.

See the paper "Inverted Residuals and Linear Blocks:
Mobile Networks for Classification, Detection and Segmentation" for more details.
'''

import torch
import torch.nn as nn
import torch.nn.functional as F

__all__ = [
    "MobileNetV2_Multi",
]

class Block(nn.Module):
    '''expand + depthwise + pointwise'''
    def __init__(self, in_planes, out_planes, expansion, stride, use_skip_bn=False):
        super(Block, self).__init__()
        self.stride = stride
        self.use_skip_bn = use_skip_bn
        self.expansion = expansion
        planes = expansion * in_planes

        self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, stride=1, padding=0, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, groups=planes, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)
        self.conv3 = nn.Conv2d(planes, out_planes, kernel_size=1, stride=1, padding=0, bias=False)
        self.bn3 = nn.BatchNorm2d(out_planes)

        if (self.use_skip_bn == True):
            self.bn1_skip = nn.BatchNorm2d(planes)
            self.bn2_skip = nn.BatchNorm2d(planes)
            self.bn3_skip = nn.BatchNorm2d(out_planes)

        self.use_res_connect = self.stride == 1 and in_planes == out_planes

    def forward(self, x, skip=False):
        if (self.use_skip_bn == True and skip == True):
            if self.expansion != 1:
                out = F.relu6(self.bn1_skip(self.conv1(x)), inplace=True)
            else:
                out = x
            out = F.relu6(self.bn2_skip(self.conv2(out)), inplace=True)
            out = self.bn3_skip(self.conv3(out))
        else:
            if self.expansion != 1:
                out = F.relu6(self.bn1(self.conv1(x)), inplace=True)
            else:
                out = x
            out = F.relu6(self.bn2(self.conv2(out)), inplace=True)
            out = self.bn3(self.conv3(out))

        if self.use_res_connect:
            out = x + out

        return out


class Block_Skip(nn.Module):
    '''expand + depthwise + pointwise'''
    def __init__(self, in_planes, out_planes, expansion, stride, has_skip_branch=False, use_skip_bn=True):
        super(Block_Skip, self).__init__()
        self.stride = stride

        self.has_skip_branch = has_skip_branch
        self.use_skip_bn = use_skip_bn

        self.expansion = expansion
        planes = expansion * in_planes
        self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, stride=1, padding=0, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, groups=planes, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)
        self.conv3 = nn.Conv2d(planes, out_planes, kernel_size=1, stride=1, padding=0, bias=False)
        self.bn3 = nn.BatchNorm2d(out_planes)

        if (self.use_skip_bn == True):
            self.bn1_skip = nn.BatchNorm2d(planes)
            self.bn2_skip = nn.BatchNorm2d(planes)
            self.bn3_skip = nn.BatchNorm2d(out_planes)

        if (self.has_skip_branch == True):
            # for experiment
            self.conv3_skip = nn.Conv2d(planes, out_planes, kernel_size=1, stride=1, padding=0, bias=False)
            #self.conv3_skip = self.conv3

        self.use_res_connect = self.stride == 1 and in_planes == out_planes

    def forward(self, x, skip=False):
        if (skip == True):
            out = F.relu(self.bn1_skip(self.conv1(x)), inplace=True)
            out = F.relu(self.bn2_skip(self.conv2(out)), inplace=True)
            if (self.has_skip_branch == True):
                out = self.bn3_skip(self.conv3_skip(out))
                # print("xxx-jump")
            else:
                out = self.bn3_skip(self.conv3(out))
                # print("xxx")
        else:        
            out = F.relu(self.bn1(self.conv1(x)), inplace=True)
            out = F.relu(self.bn2(self.conv2(out)), inplace=True)
            out = self.bn3(self.conv3(out))

        if self.use_res_connect:
            out = x + out

        return out


class SkippableSequential_exp(nn.Sequential):
    def forward(self, input, skip=False):
        for i in range(len(self)):
            if ((skip == True) and (self[i].use_skip_bn == False)):
                pass
            else:
                input = self[i](input, skip=skip) 
        # print('---')
 
        return input

class SkippableSequential(nn.Sequential):
    def forward(self, input, skip=False):
        for i in range(len(self)):
            input = self[i](input, skip=skip) 
        # print('---')
 
        return input


class MobileNetV2_Multi(nn.Module):
    def __init__(self, class_num=1000):
        super(MobileNetV2_Multi, self).__init__()

        self.pre = nn.Sequential(
            nn.Conv2d(3, 32, 3, stride=2, padding=1),   # for ILSVRC
            nn.BatchNorm2d(32),
            nn.ReLU6(inplace=True)
        )

        self.layer1 = Block(32, 16, 1, 1, False)  
        self.layer2 = self._make_stage(2, 16, 24, 2, 6, False) #exp
        self.layer3_skippable = self._make_skip_stage(3, 24, 32, 2, 6)
        self.layer4_skippable = self._make_skip_stage(4, 32, 64, 2, 6)
        self.layer5_skippable = self._make_skip_stage(3, 64, 96, 1, 6)
        self.layer6_skippable = self._make_skip_stage(3, 96, 160, 2, 6)  # 2 for ilsvrc

        self.layer7 = Block(160, 320, 6, 1, False)

        self.conv1 = nn.Sequential(
            nn.Conv2d(320, 1280, 1),
            nn.BatchNorm2d(1280),
            nn.ReLU6(inplace=True)
        )

        self.avg_pool = nn.AdaptiveAvgPool2d((1,1))
        self.flatten = nn.Flatten()    

        self.classifier = nn.Sequential(
            nn.Dropout(0.2),  
            nn.Linear(1280, class_num),
        )

        # weight initialization
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out')
                if m.bias is not None:
                    nn.init.zeros_(m.bias)
            elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
                nn.init.ones_(m.weight)
                nn.init.zeros_(m.bias)
            elif isinstance(m, nn.Linear):
                nn.init.normal_(m.weight, 0, 0.01)
                nn.init.zeros_(m.bias)


    def forward(self, x, skip=(False,False,False,False)):
        x = self.pre(x)
        x = self.layer1(x)
        x = self.layer2(x) 
        x = self.layer3_skippable(x, skip = skip[0])
        x = self.layer4_skippable(x, skip = skip[1])
        x = self.layer5_skippable(x, skip = skip[2])
        x = self.layer6_skippable(x, skip = skip[3])
        x = self.layer7(x)
        x = self.conv1(x)
        x = self.avg_pool(x)
        x = self.flatten(x)
        x = self.classifier(x)

        return x

    def _make_stage(self, repeat, in_channels, out_channels, stride, t, use_skip_bn=False):

        layers = []
        layers.append(Block(in_channels, out_channels, t, stride, use_skip_bn=use_skip_bn))

        while repeat - 1:
            layers.append(Block(out_channels, out_channels, t, 1, use_skip_bn=use_skip_bn))
            repeat -= 1

        return nn.Sequential(*layers)

    def _make_skip_stage(self, repeat, in_channels, out_channels, stride, t):

        n_skip = repeat -2
        layers = []
        layers.append(Block_Skip(in_channels, out_channels, t, stride, use_skip_bn=True))

        for b in range(1, repeat):
            _has_skip_branch = False 
            _use_skip_bn = True
            if (b == n_skip):
                _has_skip_branch = True
            if (b > n_skip):
                _use_skip_bn = False
            layers.append(Block_Skip(out_channels, out_channels, t, 1, \
                has_skip_branch=_has_skip_branch, use_skip_bn=_use_skip_bn))

        return SkippableSequential_exp(*layers)


def test():
    
    net = MobileNetV2_Multi()
    x = torch.randn(1,3,224,224) 

    print(net)
    #y = net(x, skip=(True, True, True, True))
    y = net(x, skip=(False, False, False, False))
    print(y.size())
    
if __name__ == '__main__':
    test()

