import torch
import torch.nn as nn
import torch.nn.functional as F

conv_num_cfg = {
    'resnet18': 8 + 4,
    'resnet34': 16 + 4,
    'resnet50': 16 + 4,
    'resnet101': 33 + 4,
    'resnet152': 50 + 4
}


def adapt_channel(honey, cfg):
    if cfg == 'resnet18':
        expansion = 1
        stage_repeat = [2, 2, 2, 2]
        stage_out_channel = [64] + [64] * 2 + [128] * 2 + [256] * 2 + [512] * 2
    elif cfg == 'resnet34':
        expansion = 1
        stage_repeat = [3, 4, 6, 3]
        stage_out_channel = [64] + [64] * 3 + [128] * 4 + [256] * 6 + [512] * 3
    elif cfg == 'resnet50':
        expansion = 4
        stage_repeat = [3, 4, 6, 3]
        stage_out_channel = [
            64
        ] + [256] * 3 + [512] * 4 + [1024] * 6 + [2048] * 3
    elif cfg == 'resnet101':
        expansion = 4
        stage_repeat = [3, 4, 23, 3]
        stage_out_channel = [
            64
        ] + [256] * 3 + [512] * 4 + [1024] * 23 + [2048] * 3
    elif cfg == 'resnet152':
        expansion = 4
        stage_repeat = [3, 8, 36, 3]
        stage_out_channel = [
            64
        ] + [256] * 3 + [512] * 8 + [1024] * 36 + [2048] * 3

    stage_oup_cprate = []
    stage_oup_cprate += [float(honey[0] / 10)]
    for i in range(len(stage_repeat) - 1):
        stage_oup_cprate += [float(honey[i + 1] / 10)] * stage_repeat[i]
    stage_oup_cprate += [1.] * stage_repeat[-1]

    mid_scale_cprate = [float(i / 10) for i in honey[len(stage_repeat):]]

    overall_channel = []
    mid_channel = []
    for i in range(len(stage_out_channel)):
        if i == 0:
            overall_channel += [
                int(stage_out_channel[i] * stage_oup_cprate[i])
            ]
        else:
            overall_channel += [
                int(stage_out_channel[i] * stage_oup_cprate[i])
            ]
            mid_channel += [
                int(stage_out_channel[i] // expansion *
                    mid_scale_cprate[i - 1])
            ]
    return overall_channel, mid_channel


def conv3x3(in_planes, out_planes, stride=1):
    """3x3 convolution with padding"""
    return nn.Conv2d(in_planes,
                     out_planes,
                     kernel_size=3,
                     stride=stride,
                     padding=1,
                     bias=False)


def conv1x1(in_planes, out_planes, stride=1):
    """1x1 convolution"""
    return nn.Conv2d(in_planes,
                     out_planes,
                     kernel_size=1,
                     stride=stride,
                     bias=False)


class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self,
                 midplanes,
                 inplanes,
                 planes,
                 stride=1,
                 is_downsample=False):
        super(BasicBlock, self).__init__()
        norm_layer = nn.BatchNorm2d
        self.conv1 = nn.Conv2d(inplanes,
                               midplanes,
                               kernel_size=3,
                               stride=stride,
                               padding=1,
                               bias=False)
        self.bn1 = nn.BatchNorm2d(midplanes)
        self.conv2 = nn.Conv2d(midplanes,
                               planes,
                               kernel_size=3,
                               stride=1,
                               padding=1,
                               bias=False)
        self.bn2 = nn.BatchNorm2d(planes)


        self.is_downsample = is_downsample
        if is_downsample:
            self.downsample = nn.Sequential(
                conv1x1(inplanes, planes, stride=stride),
                norm_layer(planes),
            )

    def forward(self, x):
        identity = x
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))

        if self.is_downsample:
            identity = self.downsample(x)
        out += identity

        out = F.relu(out)
        return out


class Bottleneck(nn.Module):

    def __init__(self,
                 midplanes,
                 inplanes,
                 planes,
                 stride=1,
                 is_downsample=False):
        super(Bottleneck, self).__init__()
        expansion = 4

        norm_layer = nn.BatchNorm2d
        self.conv1 = conv1x1(inplanes, midplanes)
        self.bn1 = norm_layer(midplanes)
        self.relu1 = nn.ReLU(inplace=True)

        self.conv2 = conv3x3(midplanes, midplanes, stride)
        self.bn2 = norm_layer(midplanes)
        self.relu2 = nn.ReLU(inplace=True)

        self.conv3 = conv1x1(midplanes, planes)
        self.bn3 = norm_layer(planes)
        self.relu3 = nn.ReLU(inplace=True)

        self.stride = stride
        self.inplanes = inplanes
        self.planes = planes
        self.midplanes = midplanes

        self.is_downsample = is_downsample
        self.expansion = expansion

        if is_downsample:
            self.downsample = nn.Sequential(
                conv1x1(inplanes, planes, stride=stride),
                norm_layer(planes),
            )

    def forward(self, x):
        identity = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu1(out)

        out = self.conv2(out)
        out = self.bn2(out)
        out = self.relu2(out)

        out = self.conv3(out)
        out = self.bn3(out)

        if self.is_downsample:
            identity = self.downsample(x)

        out += identity
        out = self.relu3(out)

        return out


class ResNet(nn.Module):

    def __init__(self, block, num_blocks, cfg, num_classes=1000, honey=None):
        super(ResNet, self).__init__()
        self.in_planes = 64
        self.honey = honey
        print('compress rate:', self.honey)
        layer_num = 0
        overall_channel, mid_channel = adapt_channel(honey, cfg)

        self.conv1 = nn.Conv2d(3,
                               overall_channel[layer_num],
                               kernel_size=7,
                               stride=2,
                               padding=3,
                               bias=False)
        self.bn1 = nn.BatchNorm2d(overall_channel[layer_num])
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)

        self.layer1 = nn.ModuleList()
        self.layer2 = nn.ModuleList()
        self.layer3 = nn.ModuleList()
        self.layer4 = nn.ModuleList()

        layer_num += 1
        if block == Bottleneck:
            for i in range(len(num_blocks)):
                if i == 0:
                    eval('self.layer%d' % (i + 1)).append(
                        Bottleneck(mid_channel[layer_num - 1],
                                   overall_channel[layer_num - 1],
                                   overall_channel[layer_num],
                                   stride=1,
                                   is_downsample=True))
                    layer_num += 1
                else:
                    eval('self.layer%d' % (i + 1)).append(
                        Bottleneck(mid_channel[layer_num - 1],
                                   overall_channel[layer_num - 1],
                                   overall_channel[layer_num],
                                   stride=2,
                                   is_downsample=True))
                    layer_num += 1

                for j in range(1, num_blocks[i]):
                    eval('self.layer%d' % (i + 1)).append(
                        Bottleneck(mid_channel[layer_num - 1],
                                   overall_channel[layer_num - 1],
                                   overall_channel[layer_num]))
                    layer_num += 1

        elif block == BasicBlock:
            for i in range(len(num_blocks)):
                if i == 0:
                    eval('self.layer%d' % (i + 1)).append(
                        BasicBlock(mid_channel[layer_num - 1],
                                   overall_channel[layer_num - 1],
                                   overall_channel[layer_num],
                                   stride=1))
                    layer_num += 1
                else:
                    eval('self.layer%d' % (i + 1)).append(
                        BasicBlock(mid_channel[layer_num - 1],
                                   overall_channel[layer_num - 1],
                                   overall_channel[layer_num],
                                   stride=2,
                                   is_downsample=True))
                    layer_num += 1

                for j in range(1, num_blocks[i]):
                    eval('self.layer%d' % (i + 1)).append(
                        BasicBlock(mid_channel[layer_num - 1],
                                   overall_channel[layer_num - 1],
                                   overall_channel[layer_num]))
                    layer_num += 1

        self.avgpool = nn.Sequential(nn.AvgPool2d(7))
        self.fc = nn.Linear(overall_channel[-1], num_classes)

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight,
                                        mode='fan_out',
                                        nonlinearity='relu')
            elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)

    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)

        for i, block in enumerate(self.layer1):
            x = block(x)
        for i, block in enumerate(self.layer2):
            x = block(x)
        for i, block in enumerate(self.layer3):
            x = block(x)
        for i, block in enumerate(self.layer4):
            x = block(x)

        x = self.avgpool(x)
        x = x.view(x.size(0), -1)
        x = self.fc(x)

        return x


def resnet(cfg, honey=None, num_classes=1000):
    if cfg == 'resnet18':
        if honey == None:
            honey = [10] * 12
        return ResNet(BasicBlock, [2, 2, 2, 2],
                      cfg,
                      num_classes=num_classes,
                      honey=honey)
    elif cfg == 'resnet34':
        if honey == None:
            honey = [10] * 20
        return ResNet(BasicBlock, [3, 4, 6, 3],
                      cfg,
                      num_classes=num_classes,
                      honey=honey)
    elif cfg == 'resnet50':
        if honey == None:
            honey = [10] * 20
        return ResNet(Bottleneck, [3, 4, 6, 3],
                      cfg,
                      num_classes=num_classes,
                      honey=honey)
    elif cfg == 'resnet101':
        if honey == None:
            honey = [10] * 37
        return ResNet(Bottleneck, [3, 4, 23, 3],
                      cfg,
                      num_classes=num_classes,
                      honey=honey)
    elif cfg == 'resnet152':
        if honey == None:
            honey = [10] * 54
        return ResNet(Bottleneck, [3, 8, 36, 3],
                      cfg,
                      num_classes=num_classes,
                      honey=honey)
