import os
#os.environ['XLA_FLAGS'] = '--xla_dump_to=/tmp/xla_dump --xla_dump_hlo_as_text --xla_dump_hlo_pass_re=.*'

import jax
#jax.config.update("jax_log_compiles", True)
from jax import numpy as jnp
#from jax import profiler
import numpy as np
from .single_level_attention import SimpleAttention, CausalAttention, FastSimpleAttention, StubClustering, MultiLevelAttention, MultiLevelClustering
import time
from functools import partial
import nvtx

from flash_attention_jax import flash_attention as flash_attention_jax_lucidrains
#from flash_attn_jax import flash_mha as flash_attention_cuda_nshepperd
import math

@jax.jit
def flash_attention_pure_jax(queries, keys, values):
    N, D = queries.shape
    _, V = values.shape
    mask = jnp.ones((N, N), dtype=jnp.bool_)
    out = flash_attention_jax_lucidrains(queries[None,None], keys[None,None], values[None,None], mask)
    assert out.shape == (1, 1, N, V), f"Expected output shape (1, 1, {N}, {V}), got {out.shape}"
    return None, out[0,0]


@jax.jit
def flash_attention_cudnn(queries, keys, values):
    N, D = queries.shape
    _, V = values.shape
    out = jax.nn.dot_product_attention(
        queries[None,:,None,:].astype(jnp.bfloat16),
        keys[None,:,None,:].astype(jnp.bfloat16),
        values[None,:,None,:].astype(jnp.bfloat16),
        scale=1.0,
        is_causal=False,
        local_window_size=None,
        implementation='cudnn',
    ).astype(queries.dtype)
    return None, out[0,:,0,:]



def get_algos(Qs=(16,), Ks=(16,)):
    simple_attention = FastSimpleAttention(
        K=Ks[0],
        dipole=True,
        importance=False,
        cluster_importance=False,
        logmass_dipole=False,
        mean_query_importance=False,
    )
    clustering = MultiLevelClustering(
        Qs=Qs,
        Ks=Ks,
        coupled_clustering=False,
        inner_iters=3,
        outer_iters=3,
        max_cluster_scale=1.5,
        compute_indices=True,
    )
    multi_attention = MultiLevelAttention(
        clustering=clustering,
    )
    causal_attention = CausalAttention(
        simple_attention=simple_attention,
        block_size=2048
    )

    def vectorize(fn):
        fn = jax.vmap(fn, in_axes=(1, None, None), out_axes=1) # gqa
        fn = jax.vmap(fn, in_axes=(1, 1, 1), out_axes=1) # multi-head
        fn = jax.vmap(fn, in_axes=(0, 0, 0), out_axes=0) # batched
        return fn

    exact = causal_attention.exact_attention
    exact = simple_attention.exact_attention
    exact_v = vectorize(exact)
    approx = causal_attention.fast_attention
    approx = multi_attention.cluster_then_attention
    approx_v = vectorize(approx)
    #cluster = simple_attention.do_clustering
    cluster = multi_attention.clustering.do_clustering
    attend = multi_attention.attend
    flash = simple_attention.flash_attention
    return exact, approx, cluster, attend, flash


def get_data(index=0, d=256, n=1024):
    data = np.load("fma/data/data.npz")
    ks, vs, qs = data["keys"], data["values"], data["queries"]
    ks, vs, qs = ks[index], vs[index], qs[index]
    #data = jnp.load("examples/minigpt/qkv_step_2600.npz")
    data = jnp.load("examples/minigpt/qkv_data_64_T1/qkv_step_15200.npz")
    ks, vs, qs = data["k"], data["v"], data["q"]
    initial_n, initial_d = ks.shape[-2], ks.shape[-1]
    ks = ks / jnp.sqrt(jnp.sqrt(initial_d))
    qs = qs / jnp.sqrt(jnp.sqrt(initial_d))
    ks, vs, qs = ks[0,:,index,:], vs[0,:,index,:], qs[0,:,index,:]
    #print(ks.shape)
    #exit()
    while n > ks.shape[-2]:
        # copy extend ks, vs and qs by copying
        ks = np.concatenate([ks, ks], axis=-2)
        vs = np.concatenate([vs, vs], axis=-2)
        qs = np.concatenate([qs, qs], axis=-2)
    #dtype = jnp.float16
    dtype = jnp.float32
    ks = jnp.asarray(ks[:n,:d], dtype=dtype)
    vs = jnp.asarray(vs[:n,:d], dtype=dtype)
    qs = jnp.asarray(qs[-n:,:d], dtype=dtype)

    def center(arr):
        return arr - jnp.mean(arr, axis=0, keepdims=True)
    ks = center(ks)
    vs = center(vs)
    qs = center(qs)
    perturb_vec = jax.random.normal(jax.random.PRNGKey(0), shape=(d,), dtype=dtype)
    perturb_mag = 0e2
    perturb_vec = perturb_vec / jnp.linalg.norm(perturb_vec) * perturb_mag
    qs = qs + perturb_vec

    vmags = jnp.linalg.norm(vs, axis=-1)
    quant95 = jnp.quantile(vmags, 0.95)
    # Rescale the values to have a maximum magnitude of quant95
    rescale = quant95 / jnp.maximum(quant95, vmags)
    vs = vs * rescale[:, None]

    upscale_qs_factor = 1e0
    qs = qs * upscale_qs_factor

    return ks, vs, qs

def analyze_clusters(qlabels, klabels):
    @partial(jax.vmap, in_axes=(None, 0), out_axes=0)
    def cluster_size(labels, idx):
        return jnp.sum(labels == idx)
    max_q = jnp.max(qlabels)
    q_sizes = cluster_size(qlabels, jnp.arange(max_q + 1))
    max_k = jnp.max(klabels)
    k_sizes = cluster_size(klabels, jnp.arange(max_k + 1))
    quantiles = np.array([0.1, 0.25, 0.5, 0.75, 0.9])
    print(f"Query clusters avg {np.mean(q_sizes):.0f}")
    print(f"Quantiles: {np.min(q_sizes)} {np.quantile(q_sizes, quantiles)} {np.max(q_sizes)}")
    print(f"Key clusters avg {np.mean(k_sizes):.0f}")
    print(f"Quantiles: {np.min(k_sizes)} {np.quantile(k_sizes, quantiles)} {np.max(k_sizes)}")

def report_accuracy(ex_lmass, ex_values, ap_lmass, ap_values):
    assert ex_lmass.shape == ap_lmass.shape, f"Expected log-mass shapes to match: {ex_lmass.shape} vs {ap_lmass.shape}"
    assert ex_values.shape == ap_values.shape, f"Expected values shapes to match: {ex_values.shape} vs {ap_values.shape}"
    assert ap_values.ndim == 2
    def center(arr):
        return arr - jnp.mean(arr, axis=0, keepdims=True)
    def corrcoef(arr1, arr2):
        flat1 = arr1.flatten()
        flat2 = arr2.flatten()
        return np.corrcoef(flat1, flat2)[0, 1]
    lm_corr = corrcoef(center(ex_lmass), center(ap_lmass))
    def mycorrcoef(arr1, arr2):
        N, D = arr1.shape
        N2, D2 = arr2.shape
        assert N == N2 and D == D2, f"Shapes must match: {N}x{D} vs {N2}x{D2}"
        carr1, carr2 = center(arr1), center(arr2)
        mom11 = jnp.mean(carr1 * carr1)
        mom12 = jnp.mean(carr1 * carr2)
        mom22 = jnp.mean(carr2 * carr2)
        return np.array(mom12 / (jnp.sqrt(mom11 * mom22) + 1e-30))
    assert jnp.all(jnp.isfinite(center(ap_values))), "Approx values contain NaNs or Infs"
    assert jnp.all(jnp.isfinite(center(ex_values))), "Exact values contain NaNs or Infs"
    #val_corr = corrcoef(center(ex_values), center(ap_values))
    my_val_corr = mycorrcoef(ex_values, ap_values)
    print(f"--- Accuracy Report ---")
    print(f"Log-mass correlation: {lm_corr:.3f}")
    print(f"Values correlation: {float(my_val_corr):.3f}")
    print(f"--- End of Report ---")


def analyze_live(N=8192, D=32, **algo_kwargs):
    profile_dir = "/concept/jax_profiles"
    import os
    os.makedirs(profile_dir, exist_ok=True)
    print("Loading data...")
    keys, values, queries = get_data(index=0, d=D, n=N)
    print("Data loaded.")
    print("getting algorithms...")
    exact, approx, cluster, attend, flash = get_algos(**algo_kwargs)
    
    def anno_block(fn, name):
        #@partial(profiler.annotate_function, name=name)
        def wrapped(*args, **kwargs):
            return jax.block_until_ready(fn(*args, **kwargs))
        return wrapped

    @jax.jit
    def cudnn_vjp(queries, keys, values):
        result, vjp_fn = jax.vjp(lambda q,k,v: flash_attention_cudnn(q,k,v)[1], queries, keys, values)
        return result, vjp_fn

    @jax.jit
    def approx_vjp(queries, keys, values):
        result, vjp_fn = jax.vjp(lambda q,k,v: approx(q,k,v)[1], queries, keys, values)
        return result, vjp_fn

    grad_probe = jax.random.normal(jax.random.PRNGKey(0), shape=values.shape, dtype=values.dtype)
    @jax.jit
    def cudnn_grad(queries, keys, values):
        @jax.grad
        def cudnn_loss(queries, keys, values):
            _, out = flash_attention_cudnn(queries, keys, values)
            return jnp.sum(out * grad_probe)
        return cudnn_loss(queries, keys, values)
    @jax.jit
    def approx_grad(queries, keys, values):
        @jax.grad
        def approx_loss(queries, keys, values):
            _, out = approx(queries, keys, values)
            return jnp.sum(out * grad_probe)
        return approx_loss(queries, keys, values)
    @jax.jit
    def exact_grad(queries, keys, values):
        @jax.grad
        def exact_loss(queries, keys, values):
            _, out = exact(queries, keys, values)
            return jnp.sum(out * grad_probe)
        return exact_loss(queries, keys, values)

    reps = 1000
    print("Compiling approx...")
    approx_comp = jax.jit(approx).lower(queries, keys, values).compile()
    approx_comp = anno_block(approx_comp, "approx_attention")
    print("Compiling exact...")
    exact_comp = jax.jit(exact).lower(queries, keys, values).compile()
    exact_comp = anno_block(exact_comp, "exact_attention")
    print("Compiling clustering...")
    cluster_comp = jax.jit(cluster).lower(queries, keys, values).compile()
    cluster_comp = anno_block(cluster_comp, "do_clustering")
    print("Compiling single level attention...")
    K = algo_kwargs.get("Ks", (16,))[0]
    n = math.ceil(N / K * 1.5)
    dummy_fwd_indices = jnp.zeros((K, n), dtype=jnp.int32)
    dummy_bwd_indices = jnp.zeros((N,), dtype=jnp.int32)
    attend_comp = jax.jit(attend).lower(queries, jnp.zeros((N,), dtype=jnp.int32), keys, jnp.zeros((N,), dtype=jnp.int32), values, qfwd_indices=dummy_fwd_indices, qbwd_indices=dummy_bwd_indices, kfwd_indices=dummy_fwd_indices, kbwd_indices=dummy_bwd_indices).compile()
    attend_comp = anno_block(attend_comp, "single_level_attention")
    print("Compiling flash attention...")
    flash_comp = jax.jit(flash).lower(queries, keys, values).compile()
    flash_comp = anno_block(flash_comp, "flash_attention")
    print("Compiling lucidrains flash attention...")
    lucidrains_flash_comp = jax.jit(flash_attention_pure_jax).lower(queries, keys, values).compile()
    lucidrains_flash_comp = anno_block(lucidrains_flash_comp, "lucidrains_flash_attention")
    print("Compiling cudnn flash attention...")
    flash_cudnn_comp = jax.jit(flash_attention_cudnn).lower(queries, keys, values).compile()
    flash_cudnn_comp = anno_block(flash_cudnn_comp, "flash_attention_cudnn")
    # warmup
    print("Warming up...")
    for i in range(100):
        jax.block_until_ready(cluster_comp(queries, keys, values))
    centered_keys = keys - jnp.mean(keys, axis=0, keepdims=True)
    #with profiler.trace(profile_dir, create_perfetto_link=False):
    if True:
        with nvtx.annotate("clustering"):
            print("Timing clustering...")
            start = time.time()
            for i in range(reps):
                qlabels, klabels = cluster_comp(queries, keys, values)
                if isinstance(qlabels, tuple):
                    _, qlabels, _, _ = qlabels
                if isinstance(klabels, tuple):
                    _, klabels, _, _ = klabels
                assert jnp.all(jnp.isfinite(qlabels)), "Cluster labels contain NaNs or Infs"
                assert jnp.all(jnp.isfinite(klabels)), "Cluster labels contain NaNs or Infs"
            #cluster(keys, queries, values)
            end = time.time()
            print(f"Clustering took {(end - start)/reps:.6f} seconds")
        analyze_clusters(qlabels, klabels)
        
        print("Timing approximate attention...")
        with nvtx.annotate("cluster_then_attention"):
            start = time.time()
            for i in range(reps):
                ap_lmass, ap_values = approx_comp(queries, keys, values)
                #assert jnp.all(jnp.isfinite(ap_lmass)), "Approx log-mass contains NaNs or Infs"
                #assert jnp.all(jnp.isfinite(ap_values)), "Approx values contain NaNs or Infs"
            #approx_comp(keys, queries, values)
            end = time.time()
            print(f"Approx attention took {(end - start)/reps:.6f} seconds")

        print("Timing flash attention...")
        start = time.time()
        for i in range(10):
            fl_lmass, fl_values = flash_comp(queries, keys, values)
            assert jnp.all(jnp.isfinite(fl_lmass)), "Flash log-mass contains NaNs or Infs"
            assert jnp.all(jnp.isfinite(fl_values)), "Flash values contain NaNs or Infs"
        #flash_comp(keys, queries, values)
        end = time.time()
        print(f"Flash attention took {(end - start)/10:.6f} seconds")

        with nvtx.annotate("clustering"):
            print("Timing clustering...")
            start = time.time()
            for i in range(reps):
                qlabels, klabels = cluster_comp(queries, keys, values)
            if isinstance(qlabels, tuple):
                qcounts, qlabels, qfwd_indices, qbwd_indices = qlabels
            if isinstance(klabels, tuple):
                kcounts, klabels, kfwd_indices, kbwd_indices = klabels
            #cluster(keys, queries, values)
            end = time.time()
            print(f"Clustering took {(end - start)/reps:.6f} seconds")

        print("Timing approximate attention...")
        with nvtx.annotate("cluster_then_attention"):
            start = time.time()
            for i in range(reps):
                ap_lmass, ap_values = approx_comp(queries, keys, values)
                #assert jnp.all(jnp.isfinite(ap_lmass)), "Approx log-mass contains NaNs or Infs"
                #assert jnp.all(jnp.isfinite(ap_values)), "Approx values contain NaNs or Infs"
            #approx_comp(keys, queries, values)
            end = time.time()
            print(f"Approx attention took {(end - start)/reps:.6f} seconds")

        with nvtx.annotate("single_level_attention"):
            print("Timing single level attention...")
            start = time.time()
            for i in range(reps):
                attend_comp(queries, qlabels, keys, klabels, values, 
                    qfwd_indices=qfwd_indices, qbwd_indices=qbwd_indices,
                    kfwd_indices=kfwd_indices, kbwd_indices=kbwd_indices)
            #attend(queries, qlabels, keys, klabels, values)
            end = time.time()
            print(f"Single level attention took {(end - start)/reps:.6f} seconds")

        warmup_reps = 100

        print("Timing cudnn flash attention...")
        start = time.time()
        for i in range(reps):
            cudnn_lmass, cudnn_values = flash_cudnn_comp(queries, keys, values)
        end = time.time()
        print(f"Cudnn flash attention took {(end - start)/reps:.6f} seconds")
        
        print("Approx vs cudnn flash attention accuracy:")
        report_accuracy(fl_lmass, cudnn_values, ap_lmass, ap_values)
        #exit()

        print("Timing cudnn flash attention vjp...")
        for i in range(warmup_reps):
            cudnn_values, cudnn_vjp_fn = jax.block_until_ready(cudnn_vjp(queries, keys, values))
        start = time.time()
        for i in range(reps):
            cudnn_values, cudnn_vjp_fn = jax.block_until_ready(cudnn_vjp(queries, keys, values))
        end = time.time()
        print(f"Cudnn flash attention vjp took {(end - start)/reps:.6f} seconds")

        print("Timing cudnn backward...")
        cudnn_vjp_fn = jax.jit(cudnn_vjp_fn)
        for i in range(warmup_reps):
            grads = jax.block_until_ready(cudnn_vjp_fn(jnp.zeros_like(cudnn_values)))
        start = time.time()
        for i in range(reps):
            grads = jax.block_until_ready(cudnn_vjp_fn(jnp.zeros_like(cudnn_values)))
        end = time.time()
        print(f"Cudnn flash attention backward took {(end - start)/reps:.6f} seconds")

        print("Timing cudnn grad...")
        for i in range(warmup_reps):
            grads = jax.block_until_ready(cudnn_grad(queries, keys, values))
        start = time.time()
        for i in range(reps):
            grads = jax.block_until_ready(cudnn_grad(queries, keys, values))
        end = time.time()
        print(f"Cudnn flash attention grad took {(end - start)/reps:.6f} seconds")

        print("Timing approximate attention vjp...")
        for i in range(warmup_reps):
            ap_values, ap_vjp_fn = jax.block_until_ready(approx_vjp(queries, keys, values))
        start = time.time()
        for i in range(reps):
            ap_values, ap_vjp_fn = jax.block_until_ready(approx_vjp(queries, keys, values))
        end = time.time()
        print(f"Approx attention vjp took {(end - start)/reps:.6f} seconds")

        print("Timing approximate backward...")
        ap_vjp_fn = jax.jit(ap_vjp_fn)
        for i in range(warmup_reps):
            grads = jax.block_until_ready(ap_vjp_fn(jnp.zeros_like(ap_values)))
        start = time.time()
        for i in range(reps):
            grads = jax.block_until_ready(ap_vjp_fn(jnp.zeros_like(ap_values)))
        end = time.time()
        print(f"Approx attention backward took {(end - start)/reps:.6f} seconds")

        print("Timing approximate grad...")
        for i in range(warmup_reps):
            grads = jax.block_until_ready(approx_grad(queries, keys, values))
        start = time.time()
        with nvtx.annotate("approximate_grad"):
            for i in range(reps):
                grads = jax.block_until_ready(approx_grad(queries, keys, values))
        end = time.time()
        print(f"Approx attention grad took {(end - start)/reps:.6f} seconds")


        print("Timing lucidrains flash attention...")
        start = time.time()
        for i in range(10):
            lr_lmass, lr_values = lucidrains_flash_comp(queries, keys, values)
        end = time.time()
        print(f"Lucidrains flash attention took {(end - start)/10:.6f} seconds")




        #exit()
        analyze_clusters(qlabels, klabels)
        print("Approx vs Flash attention accuracy:")
        report_accuracy(fl_lmass, fl_values, ap_lmass, ap_values)
        print("Approx vs cudnn flash attention accuracy:")
        report_accuracy(fl_lmass, cudnn_values, ap_lmass, ap_values)


        #exit()


    #if False:
        # time the execution of the exact attention
        print("Timing exact attention...")
        start = time.time()
        for i in range(reps):
            ex_lmass, ex_values = exact_comp(queries, keys, values)
        #exact_comp(keys, queries, values)
        end = time.time()
        print(f"Exact attention took {(end - start)/reps:.6f} seconds")

        print("Timing exact attention grad...")
        for i in range(warmup_reps):
            grads = jax.block_until_ready(exact_grad(queries, keys, values))
        start = time.time()
        for i in range(reps):
            grads = jax.block_until_ready(exact_grad(queries, keys, values))
        end = time.time()
        print(f"Exact attention grad took {(end - start)/reps:.6f} seconds")


        print("Flash vs Exact attention accuracy:")
        report_accuracy(ex_lmass, ex_values, fl_lmass, fl_values)
        print("Cudnn Flash vs Exact attention accuracy:")
        report_accuracy(ex_lmass, ex_values, ex_lmass, cudnn_values)
        print("Approx vs Exact attention accuracy:")
        report_accuracy(ex_lmass, ex_values, ap_lmass, ap_values)



def main():
    print("Starting jax stuff...")
    #print(f"devices: {jax.devices()}")
    x = jnp.array([1, 2, 3])
    jax.block_until_ready(x)
    print("Running analyze_live")
    analyze_live(
        N=2**13,
        D=2**6,
        Qs=(2**6,),
        Ks=(2**6,),
    )
    exit()
    analyze_algos(
        N=2**20,
        D=16
    )
if __name__ == "__main__":
    main()
