import sys

sys.path.append("../..")

import torch

# from ref import ref_selective_scan
# from triton_parallel_scan import triton_selective_scan
from modeling.mamba.pscan_triton_seq import triton_selective_scan_sequential


import torch
from einops import einsum, rearrange, repeat


# credit: https://github.com/johnma2006/mamba-minimal/blob/master/model.py#L275
def ref_selective_scan(u, delta, A, B, C, D, initial_state):
    """Does selective scan algorithm. See:
        - Section 2 State Space Models in the Mamba paper [1]
        - Algorithm 2 in Section 3.2 in the Mamba paper [1]
        - run_SSM(A, B, C, u) in The Annotated S4 [2]

    This is the classic discrete state space formula:
        x(t + 1) = Ax(t) + Bu(t)
        y(t)     = Cx(t) + Du(t)
    except B and C (and the step size delta, which is used for discretization) are dependent on the input x(t).

    Args:
        u: shape (b, l, d_in)    (See Glossary at top for definitions of b, l, d_in, n...)
        delta: shape (b, l, d_in)
        A: shape (d_in, n)
        B: shape (b, l, n)
        C: shape (b, l, n)
        D: shape (d_in,)

    Returns:
        output: shape (b, l, d_in)

    Official Implementation:
        selective_scan_ref(), https://github.com/state-spaces/mamba/blob/main/mamba_ssm/ops/selective_scan_interface.py#L86
        Note: I refactored some parts out of `selective_scan_ref` out, so the functionality doesn't match exactly.

    """
    original_dtype = u.dtype
    u, delta, A, B, C, D = map(lambda x: x.float(), (u, delta, A, B, C, D))
    (b, l, d_in) = u.shape
    n = A.shape[1]

    # Discretize continuous parameters (A, B)
    # - A is discretized using zero-order hold (ZOH) discretization (see Section 2 Equation 4 in the Mamba paper [1])
    # - B is discretized using a simplified Euler discretization instead of ZOH. From a discussion with authors:
    #   "A is the more important term and the performance doesn't change much with the simplification on B"
    deltaA = torch.exp(einsum(delta, A, "b l d_in, d_in n -> b l d_in n"))
    deltaB_u = einsum(delta, B, u, "b l d_in, b l n, b l d_in -> b l d_in n")

    # Perform selective scan (see scan_SSM() in The Annotated S4 [2])
    # Note that the below is sequential, while the official implementation does a much faster parallel scan that
    # is additionally hardware-aware (like FlashAttention).
    x = torch.zeros((b, d_in, n), device=deltaA.device)
    x += initial_state
    ys = []
    for i in range(l):
        x = deltaA[:, i] * x + deltaB_u[:, i]
        y = einsum(x, C[:, i, :], "b d_in n, b n -> b d_in")
        ys.append(y)
    y = torch.stack(ys, dim=1)  # shape (b, l, d_in)

    y = y + u * D[None, None, :]

    return y.to(original_dtype), x


if __name__ == "__main__":
    B = 2
    T = 512
    D = 512
    K = 16
    dtype = torch.float32
    A = (-(torch.rand(D, K, dtype=dtype)).exp().cuda()).requires_grad_(True)
    x = torch.randn(B, T, D, dtype=dtype).cuda().requires_grad_(True)
    delta = torch.randn(B, T, D, dtype=dtype).sigmoid().cuda().requires_grad_(True)
    B2 = torch.randn(B, T, K, dtype=dtype).cuda().requires_grad_(True)
    C = torch.randn(B, T, K, dtype=dtype).cuda().requires_grad_(True)
    D2 = torch.randn(D, dtype=dtype).cuda().requires_grad_(True)

    initial_state = torch.randn(B, D, K, dtype=dtype).cuda().requires_grad_(False)

    tri, tri_final = triton_selective_scan_sequential(
        x, delta, A, B2, C, D2, initial_state
    )
    do = torch.randn_like(tri)
    tri.backward(do)

    tri_dc, C.grad = C.grad.clone(), None
    tri_dx, x.grad = x.grad.clone(), None
    tri_db, B2.grad = B2.grad.clone(), None
    tri_delta, delta.grad = delta.grad.clone(), None
    tri_A, A.grad = A.grad.clone(), None

    ref, ref_final = ref_selective_scan(x, delta, A, B2, C, D2, initial_state)

    print('y', (tri - ref).abs().max())
    print('h_last', (tri_final - ref_final).abs().max())

    ref.backward(do)
    ref_dc, C.grad = C.grad.clone(), None
    ref_dx, x.grad = x.grad.clone(), None
    ref_db, B2.grad = B2.grad.clone(), None
    ref_delta, delta.grad = delta.grad.clone(), None
    ref_A, A.grad = A.grad.clone(), None

    print('dc', (tri_dc - ref_dc).abs().max())
    print('dx', (tri_dx - ref_dx).abs().max())
    print('db', (tri_db - ref_db).abs().max())
    print('delta', (tri_delta - ref_delta).abs().max())
    print('A', (tri_A - ref_A).abs().max())
    breakpoint()
