from __future__ import absolute_import

import torch.nn as nn
from .mixbin import BinReLU, HardBinaryConv2d

class _BatchNorm2d(nn.BatchNorm2d):

    def __init__(self, num_features, *args, **kwargs):
        super(_BatchNorm2d, self).__init__(
            num_features, *args, eps=1e-6, momentum=0.05, **kwargs)


class _AlexNet(nn.Module):
    
    def forward(self, x):
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.conv3(x)
        x = self.conv4(x)
        x = self.conv5(x)
        return x

class MixBinAlexNetV1(_AlexNet):
    output_stride = 8

    def __init__(self, model_type="FFFFF"):
        super(MixBinAlexNetV1,  self).__init__()
        self.conv1 = nn.Sequential(
                nn.Conv2d(3, 96, 11, 2),
                _BatchNorm2d(96),
                nn.ReLU(inplace=True),
                nn.MaxPool2d(3, 2))
        if model_type[1]=="F":    
            self.conv2 = nn.Sequential(
                nn.Conv2d(96, 256, 5, 1, groups=2),
                _BatchNorm2d(256),
                BinReLU(),
                nn.MaxPool2d(3, 2))
        else:
            self.conv2 = nn.Sequential(
                HardBinaryConv2d(96, 256, 5, 1, groups=2),
                _BatchNorm2d(256),
                BinReLU(),
                nn.MaxPool2d(3, 2))
        
        if model_type[2]=="F":    
            self.conv3 = nn.Sequential(
                nn.Conv2d(256, 384, 3, 1),
                _BatchNorm2d(384),
                BinReLU())
        else:
            self.conv3 = nn.Sequential(
                HardBinaryConv2d(256, 384, 3, 1),
                _BatchNorm2d(384),
                BinReLU())
        
        if model_type[3]=="F":
            self.conv4 = nn.Sequential(
                nn.Conv2d(384, 384, 3, 1, groups=2),
                _BatchNorm2d(384),
                BinReLU())
        else:
            self.conv4 = nn.Sequential(
                HardBinaryConv2d(384, 384, 3, 1, groups=2),
                _BatchNorm2d(384),
                BinReLU())
        
        if model_type[4]=="F":
            self.conv5 = nn.Sequential(
                nn.Conv2d(384, 256, 3, 1, groups=2))
        else:
            self.conv5 = nn.Sequential(
                HardBinaryConv2d(384, 256, 3, 1, groups=2))