import torch
import torch.nn as nn
import torch.nn.functional as F


conv_num_cfg = {
    'resnet18' : 8+4,
    'resnet34' : 16+4,
    'resnet50' : 16+4,
    'resnet101' : 33+4,
    'resnet152' : 50+4 
}

def adapt_channel(honey, cfg):
    if cfg == 'resnet50':
        expansion = 4
        stage_repeat = [3, 4, 6, 3]
        stage_out_channel = [64] + [64] * 3 + [128] * 4 + [256] * 6 + [512] * 3 + [2048]

    stage_oup_cprate = []
    stage_oup_cprate += [int(honey[0])]
    for i in range(len(stage_repeat)-1):
        stage_oup_cprate += [int(honey[i+1])] * stage_repeat[i]
    stage_oup_cprate +=[int(stage_out_channel[-1])] * stage_repeat[-1]
    print('stage_oup_cprate:',stage_oup_cprate)
    mid_cprate = [int(i) for i in honey[len(stage_repeat):]]
    print('mid_cprate:',mid_cprate)
    overall_channel = []
    mid_channel = []
    mid_channel_2 = []

    if cfg == 'resnet18' or cfg == 'resnet34':
        for i in range(len(stage_out_channel)-1):
            if i == 0 :
                overall_channel += [stage_oup_cprate[i]]
            else:
                overall_channel += [stage_oup_cprate[i]]
                mid_channel += [int(mid_cprate[i-1])]
    else:
        for i in range(len(stage_out_channel)-1):
            if i == 0 :
                overall_channel += [stage_oup_cprate[i]]
            else:
                overall_channel += [stage_oup_cprate[i]]
                mid_channel += [int(mid_cprate[2*(i-1)])]
                mid_channel_2 += [int(mid_cprate[2*i-1])]

    return overall_channel, mid_channel, mid_channel_2

def conv3x3(in_planes, out_planes, stride=1):
    """3x3 convolution with padding"""
    return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
                     padding=1, bias=False)

def conv1x1(in_planes, out_planes, stride=1):
    """1x1 convolution"""
    return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)

class BasicBlock(nn.Module):
    expansion = 1
    def __init__(self, midplanes, inplanes, planes, stride=1, is_downsample=False):
        super(BasicBlock, self).__init__()
        norm_layer = nn.BatchNorm2d
        self.conv1 = nn.Conv2d(inplanes, midplanes, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(midplanes)
        self.conv2 = nn.Conv2d(midplanes, planes, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)

        self.is_downsample = is_downsample
        if is_downsample:
            self.downsample = nn.Sequential(
                conv1x1(inplanes, planes, stride=stride),
                norm_layer(planes),
            )
            
    def forward(self, x):
        identity = x
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))

        if self.is_downsample:
            identity = self.downsample(x)
        out += identity
        
        out = F.relu(out)
        return out


class Bottleneck(nn.Module):
    def __init__(self, midplanes_1, midplanes_2, inplanes, planes, stride=1, is_downsample=False):
        super(Bottleneck, self).__init__()
        expansion = 4

        norm_layer = nn.BatchNorm2d
        self.conv1 = conv1x1(inplanes, midplanes_1)
        self.bn1 = norm_layer(midplanes_1)
        self.relu1 = nn.ReLU(inplace=True)

        self.conv2 = conv3x3(midplanes_1, midplanes_2, stride)
        self.bn2 = norm_layer(midplanes_2)
        self.relu2 = nn.ReLU(inplace=True)

        self.conv3 = conv1x1(midplanes_2, planes)
        self.bn3 = norm_layer(planes)
        self.relu3 = nn.ReLU(inplace=True)

        self.stride = stride
        self.inplanes = inplanes
        self.planes = planes
        self.midplanes_1 = midplanes_1
        self.midplanes_2 = midplanes_2

        self.is_downsample = is_downsample
        self.expansion = expansion

        if is_downsample:
            self.downsample = nn.Sequential(
                conv1x1(inplanes, planes, stride=stride),
                norm_layer(planes),
            )

    def forward(self, x):
        identity = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu1(out)

        out = self.conv2(out)
        out = self.bn2(out)
        out = self.relu2(out)

        out = self.conv3(out)
        out = self.bn3(out)

        if self.is_downsample:
            identity = self.downsample(x)

        out += identity
        out = self.relu3(out)

        return out

class ResNet(nn.Module):
    def __init__(self, block, num_blocks, cfg, num_classes=1000, honey=None):
        super(ResNet, self).__init__()
        self.in_planes = 64
        self.honey = honey
        print('self.honey:',self.honey)
        layer_num = 0
        if block == Bottleneck:
            overall_channel, mid_channel_1, mid_channel_2 = adapt_channel(honey, cfg)
        elif block == BasicBlock:
            overall_channel, mid_channel, _ = adapt_channel(honey, cfg)

        self.conv1 = nn.Conv2d(3, overall_channel[layer_num], kernel_size=7, stride=2, padding=3,
                               bias=False)
        self.bn1 = nn.BatchNorm2d(overall_channel[layer_num])
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        
        self.layer1 = nn.ModuleList()
        self.layer2 = nn.ModuleList()
        self.layer3 = nn.ModuleList()
        self.layer4 = nn.ModuleList()

        layer_num += 1
        if block == Bottleneck:
            for i in range(len(num_blocks)):
                if i == 0:
                    eval('self.layer%d' % (i+1)).append(Bottleneck(mid_channel_1[layer_num-1], mid_channel_2[layer_num-1], overall_channel[layer_num-1], overall_channel[layer_num], stride=1, is_downsample=True))
                    layer_num += 1
                else:
                    eval('self.layer%d' % (i+1)).append(Bottleneck(mid_channel_1[layer_num-1], mid_channel_2[layer_num-1], overall_channel[layer_num-1], overall_channel[layer_num], stride=2, is_downsample=True))
                    layer_num += 1

                for j in range(1, num_blocks[i]):
                    eval('self.layer%d' % (i+1)).append(Bottleneck(mid_channel_1[layer_num-1], mid_channel_2[layer_num-1], overall_channel[layer_num-1], overall_channel[layer_num]))
                    layer_num +=1
                    
        elif block == BasicBlock:
            for i in range(len(num_blocks)):
                if i == 0:
                    eval('self.layer%d' % (i+1)).append(BasicBlock(mid_channel[layer_num-1], overall_channel[layer_num-1], overall_channel[layer_num], stride=1))
                    layer_num += 1
                else:
                    eval('self.layer%d' % (i+1)).append(BasicBlock(mid_channel[layer_num-1], overall_channel[layer_num-1], overall_channel[layer_num], stride=2, is_downsample=True))
                    layer_num += 1

                for j in range(1, num_blocks[i]):
                    eval('self.layer%d' % (i+1)).append(BasicBlock(mid_channel[layer_num-1], overall_channel[layer_num-1], overall_channel[layer_num]))
                    layer_num +=1

        self.avgpool = nn.Sequential(nn.AvgPool2d(7))
        self.fc = nn.Linear(overall_channel[-1], num_classes)

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
            elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)

    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)

        for i, block in enumerate(self.layer1):
            x = block(x)
        for i, block in enumerate(self.layer2):
            x = block(x)
        for i, block in enumerate(self.layer3):
            x = block(x)
        for i, block in enumerate(self.layer4):
            x = block(x)

        x = self.avgpool(x)
        x = x.view(x.size(0), -1)
        x = self.fc(x)

        return x

def resnet(cfg, honey= None, num_classes = 1000):
    if cfg == 'resnet50':
        if honey == None:
            # honey = [10] * 20
            honey = [64] + [256] + [512] + [1024] + [64] * 3 * 2 + [128] * 4 * 2 + [256] * 6 * 2 + [512] * 3 * 2
        else:
            honey.insert(1 , 1024)
            honey.insert(1 , 512)
            honey.insert(1 , 256)
        return ResNet(Bottleneck, [3,4,6,3], cfg, num_classes=num_classes, honey=honey)


