import torch.nn as nn
from collections import OrderedDict
import math
from models.layers import SeqToANNContainer, LIFSpike
from modules import *

class HardBinaryConv2d(nn.Module):
    def __init__(self, in_chn, out_chn, bit_num=None, kernel_size=3, stride=1, padding=1):
        super(HardBinaryConv2d, self).__init__()
        self.stride = stride
        self.padding = padding
        self.number_of_weights = in_chn * out_chn * kernel_size * kernel_size
        self.shape = (out_chn, in_chn, kernel_size, kernel_size)
        #self.weight = nn.Parameter(torch.rand((self.number_of_weights,1)) * 0.001, requires_grad=True)
        self.weight = nn.Parameter(torch.rand((self.shape)) * 0.001, requires_grad=True)
        self.binarize = Binarize(bit_num=bit_num)

    def forward(self, x):
        #real_weights = self.weights.view(self.shape)
        real_weights = self.weight
        scaling_factor = torch.mean(torch.mean(torch.mean(abs(real_weights), dim=3, keepdim=True), \
            dim=2, keepdim=True), dim=1, keepdim=True)
        #print(scaling_factor, flush=True)
        scaling_factor = scaling_factor.detach()
        binary_weights = scaling_factor * self.binarize(real_weights)   # 挑选卷积核后再乘缩放因子
        #print(binary_weights, flush=True)
        y = F.conv2d(x, binary_weights, stride=self.stride, padding=self.padding)

        return y

class Layer(nn.Module):
    def __init__(self,in_plane,out_plane,kernel_size,stride,padding, bit_num=None):
        super(Layer, self).__init__()
        self.fwd = SeqToANNContainer(
            HardBinaryConv2d(in_plane,out_plane,bit_num, kernel_size,stride,padding),
            nn.BatchNorm2d(out_plane))
        self.act = LIFSpike()

    def forward(self,x):
        x = self.fwd(x)
        x = self.act(x)
        return x

class Layer_conv(nn.Module):
    def __init__(self,in_plane,out_plane,kernel_size,stride,padding):
        super(Layer_conv, self).__init__()
        self.fwd = SeqToANNContainer(
            nn.Conv2d(in_plane,out_plane, kernel_size,stride,padding),
            nn.BatchNorm2d(out_plane))
        self.act = LIFSpike()

    def forward(self,x):
        x = self.fwd(x)
        x = self.act(x)
        return x

class sbvggsnn(nn.Module):
    def __init__(self, num_classes=10, bit_num=4):
        super(sbvggsnn, self).__init__()
        pool = SeqToANNContainer(nn.AvgPool2d(2))
        #pool = APLayer(2)
        self.bit_num = bit_num
        print('bit_num = ', self.bit_num)
        self.features = nn.Sequential(
            Layer_conv(2,64,3,1,1),
            Layer(64,128,3,1,1, self.bit_num),
            pool,
            Layer(128,256,3,1,1, self.bit_num),
            Layer(256,256,3,1,1, self.bit_num),
            pool,
            Layer(256,512,3,1,1, self.bit_num),
            Layer(512,512,3,1,1, self.bit_num),
            pool,
            Layer(512,512,3,1,1, self.bit_num),
            Layer(512,512,3,1,1, self.bit_num),
            pool,
        )
        W = int(48/2/2/2/2)
        # self.T = 4
        self.classifier = SeqToANNContainer(nn.Linear(512*W*W, num_classes))

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')

    def forward(self, input):
        x = self.features(input)
        x = torch.flatten(x, 2)
        x = self.classifier(x)
        return x