import jax
from jax import numpy as jnp, lax
from functools import partial
from jax.experimental import pallas as pl
from jax.experimental.pallas import triton as plgpu
import math
import numpy as np

DEFAULT_MASK_VALUE = -0.7 * float(np.finfo(np.float32).max)


def spatial_preprocess_ret(K: int, Q: int, ret_idx, qlabels, valid_idx):
    T, R = ret_idx.shape
    total_idx = qlabels[:, None] * K + ret_idx # (T, R)
    total_idx = jnp.where(valid_idx, total_idx, Q*K)
    flat_ret = total_idx.flatten() # (T * R,)
    sizes = jnp.bincount(flat_ret, length=Q * K) # (Q * K,)
    offsets = jnp.concatenate([jnp.array([0]), jnp.cumsum(sizes[:-1])]) # (Q * K,)
    sort_idxs = jnp.argsort(flat_ret) # (T * R,)
    sorted_qfwd = sort_idxs // R
    sorted_ridx = sort_idxs % R
    return sorted_qfwd, sorted_ridx, sizes, offsets


def ref_spatial_retrieval_postprocess(m_in, l_in, v_in, m_ret, l_ret, v_ret):
    #print(f"all input shapes: m_in {m_in.shape}, l_in {l_in.shape}, v_in {v_in.shape}, m_ret {m_ret.shape}, l_ret {l_ret.shape}, v_ret {v_ret.shape}")
    m_total = jnp.maximum(jnp.max(m_ret, axis=1), DEFAULT_MASK_VALUE) # (N,)
    correction_in = jnp.exp2(m_in - m_total) # (N,)
    v_in_corr = v_in * correction_in[:, None] # (N, V)
    l_in_corr = l_in * correction_in # (N,)
    correction_ret = jnp.exp2(m_ret - m_total[:, None]) # (N, R)
    #print(f"shapes: v_ret {v_ret.shape}, correction_ret {correction_ret.shape}")
    v_ret_corr = v_ret.astype(correction_ret.dtype) * correction_ret[:, :, None] # (N, R, V)
    l_ret_corr = l_ret * correction_ret # (N, R)
    v_total = v_in_corr + jnp.sum(v_ret_corr, axis=1) # (N, V)
    l_total = l_in_corr + jnp.sum(l_ret_corr, axis=1) # (N,)
    return m_total, l_total, v_total


def spatial_retrieval_kernel(
    q_ref, q_bar_ref, k_ref, v_ref, mu_ref,
    sorted_qfwd_ref, sorted_ridx_ref, sizes_ref, offsets_ref, kidx_ref, qidx_ref,
    m_in_ref,
    m_out_ref, l_out_ref, v_out_ref, blk_idx_out_ref,
    *,
    causal_stride: int, block_t: int, max_c: int, bidiagonal: bool, meta_stride: int,
):
    sm_scale = 1.0 / math.sqrt(q_ref.shape[-1])
    qk_scale = math.log2(math.e) * sm_scale
    mu_scale = math.log2(math.e)

    qidx = qidx_ref[...] # ()
    q_bar = q_bar_ref[qidx, :] # (D,)

    size_t = sizes_ref[...] # ()
    offset_t = offsets_ref[...] # ()
    kidx = kidx_ref[...] # ()

    # CHECK INDEXING HERE
    C, Q, K, D = k_ref.shape
    T, R, P = blk_idx_out_ref.shape
    kblk = jnp.arange(C)
    mask_k = kblk < max_c
    k = pl.load(k_ref, (slice(None), qidx, kidx, slice(None)), mask=mask_k[:,None], other=0.0) # (C, D)
    v = pl.load(v_ref, (slice(None), qidx, kidx, slice(None)), mask=mask_k[:,None], other=0.0) # (C, V)
    mu = pl.load(mu_ref, (slice(None), qidx, kidx), mask=mask_k, other=DEFAULT_MASK_VALUE) # (C,)
    kblk = jnp.arange(k.shape[0]) // meta_stride


    #raise NotImplementedError("spatial retrieval kernel not implemented yet")
    def t_body(start_t, _):
        base_t = start_t * block_t + offset_t
        curr_t_slice = pl.dslice(base_t, block_t)
        mask_t = jnp.arange(block_t) + base_t < offset_t + size_t
        qfwd = pl.load(sorted_qfwd_ref, (curr_t_slice,), mask=mask_t, other=0)
        ridx = pl.load(sorted_ridx_ref, (curr_t_slice,), mask=mask_t, other=0)
        qblk = qfwd // (causal_stride * meta_stride)
        m = pl.load(m_in_ref, (qfwd,), mask=mask_t, other=DEFAULT_MASK_VALUE)
        q = pl.load(q_ref, (qfwd, slice(None)), mask=mask_t[:,None], other=0.0) # (t, D)
        qres = q - q_bar[None, :]

        causal_mask = kblk < qblk[:, None] if not bidiagonal else kblk < (qblk[:, None] - 1)

        # NEED TO ADD CAUSAL MASKING
        attn_scores = pl.dot(qres, k, trans_b=True, allow_tf32=True) * qk_scale + mu[None, :] * mu_scale
        attn_scores = jnp.where(causal_mask, attn_scores, DEFAULT_MASK_VALUE)
        for pidx in range(P):
            #blk_idx = jnp.argmax(attn_scores, axis=-1)
            blk_idx_actual = jnp.argmax(attn_scores, axis=-1)
            blk_idx_meta = jnp.argmax(attn_scores, axis=-1) // meta_stride
            #blk_idx = jnp.zeros_like(blk_idx)
            #blk_idx = qblk - 1
            attn_to_kill = jnp.arange(attn_scores.shape[1])[None, :] == blk_idx_actual[:, None]
            attn_scores = jnp.where(attn_to_kill, DEFAULT_MASK_VALUE, attn_scores)
            if bidiagonal:
                blk_idx_actual = jnp.where(blk_idx_meta < (qblk - 1), blk_idx_actual, -1)
            else:
                blk_idx_actual = jnp.where(blk_idx_meta < qblk, blk_idx_actual, -1)
            pl.store(blk_idx_out_ref, (qfwd, ridx, pidx), blk_idx_actual.astype(blk_idx_out_ref.dtype), mask=mask_t)

        m_curr = jnp.max(attn_scores, axis=-1) # (t,)
        m_next = jnp.maximum(m, m_curr)
        s = jnp.exp2(attn_scores - m_next[:, None]) # (t, C)
        s = jnp.where(causal_mask, s, 0.0)

        l_curr = jnp.sum(s, axis=-1) # (t,)
        ov_curr = pl.dot(s.astype(v.dtype), v, allow_tf32=True) # (t, V)

        pl.store(m_out_ref, (qfwd, ridx), m_next.astype(m_out_ref.dtype), mask=mask_t)
        pl.store(l_out_ref, (qfwd, ridx), l_curr.astype(l_out_ref.dtype), mask=mask_t)
        pl.store(v_out_ref, (qfwd, ridx, slice(None)), ov_curr.astype(v_out_ref.dtype), mask=mask_t[:, None])

    t_lower_bound = 0
    t_upper_bound = pl.cdiv(size_t, block_t)
    jax.lax.fori_loop(t_lower_bound, t_upper_bound, t_body, None)


def spatial_retrieval(num_spatial: int, causal_block: int, bidiagonal: bool, meta_multiplier: int, q, q_bar, k, v, m, qmeta, ret_idx, m_in, l_in, v_in):
    """figure out which spatial index to retrieve from, attend to summaries of others"""
    assert ret_idx.dtype == jnp.int32, "ret_idx must be int32"
    assert qmeta.lab.dtype == jnp.int32, "qmeta.lab must be int32"
    T, R = ret_idx.shape
    Q, D = q_bar.shape
    C, Q, K, V = v.shape
    P = num_spatial

    assert C <= 256, "C too large for spatial retrieval with no blocking over C"
    if C >= 128:
        print(f"Warning: C={C} is large (>=128) for spatial retrieval with no blocking over C, performance may be suboptimal")
    block_C = max(16, C)

    block_t = 16
    qblk = jnp.arange(T) // causal_block
    first_valid_qblk = 3 if bidiagonal else 2
    valid_idx = jnp.broadcast_to((qblk >= first_valid_qblk)[:,None], (T, R))
    flat_labels = qmeta.lab.flatten()
    sorted_qfwd, sorted_ridx, sizes, offsets = spatial_preprocess_ret(K, Q, ret_idx, flat_labels, valid_idx)
    subseg_labels_q = jnp.broadcast_to(jnp.arange(Q)[:, None], (Q, K)).flatten()
    subseg_labels_k = jnp.broadcast_to(jnp.arange(K)[None, :], (Q, K)).flatten()

    causal_stride = causal_block * R
    #print("Note: should set max_c to C-1 once spatial_retrieval works, and check it is faster")
    padded_k_shape = (block_C,) + k.shape[1:]
    padded_v_shape = (block_C,) + v.shape[1:]
    padded_m_shape = (block_C,) + m.shape[1:]
    m_ret, l_ret, v_ret, blk_idx = pl.pallas_call(
        partial(spatial_retrieval_kernel, causal_stride=causal_block, block_t=block_t, max_c=C, bidiagonal=bidiagonal, meta_stride=meta_multiplier),
        out_shape = [
            jnp.tile(m_in[:, None], (1, R)), # m_out
            jnp.tile(l_in[:, None], (1, R)), # l_out
            jnp.tile(v_in[:, None, :], (1, R, 1)), # v_out
            jnp.tile(ret_idx[:, :, None], (1, 1, P)), # blk_idx_out
            #ret_idx,
        ],
        grid = (len(sizes),),
        in_specs = [
            pl.BlockSpec(q.shape, lambda i: (0, 0)), # q_ref
            pl.BlockSpec(q_bar.shape, lambda i: (0, 0)), # q_bar_ref
            pl.BlockSpec(padded_k_shape, lambda i: (0, 0, 0, 0)), # k_ref
            pl.BlockSpec(padded_v_shape, lambda i: (0, 0, 0, 0)), # v_ref
            pl.BlockSpec(padded_m_shape, lambda i: (0, 0, 0)), # mu_ref
            pl.BlockSpec(sorted_qfwd.shape, lambda i: (0,)), # sorted_qfwd_ref
            pl.BlockSpec(sorted_ridx.shape, lambda i: (0,)), # sorted_ridx_ref
            pl.BlockSpec((None,), lambda i: (i,)), # sizes_ref
            pl.BlockSpec((None,), lambda i: (i,)), # offsets_ref
            pl.BlockSpec((None,), lambda i: (i,)), # kidx_ref
            pl.BlockSpec((None,), lambda i: (i,)), # qidx_ref
            pl.BlockSpec(m_in.shape, lambda i: (0,)), # m_in_ref
        ],
        out_specs = [
            pl.BlockSpec((T, R), lambda i: (0, 0)), # m_out_ref
            pl.BlockSpec((T, R), lambda i: (0, 0)), # l_out_ref
            pl.BlockSpec((T, R, V), lambda i: (0, 0, 0)), # v_out_ref
            pl.BlockSpec((T, R, P), lambda i: (0, 0, 0)), # blk_idx_out_ref
        ],
        compiler_params=plgpu.CompilerParams(num_warps=1, num_stages=2),
        name="spatial_retrieval_forward",
    )(q, q_bar, k, v, m, sorted_qfwd, sorted_ridx, sizes, offsets, subseg_labels_k, subseg_labels_q, m_in)
    m_ret = jnp.where(valid_idx, m_ret, m_in[:, None])
    l_ret = jnp.where(valid_idx, l_ret, 0.0)
    v_ret = jnp.where(valid_idx[:, :, None], v_ret, 0.0)
    #blk_idx = blk_idx[:,:,0]
    if bidiagonal:
        assert P == 1, "bidiagonal spatial retrieval only supports P=1 for now"
        blk_idx = jnp.where((qblk == 0)[:, None, None], -1, blk_idx)
        blk_idx = jnp.where((qblk == 1)[:, None, None], -1, blk_idx)
        blk_idx = jnp.where((qblk == 2)[:, None, None], 0, blk_idx)
    else:
        blk_idx = jnp.where((qblk == 0)[:, None, None], -1, blk_idx)
        qblk_1_blk_idx = jnp.full((P,), -1, dtype=blk_idx.dtype).at[0].set(0)
        blk_idx = jnp.where((qblk == 1)[:, None, None], qblk_1_blk_idx[None,None,:], blk_idx)


    blk_idx_negs_first = jnp.sum(blk_idx[:causal_block] < 0)
    blk_idx_cnts_first = jnp.bincount(blk_idx[:causal_block].flatten(), length=C).at[0].add(-blk_idx_negs_first)
    blk_idx_negs_rest = jnp.sum(blk_idx[causal_block:] < 0)
    blk_idx_cnts_rest = jnp.bincount(blk_idx[causal_block:].flatten(), length=C).at[0].add(-blk_idx_negs_rest)
    #jax.debug.print("spatial retrieval blk_idx counts: first block {}, rest {}", blk_idx_cnts_first, blk_idx_cnts_rest)

    spatial_retrieval_meta = (sorted_qfwd, sorted_ridx, sizes, offsets)

    m_out, l_out, v_out = ref_spatial_retrieval_postprocess(m_in, l_in, v_in, m_ret, l_ret, v_ret)
    return m_out, l_out, v_out, blk_idx, spatial_retrieval_meta


def spatial_retrieval_bwd_kernel(
    q_ref, q_bar_ref, k_ref, v_ref, mu_ref,
    sorted_qfwd_ref, sorted_ridx_ref, sizes_ref, offsets_ref, kidx_ref, qidx_ref,
    lse_ref, dlse_ref, do_ref,
    dq_out_ref, dk_out_ref, dv_out_ref, dmu_out_ref, dq_bar_out_ref,
    *,
    causal_stride: int, block_t: int, max_c: int, bidiagonal: bool, num_spatial: int,
):
    sm_scale = 1.0 / math.sqrt(q_ref.shape[-1])
    qk_scale = math.log2(math.e) * sm_scale
    mu_scale = math.log2(math.e)
    lse_scale = math.log2(math.e)

    qidx = qidx_ref[...] # ()
    q_bar = q_bar_ref[qidx, :] # (D,)

    size_t = sizes_ref[...] # ()
    offset_t = offsets_ref[...] # ()
    kidx = kidx_ref[...] # ()

    # CHECK INDEXING HERE
    C, Q, K, D = k_ref.shape
    kblk = jnp.arange(C)
    mask_k = kblk < max_c
    k = pl.load(k_ref, (slice(None), qidx, kidx, slice(None)), mask=mask_k[:,None], other=0.0) # (C, D)
    v = pl.load(v_ref, (slice(None), qidx, kidx, slice(None)), mask=mask_k[:,None], other=0.0) # (C, V)
    mu = pl.load(mu_ref, (slice(None), qidx, kidx), mask=mask_k, other=DEFAULT_MASK_VALUE) # (C,)
    kblk = jnp.arange(k.shape[0])


    #raise NotImplementedError("spatial retrieval kernel not implemented yet")
    def t_body(start_t, carry):
        dk_prev, dv_prev, dmu_prev, dq_bar_prev = carry

        base_t = start_t * block_t + offset_t
        curr_t_slice = pl.dslice(base_t, block_t)
        mask_t = jnp.arange(block_t) + base_t < offset_t + size_t
        qfwd = pl.load(sorted_qfwd_ref, (curr_t_slice,), mask=mask_t, other=0)
        ridx = pl.load(sorted_ridx_ref, (curr_t_slice,), mask=mask_t, other=0)
        qblk = qfwd // causal_stride
        q = pl.load(q_ref, (qfwd, slice(None)), mask=mask_t[:,None], other=0.0) # (t, D)
        lse = pl.load(lse_ref, (qfwd,), mask=mask_t, other=0.0)
        do = pl.load(do_ref, (qfwd, slice(None)), mask=mask_t[:,None], other=0.0)
        dlse = pl.load(dlse_ref, (qfwd,), mask=mask_t, other=0.0)
        qres = q - q_bar[None, :]

        causal_mask = kblk < qblk[:, None] if not bidiagonal else kblk < (qblk[:, None] - 1)

        attn_scores = pl.dot(qres, k, trans_b=True, allow_tf32=True) * qk_scale + mu[None, :] * mu_scale
        attn_scores = attn_scores - lse[:, None] * lse_scale
        attn_scores = jnp.where(causal_mask, attn_scores, DEFAULT_MASK_VALUE)
        attn_to_kill = jnp.zeros_like(attn_scores, dtype=bool)
        for pidx in range(num_spatial):
            blk_idx = jnp.argmax(attn_scores, axis=-1)
            #blk_idx = jnp.zeros_like(blk_idx)
            #blk_idx = qblk - 1
            attn_to_kill = attn_to_kill | jnp.arange(attn_scores.shape[1])[None, :] == blk_idx[:, None]
            attn_scores = jnp.where(attn_to_kill, DEFAULT_MASK_VALUE, attn_scores)

        s = jnp.exp2(attn_scores)
        s = jnp.where(causal_mask & (~attn_to_kill), s, 0.0)

        dv = pl.dot(s.astype(do.dtype), do, trans_a=True, allow_tf32=True) # (C, V)
        dv_next = dv_prev + dv # (C, V)
        gv = pl.dot(do.astype(v.dtype), v, trans_b=True, allow_tf32=True)
        ds = gv + dlse[:, None]
        dqk = ds * s # (t, C)

        dmu = jnp.sum(dqk, axis=0) * mu_scale # (C,)
        dmu_next = dmu_prev + dmu # (C,)
        dk = pl.dot(dqk.astype(q.dtype), qres, trans_a=True, allow_tf32=True) * sm_scale # (C, D)
        dk_next = dk_prev + dk # (C, D)
        dq = pl.dot(dqk.astype(k.dtype), k, allow_tf32=True) * sm_scale # (t, D)
        pl.store(dq_out_ref, (ridx, qfwd, slice(None)), dq.astype(dq_out_ref.dtype), mask=mask_t[:,None])
        dq_bar_next = dq_bar_prev - jnp.sum(dq, axis=0)

        return dk_next, dv_next, dmu_next, dq_bar_next

    dk = jnp.zeros(k.shape, dtype=jnp.float32) # (C, D)
    dv = jnp.zeros(v.shape, dtype=jnp.float32) # (C, V)
    dmu = jnp.zeros(mu.shape, dtype=jnp.float32) # (C,)
    dq_bar = jnp.zeros(q_bar.shape, dtype=jnp.float32) # (D,)

    t_lower_bound = 0
    t_upper_bound = pl.cdiv(size_t, block_t)
    dk, dv, dmu, dq_bar = jax.lax.fori_loop(t_lower_bound, t_upper_bound, t_body, (dk, dv, dmu, dq_bar))

    pl.store(dk_out_ref, (slice(None), qidx, kidx, slice(None)), dk.astype(dk_out_ref.dtype), mask=mask_k[:,None])
    pl.store(dv_out_ref, (slice(None), qidx, kidx, slice(None)), dv.astype(dv_out_ref.dtype), mask=mask_k[:,None])
    pl.store(dmu_out_ref, (slice(None), qidx, kidx), dmu.astype(dmu_out_ref.dtype), mask=mask_k)
    pl.store(dq_bar_out_ref, (kidx, qidx, slice(None),), dq_bar.astype(dq_bar_out_ref.dtype))


def spatial_retrieval_bwd(num_spatial: int, causal_block: int, bidiagonal: bool, q, q_bar, k, v, m, qmeta, ret_idx, lse, v_out, dlse, dv_out, spatial_retrieval_meta):
    """figure out which spatial index to retrieve from, attend to summaries of others"""
    T, R = ret_idx.shape
    Q, D = q_bar.shape
    C, Q, K, V = v.shape

    dlse_modified = dlse - jnp.sum(dv_out * v_out, axis=-1).astype(dlse.dtype) # (T,)

    assert C <= 64, "C too large for spatial retrieval with no blocking over C"
    block_C = max(16, C)

    block_t = 16

    qblk = jnp.arange(T) // causal_block
    first_valid_qblk = 3 if bidiagonal else 2
    valid_qblk = qblk >= first_valid_qblk

    sorted_qfwd, sorted_ridx, sizes, offsets = spatial_retrieval_meta
    subseg_labels_q = jnp.broadcast_to(jnp.arange(Q)[:, None], (Q, K)).flatten()
    subseg_labels_k = jnp.broadcast_to(jnp.arange(K)[None, :], (Q, K)).flatten()

    causal_stride = causal_block * R
    #print("Note: should set max_c to C-1 once spatial_retrieval works, and check it is faster")
    padded_k_shape = (block_C,) + k.shape[1:]
    padded_v_shape = (block_C,) + v.shape[1:]
    padded_m_shape = (block_C,) + m.shape[1:]
    dqs, dk, dv, dm, dq_bars = pl.pallas_call(
        partial(spatial_retrieval_bwd_kernel, causal_stride=causal_block, block_t=block_t, max_c=C, bidiagonal=bidiagonal, num_spatial=num_spatial),
        out_shape = [
            jax.ShapeDtypeStruct((R, T, D), q.dtype), # dq_out
            k, v, m,
            jax.ShapeDtypeStruct((K, Q, D), q_bar.dtype), # dq_bar_out
        ],
        grid = (len(sizes),),
        in_specs = [
            pl.BlockSpec(q.shape, lambda i: (0, 0)), # q_ref
            pl.BlockSpec(q_bar.shape, lambda i: (0, 0)), # q_bar_ref
            pl.BlockSpec(padded_k_shape, lambda i: (0, 0, 0, 0)), # k_ref
            pl.BlockSpec(padded_v_shape, lambda i: (0, 0, 0, 0)), # v_ref
            pl.BlockSpec(padded_m_shape, lambda i: (0, 0, 0)), # mu_ref
            pl.BlockSpec(sorted_qfwd.shape, lambda i: (0,)), # sorted_qfwd_ref
            pl.BlockSpec(sorted_ridx.shape, lambda i: (0,)), # sorted_ridx_ref
            pl.BlockSpec((None,), lambda i: (i,)), # sizes_ref
            pl.BlockSpec((None,), lambda i: (i,)), # offsets_ref
            pl.BlockSpec((None,), lambda i: (i,)), # kidx_ref
            pl.BlockSpec((None,), lambda i: (i,)), # qidx_ref
            pl.BlockSpec(lse.shape, lambda i: (0,)), # lse_ref
            pl.BlockSpec(dlse_modified.shape, lambda i: (0,)), # dlse_ref
            pl.BlockSpec(dv_out.shape, lambda i: (0, 0)), # do_ref
        ],
        out_specs = [
            pl.BlockSpec((R, T, D), lambda i: (0, 0, 0)), # dq_out_ref
            pl.BlockSpec(padded_k_shape, lambda i: (0, 0, 0, 0)), # dk_out_ref
            pl.BlockSpec(padded_v_shape, lambda i: (0, 0, 0, 0)), # dv_out_ref
            pl.BlockSpec(padded_m_shape, lambda i: (0, 0, 0)), # dmu_out_ref
            pl.BlockSpec((K, Q, D), lambda i: (0, 0, 0)), # dq_bar_out_ref
        ],
        compiler_params=plgpu.CompilerParams(num_warps=1, num_stages=2),
        name="spatial_retrieval_backward",
    )(q, q_bar, k, v, m, sorted_qfwd, sorted_ridx, sizes, offsets, subseg_labels_k, subseg_labels_q, lse, dlse_modified, dv_out)
    dqs = jnp.where(valid_qblk[None, :, None], dqs, 0.0)
    dq = jnp.sum(dqs, axis=0) # (T, D)
    dq_bar = jnp.sum(dq_bars, axis=0) # (Q, D)
    return dq, dk, dv, dm, dq_bar
