import torch
import torch.nn as nn
import torch.nn.functional as F

class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, in_planes, planes, stride=1):
        super(BasicBlock, self).__init__()
        self.conv1 = nn.Conv2d(
            in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3,
                               stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)

        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != self.expansion*planes:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_planes, self.expansion*planes,
                          kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(self.expansion*planes)
            )
        self.x1_test = MyTestPlace()
        self.x2_test = MyTestPlace()

    def forward(self, x):
        out = self.x1_test(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out += self.shortcut(x)
        out = self.x2_test(out)
        return out


class Bottleneck(nn.Module):
    expansion = 4

    def __init__(self, in_planes, planes, stride=1):
        super(Bottleneck, self).__init__()
        self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3,
                               stride=stride, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)
        self.conv3 = nn.Conv2d(planes, self.expansion *
                               planes, kernel_size=1, bias=False)
        self.bn3 = nn.BatchNorm2d(self.expansion*planes)

        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != self.expansion*planes:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_planes, self.expansion*planes,
                          kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(self.expansion*planes)
            )

        self.x1_test = MyTestPlace()
        self.x2_test = MyTestPlace()
        self.x3_test = MyTestPlace()

    def forward(self, x):
        out = self.x1_test(self.bn1(self.conv1(x)))
        out = self.x2_test(self.bn2(self.conv2(out)))
        out = self.bn3(self.conv3(out))
        out += self.shortcut(x)
        out = self.x3_test(out)
        return out


class ResNet(nn.Module):
    def __init__(self, block, num_blocks, num_classes=10):
        super(ResNet, self).__init__()
        self.in_planes = 64

        self.conv1 = nn.Conv2d(3, 64, kernel_size=3,
                               stride=1, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)
        self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)
        self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)
        self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)
        self.linear = nn.Linear(512*block.expansion, num_classes)

        self.x1_test = MyTestPlace()

    def _make_layer(self, block, planes, num_blocks, stride):
        strides = [stride] + [1]*(num_blocks-1)
        layers = []
        for stride in strides:
            layers.append(block(self.in_planes, planes, stride))
            self.in_planes = planes * block.expansion
        return nn.Sequential(*layers)

    def forward_features(self, x):
        out = self.x1_test(self.bn1(self.conv1(x)))
        out = self.layer1(out)
        out = self.layer2(out)
        out = self.layer3(out)
        out = self.layer4(out)
        out = F.avg_pool2d(out, 4)
        out = out.view(out.size(0), -1)
        out = self.linear(out)
        return out

    def forward(self, x, T = 1):
        x2 = None
        x2_list = []
        for i in range(T):
            x1 = torch.clone(x)
            x1 = self.forward_features(x1)
            x2 = x1 if x2 == None else x1 + x2
            x2_list.append(x2)
        return x2_list
        

class SFNeuron(nn.Module):
    def __init__(self, scale_p = None, scale_n = None, times = None):
        super(MTNeuron, self).__init__()
        self.scale_p = scale_p
        self.scale_n = scale_n
        self.times = times
        self.t = 0
        self.neuron = None

    def forward_linear(self, x):
        fire = self.scale_p * torch.floor(x / self.scale_p).clamp(min=0, max=self.times),        
        return fire

    def forward(self, x): # 接受前一层输入x，计算多阈值触发的发射list，求和得到实际发射值输出（sum写在模型的forward中）
        if self.t == 0:
            self.neuron = torch.zeros_like(x)
        self.neuron += x

        fire = self.forward_linear(self.neuron)
        
        self.neuron -= fire
        self.t += 1
        return fire
    
    def reset(self):
        self.t = 0
        self.neuron = None

def replace_testneuron_by_sfneuron(model, args):
    for name, module in model._modules.items():
        if hasattr(module, "_modules"):
            model._modules[name] = replace_testneuron_by_sfneuron(module, args)
        if module.__class__.__name__.lower() == 'testneuron':
            p, n = float(model._modules[name].scale_p[0]), float(model._modules[name].scale_n[0])
            times = args.linear_num
            sp = p * args.lambda
            sn = n * args.lambda
            model._modules[name] = SFNeuron(
                scale_p = sp,
                scale_n = sn,
                times = times,
            )
    return model
    
class MyTestPlace(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, x):
        x = F.relu(x)
        return x

class TestNeuron(nn.Module):
    def __init__(self, percent = None):
        super(TestNeuron, self).__init__()
        self.percent = percent
        self.num = 0
        self.scale_p = torch.nn.Parameter(torch.FloatTensor([0.]))
        self.scale_n = torch.nn.Parameter(torch.FloatTensor([0.]))

    def forward(self, x):
        x2 = x.reshape(-1)
        N = x2.numel()
        k = int((1 - self.percent) * N)

        if k == 0:
            threshold = torch.max(x2).item()
            self.scale_p = torch.nn.Parameter((self.scale_p * self.num + threshold) / (self.num + 1))
            threshold = -torch.min(x2).item()
            self.scale_n = torch.nn.Parameter((self.scale_n * self.num + threshold) / (self.num + 1))
            self.num += 1
            return F.relu(x)
        
        threshold = torch.topk(x2, k, largest = True).values[-1].item()
        self.scale_p = torch.nn.Parameter((self.scale_p * self.num + threshold) / (self.num + 1))
        threshold = -torch.topk(x2, k, largest = False).values[-1].item()
        self.scale_n = torch.nn.Parameter((self.scale_n * self.num + threshold)/(self.num + 1))
        self.num += 1
        return F.relu(x)
        
    def reset(self):
        pass

def replace_test_by_testneuron(model, percent=None):
    for name, module in model._modules.items():
        if hasattr(module, "_modules"):
            model._modules[name] = replace_test_by_testneuron(module, percent)
        if module.__class__.__name__.lower() == 'mytestplace':
            model._modules[name] = TestNeuron(percent = percent)
    return model

def reset_net(model):
    for name, module in model._modules.items():
        if hasattr(module,"_modules"):
            reset_net(module)
        if 'neuron' in module.__class__.__name__.lower():
            module.reset()
    return model

def ResNet18():
    return ResNet(BasicBlock, [2, 2, 2, 2])