import jax
from jax import numpy as jnp, vmap
from jax.scipy.special import logsumexp
import numpy as np
from fma.single_level_attention import MultiLevelClustering, MultiLevelAttention
from functools import partial
from fma.flash_attention import attn as flash_attention
from fma.pallas_monopole import cluster_attn as fma_attention
from fma.pallas_retrieval import ref_causal_attn as causal_retrieval_attention
import math

cudnn_dtype = jnp.bfloat16

def make_approx(attention_kwargs=dict(), **cluster_kwargs):
    clustering = MultiLevelClustering(**cluster_kwargs)
    attention = MultiLevelAttention(clustering=clustering, **attention_kwargs)
    return attention.cluster_then_attention

_approx_dense_attention = make_approx(
    Qs = (2**6,),
    Ks = (2**6,),
    coupled_clustering = False,
    inner_iters = 3,
    outer_iters = 3,
    max_cluster_scale = 1.5,
    compute_indices = True,
)


def cudnn_window8192(queries, keys, values):
    window = (8192, 0)
    print(f"Compiling cudnn attention with window size {window}")
    if queries.dtype == jnp.float32:
        queries = queries.astype(cudnn_dtype)
        keys = keys.astype(cudnn_dtype)
        values = values.astype(cudnn_dtype)
    out = jax.nn.dot_product_attention(queries, keys, values, is_causal=True, local_window_size=window, implementation="cudnn")
    return out.astype(values.dtype)

def naive_attention(queries, keys, values, *, window):
    attention_scores = jnp.einsum("nd,md->nm", queries, keys) #/ jnp.sqrt(queries.shape[-1])
    n_idx = jnp.arange(queries.shape[0])
    m_idx = jnp.arange(keys.shape[0])
    relative_position = n_idx[:, None] - m_idx[None, :]
    mask = (relative_position <= window[0]) & (relative_position >= -window[1])
    attention_scores = jnp.where(mask, attention_scores, -1e30)
    attention_weights = jax.nn.softmax(attention_scores, axis=-1)
    total_score = logsumexp(attention_scores, axis=-1)
    out = jnp.einsum("nm,md->nd", attention_weights, values)
    return total_score, out

def is_power_of_two(x: int):
    return x > 0 and (x & (x - 1)) == 0

def merge(score0, score1, out0, out1):
    weights0, weights1 = jax.nn.softmax(jnp.stack([score0, score1], axis=0), axis=0)
    return weights0[...,None] * out0 + weights1[...,None] * out1

def recursive_blocked_attn(block_size, queries, keys, values):
    assert is_power_of_two(block_size)
    N, D = queries.shape
    sm_scale = float(1.0 / math.sqrt(D))
    assert is_power_of_two(N)
    assert N % block_size == 0, f"N={N} must be divisible by block_size={block_size}"
    if N <= block_size:
        window = (N, 0)
        scores, out = flash_attention(queries, keys, values, window=window, sm_scale=sm_scale)
        return scores, out
    else:
        qs = queries.reshape(2, N//2, D)
        ks = keys.reshape(2, N//2, D)
        vs = values.reshape(2, N//2, D)

        approx_attention = fma_attention
        #approx_attention = partial(flash_attention, window=(N//2, N//2), sm_scale=sm_scale)

        (yscore0, yscore1), (yout0, yout1) = jax.vmap(partial(recursive_blocked_attn, block_size))(qs, ks, vs)

        _, q1 = qs
        k0, _ = ks
        v0, _ = vs
        zscore1, zout1 = approx_attention(q1, k0, v0)
        yscore1 = jnp.logaddexp(yscore1, zscore1)
        yout1 = merge(yscore1, zscore1, yout1, zout1)
        
        final_score = jnp.concatenate([yscore0, yscore1], axis=0)
        final_out = jnp.concatenate([yout0, yout1], axis=0)
        return final_score, final_out

def one_block_causal_attn(queries, keys, values):
    N, D = queries.shape
    sm_scale = float(1.0 / math.sqrt(D))
    N = min(N, 256)
    scores, out = flash_attention(queries, keys, values, window=(N,0), sm_scale=sm_scale)
    return scores, out

recursive_blocked_attn_8192_mha = vmap(vmap(partial(recursive_blocked_attn, 8192), in_axes=-2, out_axes=-2))
recursive_blocked_attn_4096_mha = vmap(vmap(partial(recursive_blocked_attn, 4096), in_axes=-2, out_axes=-2))
one_block_causal_attn_mha = vmap(vmap(one_block_causal_attn, in_axes=-2, out_axes=-2))

@vmap
@partial(vmap, in_axes=-2, out_axes=-2)
def prefix_sum_retrieval_attention_mha(queries, keys, values):
    lse_out, v_out = causal_retrieval_attention(queries, keys, values)
    return v_out






@jax.jit
@jax.vmap
@partial(jax.vmap, in_axes=-2, out_axes=-2)
def jax_window8192(queries, keys, values):
    print(f"Compiling JAX attention with window size (8192, 0)")
    #frac_q_nonfinite = jnp.mean(~jnp.isfinite(queries))
    #frac_k_nonfinite = jnp.mean(~jnp.isfinite(keys))
    #frac_v_nonfinite = jnp.mean(~jnp.isfinite(values))
    #all_finite = jnp.all(jnp.isfinite(queries)) and jnp.all(jnp.isfinite(keys)) & jnp.all(jnp.isfinite(values))
    #jax.debug.print("[jax_window8192] nonfinite frac q: {}, k: {}, v: {}", frac_q_nonfinite, frac_k_nonfinite, frac_v_nonfinite)
    sm_scale = float(1.0 / math.sqrt(queries.shape[-1]))
    rescale = jnp.sqrt(jnp.sqrt(queries.shape[-1]))
    #print(f"Dtypes: queries={queries.dtype}, keys={keys.dtype}, values={values.dtype}, rescale={rescale.dtype}")
    #queries = queries / rescale.astype(queries.dtype)
    #keys = keys / rescale.astype(keys.dtype)
    #S = 32768
    S = 8192
    #S = 4096
    #S = 2048
    N, D = queries.shape
    n = N // S
    assert N % S == 0, f"Number of queries {N} must be divisible by window size {S}"
    bqueries = queries.reshape(n, S, D)
    bkeys = keys.reshape(n, S, D)
    bvalues = values.reshape(n, S, D)
    def merge(score0, score1, out0, out1):
        weights0, weights1 = jax.nn.softmax(jnp.stack([score0, score1], axis=0), axis=0)
        return weights0[...,None] * out0 + weights1[...,None] * out1
    bscores, bout = jax.vmap(partial(flash_attention, window=(S,0), sm_scale=sm_scale))(bqueries, bkeys, bvalues)
    #return bout.reshape(N, D)
    #bscores1, bout1 = jax.vmap(partial(naive_attention, window=(S,S)))(bqueries[1:], bkeys[:-1], bvalues[:-1])
    #bscores1, bout1 = jax.vmap(_approx_dense_attention)(bqueries[1:], bkeys[:-1], bvalues[:-1])
    bscores1, bout1 = jax.vmap(partial(flash_attention, sm_scale=sm_scale))(bqueries[1:], bkeys[:-1], bvalues[:-1])
    #bscores1, bout1 = jax.vmap(fma_attention)(bqueries[1:], bkeys[:-1], bvalues[:-1])
    bout = bout.at[1:].set(merge(bscores1, bscores[1:], bout1, bout[1:]))
    bscores = bscores.at[1:].set(jnp.logaddexp(bscores1, bscores[1:]))
    bscores2, bout2 = jax.vmap(partial(flash_attention, window=(0,S), sm_scale=sm_scale))(bqueries[2:], bkeys[:-2], bvalues[:-2])
    #print(bscores.shape, bscores1.shape, bscores2.shape)
    #print(bout.shape, bout1.shape, bout2.shape)
    bout = bout.at[2:].set(merge(bscores2, bscores[2:], bout2, bout[2:]))
    bscores = bscores.at[2:].set(jnp.logaddexp(bscores2, bscores[2:]))
    return bout.reshape(N, D)


def cudnn_acausal_unscaled(queries, keys, values):
    out = jax.nn.dot_product_attention(
        queries.astype(cudnn_dtype),
        keys.astype(cudnn_dtype),
        values.astype(cudnn_dtype),
        is_causal=False,
        scale=1.0,
        local_window_size=None,
        implementation="cudnn",
    )
    return out.astype(values.dtype)


def load_data():
    data = jnp.load("qkv_data_64_T1/qkv_step_15200.npz")
    #data = jnp.load("qkv_data_64_T1/qkv_step_0.npz")
    ks, vs, qs = data["k"], data["v"], data["q"]
    return ks, vs, qs #BNHD

def main():
    #block_approx = make_approx(
    #    Qs = (2**6,),
    #    Ks = (2**6,),
    #    coupled_clustering = False,
    #    inner_iters = 3,
    #    outer_iters = 3,
    #    max_cluster_scale = 1.5,
    #    compute_indices = True,
    #)
    block_approx = _approx_dense_attention
    batch_multi_head_block_approx = jax.vmap(jax.vmap(block_approx, in_axes=-2, out_axes=-2))
    ks, vs, qs = load_data()
    N = 2**14
    input_dtype = jnp.bfloat16
    ks = ks.astype(input_dtype)[:,:N, :, :]
    vs = vs.astype(input_dtype)[:,:N, :, :]
    qs = qs.astype(input_dtype)[:,:N, :, :]
    print(qs.shape)
    B, N, H, D = qs.shape

    exact_out = cudnn_acausal_unscaled(qs/jnp.sqrt(D), ks, vs)
    approx_logmass, approx_out = batch_multi_head_block_approx(qs/jnp.sqrt(D), ks, vs)
    error = exact_out - approx_out
    exact_norm = jnp.linalg.norm(exact_out.transpose(0, 2, 1, 3).reshape(B, H, -1), axis=-1)
    error_norm = jnp.linalg.norm(error.transpose(0, 2, 1, 3).reshape(B, H, -1), axis=-1)
    print("Signal-to-Noise Ratio (SNR):", exact_norm / error_norm)

    window_exact_out = cudnn_window8192(qs, ks, vs)
    window_approx_out = jax_window8192(qs, ks, vs)
    exact_norm_window = jnp.linalg.norm(window_exact_out.transpose(0, 2, 1, 3).reshape(B, H, -1), axis=-1)
    error_window = window_exact_out - window_approx_out
    error_norm_window = jnp.linalg.norm(error_window.transpose(0, 2, 1, 3).reshape(B, H, -1), axis=-1)
    print("Windowed Signal-to-Noise Ratio (SNR):", exact_norm_window / error_norm_window)

if __name__ == "__main__":
    main()



