"""
Experimental variant: diagonal SSD-style SSM with fixed (input-independent) A, B, C, D.

Differs from mamba_SSD_diag.py:
- A, B, C, D are learned but NOT input-dependent (no selective Δ/B/C).
- Uses the discrete diagonal SSM form from the SSD experiments (explicit scan).
"""

from __future__ import annotations

import torch
import torch.nn as nn
import torch.nn.functional as F


class SSDMambaBlockExp(nn.Module):
    """
    Drop-in block using fixed diagonal A/B/C/D (no input-dependent selectivity).
    """

    def __init__(self, args):
        super().__init__()
        self.args = args

        self.in_proj = nn.Linear(args.d_model, args.d_inner * 2, bias=args.bias)
        self.conv1d = nn.Conv1d(
            in_channels=args.d_inner,
            out_channels=args.d_inner,
            bias=args.conv_bias,
            kernel_size=args.d_conv,
            groups=args.d_inner,
            padding=args.d_conv - 1,
        )

        # --- SSD-style diagonal SSM params (discrete, input-independent) ---
        # A_diag[k, n] is the n-th diagonal element for feature k
        self.A_log = nn.Parameter(torch.zeros(args.d_inner, args.d_state))
        self.B = nn.Parameter(torch.randn(args.d_inner, args.d_state) * 0.1)
        self.C = nn.Parameter(torch.randn(args.d_inner, args.d_state) * 0.1)
        self.D = nn.Parameter(torch.ones(args.d_inner))

        self.out_proj = nn.Linear(args.d_inner, args.d_model, bias=args.bias)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        b, l, d = x.shape
        x_and_res = self.in_proj(x)
        x_proj, res = x_and_res.split(self.args.d_inner, dim=-1)

        x_proj = x_proj.transpose(1, 2)  # (b, d_inner, l)
        x_proj = self.conv1d(x_proj)[..., :l]
        x_proj = x_proj.transpose(1, 2)  # (b, l, d_inner)
        x_proj = F.silu(x_proj)

        y = self.ssm(x_proj)
        y = y * F.silu(res)
        return self.out_proj(y)

    # def ssm(self, x: torch.Tensor) -> torch.Tensor:
    #     """
    #     Discrete diagonal SSM as in the SSD experiments:
    #       h_t = A h_{t-1} + B ⊙ x_t
    #       y_t = Cᵀ h_t + D ⊙ x_t
    #     with A diagonal, B, C, D independent of x.

    #     x: (b, l, d_in)
    #     """
    #     b, l, d_in = x.shape
    #     n = self.args.d_state

    #     # Diagonal A, one N-dim state per feature channel
    #     A = torch.exp(self.A_log.float())  # (d_in, n)
    #     B = self.B.float()  # (d_in, n)
    #     C = self.C.float()  # (d_in, n)
    #     D = self.D.float()  # (d_in,)

    #     state = x.new_zeros(b, d_in, n)
    #     ys = []

    #     for t in range(l):
    #         x_t = x[:, t, :]  # (b, d_in)
    #         x_t_expanded = x_t.unsqueeze(-1)  # (b, d_in, 1)
    #         state = A.unsqueeze(0) * state + B.unsqueeze(0) * x_t_expanded
    #         y_t = (C.unsqueeze(0) * state).sum(dim=-1)  # (b, d_in)
    #         y_t = y_t + x_t * D
    #         ys.append(y_t)

    #     y = torch.stack(ys, dim=1)  # (b, l, d_in)
    #     return y
    def ssm(self, X: torch.Tensor) -> torch.Tensor:
        """
        Diagonal SSD-style SSM, matching Algorithm 1 notation:

            Z^n <- f(b^n, X)      // input projection per mode
            H^n <- g(a^n, Z^n)    // diagonal SSM recurrence per mode
            Y^n <- f(c^n, H^n)    // output projection per mode
            Y   <- sum_n Y^n      // sum over modes (plus D skip)

        Here we implement:
            h_t^{(k,n)} = a_{k,n} * h_{t-1}^{(k,n)} + b_{k,n} * X_t^{(k)}
            Y_t^{(k)}   = sum_n c_{k,n} * h_t^{(k,n)} + d_k * X_t^{(k)}

        Shapes:
            X      : (B, T, d_in)
            a, b, c: (d_in, N)
            h, Z   : (B, T, d_in, N)
            Y      : (B, T, d_in)
        """
        B, T, d_in = X.shape
        N = self.args.d_state

        # Parameters: a^n, b^n, c^n, d (per feature k and mode n)
        a = torch.exp(self.A_log.float())   # (d_in, N)
        b_vec = self.B.float()              # (d_in, N)
        c_vec = self.C.float()              # (d_in, N)
        d_skip = self.D.float()             # (d_in,)

        # ---- Z^n <- f(b^n, X) ----
        # Broadcast X over modes and multiply by b^n
        # X_exp: (B, T, d_in, 1), b_broadcast: (1, 1, d_in, N)
        X_exp = X.unsqueeze(-1)
        b_broadcast = b_vec.unsqueeze(0).unsqueeze(0)
        Z = X_exp * b_broadcast            # (B, T, d_in, N)

        # ---- H^n <- g(a^n, Z^n) ----
        # Diagonal recurrence per mode and feature:
        # h_t = a ⊙ h_{t-1} + Z_t
        H_steps = []
        h = X.new_zeros(B, d_in, N)        # h_0
        a_broadcast = a.unsqueeze(0)       # (1, d_in, N)
        for t in range(T):
            Z_t = Z[:, t]                  # (B, d_in, N)
            h = a_broadcast * h + Z_t
            H_steps.append(h)
        H = torch.stack(H_steps, dim=1)    # (B, T, d_in, N)

        # ---- Y^n <- f(c^n, H^n) ----
        # Multiply each mode by c^n and sum over modes n
        c_broadcast = c_vec.unsqueeze(0).unsqueeze(0)        # (1, 1, d_in, N)
        Y_modes = (c_broadcast * H).sum(dim=-1)              # (B, T, d_in)

        # ---- Y <- sum_n Y^n + D ⊙ X ----
        Y = Y_modes + X * d_skip.unsqueeze(0).unsqueeze(0)   # (B, T, d_in)

        return Y



class SSDResidualBlockExp(nn.Module):
    def __init__(self, args):
        super().__init__()
        from mamba_experiments import mamba_SSD  # type: ignore

        self.norm = mamba_SSD.RMSNorm(args.d_model)
        self.mixer = SSDMambaBlockExp(args)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.mixer(self.norm(x)) + x


class SSDMambaModelExp(nn.Module):
    """
    Top-level model mirroring mamba_SSD.Mamba but using the fixed diagonal SSD block.
    """

    def __init__(self, args):
        super().__init__()
        from mamba_experiments import mamba_SSD  # type: ignore

        self.args = args
        self.embedding = nn.Embedding(args.vocab_size, args.d_model)
        self.layers = nn.ModuleList([SSDResidualBlockExp(args) for _ in range(args.n_layer)])
        self.norm_f = mamba_SSD.RMSNorm(args.d_model)
        self.lm_head = nn.Linear(args.d_model, args.vocab_size, bias=False)
        self.lm_head.weight = self.embedding.weight

    def forward(self, input_ids: torch.Tensor) -> torch.Tensor:
        x = self.embedding(input_ids)
        for layer in self.layers:
            x = layer(x)
        x = self.norm_f(x)
        logits = self.lm_head(x)
        return logits[:, -1, :]  # last-token logits
