from numpy import block, isin
import torch
import torch.nn as nn
from models.block.DepthSeperable import DepthSeperabelConv2d
from models.utils.utils import get_activation_function, make_conv_block

class MobileNet(nn.Module):
    # (128,2) means conv planes=128, conv stride=2, by default conv stride=1
    cfg = [64, (128, 2), 128, (256, 2), 256, (512, 2), 512, 512, 512, 512, 512, (1024, 2), 1024]

    def __init__(self, activation_type, num_classes=10, oper_order='cba', dataset='cifar10',
                depthwise_acti=True, tau=0.0):
        super(MobileNet, self).__init__()
        self.activation_generator = get_activation_function(activation_type, tau=tau)
        self.oper_order = oper_order
        self.cutted_resolution = 1

        if 'cifar' in dataset or 'tinyImageNet' in dataset:
            stride = 1
        else:
            stride = 2

        self.stem = make_conv_block(3, 32, kernel_size=3, stride=stride, padding=1,
                                    activation_generator=self.activation_generator, oper_order=self.oper_order)

        self.layers = self._make_layers(in_planes=32, depthwise_acti=depthwise_acti)
        last_channel = self.cfg[-1] if not isinstance(self.cfg[-1], tuple) else self.cfg[-1][0]

        if 'cifar' in dataset:
            self.avgpool = nn.AvgPool2d(2 * self.cutted_resolution)
        elif 'tinyImageNet' == dataset:
            self.avgpool = nn.AvgPool2d(2 * self.cutted_resolution)
            last_channel *= (2 * 2)
        elif dataset == 'ImageNet' or dataset == 'cub200':
            self.avgpool = nn.AvgPool2d(7 * self.cutted_resolution)

        self.linear = nn.Linear(last_channel, num_classes)

    def _make_layers(self, in_planes, depthwise_acti):
        layers = []
        for x in self.cfg:
            out_planes = x if isinstance(x, int) else x[0]
            stride = 1 if isinstance(x, int) else x[1]
            layers.append(DepthSeperabelConv2d(in_planes, out_planes, kernel_size=3,
                                               stride=stride, padding=1,
                                               activation_generator=self.activation_generator,
                                               oper_order=self.oper_order, depthwise_acti=depthwise_acti))
            in_planes = out_planes
        return nn.Sequential(*layers)

    def forward(self, x):
        out = self.stem(x)
        out = self.layers(out)
        out = self.avgpool(out)
        out = out.view(out.size(0), -1)

        out = self.linear(out)
        return out
