import torch
import torch.nn as nn
import torch.nn.functional as F
from modules import OnlineNeuron

def conv3x3(in_planes, out_planes, stride=1):
    return BinaryConv(in_planes, out_planes, kernel_size=3, stride=stride, padding=1)

def conv1x1(in_planes, out_planes, stride=1):
    return BinaryConv(in_planes, out_planes, kernel_size=1, stride=stride, padding=0)

    
class BinaryConv(nn.Module):
    def __init__(self, in_chn, out_chn, kernel_size=3, stride=1, padding=1):
        super(BinaryConv, self).__init__()
        self.stride = stride
        self.padding = padding
        self.shape = (out_chn, in_chn, kernel_size, kernel_size)
        self.weight = nn.Parameter(torch.rand((self.shape)) * 0.001, requires_grad=True)
        self.test_time = False
        
    def forward(self, x):
        if self.test_time is False:
            real_weights = self.weight
            scaling_factor = torch.mean(torch.mean(torch.mean(abs(real_weights),dim=3,keepdim=True),dim=2,keepdim=True),dim=1,keepdim=True)
            scaling_factor = scaling_factor.detach()
            binary_weights_no_grad = scaling_factor * torch.sign(real_weights)
            cliped_weights = torch.clamp(real_weights, -1.0, 1.0)
            binary_weights = binary_weights_no_grad.detach() - cliped_weights.detach() + cliped_weights
            y = F.conv2d(x, binary_weights, stride=self.stride, padding=self.padding)
            return y
        else:
            return F.conv2d(x, self.weight, stride=self.stride, padding=self.padding)

        
class BasicBlock(nn.Module):

    def __init__(self, T, inplanes, planes, downsample=None, use_eca=0, mem_bn=False, parallel_mode=False):
        super(BasicBlock, self).__init__()
        self.conv1 = nn.Sequential(
            conv3x3(inplanes, planes),
            nn.BatchNorm2d(planes)
        )
        self.attn1 = ECAAttention() if use_eca > 0 else nn.Identity()
        self.conv2 = nn.Sequential(
            conv3x3(planes, planes),
            nn.BatchNorm2d(planes)
        )   
        self.attn2 = ECAAttention() if use_eca > 0 else nn.Identity()         
        self.downsample = downsample
        self.act1 = OnlineNeuron(T, mem_bn, inplanes) #, parallel_mode
        self.act2 = OnlineNeuron(T, mem_bn, planes, parallel_mode)
        self.inplanes = inplanes
        self.planes = planes
        self.use_eca = use_eca

    def forward(self, x):
        identity = x  
        if self.downsample is not None:
            identity = self.downsample(x)
        if self.use_eca < 2:
            out = self.attn1(self.conv1(self.act1(x)))
            out = self.attn2(self.conv2(self.act2(out))) + identity
        else:
            if self.inplanes != self.planes:
                out = self.attn1(self.conv1(self.act1(x)))
            else:
                out = self.attn1(x + self.conv1(self.act1(x)))
            out = self.attn2(out + self.conv2(self.act2(out))) + identity
        return out
        
        
class BinaryActivation(nn.Module):
    def __init__(self):
        super(BinaryActivation, self).__init__()

    def forward(self, x):
        if self.training:
            out_forward = torch.sign(x)
            mask1 = x < -1
            mask2 = x < 0
            mask3 = x < 1
            out1 = (-1) * mask1.type(torch.float32) + (x*x + 2*x) * (1-mask1.type(torch.float32))
            out2 = out1 * mask2.type(torch.float32) + (-x*x + 2*x) * (1-mask2.type(torch.float32))
            out3 = out2 * mask3.type(torch.float32) + 1 * (1- mask3.type(torch.float32))
            out = out_forward.detach() - out3.detach() + out3
            return out
        else:
            return torch.sign(x)
        
        
class ECAAttention(nn.Module):

    def __init__(self, kernel_size=5):
        super().__init__()
        self.gap = nn.AdaptiveAvgPool2d(1)
        self.conv = nn.Conv1d(1, 1, kernel_size=kernel_size, padding=(kernel_size - 1) // 2)
        self.act = BinaryActivation()
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        y = self.gap(x) 
        y = y.squeeze(-1).permute(0, 2, 1)
        y = self.conv(self.act(y))
        y = self.sigmoid(y)
        y = y.permute(0, 2, 1).unsqueeze(-1)
        return x * y.expand_as(x)


class ResNet(nn.Module):
    def __init__(self, T, block, layers, num_classes=1000, use_dvs=False, use_resnet19=False, use_eca=0, mem_bn=False, parallel_mode=False):
        super(ResNet, self).__init__()
        self.T = T
        self.use_dvs = use_dvs
        self.inplanes = 64
        self.num_classes = num_classes
        self.use_eca = use_eca
        self.mem_bn = mem_bn
        self.parallel_mode = parallel_mode

        if num_classes == 1000:
            self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3, bias=False)
            self.bn1 = nn.BatchNorm2d(self.inplanes)
        elif self.use_dvs is True:
            self.conv1 = nn.Conv2d(2, self.inplanes, kernel_size=3, padding=1, bias=False)
            self.bn1 = nn.BatchNorm2d(self.inplanes)
        else:
            self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=3, padding=1, bias=False)
            self.bn1 = nn.BatchNorm2d(self.inplanes)

        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        
        if use_resnet19 is True:
            self.layer1 = nn.Sequential()
            self.layer2 = self._make_layer(block, 128, layers[0], stride=2 if self.use_dvs is True else 1)
            self.layer3 = self._make_layer(block, 256, layers[1], stride=2)
            self.layer4 = self._make_layer(block, 512, layers[2], stride=2)
            self.fc = nn.Sequential(
                nn.Linear(512, 256),
                OnlineNeuron(T),
                nn.Linear(256, num_classes)
            )
        else:
            self.layer1 = self._make_layer(block, 64, layers[0])
            self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
            self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
            self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
            self.fc = nn.Linear(512, num_classes)

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')

    def _make_layer(self, block, planes, blocks, stride=1):
        downsample = None
        if stride != 1 or self.inplanes != planes:
            downsample = nn.Sequential(
                conv1x1(self.inplanes, planes),
                nn.BatchNorm2d(planes)
            )

        layers = []
        layers.append(block(self.T, self.inplanes, planes, downsample, self.use_eca, self.mem_bn, self.parallel_mode))
        self.inplanes = planes
        for _ in range(1, blocks):
            layers.append(block(self.T, self.inplanes, planes, None, self.use_eca, self.mem_bn, self.parallel_mode))
        return nn.Sequential(*layers)

    def snn_forward_impl(self, x):
        x = self.bn1(self.conv1(x))
        if self.use_dvs is False:
            if self.num_classes >= 200:
                x = self.maxpool(x)
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)
        x = self.avgpool(x)
        x = torch.flatten(x, 1)
        return self.fc(x)

        
    def forward(self, x):
        return self.snn_forward_impl(x)


def _resnet(T, block, layers, **kwargs):
    model = ResNet(T, block, layers, **kwargs)
    return model

def resnet18(T, **kwargs):
    return _resnet(T, BasicBlock, [2, 2, 2, 2], **kwargs)

def resnet19(T, **kwargs):
    return _resnet(T, BasicBlock, [3, 3, 2], **kwargs)

def resnet34(T, **kwargs):
    return _resnet(T, BasicBlock, [3, 4, 6, 3], **kwargs)