import torch
import torch.nn as nn
from copy import deepcopy
from spikingjelly.activation_based import layer, functional
from spikingjelly.clock_driven.neuron import MultiStepParametricLIFNode, MultiStepLIFNode
import torch.nn.functional as F

try:
    from torchvision.models.utils import load_state_dict_from_url
except ImportError:
    from torchvision._internally_replaced_utils import load_state_dict_from_url

__all__ = ['spiking_resnet18', 'spiking_resnet34']



decay = 0.25  # 0.25 # decay constants
class MultiSpikeD(nn.Module):
    class DynamicQuant(torch.autograd.Function):
        @staticmethod
        def forward(ctx, input, D):
            ctx.save_for_backward(input)
            ctx.D = D
            return torch.round(torch.clamp(input, min=0.0, max=D))

        @staticmethod
        def backward(ctx, grad_output):
            input, = ctx.saved_tensors
            D = ctx.D
            grad_input = grad_output.clone()
            grad_input[input < 0] = 0
            grad_input[input > D] = 0
            return grad_input, None

    def forward(self, x, D):

        return self.DynamicQuant.apply(x, D)


class mem_update(nn.Module):
    def __init__(self, act=False, D0=1.0, alpha=0.01):
        super(mem_update, self).__init__()

        self.alpha = alpha     
        self.register_buffer('D', torch.tensor(D0)) 
        self.act = act
        self.qtrick = MultiSpikeD() 


    def forward(self, x):

        mem = torch.zeros_like(x[0]).to(x.device)
        spike = torch.zeros_like(x[0]).to(x.device)
        output = torch.zeros_like(x)
        mem_old = 0
        time_window = x.shape[0]
        for i in range(time_window):
            if i >= 1:
                mem = (mem_old - spike.detach()) * decay + x[i]

            else:
                mem = x[i]
            spike = self.qtrick(mem, self.D)

            if self.training:
                exceed_count = torch.sum(mem > self.D).float()
                self.D = self.D + self.alpha * exceed_count / mem.numel()

  
            mem_old = mem.clone()
            output[i] = spike

        return output
class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
                 base_width=64, dilation=1, norm_layer=None):
        super(BasicBlock, self).__init__()
        if norm_layer is None:
            norm_layer = nn.BatchNorm2d
        if groups != 1 or base_width != 64:
            raise ValueError('MemAddBasicBlock only supports groups=1 and base_width=64')
        if dilation > 1:
            raise NotImplementedError("Dilation > 1 not supported in MemAddBasicBlock")

        self.sn1 = mem_update()
        self.conv1 = layer.SeqToANNContainer(
            nn.Conv2d(inplanes, planes, kernel_size=3, padding=1, stride=stride, bias=False),
            nn.BatchNorm2d(planes)
        )
        self.sn2 = mem_update()

        self.conv2 = layer.SeqToANNContainer(
            nn.Conv2d(planes, planes, kernel_size=3, padding=1, stride=1, bias=False),
            norm_layer(planes)
        )
        
        self.downsample = downsample

    @torch.no_grad()
    def _encode_timesteps(self, x_int: torch.Tensor, weights: torch.Tensor):

        v = x_int.squeeze(0)           
        T = weights.numel()
        bits = [None] * T              


        for idx in reversed(range(T)):
            w = weights[idx]
            b = (v >= w).float()
            v = v - b * w
            bits[idx] = b

        bits = torch.stack(bits, dim=0)   # (T,B,C,H,W)
        return bits

    @torch.no_grad()
    def _decode_timesteps(self, bits: torch.Tensor, weights: torch.Tensor):

        T = weights.numel()
        w = weights.view(T, 1, 1, 1, 1)
        x_rec = (bits * w).sum(dim=0, keepdim=True)   # (1,B,C,H,W)
        return x_rec

    def _apply_conv_with_decomp(self, x, neuron: mem_update, conv_container):


        use_int_trick = (not self.training) and (x.shape[0] == 1) and hasattr(neuron, 'decomp_weights')

        if not use_int_trick:
            return conv_container(x)

        weights = neuron.decomp_weights   # (T,)


        x_int = x                         # (1,B,C,H,W)
        bits = self._encode_timesteps(x_int, weights)   # (T,B,C,H,W)

        conv = conv_container[0]    # Conv2d
        bn   = conv_container[1]    # BatchNorm2d

        T, B, C, H, W = bits.shape
        bits_flat = bits.view(T * B, C, H, W)      # (T*B,C,H,W)

        y = conv(bits_flat)                        # (T*B,C_out,H_out,W_out)
        _, C_out, H_out, W_out = y.shape
        y = y.view(T, B, C_out, H_out, W_out)      # (T,B,C_out,H_out,W_out)


        out = self._decode_timesteps(y, weights)   # (1,B,C_out,H_out,W_out)


        out = bn(out.squeeze(0)).unsqueeze(0)


        # with torch.no_grad():
        #     out_ref = conv_container(x_int)
        #     print("allclose?", torch.allclose(out, out_ref, atol=1e-4, rtol=1e-3))

        return out

    def forward(self, x):
        identity = x

        # sn1 -> conv1
        out = self.sn1(x)                          # (T,B,C,H,W)
        out = self._apply_conv_with_decomp(out, self.sn1, self.conv1)

        # sn2 -> conv2
        out = self.sn2(out)
        out = self._apply_conv_with_decomp(out, self.sn2, self.conv2)

        if self.downsample is not None:
            identity = self.downsample(x)

        out += identity
        return out


class SpikingResNet(nn.Module):

    def __init__(self, block, layers, num_classes=1000, zero_init_residual=False,
                 groups=1, width_per_group=64, replace_stride_with_dilation=None,
                 norm_layer=None, T=1):
        super(SpikingResNet, self).__init__()
        self.T = T
        if norm_layer is None:
            norm_layer = nn.BatchNorm2d
        self._norm_layer = norm_layer

        self.inplanes = 64
        self.dilation = 1
        if replace_stride_with_dilation is None:
            # each element in the tuple indicates if we should replace
            # the 2x2 stride with a dilated convolution instead
            replace_stride_with_dilation = [False, False, False]
        if len(replace_stride_with_dilation) != 3:
            raise ValueError("replace_stride_with_dilation should be None "
                             "or a 3-element tuple, got {}".format(replace_stride_with_dilation))
        self.groups = groups
        self.base_width = width_per_group
        self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3,
                               bias=False)
        self.bn1 = nn.BatchNorm2d(self.inplanes)

    
        self.layer1 = self._make_layer(block, 64, layers[0], stride=2)
        self.layer2 = self._make_layer(block, 128, layers[1], stride=2,
                                       dilate=replace_stride_with_dilation[0])
        self.layer3 = self._make_layer(block, 256, layers[2], stride=2,
                                       dilate=replace_stride_with_dilation[1])
        self.layer4 = self._make_layer(block, 512, layers[3], stride=2,
                                       dilate=replace_stride_with_dilation[2])
        self.sn = mem_update()
        self.flatten = layer.SeqToANNContainer(nn.Flatten())
        self.avgpool = layer.SeqToANNContainer(nn.AdaptiveAvgPool2d((1, 1)))
        self.fc = nn.Linear(512 * block.expansion, num_classes)

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
            elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)
        if zero_init_residual:
            zero_init_blocks(self)

        self._set_decomp_schemes()

    def _set_decomp_schemes(self):

        D_list = [
            12, 7, 7, 5,
            8, 5, 5, 4,
            5, 4, 4, 3,
            4, 3, 4, 3,
            10
        ]
        weight_codes = [
            "1236", "124", "124", "123",
            "1234", "123", "123", "123",
            "123", "123", "123", "12",
            "123", "12", "123", "12",
            "1234",
        ]

        assert len(D_list) == len(weight_codes)


        def code_to_weights(code: str):
            return [int(ch) for ch in code]

        schemes = [(D_list[i], code_to_weights(weight_codes[i])) for i in range(len(D_list))]

        idx = 0
        for m in self.modules():
            if isinstance(m, mem_update):
                if idx >= len(schemes):

                    break
                D_val, w_list = schemes[idx]

                m.D.data = torch.tensor(float(D_val), device=m.D.device)

                w_tensor = torch.tensor(w_list, dtype=torch.float32, device=m.D.device)
                m.register_buffer('decomp_weights', w_tensor)

                idx += 1

    def _make_layer(self, block, planes, blocks, stride=1, dilate=False):
        norm_layer = self._norm_layer
        downsample = None
        previous_dilation = self.dilation
        if dilate:
            self.dilation *= stride
            stride = 1
        if stride != 1 or self.inplanes != planes * block.expansion:
            downsample = layer.SeqToANNContainer(
                    nn.Conv2d(self.inplanes, planes * block.expansion, kernel_size=3, padding=1, stride=stride),
                    norm_layer(planes * block.expansion),
                )

        layers = []
        layers.append(block(self.inplanes, planes, stride, downsample, self.groups,
                            self.base_width, previous_dilation, norm_layer))
        self.inplanes = planes * block.expansion
        for _ in range(1, blocks):
            layers.append(block(self.inplanes, planes, groups=self.groups,
                                base_width=self.base_width, dilation=self.dilation,
                                norm_layer=norm_layer))

        return nn.Sequential(*layers)

    def _forward_impl(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x.unsqueeze_(0)
        out = x.repeat(self.T, 1, 1, 1, 1)
        out = self.layer1(out)
        out = self.layer2(out)
        out = self.layer3(out)
        out = self.layer4(out)
        out = self.sn(out)

        out = self.avgpool(out)
        out = self.flatten(out)
        out = self.fc(out)
        return out

    def forward(self, x):
        return self._forward_impl(x)


def _spiking_resnet(arch, block, layers, pretrained, progress, **kwargs):
    model = SpikingResNet(block, layers, **kwargs)
    if pretrained:
        state_dict = load_state_dict_from_url(model_urls[arch],
                                              progress=progress)
        model.load_state_dict(state_dict)
    return model


def spiking_resnet18(pretrained=False, progress=True,**kwargs):


    return _spiking_resnet('resnet18', BasicBlock, [2, 2, 2, 2], pretrained, progress, **kwargs)


def spiking_resnet34(pretrained=False, progress=True, **kwargs):

    return _spiking_resnet('resnet34', BasicBlock, [3, 4, 6, 3], pretrained, progress, **kwargs)


