import math
import torch
import torch.nn as nn
from timm.layers import SqueezeExcite
import torch.nn.functional as F

"""《EfficientViM: Efficient Vision Mamba with Hidden State Mixer based State Space Duality》 CVPR 2025
对于在资源受限的环境中部署神经网络，先前的研究已经构建了轻量级架构，分别使用卷积和注意力来捕获局部和全局依赖关系。最近，状态空间模型已成为一种有效的全局令牌交互，其计算成本在令牌数量上具有良好的线性关系。
然而，使用 SSM 构建的高效视觉主干探索较少。在本文中，我们介绍了 Efficient Vision Mamba (EfficientViM)，
这是一种基于隐藏状态混合器的状态空间对偶 (hidden state mixer-based state space duality,HSM-SSD) 构建的新型架构，可有效捕获全局依赖关系并进一步降低计算成本。
在 HSM-SSD 层中，我们重新设计了之前的 SSD 层以在隐藏状态下启用通道混合操作。此外，我们提出了多阶段隐藏状态融合以进一步增强隐藏状态的表示能力，并提供缓解内存限制操作造成的瓶颈的设计。
因此，EfficientViM 系列在 ImageNet-1k 上实现了速度与准确度之间的最佳平衡，与速度更快的第二佳模型 SHViT 相比，性能提升高达 0.7%。
此外，与之前的研究相比，我们在缩放图像或采用蒸馏训练时观察到吞吐量和准确度的显著提高。
"""


class LayerNorm2D(nn.Module):
    """LayerNorm for channels of 2D tensor(B C H W)"""

    def __init__(self, num_channels, eps=1e-5, affine=True):
        super(LayerNorm2D, self).__init__()
        self.num_channels = num_channels
        self.eps = eps
        self.affine = affine

        if self.affine:
            self.weight = nn.Parameter(torch.ones(1, num_channels, 1, 1))
            self.bias = nn.Parameter(torch.zeros(1, num_channels, 1, 1))
        else:
            self.register_parameter('weight', None)
            self.register_parameter('bias', None)

    def forward(self, x):
        mean = x.mean(dim=1, keepdim=True)  # (B, 1, H, W)
        var = x.var(dim=1, keepdim=True, unbiased=False)  # (B, 1, H, W)

        x_normalized = (x - mean) / torch.sqrt(var + self.eps)  # (B, C, H, W)

        if self.affine:
            x_normalized = x_normalized * self.weight + self.bias

        return x_normalized


class LayerNorm1D(nn.Module):
    """LayerNorm for channels of 1D tensor(B C L)"""

    def __init__(self, num_channels, eps=1e-5, affine=True):
        super(LayerNorm1D, self).__init__()
        self.num_channels = num_channels
        self.eps = eps
        self.affine = affine

        if self.affine:
            self.weight = nn.Parameter(torch.ones(1, num_channels, 1))
            self.bias = nn.Parameter(torch.zeros(1, num_channels, 1))
        else:
            self.register_parameter('weight', None)
            self.register_parameter('bias', None)

    def forward(self, x):
        mean = x.mean(dim=1, keepdim=True)  # (B, 1, H, W)
        var = x.var(dim=1, keepdim=True, unbiased=False)  # (B, 1, H, W)

        x_normalized = (x - mean) / torch.sqrt(var + self.eps)  # (B, C, H, W)

        if self.affine:
            x_normalized = x_normalized * self.weight + self.bias

        return x_normalized


class ConvLayer2D(nn.Module):
    def __init__(self, in_dim, out_dim, kernel_size=3, stride=1, padding=0, dilation=1, groups=1, norm=nn.BatchNorm2d,
                 act_layer=nn.ReLU, bn_weight_init=1):
        super(ConvLayer2D, self).__init__()
        self.conv = nn.Conv2d(
            in_dim,
            out_dim,
            kernel_size=(kernel_size, kernel_size),
            stride=(stride, stride),
            padding=(padding, padding),
            dilation=(dilation, dilation),
            groups=groups,
            bias=False
        )
        self.norm = norm(num_features=out_dim) if norm else None
        self.act = act_layer() if act_layer else None

        if self.norm:
            torch.nn.init.constant_(self.norm.weight, bn_weight_init)
            torch.nn.init.constant_(self.norm.bias, 0)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.conv(x)
        if self.norm:
            x = self.norm(x)
        if self.act:
            x = self.act(x)
        return x


class ConvLayer1D(nn.Module):
    def __init__(self, in_dim, out_dim, kernel_size=3, stride=1, padding=0, dilation=1, groups=1, norm=nn.BatchNorm1d,
                 act_layer=nn.ReLU, bn_weight_init=1):
        super(ConvLayer1D, self).__init__()
        self.conv = nn.Conv1d(
            in_dim,
            out_dim,
            kernel_size=kernel_size,
            stride=stride,
            padding=padding,
            dilation=dilation,
            groups=groups,
            bias=False
        )
        self.norm = norm(num_features=out_dim) if norm else None
        self.act = act_layer() if act_layer else None

        if self.norm:
            torch.nn.init.constant_(self.norm.weight, bn_weight_init)
            torch.nn.init.constant_(self.norm.bias, 0)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.conv(x)
        if self.norm:
            x = self.norm(x)
        if self.act:
            x = self.act(x)
        return x


class FFN(nn.Module):
    def __init__(self, in_dim, dim):
        super().__init__()
        self.fc1 = ConvLayer2D(in_dim, dim, 1)
        self.fc2 = ConvLayer2D(dim, in_dim, 1, act_layer=None, bn_weight_init=0)

    def forward(self, x):
        x = self.fc2(self.fc1(x))
        return x


class Stem(nn.Module):
    def __init__(self, in_dim=3, dim=96):
        super().__init__()
        self.conv = nn.Sequential(
            ConvLayer2D(in_dim, dim // 8, kernel_size=3, stride=2, padding=1),
            ConvLayer2D(dim // 8, dim // 4, kernel_size=3, stride=2, padding=1),
            ConvLayer2D(dim // 4, dim // 2, kernel_size=3, stride=2, padding=1),
            ConvLayer2D(dim // 2, dim, kernel_size=3, stride=2, padding=1, act_layer=None))

    def forward(self, x):
        x = self.conv(x)
        return x


class PatchMerging(nn.Module):
    def __init__(self, in_dim, out_dim, ratio=4.0):
        super().__init__()
        hidden_dim = int(out_dim * ratio)
        self.conv = nn.Sequential(
            ConvLayer2D(in_dim, hidden_dim, kernel_size=1),
            ConvLayer2D(hidden_dim, hidden_dim, kernel_size=3, stride=2, padding=1, groups=hidden_dim),
            SqueezeExcite(hidden_dim, .25),
            ConvLayer2D(hidden_dim, out_dim, kernel_size=1, act_layer=None)
        )

        self.dwconv1 = ConvLayer2D(in_dim, in_dim, 3, padding=1, groups=in_dim, act_layer=None)
        self.dwconv2 = ConvLayer2D(out_dim, out_dim, 3, padding=1, groups=out_dim, act_layer=None)

    def forward(self, x):
        x = x + self.dwconv1(x)
        x = self.conv(x)
        x = x + self.dwconv2(x)
        return x


class HSMSSD1D(nn.Module):
    def __init__(self, d_model, ssd_expand=1, A_init_range=(1, 16), state_dim=64):
        super().__init__()
        self.ssd_expand = ssd_expand
        self.d_inner = int(self.ssd_expand * d_model)
        self.state_dim = state_dim

        # 对输入做一次1x1卷积，输出 3 * state_dim 通道（分别是 B, C, dt）
        self.BCdt_proj = ConvLayer1D(d_model, 3 * state_dim, kernel_size=1, norm=None, act_layer=None)

        # 对 BCdt 进行 1D Depthwise卷积混合
        conv_dim = 3 * state_dim
        self.dw = ConvLayer1D(conv_dim, conv_dim, kernel_size=3, stride=1, padding=1, groups=conv_dim, norm=None, act_layer=None, bn_weight_init=0)

        # HSM门控部分
        self.hz_proj = ConvLayer1D(d_model, 2 * self.d_inner, kernel_size=1, norm=None, act_layer=None)
        self.out_proj = ConvLayer1D(self.d_inner, d_model, kernel_size=1, norm=None, act_layer=None, bn_weight_init=0)

        # 状态转移 A 向量
        A = torch.empty(self.state_dim, dtype=torch.float32).uniform_(*A_init_range)
        self.A = torch.nn.Parameter(A)

        self.act = nn.SiLU()
        self.D = nn.Parameter(torch.ones(1))
        self.D._no_weight_decay = True

    def forward(self, x):  # x: [B, C, T]
        batch, _, T = x.shape  # T: 序列长度

        # Step1: 输入线性映射
        BCdt = self.dw(self.BCdt_proj(x))  # [B, 3D, T]

        # Step2: 拆分出状态三项
        B, C, dt = torch.split(BCdt, [self.state_dim] * 3, dim=1)  # 每个: [B, D, T]

        # Step3: 状态权重（可学习）生成
        A = (dt + self.A.view(1, -1, 1)).softmax(-1)  # [B, D, T]

        # Step4: 状态混合构建
        AB = A * B  # [B, D, T]

        # Step5: 输入与状态融合
        h = torch.matmul(x, AB.transpose(-2, -1))  # [B, C, T] @ [B, T, D] = [B, C, D]

        # Step6: HSM门控
        h, z = torch.split(self.hz_proj(h), [self.d_inner, self.d_inner], dim=1)  # [B, D_in, D]
        h = self.out_proj(h * self.act(z) + h * self.D)  # [B, D_out, D]

        # Step7: 状态读出（恢复到 T 维度）
        y = h @ C # B C N, B C L -> B C L

        return y, h  # 输出预测 y 及隐藏状态 h

class HSMSSD(nn.Module):
    def __init__(self, d_model, ssd_expand=1, A_init_range=(1, 16), state_dim=64):
        super().__init__()
        self.ssd_expand = ssd_expand
        self.d_inner = int(self.ssd_expand * d_model)
        self.state_dim = state_dim

        self.BCdt_proj = ConvLayer1D(d_model, 3 * state_dim, 1, norm=None, act_layer=None)
        conv_dim = self.state_dim * 3
        self.dw = ConvLayer2D(conv_dim, conv_dim, 3, 1, 1, groups=conv_dim, norm=None, act_layer=None, bn_weight_init=0)
        self.hz_proj = ConvLayer1D(d_model, 2 * self.d_inner, 1, norm=None, act_layer=None)
        self.out_proj = ConvLayer1D(self.d_inner, d_model, 1, norm=None, act_layer=None, bn_weight_init=0)

        A = torch.empty(self.state_dim, dtype=torch.float32).uniform_(*A_init_range)
        self.A = torch.nn.Parameter(A)
        self.act = nn.SiLU()
        self.D = nn.Parameter(torch.ones(1))
        self.D._no_weight_decay = True

    def forward(self, x):
        batch, _, L = x.shape
        H = int(math.sqrt(L))

        BCdt = self.dw(self.BCdt_proj(x).view(batch, -1, H, H)).flatten(2)
        B, C, dt = torch.split(BCdt, [self.state_dim, self.state_dim, self.state_dim], dim=1)
        A = (dt + self.A.view(1, -1, 1)).softmax(-1)

        AB = (A * B)
        h = x @ AB.transpose(-2, -1)

        h, z = torch.split(self.hz_proj(h), [self.d_inner, self.d_inner], dim=1)
        h = self.out_proj(h * self.act(z) + h * self.D)
        y = h @ C  # B C N, B C L -> B C L

        y = y.view(batch, -1, H, H).contiguous()  # + x * self.D  # B C H W
        return y, h


class EfficientViMBlock(nn.Module):
    def __init__(self, dim, mlp_ratio=4., ssd_expand=1, state_dim=64):
        super().__init__()
        self.dim = dim
        self.mlp_ratio = mlp_ratio

        self.mixer = HSMSSD(d_model=dim, ssd_expand=ssd_expand, state_dim=state_dim)
        self.norm = LayerNorm1D(dim)

        self.dwconv1 = ConvLayer2D(dim, dim, 3, padding=1, groups=dim, bn_weight_init=0, act_layer=None)
        self.dwconv2 = ConvLayer2D(dim, dim, 3, padding=1, groups=dim, bn_weight_init=0, act_layer=None)

        self.ffn = FFN(in_dim=dim, dim=int(dim * mlp_ratio))

        # LayerScale
        self.alpha = nn.Parameter(1e-4 * torch.ones(4, dim), requires_grad=True)

    def forward(self, x):
        alpha = torch.sigmoid(self.alpha).view(4, -1, 1, 1)

        # DWconv1
        x = (1 - alpha[0]) * x + alpha[0] * self.dwconv1(x)

        # HSM-SSD
        x_prev = x
        x, h = self.mixer(self.norm(x.flatten(2)))
        x = (1 - alpha[1]) * x_prev + alpha[1] * x

        # DWConv2
        x = (1 - alpha[2]) * x + alpha[2] * self.dwconv2(x)

        # FFN
        x = (1 - alpha[3]) * x + alpha[3] * self.ffn(x)
        return x, h

class CHSMSSD1D(nn.Module):
    def __init__(self, d_model, ssd_expand=1, A_init_range=(1, 16), state_dim=64):
        super().__init__()
        self.ssd_expand = ssd_expand
        self.d_inner = int(self.ssd_expand * d_model)
        self.state_dim = state_dim

        self.BCdt_proj = ConvLayer1D(d_model, 3 * state_dim, kernel_size=1, norm=None, act_layer=None)
        conv_dim = 3 * state_dim
        self.dw = ConvLayer1D(conv_dim, conv_dim, kernel_size=3, stride=1, padding=1, groups=conv_dim, norm=None,
                              act_layer=None, bn_weight_init=0)

        self.hz_proj = ConvLayer1D(d_model, 2 * self.d_inner, kernel_size=1, norm=None, act_layer=None)
        self.out_proj = ConvLayer1D(self.d_inner, d_model, kernel_size=1, norm=None, act_layer=None, bn_weight_init=0)

        A = torch.empty(self.state_dim, dtype=torch.float32).uniform_(*A_init_range)
        self.A = torch.nn.Parameter(A)

        self.act = nn.SiLU()
        self.D = nn.Parameter(torch.ones(1))
        self.D._no_weight_decay = True

    def forward(self, x):  # x: [B, C, T]
        B, _, T = x.shape

        # Step1: 输入线性映射
        BCdt = self.dw(self.BCdt_proj(x))  # [B, 3D, T]

        # Step2: 拆分出状态三项
        B_, C_, dt = torch.split(BCdt, [self.state_dim] * 3, dim=1)  # 每个: [B, D, T]

        # ✅ Step3: 正确的 softmax 维度
        logits = dt + self.A.view(1, -1, 1)  # [B, D, T]

        # ✅ Step4: 使用 masked softmax 替代 mask * softmax
        mask = torch.tril(torch.ones(T, T, device=x.device)).unsqueeze(0).unsqueeze(0)  # [1, 1, T, T]
        logits = logits.unsqueeze(-1).expand(-1, -1, T, T)  # [B, D, T, T]
        logits = logits.masked_fill(mask == 0, float('-inf'))  # 未来不可见
        A = F.softmax(logits, dim=-1)  # [B, D, T, T]

        # ✅ Step5: 使用 matmul 替代 einsum
        AB = torch.matmul(A, B_.unsqueeze(-1)).squeeze(-1)  # [B, D, T]

        # Step6: 输入与状态融合
        h = torch.matmul(x, AB.transpose(-2, -1))  # [B, C, T] @ [B, T, D] = [B, C, D]

        # Step7: HSM门控
        h, z = torch.split(self.hz_proj(h), [self.d_inner, self.d_inner], dim=1)
        h = self.out_proj(h * self.act(z) + h * self.D)

        # Step8: 状态读出
        y = h @ C_  # [B, C, D] @ [B, D, T] = [B, C, T]

        return y, h


class HSMSSD1D1(nn.Module):
    def __init__(self, d_model, ssd_expand=1, A_init_range=(1, 16), state_dim=64):
        super().__init__()
        self.ssd_expand = ssd_expand
        self.d_inner = int(self.ssd_expand * d_model)
        self.state_dim = state_dim

        # 对输入做一次1x1卷积，输出 3 * state_dim 通道（分别是 B, C, dt）
        self.BCdt_proj = ConvLayer1D(d_model, 3 * state_dim, kernel_size=1, norm=None, act_layer=None)

        # 对 BCdt 进行 1D Depthwise卷积混合
        conv_dim = 3 * state_dim
        self.dw = ConvLayer1D(conv_dim, conv_dim, kernel_size=3, stride=1, padding=1, groups=conv_dim, norm=None, act_layer=None, bn_weight_init=0)

        # HSM门控部分
        self.hz_proj = ConvLayer1D(d_model, 2 * self.d_inner, kernel_size=1, norm=None, act_layer=None)
        self.out_proj = ConvLayer1D(self.d_inner, d_model, kernel_size=1, norm=None, act_layer=None, bn_weight_init=0)

        # 状态转移 A 向量
        A = torch.empty(self.state_dim, dtype=torch.float32).uniform_(*A_init_range)
        self.A = torch.nn.Parameter(A)

        self.act = nn.SiLU()
        self.D = nn.Parameter(torch.ones(1))
        self.D._no_weight_decay = True

    def forward(self, x):  # x: [B, C, T]
        batch, _, T = x.shape  # T: 序列长度

        # Step1: 输入线性映射
        BCdt = self.dw(self.BCdt_proj(x))  # [B, 3D, T]

        # Step2: 拆分出状态三项
        B, C, dt = torch.split(BCdt, [self.state_dim] * 3, dim=1)  # 每个: [B, D, T]

        # Step3: 状态权重（可学习）生成
        A = (dt + self.A.view(1, -1, 1)).softmax(-1)  # [B, D, T]

        # Step4: 状态混合构建
        AB = A * B  # [B, D, T]

        # Step5: 输入与状态融合
        h = torch.matmul(x, AB.transpose(-2, -1))  # [B, C, T] @ [B, T, D] = [B, C, D]

        # Step6: HSM门控
        h, z = torch.split(self.hz_proj(h), [self.d_inner, self.d_inner], dim=1)  # [B, D_in, D]
        h = self.out_proj(h * self.act(z) + h * self.D)  # [B, D_out, D]

        # Step7: 状态读出（恢复到 T 维度）
        y = h @ C # B C N, B C L -> B C L

        return y, h  # 输出预测 y 及隐藏状态 h


class CHSMSSD1D2(nn.Module):
    def __init__(self, d_model, ssd_expand=1, A_init_range=(1, 16), state_dim=64):
        super().__init__()
        self.ssd_expand = ssd_expand
        self.d_inner = int(self.ssd_expand * d_model)
        self.state_dim = state_dim

        # 对输入做一次1x1卷积，输出 3 * state_dim 通道（分别是 B, C, dt）
        self.BCdt_proj = ConvLayer1D(d_model, 3 * state_dim, kernel_size=1, norm=None, act_layer=None)

        # 对 BCdt 进行 1D Depthwise卷积混合
        conv_dim = 3 * state_dim
        self.dw = ConvLayer1D(conv_dim, conv_dim, kernel_size=3, stride=1, padding=1, groups=conv_dim, norm=None,
                              act_layer=None, bn_weight_init=0)

        # HSM门控部分
        self.hz_proj = ConvLayer1D(d_model, 2 * self.d_inner, kernel_size=1, norm=None, act_layer=None)
        self.out_proj = ConvLayer1D(self.d_inner, d_model, kernel_size=1, norm=None, act_layer=None, bn_weight_init=0)

        # 状态转移 A 向量
        A = torch.empty(self.state_dim, dtype=torch.float32).uniform_(*A_init_range)
        self.A = torch.nn.Parameter(A)

        self.act = nn.SiLU()
        self.D = nn.Parameter(torch.ones(1))
        self.D._no_weight_decay = True

    def forward(self, x):  # x: [B, C, T]
        batch, _, T = x.shape  # T: 序列长度

        # Step1: 输入线性映射
        BCdt = self.dw(self.BCdt_proj(x))  # [B, 3D, T]

        # Step2: 拆分出状态三项
        B, C, dt = torch.split(BCdt, [self.state_dim] * 3, dim=1)  # 每个: [B, D, T]

        # Step3: 状态权重（具有因果性的可学习下三角）
        # 构造基础加权项 [B, D, T]
        raw_A = dt + self.A.view(1, -1, 1).softmax(-1)    # self.A: [D], broadcast to [B, D, T]

        # 构造下三角掩码 [T, T]
        mask = torch.tril(torch.ones(T, T, device=x.device))  # 下三角 1，其余为 0

        # 扩展 raw_A -> [B, D, T, T]：对每个位置广播为行向量形式
        raw_A = raw_A.unsqueeze(-1).expand(-1, -1, T, T)

        # 应用 mask，使得 A 变成严格下三角（未来不可见）
        A = raw_A * mask  # [B, D, T, T]

        # A: [B, N, T, T]
        # B: [B, N, T]
        # 输出 AB: [B, N, T]
        AB = torch.einsum('bntt,bnt->bnt', A, B)

        # Step5: 输入与状态融合
        h = torch.matmul(x, AB.transpose(-2, -1))  # [B, C, T] @ [B, T, D] = [B, C, D]

        # Step6: HSM门控
        h, z = torch.split(self.hz_proj(h), [self.d_inner, self.d_inner], dim=1)  # [B, D_in, D]
        h = self.out_proj(h * self.act(z) + h * self.D)  # [B, D_out, D]

        # Step7: 状态读出（恢复到 T 维度）
        y = h @ C  # B C N, B C L -> B C L

        return y, h  # 输出预测 y 及隐藏状态 h


class CHSMSSD1D1(nn.Module):
    def __init__(self, d_model, ssd_expand=1, A_init_range=(1, 16), state_dim=64):
        super().__init__()
        self.ssd_expand = ssd_expand
        self.d_inner = int(self.ssd_expand * d_model)
        self.state_dim = state_dim

        # 对输入做一次1x1卷积，输出 3 * state_dim 通道（分别是 B, C, dt）
        self.BCdt_proj = ConvLayer1D(d_model, 3 * state_dim, kernel_size=1, norm=None, act_layer=None)

        # 对 BCdt 进行 1D Depthwise卷积混合
        conv_dim = 3 * state_dim
        self.dw = ConvLayer1D(conv_dim, conv_dim, kernel_size=3, stride=1, padding=1, groups=conv_dim, norm=None,
                              act_layer=None, bn_weight_init=0)

        # HSM门控部分
        self.hz_proj = ConvLayer1D(d_model, 2 * self.d_inner, kernel_size=1, norm=None, act_layer=None)
        self.out_proj = ConvLayer1D(self.d_inner, d_model, kernel_size=1, norm=None, act_layer=None, bn_weight_init=0)

        # 状态转移 A 向量
        A = torch.empty(self.state_dim, dtype=torch.float32).uniform_(*A_init_range)
        self.A = torch.nn.Parameter(A)

        self.act = nn.SiLU()
        self.D = nn.Parameter(torch.ones(1))
        self.D._no_weight_decay = True

    def forward(self, x):  # x: [B, C, T]
        batch, _, T = x.shape  # T: 序列长度

        # Step1: 输入线性映射
        BCdt = self.dw(self.BCdt_proj(x))  # [B, 3D, T]

        # Step2: 拆分出状态三项
        B, C, dt = torch.split(BCdt, [self.state_dim] * 3, dim=1)  # 每个: [B, D, T]

        # Step3: 状态权重（具有因果性的可学习下三角）
        # 构造基础加权项 [B, D, T]
        raw_A = dt + self.A.view(1, -1, 1).softmax(-1)    # self.A: [D], broadcast to [B, D, T]

        # 构造下三角掩码 [T, T]
        mask = torch.tril(torch.ones(T, T, device=x.device))  # 下三角 1，其余为 0

        # 扩展 raw_A -> [B, D, T, T]：对每个位置广播为行向量形式
        raw_A = raw_A.unsqueeze(-1).expand(-1, -1, T, T)

        # 应用 mask，使得 A 变成严格下三角（未来不可见）
        A = raw_A * mask  # [B, D, T, T]

        # A: [B, N, T, T]
        # B: [B, N, T]
        # 输出 AB: [B, N, T]
        AB = torch.einsum('bntt,bnt->bnt', A, B)

        # Step5: 输入与状态融合
        h = torch.matmul(x, AB.transpose(-2, -1))  # [B, C, T] @ [B, T, D] = [B, C, D]

        # Step6: HSM门控
        h, z = torch.split(self.hz_proj(h), [self.d_inner, self.d_inner], dim=1)  # [B, D_in, D]
        h = self.out_proj(h * self.act(z) + h * self.D)  # [B, D_out, D]

        # Step7: 状态读出（恢复到 T 维度）
        y = h @ C  # B C N, B C L -> B C L

        return y, h  # 输出预测 y 及隐藏状态 h


class CausalHSMSSD1D(nn.Module):
    def __init__(self, d_model, ssd_expand, A_init_range=(1, 16), state_dim=64):
        super().__init__()
        self.ssd_expand = ssd_expand
        self.d_inner = int(self.ssd_expand * d_model)
        self.state_dim = state_dim

        self.BCdt_proj = ConvLayer1D(d_model, 3 * state_dim, kernel_size=1, norm=None, act_layer=None)
        conv_dim = 3 * state_dim
        self.dw = ConvLayer1D(conv_dim, conv_dim, kernel_size=3, stride=1, padding=1, groups=conv_dim,
                               norm=None, act_layer=None, bn_weight_init=0)

        self.hz_proj = ConvLayer1D(d_model, 2 * self.d_inner, kernel_size=1, norm=None, act_layer=None)
        self.out_proj = ConvLayer1D(self.d_inner, d_model, kernel_size=1, norm=None, act_layer=None,
                                    bn_weight_init=0)

        A = torch.empty(self.state_dim, dtype=torch.float32).uniform_(*A_init_range)
        self.A = torch.nn.Parameter(A)

        self.act = nn.SiLU()
        self.D = nn.Parameter(torch.ones(1))
        self.D._no_weight_decay = True

    def forward(self, x):  # x: [B, C, T]
        B, C, T = x.shape

        # 1. 生成 B, C, dt
        BCdt = self.dw(self.BCdt_proj(x))  # [B, 3*D, T]
        B_state, C_state, dt = torch.split(BCdt, [self.state_dim] * 3, dim=1)  # [B, D, T] × 3

        # 2. 构造因果掩码
        mask = torch.tril(torch.ones(T, T, device=x.device)).unsqueeze(0).unsqueeze(0)  # [1, 1, T, T]

        # 3. 构造权重矩阵 A ∈ [B, D, T, T]
        raw_A = dt + self.A.view(1, -1, 1)                      # [B, D, T]
        A_masked = raw_A.unsqueeze(3).expand(-1, -1, T, T)      # [B, D, T, T]
        A_masked = A_masked.masked_fill(mask == 0, -1e9)
        A = torch.softmax(A_masked, dim=-1)                     # 因果 softmax over last T

        # 4. 加权状态构建
        AB = torch.einsum("bdtt,bdt->bdt", A, B_state)  # [B, D, T]
        # 5. 状态空间融合
        h = torch.matmul(x, AB.transpose(-2, -1))               # [B, C, D]

        # 6. 通道混合与门控
        h, z = torch.split(self.hz_proj(h), [self.d_inner, self.d_inner], dim=1)
        h = self.out_proj(h * self.act(z) + h * self.D)         # [B, D_out, D]

        # 7. 状态读出
        y = torch.matmul(h, C_state)                            # [B, C, T]

        return y, h  # 预测输出与状态表示

class CausalHSMSSD1D_SimpleMask(nn.Module):
    def __init__(self, d_model, ssd_expand=1, A_init_range=(1, 16), state_dim=64):
        super().__init__()
        self.ssd_expand = ssd_expand
        self.d_inner = int(self.ssd_expand * d_model)
        self.state_dim = state_dim

        self.BCdt_proj = ConvLayer1D(d_model, 3 * state_dim, kernel_size=1, norm=None, act_layer=None)
        conv_dim = 3 * state_dim
        self.dw = ConvLayer1D(conv_dim, conv_dim, kernel_size=3, stride=1, padding=1, groups=conv_dim,
                               norm=None, act_layer=None, bn_weight_init=0)

        self.hz_proj = ConvLayer1D(d_model, 2 * self.d_inner, kernel_size=1, norm=None, act_layer=None)
        self.out_proj = ConvLayer1D(self.d_inner, d_model, kernel_size=1, norm=None, act_layer=None,
                                    bn_weight_init=0)

        A = torch.empty(self.state_dim, dtype=torch.float32).uniform_(*A_init_range)
        self.A = torch.nn.Parameter(A)

        self.act = nn.SiLU()
        self.D = nn.Parameter(torch.ones(1))
        self.D._no_weight_decay = True

    def forward(self, x):  # x: [B, C, T]
        B, C, T = x.shape

        # Step 1: 生成三组状态参数
        BCdt = self.dw(self.BCdt_proj(x))                   # [B, 3D, T]
        B_state, C_state, dt = torch.split(BCdt, self.state_dim, dim=1)  # [B, D, T] × 3

        # Step 2: 构造因果权重 A，每个 token t 只能看 <= t 的过去位置
        A_list = []
        for t in range(T):
            logits_t = dt[:, :, :t+1] + self.A.view(1, -1, 1)             # [B, D, t+1]
            A_t = torch.softmax(logits_t, dim=-1)                         # [B, D, t+1]
            A_t_padded = F.pad(A_t, (0, T - (t + 1)), value=0.0)          # pad to [B, D, T]
            A_list.append(A_t_padded.unsqueeze(2))                       # [B, D, 1, T]

        A = torch.cat(A_list, dim=2)  # [B, D, T, T]

        # Step 3: 用 A 聚合状态 B
        AB = torch.einsum("bdtt,bdt->bdt", A, B_state)  # [B, D, T]

        # Step 4: token × 状态 → 隐状态 h
        h = torch.matmul(x, AB.transpose(-2, -1))       # [B, C, D]

        # Step 5: HSM门控混合
        h, z = torch.split(self.hz_proj(h), [self.d_inner, self.d_inner], dim=1)
        h = self.out_proj(h * self.act(z) + h * self.D)

        # Step 6: 状态读出
        y = torch.matmul(h, C_state)  # [B, C, T]

        return y, h



class CausalHSMSSD1D_test(nn.Module):
    def __init__(self, d_model, ssd_expand=1, A_init_range=(1, 16), state_dim=64):
        super().__init__()
        self.ssd_expand = ssd_expand
        self.d_inner = int(self.ssd_expand * d_model)
        self.state_dim = state_dim

        # 对输入做一次1x1卷积，输出 3 * state_dim 通道（分别是 B, C, dt）
        self.BCdt_proj = ConvLayer1D(d_model, 3 * state_dim, kernel_size=1, norm=None, act_layer=None)

        # 对 BCdt 进行 1D Depthwise卷积混合
        conv_dim = 3 * state_dim
        self.dw = ConvLayer1D(conv_dim, conv_dim, kernel_size=3, stride=1, padding=1, groups=conv_dim, norm=None, act_layer=None, bn_weight_init=0)

        # HSM门控部分
        self.hz_proj = ConvLayer1D(d_model, 2 * self.d_inner, kernel_size=1, norm=None, act_layer=None)
        self.out_proj = ConvLayer1D(self.d_inner, d_model, kernel_size=1, norm=None, act_layer=None, bn_weight_init=0)

        # 状态转移 A 向量
        A = torch.empty(self.state_dim, dtype=torch.float32).uniform_(*A_init_range)
        self.A = torch.nn.Parameter(A)

        self.act = nn.SiLU()
        self.D = nn.Parameter(torch.ones(1))
        self.D._no_weight_decay = True

    def forward(self, x):  # x: [B, C, T]
        batch, _, T = x.shape  # T: 序列长度

        # Step1: 输入线性映射
        BCdt = self.dw(self.BCdt_proj(x))  # [B, 3D, T]

        # Step2: 拆分出状态三项
        B, C, dt = torch.split(BCdt, [self.state_dim] * 3, dim=1)  # 每个: [B, D, T]

        # Step3: 状态权重（可学习）生成
        # Step3: 构造因果性状态门控矩阵 A（方法一：下三角掩码）
        B, C, dt = torch.split(BCdt, [self.state_dim] * 3, dim=1)  # 每个: [B, D, T]

        # 构造因果 mask（下三角）
        mask = torch.tril(torch.ones(T, T, device=x.device)).unsqueeze(0).unsqueeze(1)  # [1,1,T,T]
        # 扩展 dt 为 full attention 分数矩阵
        A_raw = dt.unsqueeze(-2).expand(-1, -1, T, -1)  # [B, D, T, T]：每个位置对历史所有位置打分
        A = A_raw * mask  # 下三角屏蔽未来位置（方法一）

        # Step4: 状态混合构建
        AB = torch.matmul(A, B.unsqueeze(-1)).squeeze(-1)  # [B, D, T, T] x [B, D, T, 1] → [B, D, T]

        # Step5: 输入与状态融合
        h = torch.matmul(x, AB.transpose(-2, -1))  # [B, C, T] @ [B, T, D] = [B, C, D]

        # Step6: HSM门控
        h, z = torch.split(self.hz_proj(h), [self.d_inner, self.d_inner], dim=1)  # [B, D_in, D]
        h = self.out_proj(h * self.act(z) + h * self.D)  # [B, D_out, D]

        # Step7: 状态读出（恢复到 T 维度）
        y = h @ C # B C N, B C L -> B C L

        return y, h  # 输出预测 y 及隐藏状态 h

if __name__ == '__main__':
    # block = EfficientViMBlock(dim=96).to('cuda')
    #
    # input_tensor = torch.rand(4, 96, 32, 32).to('cuda')
    #
    # output, h = block(input_tensor)
    #
    # print("Input size:", input_tensor.size())
    # print("Output size:", output.size())
    # print("Intermediate h size:", h.size())

    block = CHSMSSD1D(d_model=7, ssd_expand=1).to('cuda')
    input_tensor = torch.rand(4, 7, 32).to('cuda')
    output, h = block(input_tensor)
    print("Input size:", input_tensor.size())
    print("Output size:", output.size())
    print("Intermediate h size:", h.size())

