import torch
import torch.nn.functional as F
from fla.ops.gla.recurrent_fuse import fused_recurrent_gla


def ceildiv(a, b):
    return -(a // -b)


def naive_recurrent_gla(
    q, k, v, gk, initial_state=None, output_final_state=False, causal=True
):
    orig_dtype = q.dtype
    q, k, v, gk = map(lambda x: x.float(), (q, k, v, gk))
    batch_size, n_heads, seq_len, d_head_k = q.shape
    _, _, _, d_head_v = v.shape
    h = torch.zeros(
        batch_size, n_heads, d_head_k, d_head_v, dtype=torch.float32, device=q.device
    )
    o = torch.zeros_like(v)
    scale = d_head_k**-0.5

    if initial_state is not None:
        h += initial_state

    for i in range(seq_len):
        q_i = q[:, :, i, :] * scale
        k_i = k[:, :, i]
        v_i = v[:, :, i, :]
        gk_i = gk[:, :, i].exp()
        kv_i = k_i[..., None] * v_i[..., None, :]
        h = h * gk_i[..., None] + kv_i
        o_i = (q_i[..., None] * h).sum(-2)
        o[:, :, i] = o_i

    if causal:
        return o.to(orig_dtype), h
    else:
        o_reverse = torch.zeros_like(v)
        h = torch.zeros(
            batch_size,
            n_heads,
            d_head_k,
            d_head_v,
            dtype=torch.float32,
            device=q.device,
        )
        for i in range(seq_len - 1, -1, -1):
            q_i = q[:, :, i, :] * scale
            k_i = k[:, :, i]
            v_i = v[:, :, i, :]
            gk_i = gk[:, :, i].exp()
            kv_i = k_i[..., None] * v_i[..., None, :]
            h = h * gk_i[..., None] + kv_i
            o_i = (q_i[..., None] * h).sum(-2)
            o_reverse[:, :, i] = o_i

        return o, o_reverse


if __name__ == "__main__":
    B = 4
    H = 4
    L = 512
    D = 128
    dtype = torch.float32
    q = (torch.randn(B, H, L, D).cuda().to(dtype)).requires_grad_(True)
    k = (torch.randn(B, H, L, D).cuda().to(dtype)).requires_grad_(True)
    v = torch.randn(B, H, L, D).cuda().to(dtype).requires_grad_(True)
    g = (
        F.logsigmoid(torch.rand(B, H, L, D))
        .cuda()
        .clamp_min(-1)
        .to(torch.float32)
        .requires_grad_(True)
    )

    do = torch.rand_like(v).cuda()
    do2 = torch.rand_like(v).cuda()
    intial_state = torch.rand(B, H, D, D).cuda()

    ref, ref_rev = naive_recurrent_gla(q, k, v, g, causal=False)

    ref.backward(do, retain_graph=True)
    ref_rev.backward(do2, retain_graph=True)

    ref_dq, q.grad = q.grad.clone(), None
    ref_dk, k.grad = k.grad.clone(), None
    ref_dv, v.grad = v.grad.clone(), None
    ref_dg, g.grad = g.grad.clone(), None

    tri, tri_rev = fused_recurrent_gla(
        q,
        k,
        v,
        g,
        initial_state=None,
        scale=D**-0.5,
        output_final_state=False,
        causal=False,
    )
    tri.backward(do, retain_graph=True)
    tri_rev.backward(do2, retain_graph=True)
    tri_dq, q.grad = q.grad.clone(), None
    tri_dk, k.grad = k.grad.clone(), None
    tri_dv, v.grad = v.grad.clone(), None
    tri_dg, g.grad = g.grad.clone(), None

    assert ref.allclose(tri, 0, 1e-5)  # , breakpoint()
    assert ref_rev.allclose(tri_rev, 0, 1e-5)  # , breakpoint()
    assert ref_dq.allclose(tri_dq, 0, 1e-5)  # , breakpoint()
    assert ref_dk.allclose(tri_dk, 0, 1e-5)  # , breakpoint()
    assert ref_dv.allclose(tri_dv, 0, 1e-5)  # , breakpoint()
    assert ref_dg.allclose(tri_dg, 0, 1e-4)  # , breakpoint()

    # tri = fused_chunk_gla(q, k, v, g)
    # tri.backward(do, retain_graph=True)
    # tri_dq, q.grad = q.grad.clone(), None
    # tri_dk, k.grad = k.grad.clone(), None
    # tri_dv, v.grad = v.grad.clone(), None
    # tri_dg, g.grad = g.grad.clone(), None

    # assert ref.allclose(tri, 0, 1e-5), breakpoint()
    # assert ref_dq.allclose(tri_dq, 0, 1e-5), breakpoint()
    # assert ref_dk.allclose(tri_dk, 0, 1e-5), breakpoint()
    # assert ref_dv.allclose(tri_dv, 0, 1e-5), breakpoint()
    # assert ref_dg.allclose(tri_dg, 0, 1e-4), breakpoint()
    # breakpoint()
    print("Pass")
