import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange, repeat

from torch._higher_order_ops import associative_scan

# Define the associative binary operator for the scan
# Corresponds to: (A_j * A_i, A_j * u_i + u_j)
def ssm_combine_fn(left, right):
    a_left, b_left = left
    a_right, b_right = right
    
    # a_right * a_left
    new_a = a_right * a_left 
    # a_right * b_left + b_right
    new_b = a_right * b_left + b_right
    
    return new_a, new_b

class Mamba(nn.Module):
    def __init__(
        self,
        d_model,
        d_state=16,
        d_conv=4,
        expand=2,
        dt_rank="auto",
        dt_min=0.001,
        dt_max=0.1,
        dt_init="random",
        dt_scale=1.0,
        dt_init_floor=1e-4,
        conv_bias=True,
        bias=False,
        layer_idx=None,
        device=None,
        dtype=None,
    ):
        factory_kwargs = {"device": device, "dtype": dtype}
        super().__init__()
        self.d_model = d_model
        self.d_state = d_state
        self.d_conv = d_conv
        self.expand = expand
        self.d_inner = int(self.expand * self.d_model)
        self.dt_rank = math.ceil(self.d_model / 16) if dt_rank == "auto" else dt_rank
        self.layer_idx = layer_idx

        self.in_proj = nn.Linear(self.d_model, self.d_inner * 2, bias=bias, **factory_kwargs)

        self.conv1d = nn.Conv1d(
            in_channels=self.d_inner,
            out_channels=self.d_inner,
            bias=conv_bias,
            kernel_size=d_conv,
            groups=self.d_inner,
            padding=d_conv - 1,
            **factory_kwargs,
        )

        self.activation = "silu"
        self.act = nn.SiLU()

        self.x_proj = nn.Linear(
            self.d_inner, self.dt_rank + self.d_state * 2, bias=False, **factory_kwargs
        )
        self.dt_proj = nn.Linear(self.dt_rank, self.d_inner, bias=True, **factory_kwargs)

        # Initialize special dt projection to preserve variance at initialization
        dt_init_std = self.dt_rank**-0.5 * dt_scale
        if dt_init == "constant":
            nn.init.constant_(self.dt_proj.weight, dt_init_std)
        elif dt_init == "random":
            nn.init.uniform_(self.dt_proj.weight, -dt_init_std, dt_init_std)
        else:
            raise NotImplementedError

        # Initialize dt bias so that F.softplus(dt_bias) is between dt_min and dt_max
        dt = torch.exp(
            torch.rand(self.d_inner, **factory_kwargs) * (math.log(dt_max) - math.log(dt_min))
            + math.log(dt_min)
        ).clamp(min=dt_init_floor)
        # Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759
        inv_dt = dt + torch.log(-torch.expm1(-dt))
        with torch.no_grad():
            self.dt_proj.bias.copy_(inv_dt)
        # Our initialization would set all Linear.bias to zero, need to mark this one as _no_reinit
        self.dt_proj.bias._no_reinit = True

        # S4D real initialization
        A = repeat(
            torch.arange(1, self.d_state + 1, dtype=torch.float32, device=device),
            "n -> d n",
            d=self.d_inner,
        ).contiguous()
        A_log = torch.log(A)  # Keep A_log in fp32
        self.A_log = nn.Parameter(A_log)
        self.A_log._no_weight_decay = True

        # D "skip" parameter
        self.D = nn.Parameter(torch.ones(self.d_inner, device=device))  # Keep in fp32
        self.D._no_weight_decay = True

        self.out_proj = nn.Linear(self.d_inner, self.d_model, bias=bias, **factory_kwargs)

    def forward(self, hidden_states, inference_params=None):
        """
        hidden_states: (B, L, D)
        Returns: same shape as hidden_states
        """
        batch, seqlen, dim = hidden_states.shape

        # 1. Project inputs
        # (B, L, D) -> (B, L, 2*D_inner)
        xz = self.in_proj(hidden_states)
        x, z = xz.chunk(2, dim=-1) # (B, L, D_inner)

        # 2. Convolution (1D Causal)
        # Rearrange for Conv1d: (B, L, D) -> (B, D, L)
        x = x.transpose(1, 2)
        x = self.conv1d(x)[:, :, :seqlen] # Causal truncation
        x = self.act(x)
        x = x.transpose(1, 2) # Back to (B, L, D_inner)

        # 3. State Space Model Parameters
        # x_proj takes (B, L, D_inner) -> (B, L, dt_rank + 2*d_state)
        x_dbl = self.x_proj(x)
        dt, B, C = torch.split(x_dbl, [self.dt_rank, self.d_state, self.d_state], dim=-1)
        
        # dt projection: (B, L, dt_rank) -> (B, L, D_inner)
        dt = self.dt_proj(dt)

        # 4. Discretization
        # dt: (B, L, D)
        dt = F.softplus(dt)
        
        # A: (D, N) -> (B, L, D, N) via broadcasting
        # dA = exp(dt * A)
        A = -torch.exp(self.A_log.float()) # (D, N)
        # Broadcast A to (1, 1, D, N) and dt to (B, L, D, 1)
        dA = torch.exp(dt.unsqueeze(-1) * A) # (B, L, D, N)
        
        # B: (B, L, N)
        # x: (B, L, D)
        # u = (dt * B) * x  -> This represents the input term to the SSM
        # Reshape for broadcasting:
        # dt: (B, L, D, 1)
        # B:  (B, L, 1, N)
        # x:  (B, L, D, 1)
        dt_u = dt.unsqueeze(-1)
        B_u = B.unsqueeze(2)
        x_u = x.unsqueeze(-1)
        
        # u: (B, L, D, N)
        u = dt_u * B_u * x_u

        # 5. Associative Scan (The Core SSM)
        # We scan over the sequence length dimension (dim=1)
        # The tuple passed is (decay, input_contribution)
        # The result of the scan is the state `h` at every timestep.
        
        # Note: associative_scan requires runtime compilation for performance
        # and currently only supports CUDA for pointwise.
        scan_output = associative_scan(
            ssm_combine_fn, 
            (dA, u), 
            dim=1, 
            combine_mode="pointwise"
        )
        
        # scan_output is (final_A, final_h)
        # We only care about the hidden states `h`
        h = scan_output[1] # (B, L, D, N)

        # 6. Output Projection
        # y = C * h
        # C: (B, L, N) -> (B, L, 1, N)
        # h: (B, L, D, N)
        # result: (B, L, D)
        y = (h * C.unsqueeze(2)).sum(dim=-1)
        
        # Add skip connection D * x
        y = y + x * self.D
        
        # Gating
        y = y * self.act(z)
        
        # Final projection
        out = self.out_proj(y)
        
        return out

    # Keeping step for potential inference (not using scan)
    def step(self, hidden_states, conv_state, ssm_state):
        dtype = hidden_states.dtype
        assert hidden_states.shape[1] == 1, "Only support decoding with 1 token at a time for now"
        xz = self.in_proj(hidden_states.squeeze(1))  # (B 2D)
        x, z = xz.chunk(2, dim=-1)  # (B D)

        # Conv step
        conv_state.copy_(torch.roll(conv_state, shifts=-1, dims=-1))  # Update state (B D W)
        conv_state[:, :, -1] = x
        x = torch.sum(conv_state * rearrange(self.conv1d.weight, "d 1 w -> d w"), dim=-1)  # (B D)
        if self.conv1d.bias is not None:
            x = x + self.conv1d.bias
        x = self.act(x).to(dtype=dtype)

        x_db = self.x_proj(x)  # (B dt_rank+2*d_state)
        dt, B, C = torch.split(x_db, [self.dt_rank, self.d_state, self.d_state], dim=-1)
        dt = F.linear(dt, self.dt_proj.weight)  # (B d_inner)
        A = -torch.exp(self.A_log.float())  # (d_inner, d_state)

        # SSM step
        dt = F.softplus(dt + self.dt_proj.bias.to(dtype=dt.dtype))
        dA = torch.exp(torch.einsum("bd,dn->bdn", dt, A))
        dB = torch.einsum("bd,bn->bdn", dt, B)
        ssm_state.copy_(ssm_state * dA + rearrange(x, "b d -> b d 1") * dB)
        y = torch.einsum("bdn,bn->bd", ssm_state.to(dtype), C)
        y = y + self.D.to(dtype) * x
        y = y * self.act(z)  # (B D)

        out = self.out_proj(y)
        return out.unsqueeze(1), conv_state, ssm_state