import os
#os.environ['XLA_FLAGS'] = '--xla_dump_to=/tmp/xla_dump --xla_dump_hlo_as_text --xla_dump_hlo_pass_re=.*'

import jax
#jax.config.update("jax_log_compiles", True)
from jax import numpy as jnp
#from jax import profiler
import numpy as np
from .single_level_attention import SimpleAttention, CausalAttention, FastSimpleAttention, StubClustering, MultiLevelAttention, MultiLevelClustering
import time
from functools import partial
import nvtx

def get_algos(Qs=(16,), Ks=(16,)):
    simple_attention = FastSimpleAttention(
        K=Ks[0],
        dipole=True,
        importance=False,
        cluster_importance=False,
        logmass_dipole=False,
        mean_query_importance=False,
    )
    clustering = MultiLevelClustering(
        Qs=Qs,
        Ks=Ks,
        coupled_clustering=False,
        inner_iters=3,
        outer_iters=3,
        max_cluster_scale=1.5,
    )
    multi_attention = MultiLevelAttention(
        clustering=clustering,
    )
    causal_attention = CausalAttention(
        simple_attention=simple_attention,
        block_size=2048
    )

    def vectorize(fn):
        fn = jax.vmap(fn, in_axes=(1, None, None), out_axes=1) # gqa
        fn = jax.vmap(fn, in_axes=(1, 1, 1), out_axes=1) # multi-head
        fn = jax.vmap(fn, in_axes=(0, 0, 0), out_axes=0) # batched
        return fn

    exact = causal_attention.exact_attention
    exact = simple_attention.exact_attention
    exact_v = vectorize(exact)
    approx = causal_attention.fast_attention
    approx = multi_attention.cluster_then_attention
    approx_v = vectorize(approx)
    #cluster = simple_attention.do_clustering
    cluster = multi_attention.clustering.do_clustering
    attend = multi_attention.attend
    flash = simple_attention.flash_attention
    return exact, approx, cluster, attend, flash


def analyze_algos(N: int, D: int, B: int = 1, H: int = 4, G: int = 2, **algo_args):
    exact, approx, cluster, attend = get_algos(**algo_args)
    q_spec = jax.ShapeDtypeStruct(shape=(N, D), dtype=jnp.float32)
    k_spec = jax.ShapeDtypeStruct(shape=(N, D), dtype=jnp.float32)
    v_spec = jax.ShapeDtypeStruct(shape=(N, D), dtype=jnp.float32)
    label_spec = jax.ShapeDtypeStruct(shape=(N,), dtype=jnp.int32)

    #exact_lower = jax.jit(exact).lower(
    #    q_spec, k_spec, v_spec
    #)
    #exact_compiled = exact_lower.compile()
    #print(f"Attention cost_analysis:\n{exact_compiled.cost_analysis()}")
    #print(f"Attention memory_analysis:\n{exact_compiled.memory_analysis()}")

    exact_lower = jax.jit(approx).lower(
        q_spec, k_spec, v_spec
    )
    exact_compiled = exact_lower.compile()
    print(f"My attention cost_analysis:\n{exact_compiled.cost_analysis()}")
    print(f"My attention memory_analysis:\n{exact_compiled.memory_analysis()}")

    exact_lower = jax.jit(cluster).lower(
        q_spec, k_spec, v_spec
    )
    exact_compiled = exact_lower.compile()
    print(f"My cluster cost_analysis:\n{exact_compiled.cost_analysis()}")
    print(f"My cluster memory_analysis:\n{exact_compiled.memory_analysis()}")

    exact_lower = jax.jit(attend).lower(
        q_spec, label_spec, k_spec, label_spec, v_spec
    )
    exact_compiled = exact_lower.compile()
    print(f"My attend cost_analysis:\n{exact_compiled.cost_analysis()}")
    print(f"My attend memory_analysis:\n{exact_compiled.memory_analysis()}")
    exit()

    q_spec_v = jax.ShapeDtypeStruct(shape=(B,N,H,G,D), dtype=jnp.float32)
    k_spec_v = jax.ShapeDtypeStruct(shape=(B,N,H,D), dtype=jnp.float32)
    v_spec_v = jax.ShapeDtypeStruct(shape=(B,N,H,D), dtype=jnp.float32)

    exact_lower = jax.jit(approx_v).lower(
        q_spec_v, k_spec_v, v_spec_v
    )
    #print(f"Exact attention cost analysis:\n{exact_lower.cost_analysis()}")
    exact_compiled = exact_lower.compile()
    print(f"Batched attention cost_analysis:\n{exact_compiled.cost_analysis()}")
    print(f"Batched attention memory_analysis:\n{exact_compiled.memory_analysis()}")

def get_data(index=0, d=256, n=1024):
    data = np.load("fma/data/data.npz")
    ks, vs, qs = data["keys"], data["values"], data["queries"]
    ks, vs, qs = ks[index], vs[index], qs[index]
    while n > ks.shape[-2]:
        # copy extend ks, vs and qs by copying
        ks = np.concatenate([ks, ks], axis=-2)
        vs = np.concatenate([vs, vs], axis=-2)
        qs = np.concatenate([qs, qs], axis=-2)
    #dtype = jnp.bfloat16
    dtype = jnp.float32
    ks = jnp.asarray(ks[:n,:d], dtype=dtype)
    vs = jnp.asarray(vs[:n,:d], dtype=dtype)
    qs = jnp.asarray(qs[-n:,:d], dtype=dtype)

    def center(arr):
        return arr - jnp.mean(arr, axis=0, keepdims=True)
    ks = center(ks)
    vs = center(vs)
    #qs = center(qs)
    perturb_vec = jax.random.normal(jax.random.PRNGKey(0), shape=(d,), dtype=dtype)
    perturb_mag = 0e2
    perturb_vec = perturb_vec / jnp.linalg.norm(perturb_vec) * perturb_mag
    qs = qs + perturb_vec

    vmags = jnp.linalg.norm(vs, axis=-1)
    quant95 = jnp.quantile(vmags, 0.95)
    # Rescale the values to have a maximum magnitude of quant95
    rescale = quant95 / jnp.maximum(quant95, vmags)
    vs = vs * rescale[:, None]

    upscale_qs_factor = 1e0
    qs = qs * upscale_qs_factor

    return ks, vs, qs

def analyze_clusters(qlabels, klabels):
    @partial(jax.vmap, in_axes=(None, 0), out_axes=0)
    def cluster_size(labels, idx):
        return jnp.sum(labels == idx)
    max_q = jnp.max(qlabels)
    q_sizes = cluster_size(qlabels, jnp.arange(max_q + 1))
    max_k = jnp.max(klabels)
    k_sizes = cluster_size(klabels, jnp.arange(max_k + 1))
    quantiles = np.array([0.1, 0.25, 0.5, 0.75, 0.9])
    print(f"Query clusters avg {np.mean(q_sizes):.0f}")
    print(f"Quantiles: {np.min(q_sizes)} {np.quantile(q_sizes, quantiles)} {np.max(q_sizes)}")
    print(f"Key clusters avg {np.mean(k_sizes):.0f}")
    print(f"Quantiles: {np.min(k_sizes)} {np.quantile(k_sizes, quantiles)} {np.max(k_sizes)}")

def report_accuracy(ex_lmass, ex_values, ap_lmass, ap_values):
    assert ex_lmass.shape == ap_lmass.shape, f"Expected log-mass shapes to match: {ex_lmass.shape} vs {ap_lmass.shape}"
    assert ex_values.shape == ap_values.shape, f"Expected values shapes to match: {ex_values.shape} vs {ap_values.shape}"
    assert ap_values.ndim == 2
    def center(arr):
        return arr - jnp.mean(arr, axis=0, keepdims=True)
    def corrcoef(arr1, arr2):
        flat1 = arr1.flatten()
        flat2 = arr2.flatten()
        return np.corrcoef(flat1, flat2)[0, 1]
    lm_corr = corrcoef(center(ex_lmass), center(ap_lmass))
    def mycorrcoef(arr1, arr2):
        N, D = arr1.shape
        N2, D2 = arr2.shape
        assert N == N2 and D == D2, f"Shapes must match: {N}x{D} vs {N2}x{D2}"
        carr1, carr2 = center(arr1), center(arr2)
        mom11 = jnp.mean(carr1 * carr1)
        mom12 = jnp.mean(carr1 * carr2)
        mom22 = jnp.mean(carr2 * carr2)
        return np.array(mom12 / (jnp.sqrt(mom11 * mom22) + 1e-30))
    assert jnp.all(jnp.isfinite(center(ap_values))), "Approx values contain NaNs or Infs"
    assert jnp.all(jnp.isfinite(center(ex_values))), "Exact values contain NaNs or Infs"
    #val_corr = corrcoef(center(ex_values), center(ap_values))
    my_val_corr = mycorrcoef(ex_values, ap_values)
    print(f"--- Accuracy Report ---")
    print(f"Log-mass correlation: {lm_corr:.3f}")
    print(f"Values correlation: {my_val_corr:.3f}")
    print(f"--- End of Report ---")


def analyze_live(N=8192, D=32, **algo_kwargs):
    profile_dir = "/concept/jax_profiles"
    import os
    os.makedirs(profile_dir, exist_ok=True)
    print("Loading data...")
    keys, values, queries = get_data(index=0, d=D, n=N)
    print("Data loaded.")
    print("getting algorithms...")
    exact, approx, cluster, attend, flash = get_algos(**algo_kwargs)
    
    def anno_block(fn, name):
        #@partial(profiler.annotate_function, name=name)
        def wrapped(*args, **kwargs):
            return jax.block_until_ready(fn(*args, **kwargs))
        return wrapped

    reps = 100
    print("Compiling approx...")
    approx_comp = jax.jit(approx).lower(queries, keys, values).compile()
    approx_comp = anno_block(approx_comp, "approx_attention")
    print("Compiling exact...")
    exact_comp = jax.jit(exact).lower(queries, keys, values).compile()
    exact_comp = anno_block(exact_comp, "exact_attention")
    print("Compiling clustering...")
    cluster_comp = jax.jit(cluster).lower(queries, keys, values).compile()
    cluster_comp = anno_block(cluster_comp, "do_clustering")
    print("Compiling single level attention...")
    attend_comp = jax.jit(attend).lower(queries, jnp.zeros((N,), dtype=jnp.int32), keys, jnp.zeros((N,), dtype=jnp.int32), values).compile()
    attend_comp = anno_block(attend_comp, "single_level_attention")
    print("Compiling flash attention...")
    flash_comp = jax.jit(flash).lower(queries, keys, values).compile()
    flash_comp = anno_block(flash_comp, "flash_attention")
    # warmup
    print("Warming up...")
    for i in range(10):
        jax.block_until_ready(cluster_comp(queries, keys, values))
    centered_keys = keys - jnp.mean(keys, axis=0, keepdims=True)
    #with profiler.trace(profile_dir, create_perfetto_link=False):
    if True:
        with nvtx.annotate("clustering"):
            print("Timing clustering...")
            start = time.time()
            for i in range(reps):
                qlabels, klabels = cluster_comp(queries, keys, values)
            #cluster(keys, queries, values)
            end = time.time()
            print(f"Clustering took {(end - start)/reps:.4f} seconds")
        
        print("Timing approximate attention...")
        with nvtx.annotate("cluster_then_attention"):
            start = time.time()
            for i in range(reps):
                ap_lmass, ap_values = approx_comp(queries, keys, values)
            #approx_comp(keys, queries, values)
            end = time.time()
            print(f"Approx attention took {(end - start)/reps:.4f} seconds")

        print("Timing flash attention...")
        start = time.time()
        for i in range(reps):
            fl_lmass, fl_values = flash_comp(queries, keys, values)
        #flash_comp(keys, queries, values)
        end = time.time()
        print(f"Flash attention took {(end - start)/reps:.4f} seconds")

        with nvtx.annotate("clustering"):
            print("Timing clustering...")
            start = time.time()
            for i in range(reps):
                qlabels, klabels = cluster_comp(queries, keys, values)
            #cluster(keys, queries, values)
            end = time.time()
            print(f"Clustering took {(end - start)/reps:.4f} seconds")

        with nvtx.annotate("single_level_attention"):
            print("Timing single level attention...")
            start = time.time()
            for i in range(reps):
                attend_comp(queries, qlabels, keys, klabels, values)
            #attend(queries, qlabels, keys, klabels, values)
            end = time.time()
            print(f"Single level attention took {(end - start)/reps:.4f} seconds")


        #exit()
        analyze_clusters(qlabels, klabels)
        print("Approx vs Flash attention accuracy:")
        report_accuracy(fl_lmass, fl_values, ap_lmass, ap_values)

        #exit()


    #if False:
        # time the execution of the exact attention
        print("Timing exact attention...")
        start = time.time()
        for i in range(reps):
            ex_lmass, ex_values = exact_comp(queries, keys, values)
        #exact_comp(keys, queries, values)
        end = time.time()
        print(f"Exact attention took {(end - start)/reps:.4f} seconds")

        print("Flash vs Exact attention accuracy:")
        report_accuracy(ex_lmass, ex_values, fl_lmass, fl_values)
        print("Approx vs Exact attention accuracy:")
        report_accuracy(ex_lmass, ex_values, ap_lmass, ap_values)


def main():
    print("Starting jax stuff...")
    #print(f"devices: {jax.devices()}")
    x = jnp.array([1, 2, 3])
    jax.block_until_ready(x)
    print("Running analyze_live")
    analyze_live(
        N=2**12,
        D=2**6,
        Qs=(2**6,),
        Ks=(2**6,),
    )
    exit()
    analyze_algos(
        N=2**20,
        D=16
    )
if __name__ == "__main__":
    main()
