import torch
import torch.nn as nn
from mamba_ssm import Mamba


class MambaEmbedding(nn.Module):
    def __init__(self, input_dim, latent_dim, mamba_dim=64, dropout=0.1,
                 use_mean_pooling=False, seq_len=None):
        """
        Args:
            input_dim: 输入维度
            latent_dim: 输出维度
            mamba_dim: Mamba 中间维度
            dropout: Dropout 比例
            use_mean_pooling: 为兼容原项目接口而保留，不使用
            seq_len: 为兼容原项目接口而保留，不使用
        """
        super(MambaEmbedding, self).__init__()
        self.use_mean_pooling = use_mean_pooling  # 不使用，仅保留
        self.seq_len = seq_len  # 不使用，仅保留

        self.input_norm = nn.LayerNorm(input_dim * 2)
        self.input_proj = nn.Linear(input_dim * 2, mamba_dim)

        self.mamba1 = Mamba(d_model=mamba_dim)
        self.mamba2 = Mamba(d_model=mamba_dim)

        self.dropout = nn.Dropout(dropout)
        self.norm = nn.LayerNorm(mamba_dim)
        self.output_proj = nn.Linear(mamba_dim, latent_dim)

    def forward(self, x, mask=None):
        """
        Args:
            x: Tensor (B, T, input_dim)
            mask: Tensor (B, T, input_dim) or (B, T, 1) or None
        Returns:
            Tensor (B, T, latent_dim)
        """
        print("\n===== [MambaEmbedding] =====")
        print("【输入 x】: 原始输入 ->", x.shape)
        if mask is None:
            mask = torch.ones_like(x)
        if mask.shape[-1] == 1:
            mask = mask.expand_as(x)
        print("【输入 mask】: ->", mask.shape)
        x = torch.cat([x, mask], dim=-1)  # -> (B, T, 2 * input_dim)
        print("【拼接后】: [x | mask] ->", x.shape)
        x = self.input_norm(x)
        x = self.input_proj(x)
        print("【线性映射后】: ->", x.shape)
        res1 = x
        x = self.mamba1(x)
        print("【经过 Mamba1】: ->", x.shape)
        x = x + res1

        x = self.dropout(x)

        res2 = x
        x = self.mamba2(x)
        print("【经过 Mamba2】: ->", x.shape)
        x = x + res2

        x = self.norm(x)
        x = self.output_proj(x)
        print("【最终输出】: ->", x.shape)
        return x  # (B, T, latent_dim)
