import math
from utils import *
import torch.nn as nn
from thop import profile


class Inverted_Bottleneck(nn.Module):
    def __init__(self, inplanes, outplanes, stride, choice, activation=nn.ReLU6):
        super(Inverted_Bottleneck, self).__init__()
        self.inplanes = inplanes
        self.outplanes = outplanes
        self.stride = stride
        self.activation = activation(inplace=True)

        self.kernel_list = [3, 5, 7, 9]
        self.k_ids = choice['conv']
        self.t = 3 if choice['rate'] == 0 else 6

        self.mix_conv = nn.ModuleList([])
        self.pw_linear = nn.ModuleList([])

        # pw
        self.pw = nn.Sequential(
            nn.Conv2d(inplanes, inplanes * self.t, kernel_size=1, bias=False),
            nn.BatchNorm2d(inplanes * self.t),
            activation(inplace=True)
        )

        # dw
        for k_id in self.k_ids:
            self.mix_conv.append(
                nn.Sequential(
                    nn.Conv2d(inplanes * self.t, inplanes * self.t, kernel_size=self.kernel_list[k_id], stride=stride,
                              padding=self.kernel_list[k_id] // 2, bias=False, groups=inplanes * self.t),
                    nn.BatchNorm2d(inplanes * self.t),
                    activation(inplace=True)
                )
            )

        # pw linear
        self.pw_linear = nn.Sequential(
            nn.Conv2d(inplanes * self.t, outplanes, kernel_size=1, bias=False),
            nn.BatchNorm2d(outplanes)
        )

    def forward(self, x, drop_path_prob):
        residual = x
        # pw
        out = self.pw(x)
        # dw
        out = sum([conv(out) for conv in self.mix_conv])
        # pw linear
        out = self.pw_linear(out)
        # residual
        if self.stride == 1 and self.inplanes == self.outplanes:
            if drop_path_prob > 0.0:
                out = drop_path(out, drop_prob=drop_path_prob)
            out = out + residual
        return out


class AuxiliaryHeadCIFAR(nn.Module):
    def __init__(self, C, num_classes):
        """assuming input size 8x8"""
        super(AuxiliaryHeadCIFAR, self).__init__()
        self.features = nn.Sequential(
            nn.ReLU(inplace=False),
            nn.AvgPool2d(5, stride=3, padding=0, count_include_pad=False),
            nn.Conv2d(C, 128, 1, bias=False),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=False),
            nn.Conv2d(128, 768, 2, bias=False),
            # nn.BatchNorm2d(768),
            nn.ReLU(inplace=False)
        )
        self.classifier = nn.Linear(768, num_classes)

    def forward(self, x):
        x = self.features(x)
        x = self.classifier(x.view(x.size(0), -1))
        return x


channel = [32, 48, 48, 96, 96, 96, 192, 192, 192, 256, 256, 320, 320]
last_channel = 1280


class Network(nn.Module):
    def __init__(self, choice, auxiliary, layers=12, classes=10, dropout_rate=0.2):
        super(Network, self).__init__()
        self.layers = layers
        self._auxiliary = auxiliary

        self.stem = nn.Sequential(
            nn.Conv2d(3, channel[0], kernel_size=3, stride=1, padding=1, bias=False),
            nn.BatchNorm2d(channel[0]),
            nn.ReLU6(inplace=True)
        )

        self.Inverted_Block = nn.ModuleList([])
        for i in range(self.layers):
            if i in [2, 5]:
                self.Inverted_Block.append(Inverted_Bottleneck(channel[i], channel[i + 1], stride=2, choice=choice[i]))
            else:
                self.Inverted_Block.append(Inverted_Bottleneck(channel[i], channel[i + 1], stride=1, choice=choice[i]))
        self.last_conv = nn.Sequential(
            nn.Conv2d(channel[-1], last_channel, kernel_size=1, stride=1, padding=0, bias=False),
            nn.BatchNorm2d(last_channel),
            nn.ReLU6(inplace=True)
        )

        if self._auxiliary:
            self.auxiliary_head = AuxiliaryHeadCIFAR(channel[6], classes)

        self.global_pooling = nn.AvgPool2d(8)
        # self.global_pooling = nn.AdaptiveAvgPool2d(1)
        self.dropout = nn.Dropout(dropout_rate)
        self.classifier = nn.Linear(last_channel, classes)
        self._initialize_weights()

    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
                m.weight.data.normal_(0, math.sqrt(2. / n))
                if m.bias is not None:
                    m.bias.data.zero_()
            elif isinstance(m, nn.BatchNorm2d):
                m.weight.data.fill_(1.0)
                m.bias.data.zero_()
            elif isinstance(m, nn.Linear):
                n = m.weight.size(0)  # fan-out
                init_range = 1.0 / math.sqrt(n)
                m.weight.data.uniform_(-init_range, init_range)
                m.bias.data.zero_()

    def forward(self, x, drop_path_prob=0.0):
        logits_aux = None
        x = self.stem(x)
        for i in range(self.layers):
            x = self.Inverted_Block[i](x, drop_path_prob)
            if i == 6:
                if self._auxiliary and self.training:
                    logits_aux = self.auxiliary_head(x)
        x = self.last_conv(x)
        x = self.global_pooling(x)
        x = x.view(-1, last_channel)
        x = self.dropout(x)
        x = self.classifier(x)
        return x, logits_aux


if __name__ == '__main__':
    choice = {
        0: {'conv': [0, 0], 'rate': 0},
        1: {'conv': [0, 0], 'rate': 0},
        2: {'conv': [0, 0], 'rate': 0},
        3: {'conv': [0, 0], 'rate': 0},
        4: {'conv': [0, 0], 'rate': 0},
        5: {'conv': [0, 0], 'rate': 0},
        6: {'conv': [0, 0], 'rate': 0},
        7: {'conv': [0, 0], 'rate': 0},
        8: {'conv': [0, 0], 'rate': 0},
        9: {'conv': [0, 0], 'rate': 0},
        10: {'conv': [0, 0], 'rate': 0},
        11: {'conv': [0, 0], 'rate': 0}}

    model = Network(choice=choice, auxiliary=False, layers=12, classes=10)
    print(model)
    flops, params = profile(model, inputs=(torch.randn(1, 3, 32, 32),), verbose=False)
    print('Params: %.2fM, Flops:%.2fM' % ((params / 1e6), (flops / 1e6)))
    input = torch.randn((1, 3, 32, 32))
    print(model(input))
