import torch
import torch.nn as nn
from copy import deepcopy
from spikingjelly.activation_based import layer, functional
from spikingjelly.clock_driven.neuron import MultiStepParametricLIFNode, MultiStepLIFNode
import torch.nn.functional as F

try:
    from torchvision.models.utils import load_state_dict_from_url
except ImportError:
    from torchvision._internally_replaced_utils import load_state_dict_from_url

__all__ = ['spiking_resnet18', 'spiking_resnet34']



decay = 0.25  # 0.25 # decay constants
class MultiSpikeD(nn.Module):
    class DynamicQuant(torch.autograd.Function):
        @staticmethod
        def forward(ctx, input, D):
            ctx.save_for_backward(input)
            ctx.D = D
            return torch.round(torch.clamp(input, min=0.0, max=D))

        @staticmethod
        def backward(ctx, grad_output):
            input, = ctx.saved_tensors
            D = ctx.D
            grad_input = grad_output.clone()
            grad_input[input < 0] = 0
            grad_input[input > D] = 0
            return grad_input, None

    def forward(self, x, D):

        return self.DynamicQuant.apply(x, D)


class mem_update(nn.Module):
    def __init__(self, act=False, D0=1.0, alpha=0.01):
        super(mem_update, self).__init__()

        self.alpha = alpha     
        self.register_buffer('D', torch.tensor(D0)) 
        self.act = act
        self.qtrick = MultiSpikeD() 


    def forward(self, x):

        mem = torch.zeros_like(x[0]).to(x.device)
        spike = torch.zeros_like(x[0]).to(x.device)
        output = torch.zeros_like(x)
        mem_old = 0
        time_window = x.shape[0]
        for i in range(time_window):
            if i >= 1:
                mem = (mem_old - spike.detach()) * decay + x[i]

            else:
                mem = x[i]
            spike = self.qtrick(mem, self.D)

            if self.training:
                exceed_count = torch.sum(mem > self.D).float()
                self.D = self.D + self.alpha * exceed_count / mem.numel()

  
            mem_old = mem.clone()
            output[i] = spike

        return output

class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
                 base_width=64, dilation=1, norm_layer=None):
        super(BasicBlock, self).__init__()
        if norm_layer is None:
            norm_layer = nn.BatchNorm2d
        if groups != 1 or base_width != 64:
            raise ValueError('MemAddBasicBlock only supports groups=1 and base_width=64')
        if dilation > 1:
            raise NotImplementedError("Dilation > 1 not supported in MemAddBasicBlock")

        self.sn1 = mem_update()
        self.conv1 = layer.SeqToANNContainer(
            nn.Conv2d(inplanes, planes, kernel_size=3, padding=1, stride=stride, bias=False),
            nn.BatchNorm2d(planes)
        )
        self.sn2 = mem_update()

        self.conv2 = layer.SeqToANNContainer(
            nn.Conv2d(planes, planes, kernel_size=3, padding=1, stride=1, bias=False),
            norm_layer(planes)
        )
        
        self.downsample = downsample

    def forward(self, x):
        
        identity = x
        out = self.sn1(x)
        out = self.conv1(out)
        out = self.sn2(out)
        out = self.conv2(out)
        if self.downsample is not None:
            identity = self.downsample(x)
        out += identity
    
        return out
def zero_init_blocks(net: nn.Module):
    for m in net.modules():
        if isinstance(m, BasicBlock):
            nn.init.constant_(m.conv2.module[1].weight, 0)


class SpikingResNet(nn.Module):

    def __init__(self, block, layers, num_classes=1000, zero_init_residual=False,
                 groups=1, width_per_group=64, replace_stride_with_dilation=None,
                 norm_layer=None, T=1):
        super(SpikingResNet, self).__init__()
        self.T = T
        if norm_layer is None:
            norm_layer = nn.BatchNorm2d
        self._norm_layer = norm_layer

        self.inplanes = 64
        self.dilation = 1
        if replace_stride_with_dilation is None:
            # each element in the tuple indicates if we should replace
            # the 2x2 stride with a dilated convolution instead
            replace_stride_with_dilation = [False, False, False]
        if len(replace_stride_with_dilation) != 3:
            raise ValueError("replace_stride_with_dilation should be None "
                             "or a 3-element tuple, got {}".format(replace_stride_with_dilation))
        self.groups = groups
        self.base_width = width_per_group
        self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3,
                               bias=False)
        self.bn1 = nn.BatchNorm2d(self.inplanes)

    
        self.layer1 = self._make_layer(block, 64, layers[0], stride=2)
        self.layer2 = self._make_layer(block, 128, layers[1], stride=2,
                                       dilate=replace_stride_with_dilation[0])
        self.layer3 = self._make_layer(block, 256, layers[2], stride=2,
                                       dilate=replace_stride_with_dilation[1])
        self.layer4 = self._make_layer(block, 512, layers[3], stride=2,
                                       dilate=replace_stride_with_dilation[2])
        self.sn = mem_update()
        self.flatten = layer.SeqToANNContainer(nn.Flatten())
        self.avgpool = layer.SeqToANNContainer(nn.AdaptiveAvgPool2d((1, 1)))
        self.fc = nn.Linear(512 * block.expansion, num_classes)

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
            elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)
        if zero_init_residual:
            zero_init_blocks(self)

    def _make_layer(self, block, planes, blocks, stride=1, dilate=False):
        norm_layer = self._norm_layer
        downsample = None
        previous_dilation = self.dilation
        if dilate:
            self.dilation *= stride
            stride = 1
        if stride != 1 or self.inplanes != planes * block.expansion:
            downsample = layer.SeqToANNContainer(
                    nn.Conv2d(self.inplanes, planes * block.expansion, kernel_size=3, padding=1, stride=stride),
                    norm_layer(planes * block.expansion),
                )

        layers = []
        layers.append(block(self.inplanes, planes, stride, downsample, self.groups,
                            self.base_width, previous_dilation, norm_layer))
        self.inplanes = planes * block.expansion
        for _ in range(1, blocks):
            layers.append(block(self.inplanes, planes, groups=self.groups,
                                base_width=self.base_width, dilation=self.dilation,
                                norm_layer=norm_layer))

        return nn.Sequential(*layers)

    def _forward_impl(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x.unsqueeze_(0)
        out = x.repeat(self.T, 1, 1, 1, 1)
        out = self.layer1(out)
        out = self.layer2(out)
        out = self.layer3(out)
        out = self.layer4(out)
        out = self.sn(out)

        out = self.avgpool(out)
        out = self.flatten(out)
        out = self.fc(out)
        return out

    def forward(self, x):
        return self._forward_impl(x)


def _spiking_resnet(arch, block, layers, pretrained, progress, **kwargs):
    model = SpikingResNet(block, layers, **kwargs)
    if pretrained:
        state_dict = load_state_dict_from_url(model_urls[arch],
                                              progress=progress)
        model.load_state_dict(state_dict)
    return model


def spiking_resnet18(pretrained=False, progress=True,**kwargs):


    return _spiking_resnet('resnet18', BasicBlock, [2, 2, 2, 2], pretrained, progress, **kwargs)


def spiking_resnet34(pretrained=False, progress=True, **kwargs):

    return _spiking_resnet('resnet34', BasicBlock, [3, 4, 6, 3], pretrained, progress, **kwargs)


