import torch
import torch.nn as nn
import torch.nn.functional as F
from dataclasses import dataclass
from typing import Union
import warnings

@dataclass
class ModelArgs:
    """Configuration class for Mamba model parameters.

    Attributes:
        d_model: Dimension of the model's hidden states.
        n_layer: Number of Mamba layers in the model.
        input_dim: Dimension of input features.
        output_dim: Dimension of output features.
        input_len: Length of input sequence.
        output_len: Length of output sequence.
        d_state: Dimension of SSM states. Defaults to 16.
        expand: Expansion factor for inner dimension. Defaults to 2.
        dt_rank: Rank of delta projection. Can be int or 'auto'. Defaults to 'auto'.
        d_conv: Kernel size for 1D convolution. Defaults to 4.
        bias: Whether to use bias in linear layers. Defaults to False.
        conv_bias: Whether to use bias in convolution layers. Defaults to True.
    """
    d_model: int
    n_layer: int
    input_dim: int
    output_dim: int
    input_len: int
    output_len: int
    d_state: int = 16
    expand: int = 2
    dt_rank: Union[int, str] = 'auto'
    d_conv: int = 4
    bias: bool = False
    conv_bias: bool = True
    
    def __post_init__(self):
        self.d_inner = int(self.expand * self.d_model)
        if self.dt_rank == 'auto':
            self.dt_rank = min(max(16, self.d_model // 16), self.d_model)

class Mamba(nn.Module):
    """Mamba model for time series prediction with SDE integration.

    Args:
        args (ModelArgs): Configuration parameters for the model.
    """
    def __init__(self, args: ModelArgs):
        super().__init__()
        self.args = args
        
        self.input_proj = nn.Linear(args.input_dim, args.d_model)
        self.layers = nn.ModuleList([
            ResidualBlock(args) for _ in range(args.n_layer)
        ])
        self.norm_f = RMSNorm(args.d_model)
        self.time_mixer = nn.Linear(args.input_len, 1)
        self.output_proj = nn.Linear(args.d_model, args.output_dim)

    def forward(self, x, timesteps):
        """
        Args:
            x (torch.Tensor): Input tensor of shape (batch_size, total_len, input_dim).
            timesteps (torch.Tensor): Timestamps for each sequence point of shape 
                (batch_size, total_len + 1).
        
        Returns:
            torch.Tensor: Predicted output of shape (batch_size, output_len, output_dim).
        """
        b, total_len, _ = x.shape
        output_len = self.args.output_len
        input_len = self.args.input_len
        
        outputs = []
        batch_size = 128
        num_batches = (b + batch_size - 1) // batch_size
        
        for i in range(num_batches):
            start_idx = i * batch_size
            end_idx = min((i + 1) * batch_size, b)
            x_batch = x[start_idx:end_idx]
            timesteps_batch = timesteps[start_idx:end_idx]
            
            windows = torch.stack([
                x_batch[:, i:i+input_len, :] for i in range(output_len)
            ], dim=1)
            
            time_windows = torch.stack([
                timesteps_batch[:, i:i+input_len+1] for i in range(output_len)
            ], dim=1)
            
            windows = windows.reshape(-1, input_len, x_batch.size(-1))
            time_windows = time_windows.reshape(-1, input_len+1)
            curr_x = self.input_proj(windows)
            initial_state = curr_x[:, 0, :].clone()
            initial_state = initial_state.repeat(1, self.args.expand)
            
            for layer in self.layers:
                curr_x = layer(curr_x, initial_state, time_windows)
            curr_x = self.norm_f(curr_x)
            curr_x = curr_x.transpose(1, 2)
            curr_x = self.time_mixer(curr_x)
            curr_x = curr_x.transpose(1, 2)
            pred = self.output_proj(curr_x)
            batch_output = pred.reshape(end_idx - start_idx, output_len, -1)
            outputs.append(batch_output)
        
        output = torch.cat(outputs, dim=0)
        return output

class ResidualBlock(nn.Module):
    """Residual block combining normalization and Mamba block with SDE.
    Args:
        args (ModelArgs): Configuration parameters for the block.
    """
    def __init__(self, args: ModelArgs):
        super().__init__()
        self.args = args
        self.mixer = MambaBlockWithSDE(
            d_model=args.d_model,
            d_inner=args.d_inner,
            d_state=args.d_state,
            dt_rank=args.dt_rank,
            d_conv=args.d_conv,
            bias=args.bias,
            conv_bias=args.conv_bias
        )
        self.norm = RMSNorm(args.d_model)

    def forward(self, x, initial_state, timesteps):
        b, l, _ = x.shape
        if initial_state is None:
            warnings.warn(
                "No initial_state provided. Using zeros. This may affect model performance.",
                RuntimeWarning
            )
            initial_state = torch.zeros(b, self.args.d_inner, device=x.device)
        if timesteps is None:
            warnings.warn(
                "No timesteps provided. Using uniform [0,1]. This may not match your data.",
                RuntimeWarning
            )
            timesteps = torch.linspace(0.0, 1.0, l + 1, device=x.device)
            timesteps = timesteps.unsqueeze(0).expand(b, -1)
        output = self.mixer(self.norm(x), initial_state, timesteps) + x
        return output

class MambaBlockWithSDE(nn.Module):
    """Mamba block implementation with Stochastic Differential Equation integration.
    Args:
        d_model (int): Model dimension.
        d_inner (int): Inner dimension.
        d_state (int): State dimension.
        dt_rank (int): Delta projection dimension.
        d_conv (int): Convolution kernel size.
        bias (bool, optional): Use bias in linear layers. Defaults to False.
        conv_bias (bool, optional): Use bias in conv layers. Defaults to False.
    """
    def __init__(
        self,
        d_model: int,
        d_inner: int,
        d_state: int,
        dt_rank: int,
        d_conv: int,
        bias: bool = False,
        conv_bias: bool = False
    ):
        super().__init__()
        self.in_proj = nn.Linear(d_model, 2 * d_inner, bias=bias)
        self.conv1d = nn.Conv1d(
            in_channels=d_inner,
            out_channels=d_inner,
            kernel_size=d_conv,
            groups=d_inner,
            padding=d_conv - 1,
            bias=conv_bias,
        )
        self.x_proj = nn.Linear(d_inner, dt_rank + d_inner, bias=False)
        self.dt_proj = nn.Linear(dt_rank, d_inner, bias=True)
        self.A_log = nn.Parameter(torch.log(torch.arange(1, d_inner + 1).float()))
        self.D = nn.Parameter(torch.ones(d_inner))
        self.out_proj = nn.Linear(d_inner, d_model, bias=bias)
        self.f_theta = nn.Sequential(
            nn.Linear(d_inner, d_state),
            nn.ReLU(),
            nn.Linear(d_state, d_inner)
        )
        self.g_theta = nn.Sequential(
            nn.Linear(d_inner, d_state),
            nn.ReLU(),
            nn.Linear(d_state, d_inner)
        )

    def forward(self, x, initial_state, timesteps):
        b, l, d_in = x.shape
        assert d_in == self.in_proj.in_features, \
            f"Input dim {d_in} does not match in_proj ({self.in_proj.in_features})!"
        x_res = self.in_proj(x)
        x_main, x_gate = x_res.chunk(2, dim=-1)
        x_main = x_main.permute(0, 2, 1)
        x_main = self.conv1d(x_main)[..., :l]
        x_main = x_main.permute(0, 2, 1)
        x_main = F.silu(x_main)
        h0 = torch.zeros(b, x_main.size(-1), device=x.device)
        y = self.ssm_with_sde(x_main, initial_state, h0, timesteps)
        y = y * F.silu(x_gate)
        out = self.out_proj(y)
        return out

    def ssm_with_sde(self, x_main, initial_state, h0, timesteps,
                     dt_min=1e-6, dt_max=10.0,
                     val_min=-1e6, val_max=1e6):
        b, l, d_in = x_main.shape
        A_diag = -torch.exp(self.A_log)
        x_dbl = self.x_proj(x_main)
        dt_dim = self.dt_proj.in_features
        delta_raw, B_raw = x_dbl.split([dt_dim, d_in], dim=-1)
        delta = F.softplus(self.dt_proj(delta_raw))
        h_t = h0
        for t in range(l):
            dt = torch.clamp(timesteps[:, t+1] - timesteps[:, t], min=dt_min, max=dt_max)
            dt = dt.unsqueeze(-1).expand(-1, d_in)
            delta_t = delta[:, t, :] * dt
            deltaA = torch.clamp(torch.exp(A_diag * delta_t), min=1e-12, max=1e12)
            deltaB_u = B_raw[:, t, :] * x_main[:, t, :] * delta_t
            h_t = torch.clamp(h_t * deltaA + deltaB_u, min=val_min, max=val_max)
        f_final = self.f_theta(h_t)
        g_final = self.g_theta(h_t)
        y_t = initial_state
        outputs = []
        for t in range(l):
            dt = torch.clamp(timesteps[:, t+1] - timesteps[:, t], min=dt_min, max=dt_max)
            dt = dt.unsqueeze(-1)
            drift = torch.clamp(f_final * dt, min=val_min, max=val_max)
            diffusion = torch.clamp(g_final * torch.randn_like(y_t) * torch.sqrt(dt), 
                                  min=val_min, max=val_max)
            y_t = torch.clamp(y_t + drift + diffusion, min=val_min, max=val_max)
            outputs.append(y_t)
        outputs = torch.stack(outputs, dim=1)
        outputs = torch.clamp(outputs + x_main * self.D.unsqueeze(0), 
                              min=val_min, max=val_max)
        return outputs

class RMSNorm(nn.Module):
    """Root Mean Square Layer Normalization.
    Args:
        d_model (int): Dimension of the model.
        eps (float, optional): Small constant for numerical stability. Defaults to 1e-5.
    """
    def __init__(self, d_model: int, eps: float = 1e-5):
        super().__init__()
        self.eps = eps
        self.weight = nn.Parameter(torch.ones(d_model))

    def forward(self, x):
        output = x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) * self.weight
        return output

def create_mamba_model(
    input_dim: int,
    output_dim: int,
    input_len: int,
    output_len: int,
    d_model: int = 768,
    n_layer: int = 12,
    d_state: int = 16,
    expand: int = 2,
    dt_rank: Union[int, str] = 'auto',
    d_conv: int = 4,
    bias: bool = False,
    conv_bias: bool = True
) -> Mamba:
    """Creates a Mamba model for time series prediction.
    Returns:
        Mamba: Configured Mamba model instance.
    """
    args = ModelArgs(
        d_model=d_model,
        n_layer=n_layer,
        input_dim=input_dim,
        output_dim=output_dim,
        input_len=input_len,
        output_len=output_len,
        d_state=d_state,
        expand=expand,
        dt_rank=dt_rank,
        d_conv=d_conv,
        bias=bias,
        conv_bias=conv_bias
    )
    return Mamba(args)

def test_mamba_model():
    """Tests the Mamba model with synthetic time series data."""
    model = create_mamba_model(
        input_dim=1,
        output_dim=1,
        input_len=100,
        output_len=50,
        d_model=256,
        n_layer=4
    )
    batch_size = 16
    total_len = model.args.input_len + model.args.output_len
    input_features = torch.randn(batch_size, total_len, 1)
    timesteps = torch.arange(0, total_len + 1, dtype=torch.float32)
    timesteps = timesteps.unsqueeze(0).expand(batch_size, -1)
    print("\nModel config:")
    print(f"Input dimension: {model.args.input_dim}")
    print(f"Output dimension: {model.args.output_dim}")
    print(f"Input length: {model.args.input_len}")
    print(f"Output length: {model.args.output_len}")
    print(f"Model dimension: {model.args.d_model}")
    print(f"Inner dimension: {model.args.d_inner}")
    print("\nData shapes:")
    print(f"Input features shape: {input_features.shape}")
    print(f"Timesteps shape: {timesteps.shape}")
    output = model(input_features, timesteps)
    print("\nOutput shape:")
    print(f"Output shape: {output.shape}")
    assert output.shape == (batch_size, model.args.output_len, model.args.output_dim)
    print("\nInput example (first batch, first 5 time points):")
    print(input_features[0, :5, 0].tolist())
    print("\nOutput example (first batch, first 5 time points):")
    print(output[0, :5, 0].tolist())
    print("\nTimestep example (first window):")
    print(timesteps[0, :5].tolist())
    print("\nFirst window shape:")
    first_window = input_features[:, :model.args.input_len, :]
    print(f"First window shape: {first_window.shape}")
    print("\nLast window shape:")
    last_window = input_features[:, -model.args.input_len:, :]
    print(f"Last window shape: {last_window.shape}")
    print("\nTest passed!")

# if __name__ == "__main__":
#     test_mamba_model()
