import torch
import torch.nn as nn
import torch.nn.functional as F
from dataclasses import dataclass
from typing import Union
import warnings

@dataclass
class ModelArgs:
    """Configuration class for Mamba model parameters.

    Attributes:
        d_model: Dimension of the model's hidden states.
        n_layer: Number of Mamba layers in the model.
        input_dim: Dimension of input features.
        output_dim: Dimension of output features.
        input_len: Length of input sequence.
        output_len: Length of output sequence.
        d_state: Dimension of SSM states. Defaults to 16.
        expand: Expansion factor for inner dimension. Defaults to 2.
        dt_rank: Rank of delta projection. Can be int or 'auto'. Defaults to 'auto'.
        d_conv: Kernel size for 1D convolution. Defaults to 4.
        bias: Whether to use bias in linear layers. Defaults to False.
        conv_bias: Whether to use bias in convolution layers. Defaults to True.
    """
    d_model: int          # 模型维度
    n_layer: int          # 层数
    input_dim: int        # 输入数据维度
    output_dim: int       # 输出数据维度
    input_len: int        # 输入序列长度
    output_len: int       # 输出序列长度
    d_state: int = 16     # 状态维度
    expand: int = 2       # 扩展因子
    dt_rank: Union[int, str] = 'auto'
    d_conv: int = 4
    bias: bool = False
    conv_bias: bool = True
    
    def __post_init__(self):
        self.d_inner = int(self.expand * self.d_model)
        if self.dt_rank == 'auto':
            self.dt_rank = min(max(16, self.d_model // 16), self.d_model)

class Mamba(nn.Module):
    """Mamba model for time series prediction with SDE integration.
    
    This model combines the Mamba architecture with stochastic differential equations
    for enhanced time series modeling capabilities.

    Args:
        args (ModelArgs): Configuration parameters for the model.
    """
    def __init__(self, args: ModelArgs):
        super().__init__()
        self.args = args
        
        # 输入投影层：将输入维度投影到模型维度
        self.input_proj = nn.Linear(args.input_dim, args.d_model)
        
        # Mamba blocks with residual connections
        self.layers = nn.ModuleList([
            ResidualBlock(args) for _ in range(args.n_layer)
        ])
        
        # 最终层标准化
        self.norm_f = RMSNorm(args.d_model)

        # 时间维度处理层
        self.time_mixer = nn.Linear(args.input_len, 1)
        
        # 输出投影层：将模型维度投影到输出维度
        self.output_proj = nn.Linear(args.d_model, args.output_dim)

    def forward(self, x, timesteps):
        """Performs autoregressive prediction using matrix operations with batch processing.
        
        Args:
            x (torch.Tensor): Input tensor of shape (batch_size, total_len, input_dim).
            timesteps (torch.Tensor): Timestamps for each sequence point of shape 
                (batch_size, total_len + 1).
        
        Returns:
            torch.Tensor: Predicted output of shape (batch_size, output_len, output_dim).
        """
        b, total_len, _ = x.shape
        output_len = self.args.output_len
        input_len = self.args.input_len
        
        # 创建输出张量
        outputs = []
        
        # 分批处理
        batch_size = 128  # 可以根据显存大小调整这个值
        num_batches = (b + batch_size - 1) // batch_size
        
        for i in range(num_batches):
            start_idx = i * batch_size
            end_idx = min((i + 1) * batch_size, b)
            
            # 处理当前批次
            x_batch = x[start_idx:end_idx]
            timesteps_batch = timesteps[start_idx:end_idx]
            
            # 创建当前批次的滑动窗口
            windows = torch.stack([
                x_batch[:, i:i+input_len, :] for i in range(output_len)
            ], dim=1)
            
            # 创建对应的时间窗口
            time_windows = torch.stack([
                timesteps_batch[:, i:i+input_len+1] for i in range(output_len)
            ], dim=1)
            
            # 重塑张量以便批量处理
            windows = windows.reshape(-1, input_len, x_batch.size(-1))
            time_windows = time_windows.reshape(-1, input_len+1)
            
            # 输入投影
            curr_x = self.input_proj(windows)
            
            # 初始状态
            initial_state = curr_x[:, 0, :].clone()
            initial_state = initial_state.repeat(1, self.args.expand)
            
            # Mamba layers
            for layer in self.layers:
                curr_x = layer(curr_x, initial_state, time_windows)
            
            # 最终标准化
            curr_x = self.norm_f(curr_x)
            
            # 时间维度处理
            curr_x = curr_x.transpose(1, 2)
            curr_x = self.time_mixer(curr_x)
            curr_x = curr_x.transpose(1, 2)
            
            # 输出投影
            pred = self.output_proj(curr_x)
            
            # 重塑回原始维度
            batch_output = pred.reshape(end_idx - start_idx, output_len, -1)
            outputs.append(batch_output)
        
        # 连接所有批次的输出
        output = torch.cat(outputs, dim=0)  # (b, output_len, output_dim)
        return output

class ResidualBlock(nn.Module):
    """Residual block combining normalization and Mamba block with SDE.
    
    Args:
        args (ModelArgs): Configuration parameters for the block.
    """
    def __init__(self, args: ModelArgs):
        super().__init__()
        self.args = args
        self.mixer = MambaBlockWithSDE(
            d_model=args.d_model,
            d_inner=args.d_inner,
            d_state=args.d_state,
            dt_rank=args.dt_rank,
            d_conv=args.d_conv,
            bias=args.bias,
            conv_bias=args.conv_bias
        )
        self.norm = RMSNorm(args.d_model)

    def forward(self, x, initial_state, timesteps):
        """Process input through the residual block.
        
        Args:
            x (torch.Tensor): Input tensor of shape (batch_size, seq_len, d_model).
            initial_state (torch.Tensor): Initial state tensor of shape (batch_size, d_inner).
            timesteps (torch.Tensor): Timestamp tensor of shape (batch_size, seq_len + 1).
        
        Returns:
            torch.Tensor: Processed tensor with residual connection.
        """
        b, l, _ = x.shape
        
        # 如果没有提供initial_state，发出警告并使用零初始化
        if initial_state is None:
            warnings.warn(
                "未提供initial_state，将使用零初始化。这可能会影响模型性能。",
                RuntimeWarning
            )
            initial_state = torch.zeros(b, self.args.d_inner, device=x.device)
        
        # 如果没有提供timesteps，发出警告并使用均匀时间步
        if timesteps is None:
            warnings.warn(
                "未提供timesteps，将使用均匀时间步[0,1]。这可能不符合实际数据的时间分布。",
                RuntimeWarning
            )
            timesteps = torch.linspace(0.0, 1.0, l + 1, device=x.device)
            timesteps = timesteps.unsqueeze(0).expand(b, -1)  # (b, l+1)
        
        output = self.mixer(self.norm(x), initial_state, timesteps) + x
        return output

class MambaBlockWithSDE(nn.Module):
    """Mamba block implementation with Stochastic Differential Equation integration.
    
    Args:
        d_model (int): Model dimension.
        d_inner (int): Inner dimension.
        d_state (int): State dimension.
        dt_rank (int): Delta projection dimension.
        d_conv (int): Convolution kernel size.
        bias (bool, optional): Use bias in linear layers. Defaults to False.
        conv_bias (bool, optional): Use bias in conv layers. Defaults to False.
    """
    def __init__(
        self,
        d_model: int,      # 模型维度
        d_inner: int,      # 内部维度
        d_state: int,      # 状态维度
        dt_rank: int,      # delta 投影维度
        d_conv: int,       # 卷积核大小
        bias: bool = False,
        conv_bias: bool = False
    ):

        super().__init__()

        # 1) 输入投影: (d_model_in -> 2*d_inner)
        self.in_proj = nn.Linear(d_model, 2 * d_inner, bias=bias)

        # 2) 1D 卷积: 在 (b, d_inner, l) 上进行逐通道卷积
        self.conv1d = nn.Conv1d(
            in_channels=d_inner,
            out_channels=d_inner,
            kernel_size=d_conv,
            groups=d_inner,
            padding=d_conv - 1,
            bias=conv_bias,
        )

        # 3) x_proj 用于把特征映射到 (dt_rank + 2*d_inner) 后拆分出 delta, B
        self.x_proj = nn.Linear(d_inner, dt_rank + d_inner, bias=False)
        self.dt_proj = nn.Linear(dt_rank, d_inner, bias=True)

        # 4) learnable 对角 A, 对角 D
        self.A_log = nn.Parameter(torch.log(torch.arange(1, d_inner + 1).float()))
        self.D = nn.Parameter(torch.ones(d_inner))

        # 5) 输出投影: (d_inner -> d_model_out)
        self.out_proj = nn.Linear(d_inner, d_model, bias=bias)

        # 6) 小网络 f_theta, g_theta
        self.f_theta = nn.Sequential(
            nn.Linear(d_inner, d_state),
            nn.ReLU(),
            nn.Linear(d_state, d_inner)
        )
        self.g_theta = nn.Sequential(
            nn.Linear(d_inner, d_state),
            nn.ReLU(),
            nn.Linear(d_state, d_inner)
        )

    def forward(self, x, initial_state, timesteps):
        """Process input through the Mamba block with SDE integration.
        
        Args:
            x (torch.Tensor): Input tensor of shape (batch_size, seq_len, d_model_in).
            initial_state (torch.Tensor): Initial state of shape (batch_size, d_inner).
            timesteps (torch.Tensor): Timestamps of shape (seq_len + 1,).
        
        Returns:
            torch.Tensor: Output tensor of shape (batch_size, seq_len, d_model_out).
        """
        b, l, d_in = x.shape
        assert d_in == self.in_proj.in_features, \
            f"输入维度 {d_in} 和 in_proj 所需的 {self.in_proj.in_features} 不匹配!"

        # 1) 分割输入成为 x_main, x_gate
        x_res = self.in_proj(x)  # (b, l, 2*d_inner)
        x_main, x_gate = x_res.chunk(2, dim=-1)  # (b, l, d_inner), (b, l, d_inner)

        # 2) 1D 卷积 (b, d_inner, l)
        x_main = x_main.permute(0, 2, 1)
        x_main = self.conv1d(x_main)[..., :l]
        x_main = x_main.permute(0, 2, 1)  # 回到 (b, l, d_inner)
        x_main = F.silu(x_main)

        # 3) 调用内部的 SSM + SDE 函数
        h0 = torch.zeros(b, x_main.size(-1), device=x.device)
        y = self.ssm_with_sde(x_main, initial_state, h0, timesteps)

        # 4) 残差 gating + 输出投影
        y = y * F.silu(x_gate)
        out = self.out_proj(y)  # (b, l, d_model_out)
        return out

    def ssm_with_sde(self, x_main, initial_state, h0, timesteps,
                    dt_min=1e-6, dt_max=10.0,
                    val_min=-1e6, val_max=1e6):
        """Implements the SSM (State Space Model) with SDE (Stochastic Differential Equation).
        
        Args:
            x_main (torch.Tensor): Main input features.
            initial_state (torch.Tensor): Initial state for SDE.
            h0 (torch.Tensor): Initial hidden state for SSM.
            timesteps (torch.Tensor): Sequence of timestamps.
            dt_min (float, optional): Minimum time delta. Defaults to 1e-6.
            dt_max (float, optional): Maximum time delta. Defaults to 10.0.
            val_min (float, optional): Minimum clamp value. Defaults to -1e6.
            val_max (float, optional): Maximum clamp value. Defaults to 1e6.
        
        Returns:
            torch.Tensor: Processed sequence incorporating both SSM and SDE dynamics.
        """
        b, l, d_in = x_main.shape
        A_diag = -torch.exp(self.A_log)

        # 计算所有时间步的 delta 和 B
        x_dbl = self.x_proj(x_main)
        dt_dim = self.dt_proj.in_features
        delta_raw, B_raw = x_dbl.split([dt_dim, d_in], dim=-1)
        delta = F.softplus(self.dt_proj(delta_raw))

        # 首先更新SSM系统状态
        h_t = h0
        for t in range(l):
            dt = torch.clamp(timesteps[:, t+1] - timesteps[:, t], min=dt_min, max=dt_max)
            dt = dt.unsqueeze(-1).expand(-1, d_in)
            
            delta_t = delta[:, t, :] * dt
            deltaA = torch.clamp(torch.exp(A_diag * delta_t), min=1e-12, max=1e12)
            deltaB_u = B_raw[:, t, :] * x_main[:, t, :] * delta_t
            h_t = torch.clamp(h_t * deltaA + deltaB_u, min=val_min, max=val_max)

        # 使用最终状态计算f和g
        f_final = self.f_theta(h_t)  # (b, d_in)
        g_final = self.g_theta(h_t)  # (b, d_in)

        # 使用f和g计算整个序列的预测
        y_t = initial_state
        outputs = []
        for t in range(l):
            dt = torch.clamp(timesteps[:, t+1] - timesteps[:, t], min=dt_min, max=dt_max)
            dt = dt.unsqueeze(-1)
            
            drift = torch.clamp(f_final * dt, min=val_min, max=val_max)
            diffusion = torch.clamp(g_final * torch.randn_like(y_t) * torch.sqrt(dt), 
                                  min=val_min, max=val_max)
            
            y_t = torch.clamp(y_t + drift + diffusion, min=val_min, max=val_max)
            outputs.append(y_t)

        outputs = torch.stack(outputs, dim=1)
        outputs = torch.clamp(outputs + x_main * self.D.unsqueeze(0), 
                            min=val_min, max=val_max)
        
        return outputs

class RMSNorm(nn.Module):
    """Root Mean Square Layer Normalization.
    
    Args:
        d_model (int): Dimension of the model.
        eps (float, optional): Small constant for numerical stability. Defaults to 1e-5.
    """
    def __init__(self, d_model: int, eps: float = 1e-5):
        super().__init__()
        self.eps = eps
        self.weight = nn.Parameter(torch.ones(d_model))

    def forward(self, x):
        output = x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) * self.weight
        return output

def create_mamba_model(
    input_dim: int,
    output_dim: int,
    input_len: int,
    output_len: int,
    d_model: int = 768,
    n_layer: int = 12,
    d_state: int = 16,
    expand: int = 2,
    dt_rank: Union[int, str] = 'auto',
    d_conv: int = 4,
    bias: bool = False,
    conv_bias: bool = True
) -> Mamba:
    """Creates a Mamba model for time series prediction.
    
    Args:
        input_dim (int): Dimension of input features.
        output_dim (int): Dimension of output predictions.
        input_len (int): Length of input sequence.
        output_len (int): Length of output sequence.
        d_model (int, optional): Model dimension. Defaults to 768.
        n_layer (int, optional): Number of layers. Defaults to 12.
        d_state (int, optional): State dimension. Defaults to 16.
        expand (int, optional): Expansion factor. Defaults to 2.
        dt_rank (Union[int, str], optional): Delta projection rank. Defaults to 'auto'.
        d_conv (int, optional): Convolution kernel size. Defaults to 4.
        bias (bool, optional): Use bias in linear layers. Defaults to False.
        conv_bias (bool, optional): Use bias in conv layers. Defaults to True.
    
    Returns:
        Mamba: Configured Mamba model instance.
    """
    args = ModelArgs(
        d_model=d_model,
        n_layer=n_layer,
        input_dim=input_dim,
        output_dim=output_dim,
        input_len=input_len,
        output_len=output_len,
        d_state=d_state,
        expand=expand,
        dt_rank=dt_rank,
        d_conv=d_conv,
        bias=bias,
        conv_bias=conv_bias
    )
    return Mamba(args)

def test_mamba_model():
    """Tests the Mamba model with synthetic time series data.
    
    Creates a test model, generates synthetic data, and performs a forward pass
    to verify the model's functionality and output dimensions.
    """
    # 创建时间序列测试模型
    model = create_mamba_model(
        input_dim=1,      # 输入特征维度
        output_dim=1,     # 输出预测维度
        input_len=100,    # 输入序列长度
        output_len=50,    # 输出序列长度
        d_model=256,      # 模型维度
        n_layer=4         # 层数
    )
    
    # 生成测试数据
    batch_size = 16
    total_len = model.args.input_len + model.args.output_len
    input_features = torch.randn(batch_size, total_len, 1)  # (batch, input_len+output_len, input_dim)
    
    # 生成实际的时间步序列（假设是每小时一个数据点）
    timesteps = torch.arange(0, total_len + 1, dtype=torch.float32)  # 包含所有时间点
    timesteps = timesteps.unsqueeze(0).expand(batch_size, -1)  # (batch_size, total_len+1)
    
    print("\n模型配置:")
    print(f"Input dimension: {model.args.input_dim}")
    print(f"Output dimension: {model.args.output_dim}")
    print(f"Input length: {model.args.input_len}")
    print(f"Output length: {model.args.output_len}")
    print(f"Model dimension: {model.args.d_model}")
    print(f"Inner dimension: {model.args.d_inner}")
    
    print("\n数据维度:")
    print(f"Input features shape: {input_features.shape}")
    print(f"Timesteps shape: {timesteps.shape}")
    
    # 前向传播
    output = model(input_features, timesteps)
    
    print("\n输出维度:")
    print(f"Output shape: {output.shape}")
    assert output.shape == (batch_size, model.args.output_len, model.args.output_dim)
    
    # 数据示例
    print("\n输入数据示例 (第一个批次的前5个时间点):")
    print(input_features[0, :5, 0].tolist())
    
    print("\n输出数据示例 (第一个批次的前5个时间点):")
    print(output[0, :5, 0].tolist())
    
    print("\n时间戳示例 (一个滑动窗口的时间步):")
    print(timesteps[0, :5].tolist())
    
    # 验证滑动窗口的预测
    print("\n验证第一个滑动窗口:")
    first_window = input_features[:, :model.args.input_len, :]
    print(f"First window shape: {first_window.shape}")
    print("\n验证最后一个滑动窗口:")
    last_window = input_features[:, -model.args.input_len:, :]
    print(f"Last window shape: {last_window.shape}")
    
    print("\nTest passed!")

# if __name__ == "__main__":
#     test_mamba_model()