import jax
from chex import Array
from jax._src import config
config.update("jax_platforms", "cpu")  # Use CPU for this example
import numpy as np
from jax import numpy as jnp
from jax import random
from jax.scipy.special import logsumexp
from matplotlib import pyplot as plt
from absl import app, flags

from kmeans import hierarchical_kmeans, kmeans

FLAGS = flags.FLAGS
flags.DEFINE_integer("i", 0, "Index of the attention head to examine")
flags.DEFINE_integer("d", 256, "Dimension of the keys, values, and queries")
flags.DEFINE_integer("n", 4096, "Context length")
flags.DEFINE_integer("seed", 42, "Random seed for reproducibility")
flags.DEFINE_integer("k", 16, "Number of clusters for k-means")
flags.DEFINE_integer("block_size", 1024, "Block size for causal attention", short_name="b")
flags.DEFINE_boolean("dipole", False, "Use dipole attention")
flags.DEFINE_boolean("show", True, "Show plots")
flags.DEFINE_boolean("importance_sampling", False, "Use importance sampling for attention", short_name="s")
flags.DEFINE_boolean("importance_clustering", False, "Use importance clustering for attention", short_name="c")
flags.DEFINE_boolean("bf16", False, "Use bfloat16 for keys, values, and queries")

def get_data(index=0, d=256, n=256):
    data = np.load("fma/data/data.npz")
    ks, vs, qs = data["keys"], data["values"], data["queries"]
    ks = ks[index,:n,:d]
    vs = vs[index,:n,:d]
    qs = qs[index,:n,:d]
    if FLAGS.bf16:
        ks = jnp.array(ks).astype(jnp.bfloat16)
        vs = jnp.array(vs).astype(jnp.bfloat16)
        qs = jnp.array(qs).astype(jnp.bfloat16)
    return ks, vs, qs

def label_variance(keys: Array, values: Array, klabels: Array, label: int = None, qmean: Array = None) -> None:
    logmass = jnp.zeros((keys.shape[0],), dtype=jnp.float32)
    if qmean is not None:
        extra_logmass = jnp.einsum('ij,j->i', keys, qmean)
        logmass = logmass + extra_logmass - logsumexp(extra_logmass, axis=-1, keepdims=True) + logsumexp(logmass, axis=-1, keepdims=True)
    #logmass = logmass - logsumexp(logmass, axis=-1, keepdims=True)
    if label is not None:
        logmass = logmass + jnp.log((klabels == label).astype(jnp.float32))
    analyze_variances(keys, values, logmass)


def analyze_variances(keys: Array, values: Array, logmass: Array) -> None:
    #values = normalize_values(values, logmass)
    w = jnp.exp(logmass)
    #w = jax.nn.softmax(logmass, axis=-1)
    kmeansq = jnp.sum(w[:,None]*keys**2) / jnp.sum(w)
    kmean = jnp.sum(w[:,None]*keys, axis=-2) / jnp.sum(w)
    ksqmean = (kmean**2).sum()
    #ksqmean = (jnp.sum(w[:,None]*keys, axis=-2)**2).sum() / jnp.sum(w)**2

    vmeansq = jnp.sum(w[:,None]*values**2) / jnp.sum(w)
    vmean = jnp.sum(w[:,None]*values, axis=-2) / jnp.sum(w)
    vsqmean = (vmean**2).sum()
    #vsqmean = (jnp.sum(w[:,None]*values, axis=-2)**2).sum() / jnp.sum(w)**2
    print("Cluster size:", jnp.sum(w))
    print("Keys variance:", kmeansq - ksqmean)
    print("Values variance:", vmeansq - vsqmean)

    keys = keys - kmean
    values = values - vmean

    CC = jnp.sum(w[:, None, None] * keys[:, :, None] * keys[:, None, :], axis=0) / jnp.sum(w)
    VV = jnp.sum(w[:, None, None] * values[:, :, None] * values[:, None, :], axis=0) / jnp.sum(w)
    VC = jnp.sum(w[:, None, None] * values[:, :, None] * keys[:, None, :], axis=0) / jnp.sum(w)
    VVpred = VC @ jnp.linalg.pinv(CC) @ VC.T
    print("Values schur variance:", jnp.trace(VV - VVpred))


def normalize_values(values: Array, logmass: Array) -> Array:
    w = jnp.exp(logmass)
    vmean = jnp.sum(w[:, None] * values, axis=0) / jnp.sum(w)
    cvalues = values - vmean
    VV = jnp.sum(w[:, None, None] * cvalues[:, :, None] * cvalues[:, None, :], axis=0) / jnp.sum(w)
    VV = VV + jnp.eye(VV.shape[0])
    rootVV = jnp.linalg.cholesky(VV)
    irootVV = jnp.linalg.pinv(rootVV)
    normalized_values = cvalues @ irootVV
    return normalized_values + vmean

def centroid(xs: Array, xlabels: Array, label: int) -> Array:
    return jnp.mean(xs, where=(xlabels == label)[:,None], axis=0)
    


def main(argv):
    index = FLAGS.i
    D = FLAGS.d
    N = FLAGS.n
    K = FLAGS.k
    use_dipole = FLAGS.dipole
    use_importance_sampling = FLAGS.importance_sampling
    use_importance_clustering = FLAGS.importance_clustering
    seed = FLAGS.seed
    block_size = FLAGS.block_size


    keys, values, queries = get_data(index, D, N)
    from single_level_attention import SimpleAttention, CausalAttention
    attn_algo = SimpleAttention(
        K=K, 
        dipole=use_dipole, 
        importance=use_importance_sampling,
        cluster_importance=use_importance_clustering
    )
    do_cluster_then_attention = attn_algo.cluster_then_attention
    slc_logmass, slc_values = do_cluster_then_attention(queries, keys, values)
    slc_logmass = slc_logmass
    slc_values = slc_values
    true_output_mean = jnp.mean(slc_values, axis=0, keepdims=True)
    slc_values = slc_values - true_output_mean

    do_exact_attention = attn_algo.exact_attention
    exact_logmass, exact_values = do_exact_attention(queries, keys, values)
    exact_logmass = exact_logmass
    exact_values = exact_values
    exact_values = exact_values - true_output_mean
    print("----- Exact Attention Cross-Check -----")
    print("logmass correlation:", jnp.corrcoef(exact_logmass.flatten(), slc_logmass.flatten())[0, 1])
    print("values correlation:", jnp.corrcoef(exact_values.flatten(), slc_values.flatten())[0, 1])

    causal_algo = CausalAttention(
        simple_attention=attn_algo,
        block_size=block_size
    )
    do_exact_causal_attention = causal_algo.exact_attention
    do_exact_block_attention = causal_algo.block_attention
    do_approx_block_attention = causal_algo.fast_attention

    causal_logmass, causal_values = do_exact_causal_attention(queries, keys, values)
    causal_output_mean = jnp.mean(exact_values, axis=0, keepdims=True)
    causal_values = causal_values - causal_output_mean

    print("----- Causal Block Attention Cross-Check -----")
    block_logmass, block_values = do_exact_block_attention(queries, keys, values)
    block_values = block_values - causal_output_mean
    print("logmass correlation:", jnp.corrcoef(causal_logmass[1:].flatten(), block_logmass[1:].flatten())[0, 1])
    print("values correlation:", jnp.corrcoef(causal_values[1:].flatten(), block_values[1:].flatten())[0, 1])

    print("----- Approximate Block Attention Cross-Check -----")
    approx_logmass, approx_values = do_approx_block_attention(queries, keys, values)
    approx_values = approx_values - causal_output_mean
    print("logmass correlation:", jnp.corrcoef(causal_logmass[1:].flatten(), approx_logmass[1:].flatten())[0, 1])
    print("values correlation:", jnp.corrcoef(causal_values[1:].flatten(), approx_values[1:].flatten())[0, 1])
    

if __name__ == "__main__":
    app.run(main)
