import torch
import torch.nn as nn
from spikingjelly.clock_driven.neuron import MultiStepLIFNode
from timm.models.layers import to_2tuple, trunc_normal_, DropPath
from timm.models.registry import register_model
from timm.models.vision_transformer import _cfg
import torch.nn.functional as F
from functools import partial

import math
from einops import repeat
from src.models.sequence.kernels.ssm import SSMKernelDiag, SSMKernelDPLR

__all__ = ['spikformer', 'spikformer_inference']


def spike_activation(x, temp=1.0):
    out_s = torch.sign(x)
    out_s[torch.abs(x) < 0.5] = torch.tensor(0.)
    out_bp = torch.clamp(x, -1, 1)
    return (out_s.float() - out_bp).detach() + out_bp


class TenarySpike(nn.Module):
    def __init__(self, thresh=0.5, tau=1, gama=1.0):
        super(TenarySpike, self).__init__()
        self.thresh = thresh
        self.tau = tau
        self.gama = gama

    def forward(self, x):
        mem = 0
        spike_pot = []
        T = x.shape[0]
        for t in range(T):
            mem = mem * self.tau + x[t, ...]
            spike = spike_activation(mem - self.thresh)
            mem = (1 - torch.abs(spike)) * mem
            spike_pot.append(spike)
        return torch.stack(spike_pot, dim=0)


class StateFusion(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.dim = dim
        self.kernel_3 = nn.Parameter(torch.ones(dim, 1, 3, 3))
        self.kernel_3_1 = nn.Parameter(torch.ones(dim, 1, 3, 3))
        self.kernel_3_2 = nn.Parameter(torch.ones(dim, 1, 3, 3))
        self.alpha = nn.Parameter(torch.ones(3), requires_grad=True)

    @staticmethod
    def padding(input_tensor, padding):
        return torch.nn.functional.pad(input_tensor, padding, mode='replicate')

    def forward(self, h):
        h1 = F.conv2d(self.padding(h, (1, 1, 1, 1)), self.kernel_3, padding=0, dilation=1, groups=self.dim)
        h2 = F.conv2d(self.padding(h, (3, 3, 3, 3)), self.kernel_3_1, padding=0, dilation=3, groups=self.dim)
        h3 = F.conv2d(self.padding(h, (5, 5, 5, 5)), self.kernel_3_2, padding=0, dilation=5, groups=self.dim)
        out = self.alpha[0] * h1 + self.alpha[1] * h2 + self.alpha[2] * h3
        return out


class SD_TCM(nn.Module):
    def __init__(self, d_model, d_state=8, **kernel_args):
        super().__init__()
        self.h = d_model
        self.n = d_state
        self.D = nn.Parameter(torch.randn(self.h))

        self.kernel1 = SSMKernelDPLR(d_model=self.h, d_state=self.n, init="diag-lin", **kernel_args)
        self.state_fusion = StateFusion(d_model)

        # 推理模式相关
        self._step_mode = False
        self._batch_size = None

    def setup_step_mode(self, batch_size=16, **kwargs):
        """设置step模式，支持指定batch_size的推理"""
        self._step_mode = True
        self._batch_size = batch_size

        # 调用kernel的setup_step方法
        if hasattr(self.kernel1, '_setup_step'):
            self.kernel1._setup_step(**kwargs)
        elif hasattr(self.kernel1, 'setup_step'):
            self.kernel1.setup_step(**kwargs)

        print(f"Step mode enabled for batch_size={batch_size}")

    def reset_step_mode(self):
        """重置为训练模式"""
        self._step_mode = False
        self._batch_size = None
        print("Switched back to training mode")

    def get_initial_state(self, batch_size, device, dtype=torch.float32):
        """获取初始状态"""
        if hasattr(self.kernel1, 'default_state'):
            return self.kernel1.default_state(batch_size)
        else:
            # 如果没有default_state方法，手动创建初始状态
            # 状态形状通常是 (batch_size, d_model, d_state)
            return torch.zeros(batch_size, self.h, self.n, dtype=dtype, device=device)

    def forward_step_2d(self, u, states_dict=None):
        """
        对2D特征图进行step推理，支持batch处理
        Args:
            u: 输入张量 (B, C, H, W)
            states_dict: 状态字典，存储每个空间位置的状态
        Returns:
            y: 输出张量 (B, C, H, W)
            next_states_dict: 更新后的状态字典
        """
        if not self._step_mode:
            raise RuntimeError("Must call setup_step_mode() before using forward_step_2d()")

        B, C, H, W = u.shape
        device = u.device
        dtype = u.dtype

        # 初始化状态字典
        if states_dict is None:
            states_dict = {}
            for i in range(H):
                for j in range(W):
                    states_dict[(i, j)] = self.get_initial_state(B, device, dtype)

        # 准备输出
        y = torch.zeros_like(u)
        next_states_dict = {}

        # 对每个空间位置进行step操作
        # print(y.shape)
        for i in range(H):
            for j in range(W):
                pos_key = (i, j)
                u_pos = u[:, :, i, j]  # (B, C)
                state = states_dict.get(pos_key)

                if state is None:
                    state = self.get_initial_state(B, device, dtype)

                # 执行step操作

                if hasattr(self.kernel1, 'step'):
                    # print('==============================')
                    y_pos, next_state = self.kernel1.step(u_pos, state)
                else:
                    # 如果没有step方法，使用简化的递归更新
                    next_state = state  # 保持状态不变
                    y_pos = u_pos  # 简单的恒等变换
                # print(y_pos.shape)
                y[:, :, i, j] = y_pos.flatten(-2, -1)
                next_states_dict[pos_key] = next_state

        # 应用状态融合
        y = self.state_fusion(y)

        # 添加D项（跳跃连接）
        y = y + u * self.D.view(1, -1, 1, 1)

        return y, next_states_dict

    def forward_step_sequence(self, u_sequence, state=None):
        """
        处理序列输入的step推理
        Args:
            u_sequence: 序列张量 (B, C, L) 其中L是序列长度
            state: 状态张量
        Returns:
            y_sequence: 输出序列 (B, C, L)
            final_state: 最终状态
        """
        if not self._step_mode:
            raise RuntimeError("Must call setup_step_mode() before using forward_step_sequence()")

        B, C, L = u_sequence.shape
        device = u_sequence.device
        dtype = u_sequence.dtype

        if state is None:
            state = self.get_initial_state(B, device, dtype)

        y_sequence = torch.zeros_like(u_sequence)

        # 对序列中的每个时间步进行处理
        for t in range(L):
            u_t = u_sequence[:, :, t]  # (B, C)

            if hasattr(self.kernel1, 'step'):
                y_t, state = self.kernel1.step(u_t, state)
            else:
                # 如果没有step方法，使用简化处理
                y_t = u_t

            y_sequence[:, :, t] = y_t

        # 添加D项
        y_sequence = y_sequence + u_sequence * self.D.view(1, -1, 1)

        return y_sequence, state

    def forward(self, u, **kwargs):
        """原始的卷积前向传播，用于训练"""
        if self._step_mode:
            raise RuntimeError("Cannot use forward() in step mode. Use forward_step_*() instead.")

        B, C, H, W = u.shape
        u_flat = u.flatten(2, 3)  # (B, C, H*W)

        L = u_flat.size(-1)
        k1, _ = self.kernel1(L=L)  # (H L)
        k_f = torch.fft.rfft(k1.float(), n=2 * L)  # (H L)
        u_f = torch.fft.rfft(u_flat.float(), n=2 * L)  # (B H L)
        y = torch.fft.irfft(u_f * k_f, n=2 * L)[..., :L]  # (B H L)

        y = y.reshape(B, C, H, W)
        y = self.state_fusion(y)

        y_flat = y.flatten(2, 3)
        y_flat = y_flat + u_flat * self.D.unsqueeze(0).unsqueeze(-1)

        return y_flat.reshape(B, C, H, W)


class MLP(nn.Module):
    def __init__(self, in_features, hidden_features=None, out_features=None, drop=0.):
        super().__init__()
        out_features = out_features or in_features
        hidden_features = hidden_features or in_features
        self.fc1_conv = nn.Conv2d(in_features, hidden_features, kernel_size=1, stride=1)
        self.fc1_bn = nn.BatchNorm2d(hidden_features)
        self.fc1_lif = MultiStepLIFNode(tau=2.0, detach_reset=True, backend='cupy')

        self.fc2_conv = nn.Conv2d(hidden_features, out_features, kernel_size=1, stride=1)
        self.fc2_bn = nn.BatchNorm2d(out_features)
        self.fc2_lif = MultiStepLIFNode(tau=2.0, detach_reset=True, backend='cupy')

        self.c_hidden = hidden_features
        self.c_output = out_features

    def forward(self, x):
        T, B, C, H, W = x.shape
        x = self.fc1_conv(x.flatten(0, 1))
        x = self.fc1_bn(x).reshape(T, B, self.c_hidden, H, W).contiguous()
        x = self.fc1_lif(x)

        x = self.fc2_conv(x.flatten(0, 1))
        x = self.fc2_bn(x).reshape(T, B, C, H, W).contiguous()
        x = self.fc2_lif(x)
        return x


## 思考这部分的内容 实际上snn只是将对应的内容塞进去
class SSA(nn.Module):
    def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0., sr_ratio=1):
        super().__init__()
        assert dim % num_heads == 0, f"dim {dim} should be divided by num_heads {num_heads}."
        self.dim = dim
        self.num_heads = num_heads
        self.scale = 0.125

        self.conv = nn.Conv1d(dim, dim, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn = nn.BatchNorm1d(dim)

        self.state = SD_TCM(d_model=dim)
        self.state_bn = nn.LayerNorm(self.dim)
        self.state_lif = MultiStepLIFNode(tau=2.0, v_threshold=0.5, detach_reset=True, backend='cupy')

        self.proj_conv = nn.Conv1d(dim, dim, kernel_size=1, stride=1)
        self.proj_bn = nn.BatchNorm1d(dim)
        self.proj_lif = MultiStepLIFNode(tau=2.0, detach_reset=True, backend='cupy')

        # 推理模式相关
        self._step_mode = False

    def setup_step_mode(self, batch_size=16, **kwargs):
        """设置step模式"""
        self._step_mode = True
        self.state.setup_step_mode(batch_size=batch_size, **kwargs)

    def reset_step_mode(self):
        """重置为训练模式"""
        self._step_mode = False
        self.state.reset_step_mode()

    def forward_step(self, x, states_dict=None):
        """
        单步前向推理
        Args:
            x: 输入张量 (B, C, H, W) - 注意这里不再有时间维度T
            states_dict: 状态字典
        Returns:
            y: 输出张量 (B, C, H, W)
            next_states_dict: 更新的状态字典
        """
        if not self._step_mode:
            raise RuntimeError("Must call setup_step_mode() before using forward_step()")

        ## 思考这部分
        print(x.shape)
        B, C, H, W = x.shape
        x_flat = x.flatten(2, 3)  # (B, C, H*W)
        N = H * W

        # 卷积处理
        x_conv = self.bn(self.conv(x_flat))
        x_conv = x_conv.reshape(B, C, H, W)

        # LayerNorm
        x_norm = self.state_bn(x_conv.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)

        # 使用SD_TCM的step方法
        x_ssm, next_states_dict = self.state.forward_step_2d(x_norm, states_dict)

        # 投影层处理
        x_ssm_flat = x_ssm.flatten(2, 3)  # (B, C, H*W)
        x_proj = self.proj_bn(self.proj_conv(x_ssm_flat))
        x_proj = x_proj.reshape(B, C, H, W)

        return x_proj, next_states_dict

    def forward(self, x, res_attn):
        """原始前向传播，用于训练"""
        if self._step_mode:
            raise RuntimeError("Cannot use forward() in step mode. Use forward_step() instead.")

        T, B, C, H, W = x.shape
        x = x.flatten(3)
        T, B, C, N = x.shape
        x_for_qkv = x.flatten(0, 1)
        x = self.bn(self.conv(x_for_qkv)).reshape(T * B, C, H, W).contiguous()

        x = self.state_bn(x.transpose(-1, -2)).transpose(-1, -2).reshape(T, B, C, N)

        x = self.state_lif(x)
        x = x.flatten(0, 1)
        x = self.proj_lif(self.proj_bn(self.proj_conv(x)).reshape(T, B, C, H, W))

        return x, None


class Block(nn.Module):
    def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
                 drop_path=0., norm_layer=nn.LayerNorm, sr_ratio=1):
        super().__init__()
        self.norm1 = norm_layer(dim)
        self.attn = SSA(dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale,
                        attn_drop=attn_drop, proj_drop=drop, sr_ratio=sr_ratio)
        self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
        self.norm2 = norm_layer(dim)
        mlp_hidden_dim = int(dim * mlp_ratio)
        self.mlp = MLP(in_features=dim, hidden_features=mlp_hidden_dim, drop=drop)

        # 推理模式相关
        self._step_mode = False

    def setup_step_mode(self, batch_size=16, **kwargs):
        """设置step模式"""
        self._step_mode = True
        self.attn.setup_step_mode(batch_size=batch_size, **kwargs)

    def reset_step_mode(self):
        """重置为训练模式"""
        self._step_mode = False
        self.attn.reset_step_mode()

    def forward_step(self, x, states_dict=None):
        """
        单步前向推理
        Args:
            x: 输入张量 (B, C, H, W)
            states_dict: 状态字典
        Returns:
            y: 输出张量 (B, C, H, W)
            next_states_dict: 更新的状态字典
        """
        if not self._step_mode:
            raise RuntimeError("Must call setup_step_mode() before using forward_step()")

        # 注意力分支
        x_attn, next_states_dict = self.attn.forward_step(x, states_dict)
        x = x + x_attn

        # MLP分支 - 需要为MLP创建临时的时间维度
        x_mlp_input = x.unsqueeze(0)  # (1, B, C, H, W)
        x_mlp = self.mlp(x_mlp_input)
        x_mlp = x_mlp.squeeze(0)  # (B, C, H, W)
        x = x + x_mlp

        return x, next_states_dict

    def forward(self, x, res_attn):
        """原始前向传播，用于训练"""
        if self._step_mode:
            raise RuntimeError("Cannot use forward() in step mode. Use forward_step() instead.")

        x_attn, attn = self.attn(x, res_attn)
        x = x + x_attn
        x = x + self.mlp(x)
        return x, attn


class SPS(nn.Module):
    def __init__(self, img_size_h=128, img_size_w=128, patch_size=4, in_channels=2, embed_dims=256):
        super().__init__()
        self.image_size = [img_size_h, img_size_w]
        patch_size = to_2tuple(patch_size)
        self.patch_size = patch_size
        self.C = in_channels
        self.H, self.W = self.image_size[0] // patch_size[0], self.image_size[1] // patch_size[1]
        self.num_patches = self.H * self.W

        self.proj_conv = nn.Conv2d(in_channels, embed_dims // 8, kernel_size=3, stride=1, padding=1, bias=False)
        self.proj_bn = nn.BatchNorm2d(embed_dims // 8)
        self.proj_lif = MultiStepLIFNode(tau=2.0, detach_reset=True, backend='cupy')
        self.maxpool = torch.nn.MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)

        self.proj_conv1 = nn.Conv2d(embed_dims // 8, embed_dims // 4, kernel_size=3, stride=1, padding=1, bias=False)
        self.proj_bn1 = nn.BatchNorm2d(embed_dims // 4)
        self.proj_lif1 = MultiStepLIFNode(tau=2.0, detach_reset=True, backend='cupy')
        self.maxpool1 = torch.nn.MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)

        self.proj_conv2 = nn.Conv2d(embed_dims // 4, embed_dims // 2, kernel_size=3, stride=1, padding=1, bias=False)
        self.proj_bn2 = nn.BatchNorm2d(embed_dims // 2)
        self.proj_lif2 = MultiStepLIFNode(tau=2.0, detach_reset=True, backend='cupy')
        self.maxpool2 = torch.nn.MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)

        self.proj_conv3 = nn.Conv2d(embed_dims // 2, embed_dims, kernel_size=3, stride=1, padding=1, bias=False)
        self.proj_bn3 = nn.BatchNorm2d(embed_dims)
        self.proj_lif3 = MultiStepLIFNode(tau=2.0, detach_reset=True, backend='cupy')
        self.maxpool3 = torch.nn.MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)

        self.rpe_conv = nn.Conv2d(embed_dims, embed_dims, kernel_size=3, stride=1, padding=1, bias=False)
        self.rpe_bn = nn.BatchNorm2d(embed_dims)
        self.rpe_lif = MultiStepLIFNode(tau=2.0, detach_reset=True, backend='cupy')

        # 推理模式相关
        self._step_mode = False

    def setup_step_mode(self, batch_size=16):
        """设置step模式"""
        self._step_mode = True

    def reset_step_mode(self):
        """重置为训练模式"""
        self._step_mode = False

    def forward_step(self, x):
        """
        单步前向推理（用于patch embedding）
        Args:
            x: 输入张量 (B, C, H, W)
        Returns:
            y: 输出张量 (B, C', H', W')
            hw_shape: (H', W') 输出空间尺寸
        """
        if not self._step_mode:
            raise RuntimeError("Must call setup_step_mode() before using forward_step()")

        B, C, H, W = x.shape

        # 由于SPS主要做的是特征提取和降采样，我们可以直接处理
        # 第一层
        x = x.cuda()
        x = self.proj_conv(x)
        x = self.proj_bn(x)
        # 对于LIF，我们添加时间维度然后移除
        x_temp = x.unsqueeze(0)  # (1, B, C, H, W)
        x_temp = self.proj_lif(x_temp)
        x = x_temp.squeeze(0)  # (B, C, H, W)
        x = self.maxpool(x)

        # 第二层
        x = self.proj_conv1(x)
        x = self.proj_bn1(x)
        x_temp = x.unsqueeze(0)
        x_temp = self.proj_lif1(x_temp)
        x = x_temp.squeeze(0)
        x = self.maxpool1(x)

        # 第三层
        x = self.proj_conv2(x)
        x = self.proj_bn2(x)
        x_temp = x.unsqueeze(0)
        x_temp = self.proj_lif2(x_temp)
        x = x_temp.squeeze(0)
        x = self.maxpool2(x)

        # 第四层
        x = self.proj_conv3(x)
        x = self.proj_bn3(x)
        x_temp = x.unsqueeze(0)
        x_temp = self.proj_lif3(x_temp)
        x = x_temp.squeeze(0)
        x = self.maxpool3(x)

        # RPE层
        x_feat = x
        x = self.rpe_conv(x)
        x = self.rpe_bn(x)
        x_temp = x.unsqueeze(0)
        x_temp = self.rpe_lif(x_temp)
        x = x_temp.squeeze(0)
        x = x + x_feat

        H_out, W_out = H // self.patch_size[0], W // self.patch_size[1]
        return x, (H_out, W_out)

    def forward(self, x):
        """原始前向传播，用于训练"""
        if self._step_mode:
            raise RuntimeError("Cannot use forward() in step mode. Use forward_step() instead.")

        T, B, C, H, W = x.shape
        x = self.proj_conv(x.flatten(0, 1))
        x = self.proj_bn(x).reshape(T, B, -1, H, W).contiguous()
        x = self.proj_lif(x).flatten(0, 1).contiguous()
        x = self.maxpool(x)

        x = self.proj_conv1(x)
        x = self.proj_bn1(x).reshape(T, B, -1, H // 2, W // 2).contiguous()
        x = self.proj_lif1(x).flatten(0, 1).contiguous()
        x = self.maxpool1(x)

        x = self.proj_conv2(x)
        x = self.proj_bn2(x).reshape(T, B, -1, H // 4, W // 4).contiguous()
        x = self.proj_lif2(x).flatten(0, 1).contiguous()
        x = self.maxpool2(x)

        x = self.proj_conv3(x)
        x = self.proj_bn3(x).reshape(T, B, -1, H // 8, W // 8).contiguous()
        x = self.proj_lif3(x).flatten(0, 1).contiguous()
        x = self.maxpool3(x)

        x_feat = x.reshape(T, B, -1, H // 16, W // 16).contiguous()
        x = self.rpe_conv(x)
        x = self.rpe_bn(x).reshape(T, B, -1, H // 16, W // 16).contiguous()
        x = self.rpe_lif(x)
        x = x + x_feat

        H, W = H // self.patch_size[0], W // self.patch_size[1]
        return x, (H, W)


class Spikformer(nn.Module):
    def __init__(self,
                 img_size_h=128, img_size_w=128, patch_size=16, in_channels=2, num_classes=11,
                 embed_dims=[64, 128, 256], num_heads=[1, 2, 4], mlp_ratios=[4, 4, 4], qkv_bias=False, qk_scale=None,
                 drop_rate=0., attn_drop_rate=0., drop_path_rate=0., norm_layer=nn.LayerNorm,
                 depths=[6, 8, 6], sr_ratios=[8, 4, 2]
                 ):
        super().__init__()
        self.num_classes = num_classes
        self.depths = depths

        dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depths)]

        patch_embed = SPS(img_size_h=img_size_h,
                          img_size_w=img_size_w,
                          patch_size=patch_size,
                          in_channels=in_channels,
                          embed_dims=embed_dims)

        block = nn.ModuleList([Block(
            dim=embed_dims, num_heads=num_heads, mlp_ratio=mlp_ratios, qkv_bias=qkv_bias,
            qk_scale=qk_scale, drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[j],
            norm_layer=norm_layer, sr_ratio=sr_ratios)
            for j in range(depths)])

        setattr(self, f"patch_embed", patch_embed)
        setattr(self, f"block", block)

        # 分类头
        self.head = nn.Linear(embed_dims, num_classes) if num_classes > 0 else nn.Identity()
        self.apply(self._init_weights)

        # 推理模式相关
        self._step_mode = False

    def setup_step_mode(self, batch_size=16, **kwargs):
        """设置所有组件的step模式"""
        self._step_mode = True

        # 设置patch embedding的step模式
        patch_embed = getattr(self, "patch_embed")
        patch_embed.setup_step_mode(batch_size=batch_size)

        # 设置所有block的step模式
        block = getattr(self, "block")
        for blk in block:
            blk.setup_step_mode(batch_size=batch_size, **kwargs)

        print(f"Model setup for step inference with batch_size={batch_size}")

    def reset_step_mode(self):
        """重置为训练模式"""
        self._step_mode = False

        patch_embed = getattr(self, "patch_embed")
        patch_embed.reset_step_mode()

        block = getattr(self, "block")
        for blk in block:
            blk.reset_step_mode()

        print("Model reset to training mode")

    @torch.jit.ignore
    def _get_pos_embed(self, pos_embed, patch_embed, H, W):
        if H * W == self.patch_embed1.num_patches:
            return pos_embed
        else:
            return F.interpolate(
                pos_embed.reshape(1, patch_embed.H, patch_embed.W, -1).permute(0, 3, 1, 2),
                size=(H, W), mode="bilinear").reshape(1, -1, H * W).permute(0, 2, 1)

    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            trunc_normal_(m.weight, std=.02)
            if isinstance(m, nn.Linear) and m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.LayerNorm):
            nn.init.constant_(m.bias, 0)
            nn.init.constant_(m.weight, 1.0)

    def forward_step(self, x, states_dict=None):
        """
        单步前向推理，支持batch_size=16
        Args:
            x: 输入张量 (B, C, H, W) where B can be 16
            states_dict: 状态字典，存储所有层的状态
        Returns:
            output: 分类输出 (B, num_classes)
            next_states_dict: 更新后的状态字典
        """
        if not self._step_mode:
            raise RuntimeError("Must call setup_step_mode() before using forward_step()")

        # 初始化状态字典
        if states_dict is None:
            states_dict = {f'block_{i}': None for i in range(len(getattr(self, "block")))}

        next_states_dict = {}

        # Patch embedding
        patch_embed = getattr(self, "patch_embed")
        x, (H, W) = patch_embed.forward_step(x)

        print(x.shape)
        # 通过所有blocks
        block = getattr(self, "block")
        for i, blk in enumerate(block):
            block_states = states_dict.get(f'block_{i}')
            x, next_block_states = blk.forward_step(x, block_states)
            next_states_dict[f'block_{i}'] = next_block_states

        # 全局平均池化和分类
        x = x.flatten(2).mean(2)  # (B, C)
        x = self.head(x)  # (B, num_classes)

        return x, next_states_dict

    def forward_features(self, x):
        """特征提取，用于训练"""
        if self._step_mode:
            raise RuntimeError("Cannot use forward_features() in step mode")

        block = getattr(self, f"block")
        patch_embed = getattr(self, f"patch_embed")

        x, (H, W) = patch_embed(x)
        attn = None
        for blk in block:
            x, attn = blk(x, attn)
        return x.flatten(3).mean(3)

    def forward(self, x):
        """原始前向传播，用于训练"""
        if self._step_mode:
            raise RuntimeError("Cannot use forward() in step mode. Use forward_step() instead.")

        T = 4
        x = (x.unsqueeze(0)).repeat(T, 1, 1, 1, 1)
        x = self.forward_features(x)
        x = self.head(x.mean(0))
        return x


@register_model
def spikformer(pretrained=False, **kwargs):
    model = Spikformer(
        img_size_h=224, img_size_w=224,
        patch_size=16, embed_dims=768, num_heads=8, mlp_ratios=4,
        in_channels=3, num_classes=1000, qkv_bias=False,
        norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=8, sr_ratios=1,
        **kwargs
    )
    model.default_cfg = _cfg()
    return model


# 推理包装器类
class SpikformerInference:
    """
    Spikformer推理包装器，支持batch推理和状态管理
    """

    def __init__(self, model, batch_size=16):
        self.model = model
        self.batch_size = batch_size
        self.states_dict = None
        self.is_setup = False

    def setup(self):
        """设置推理模式"""
        self.model.eval()
        self.model.setup_step_mode(batch_size=self.batch_size)
        self.reset_states()
        self.is_setup = True
        print(f"Spikformer inference setup complete for batch_size={self.batch_size}")

    def reset_states(self):
        """重置所有状态"""
        self.states_dict = None

    def __call__(self, x):
        """
        执行推理
        Args:
            x: 输入张量 (B, C, H, W) 其中B <= self.batch_size
        Returns:
            output: 分类输出 (B, num_classes)
        """
        if not self.is_setup:
            raise RuntimeError("Must call setup() before inference")

        with torch.no_grad():
            output, self.states_dict = self.model.forward_step(x, self.states_dict)

        return output

    def cleanup(self):
        """清理，恢复到训练模式"""
        if self.is_setup:
            self.model.reset_step_mode()
            self.model.train()
            self.is_setup = False
            print("Inference mode cleaned up")


# 使用示例
def example_usage():
    """使用示例"""
    # 创建模型
    model = spikformer().cuda()

    # 创建推理器
    inferencer = SpikformerInference(model, batch_size=16)

    # 设置推理模式
    inferencer.setup()

    # 准备输入数据 (batch_size=16)
    input_batch = torch.randn(16, 3, 224, 224).cuda()

    try:
        # 执行推理

        output = inferencer(input_batch)
        print(f"Inference output shape: {output.shape}")  # Should be (16, 1000)

        # 可以继续用相同的状态处理下一个batch
        # next_input = torch.randn(16, 3, 224, 224)
        # output2 = inferencer(next_input)
        # print(f"Second inference output shape: {output2.shape}")
        #
        # # 如果要处理新的序列，重置状态
        #
        # inferencer.reset_states()
        # new_sequence_input = torch.randn(16, 3, 224, 224)
        # output3 = inferencer(new_sequence_input)
        # print(f"New sequence inference output shape: {output3.shape}")

    finally:
        # 清理
        inferencer.cleanup()


if __name__ == "__main__":
    example_usage()