import math
import torch

from torch._higher_order_ops import associative_scan, scan as torch_scan

from scan import mamba_scan as scan
# from fused_scan import fused_mamba_scan as scan


class Mamba(torch.nn.Module):
    """
    Mamba kernel module.
    """

    def __init__(
        self,
        input_dim: int,
        qk_dim: int,  # d_state
        v_dim: int,  # d_inner
        output_dim: int | None = None,
    ):
        super(Mamba, self).__init__()
        self.input_dim = input_dim
        self.qk_dim = qk_dim or input_dim
        self.v_dim = v_dim or input_dim
        self.dt_dim = math.ceil(self.input_dim / 16)
        self.output_dim = output_dim or input_dim

        # Projection to input and gating
        self.in_proj = torch.nn.Linear(input_dim, 2 * self.v_dim, bias=False)

        # Depthwise conv1d for local mixing
        self.conv1d_kernel_size = 4
        self.conv1d_bias = True
        self.conv1d = torch.nn.Conv1d(
            in_channels=self.v_dim,
            out_channels=self.v_dim,
            bias=self.conv1d_bias,
            kernel_size=self.conv1d_kernel_size,
            groups=self.v_dim,
            padding=self.conv1d_kernel_size - 1,
        )

        # Projection to dt, k, v
        self.dtqk_proj = torch.nn.Linear(
            self.v_dim, self.dt_dim + 2 * self.qk_dim, bias=False
        )
        self.dt_proj = torch.nn.Linear(self.dt_dim, self.v_dim)

        # Initialize special dt projection to preserve variance at initialization
        dt_init_std = self.dt_dim**-0.5 * 1.0
        torch.nn.init.uniform_(self.dt_proj.weight, -dt_init_std, dt_init_std)

        # Initialize dt bias so that F.softplus(dt_bias) is between dt_min and dt_max
        dt_min = 0.001
        dt_max = 0.1
        dt_init_floor = 1e-4
        dt = torch.exp(
            torch.rand(self.v_dim) * (math.log(dt_max) - math.log(dt_min))
            + math.log(dt_min)
        ).clamp(min=dt_init_floor)
        # Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759
        inv_dt = dt + torch.log(-torch.expm1(-dt))
        with torch.no_grad():
            self.dt_proj.bias.copy_(inv_dt)
        # Our initialization would set all Linear.bias to zero, need to mark this one as _no_reinit
        self.dt_proj.bias._no_reinit = True

        # A_log (v_dim, qk_dim) ssm gating parameter
        # NOTE: This is mamba initialization
        A_init = (
            torch.arange(1, self.qk_dim + 1, dtype=torch.float32)
            .unsqueeze(0)
            .repeat(self.v_dim, 1)
            .contiguous()
        )
        self.A_log = torch.nn.Parameter(torch.log(A_init))
        self.A_log._no_weight_decay = True

        # D (v_dim) "skip" parameter
        self.D = torch.nn.Parameter(torch.ones(self.v_dim))  # Keep in fp32
        self.D._no_weight_decay = True

        # Output projection
        self.out_proj = torch.nn.Linear(self.v_dim, self.output_dim, bias=False)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Forward pass of the Mamba kernel.

        :param x: Input tensor
        :type x: torch.Tensor
        :return: Output tensor
        :rtype: torch.Tensor
        """
        # Input shapes
        batch_size, seq_len, _ = x.shape

        # x, g: (batch_size, seq_len, v_dim)
        x, g = self.in_proj(x).chunk(2, dim=-1)

        # Depthwise conv1d
        # x: (batch_size, seq_len, v_dim)
        x = self.conv1d(x.transpose(-1, -2))[..., :seq_len].transpose(-1, -2)
        x = torch.nn.functional.silu(x)

        # dt: (batch_size, seq_len, dt_dim)
        # k: (batch_size, seq_len, qk_dim)
        # q: (batch_size, seq_len, qk_dim)
        dt, k, q = self.dtqk_proj(x).split(
            [self.dt_dim, self.qk_dim, self.qk_dim], dim=-1
        )
        # weight: (v_dim, dt_dim) @ (batch_size, dt_dim, seq_len).T -> (batch_size, v_dim, seq_len).T
        # dt: (batch_size, seq_len, v_dim)
        # dt = (self.dt_proj.weight @ dt.transpose(-1, -2)).transpose(-1, -2)
        # dt = torch.nn.functional.softplus(dt + self.dt_proj.bias)
        dt = self.dt_proj(dt)
        dt = torch.nn.functional.softplus(dt)

        # # Discretization
        A = -torch.exp(self.A_log.float())  # (v_dim, qk_dim)
        # dA: (batch_size, seq_len, v_dim, qk_dim)
        dA = torch.exp(dt.unsqueeze(-1) * A)
        # dBx: (batch_size, seq_len, v_dim, qk_dim)
        dBx = dt.unsqueeze(-1) * k.unsqueeze(-2) * x.unsqueeze(-1)

        # Perform associative scan over sequence dimension
        # h: (batch_size, seq_len, v_dim, qk_dim)
        h = scan(dA, dBx)

        # Fused scan
        # h = scan(x, dt, A, k, q, self.D, g)

        # Contract with q (Not using tensor core so none gemm is faster)
        # h: (batch_size, seq_len, v_dim, qk_dim
        # h = (h @ q.unsqueeze(-1)).squeeze(-1)
        h = (h * q.unsqueeze(-2)).sum(-1)
        # h = torch.einsum("blvq,blq->blv", h, q)

        # Add skip connection
        # D: (v_dim)
        x = h + x * self.D

        # Gating: (batch_size, seq_len, v_dim)
        x = x * torch.nn.functional.silu(g)

        # Output projection: (batch_size, seq_len, output_dim)
        x = self.out_proj(x)
        # x = self.out_proj(h)

        return x
