import jax
from chex import Array
from jax._src import config
config.update("jax_platforms", "cpu")  # Use CPU for this example
import numpy as np
from jax import numpy as jnp
from jax import random
from jax.scipy.special import logsumexp
from matplotlib import pyplot as plt
from absl import app, flags

from kmeans import hierarchical_kmeans, kmeans

FLAGS = flags.FLAGS
flags.DEFINE_integer("i", 0, "Index of the attention head to examine")
flags.DEFINE_integer("d", 256, "Dimension of the keys, values, and queries")
flags.DEFINE_integer("n", 256, "Context length")
flags.DEFINE_integer("seed", 42, "Random seed for reproducibility")
flags.DEFINE_integer("k", 16, "Number of clusters for k-means")
flags.DEFINE_boolean("dipole", False, "Use dipole attention")
flags.DEFINE_boolean("show", True, "Show plots")
flags.DEFINE_boolean("importance_sampling", False, "Use importance sampling for attention", short_name="s")
flags.DEFINE_boolean("importance_clustering", False, "Use importance clustering for attention", short_name="c")

def get_data(index=0, d=256, n=256):
    data = np.load("data/data.npz")
    ks, vs, qs = data["keys"], data["values"], data["queries"]
    return ks[index,:n,:d], vs[index,:n,:d], qs[index,-n:,:d]

def label_variance(keys: Array, values: Array, klabels: Array, label: int = None, qmean: Array = None) -> None:
    logmass = jnp.zeros((keys.shape[0],), dtype=jnp.float32)
    if qmean is not None:
        extra_logmass = jnp.einsum('ij,j->i', keys, qmean)
        logmass = logmass + extra_logmass - logsumexp(extra_logmass, axis=-1, keepdims=True) + logsumexp(logmass, axis=-1, keepdims=True)
    #logmass = logmass - logsumexp(logmass, axis=-1, keepdims=True)
    if label is not None:
        logmass = logmass + jnp.log((klabels == label).astype(jnp.float32))
    analyze_variances(keys, values, logmass)


def analyze_variances(keys: Array, values: Array, logmass: Array) -> None:
    #values = normalize_values(values, logmass)
    w = jnp.exp(logmass)
    #w = jax.nn.softmax(logmass, axis=-1)
    kmeansq = jnp.sum(w[:,None]*keys**2) / jnp.sum(w)
    kmean = jnp.sum(w[:,None]*keys, axis=-2) / jnp.sum(w)
    ksqmean = (kmean**2).sum()
    #ksqmean = (jnp.sum(w[:,None]*keys, axis=-2)**2).sum() / jnp.sum(w)**2

    vmeansq = jnp.sum(w[:,None]*values**2) / jnp.sum(w)
    vmean = jnp.sum(w[:,None]*values, axis=-2) / jnp.sum(w)
    vsqmean = (vmean**2).sum()
    #vsqmean = (jnp.sum(w[:,None]*values, axis=-2)**2).sum() / jnp.sum(w)**2
    print("Cluster size:", jnp.sum(w))
    print("Keys variance:", kmeansq - ksqmean)
    print("Values variance:", vmeansq - vsqmean)

    keys = keys - kmean
    values = values - vmean

    CC = jnp.sum(w[:, None, None] * keys[:, :, None] * keys[:, None, :], axis=0) / jnp.sum(w)
    VV = jnp.sum(w[:, None, None] * values[:, :, None] * values[:, None, :], axis=0) / jnp.sum(w)
    VC = jnp.sum(w[:, None, None] * values[:, :, None] * keys[:, None, :], axis=0) / jnp.sum(w)
    VVpred = VC @ jnp.linalg.pinv(CC) @ VC.T
    print("Values schur variance:", jnp.trace(VV - VVpred))


def normalize_values(values: Array, logmass: Array) -> Array:
    w = jnp.exp(logmass)
    vmean = jnp.sum(w[:, None] * values, axis=0) / jnp.sum(w)
    cvalues = values - vmean
    VV = jnp.sum(w[:, None, None] * cvalues[:, :, None] * cvalues[:, None, :], axis=0) / jnp.sum(w)
    VV = VV + jnp.eye(VV.shape[0])
    rootVV = jnp.linalg.cholesky(VV)
    irootVV = jnp.linalg.pinv(rootVV)
    normalized_values = cvalues @ irootVV
    return normalized_values + vmean

def centroid(xs: Array, xlabels: Array, label: int) -> Array:
    return jnp.mean(xs, where=(xlabels == label)[:,None], axis=0)
    


def main(argv):
    index = FLAGS.i
    D = FLAGS.d
    N = FLAGS.n
    K = FLAGS.k
    use_dipole = FLAGS.dipole
    use_importance_sampling = FLAGS.importance_sampling
    use_importance_clustering = FLAGS.importance_clustering
    seed = FLAGS.seed
    key = random.PRNGKey(seed)
    klabkey, qlabkey = random.split(key)


    keys, values, queries = get_data(index, D, N)
    #keys = keys - jnp.mean(keys, axis=0, keepdims=True)
    qmean = jnp.mean(queries, axis=0)
    #queries = queries - qmean
    #base_logmass = jnp.einsum('ij,j->i', keys, qmean)
    #base_logmass = base_logmass - logsumexp(base_logmass, axis=-1, keepdims=True)

    #qlabels = hierarchical_kmeans(xs=queries, key=qlabkey, k=K, levels=1, iters=10)
    qlabels = kmeans(xs=queries, key=qlabkey, k=K, iters=10)[1]
    klabel_xs = keys
    #klabel_xs = jnp.concatenate([keys, values], axis=-1)
    #klabel_xs = (keys[:,:,None] * values[:,None,:]).reshape(-1, D**2)
    #klabel_xs = keys * jnp.linalg.norm(values, axis=-1, keepdims=True)
    #klabels = hierarchical_kmeans(xs=klabel_xs, key=klabkey, k=K, levels=1, iters=10)
    kimplogmass = jnp.zeros((keys.shape[0],), dtype=jnp.float32)
    #kimplogmass = kimplogmass + jnp.einsum('ij,j->i', keys, qmean)
    if use_importance_clustering:
        kimplogmass = kimplogmass + jnp.log(jnp.linalg.norm(values, axis=-1) + 1e-1)
    klabels = kmeans(xs=klabel_xs, key=klabkey, k=K, iters=10, logmass=kimplogmass)[1]
    kindices = jnp.argsort(klabels)
    keys = keys[kindices]
    klabels = klabels[kindices]
    qindices = jnp.argsort(qlabels)
    qlabels = qlabels[qindices]
    queries = queries[qindices]
    values = values[kindices]

    # Analyze variances
    label_variance(keys, values, klabels)
    for label in range(K):
        print(f"----- Label {label} -----")
        label_variance(keys, values, klabels, label=label)
    print("----- Done -----")

    kcentroids = jnp.array([centroid(keys, klabels, label) for label in range(K)])
    qcentroids = jnp.array([centroid(queries, qlabels, label) for label in range(K)])
    vcentroids = jnp.array([centroid(values, klabels, label) for label in range(K)])

    basekeys = kcentroids[klabels]
    basequeries = qcentroids[qlabels]
    basevalues = vcentroids[klabels]

    reskeys = keys - basekeys
    resqueries = queries - basequeries
    resvalues = values - basevalues

    res_scores = jnp.einsum('ik,jk->ij', resqueries, reskeys)
    print("Residual Score var:", jnp.var(res_scores))

    kdetail_scores = jnp.einsum('ik,jk->ij', basequeries, reskeys)
    qdetail_scores = jnp.einsum('ik,jk->ij', resqueries, basekeys)
    base_scores = jnp.einsum('ik,jk->ij', basequeries, basekeys)



    # Compute attention scores
    causal_bias = jnp.log(np.tril(np.ones((N, N)), k=-1))
    attention_scores = jnp.einsum('ik,jk->ij', queries, keys) #+ base_logmass
    #attention_scores = attention_scores - logsumexp(attention_scores, axis=-1, keepdims=True)
    #attention_scores = attention_scores + causal_bias
    #attention_scores = attention_scores[-4096:,:4096]
    #attention_scores = attention_scores.reshape(N, 8, N//8)
    #attention_scores = logsumexp(attention_scores, axis=-1)

    scores_to_plot = base_scores + kdetail_scores + qdetail_scores
    scores_to_plot = attention_scores - res_scores
    #scores_to_plot = base_scores + res_scores
    #scores_to_plot = attention_scores
    print("Main Score var:", jnp.var(attention_scores - res_scores))
    print("Total Score var:", jnp.var(attention_scores))

    true_scores = attention_scores
    approx_scores = base_scores + kdetail_scores + qdetail_scores

    true_weights = jax.nn.softmax(true_scores, axis=-1)
    approx_weights = jax.nn.softmax(approx_scores, axis=-1)

    true_outputs = jnp.einsum('ij,jk->ik', true_weights, values)
    true_output_mean = jnp.mean(true_outputs, axis=0, keepdims=True)
    true_outputs = true_outputs - true_output_mean
    approx_outputs = jnp.einsum('ij,jk->ik', approx_weights, values) - true_output_mean

    true_logmass = logsumexp(true_scores, axis=-1)

    print("----- Output Statistics -----")
    print("True outputs var:", jnp.var(true_outputs))
    print("Approx outputs var:", jnp.var(approx_outputs))
    print("Error var:", jnp.var(true_outputs - approx_outputs))
    print("Correlation:", jnp.corrcoef(true_outputs.flatten(), approx_outputs.flatten())[0, 1])
    print("----- Simple Attention -----")
    from single_level_attention import SimpleAttention
    attn_algo = SimpleAttention(
        K=K, 
        dipole=use_dipole, 
        importance=use_importance_sampling,
        cluster_importance=use_importance_clustering
    )
    do_attn = attn_algo.single_level_attention
    sl_logmass, sl_values = do_attn(queries, qlabels, keys, klabels, values)
    sl_values = sl_values - true_output_mean
    print("----- Simple Attention Statistics -----")
    print("logmass correlation:", jnp.corrcoef(sl_logmass.flatten(), true_logmass.flatten())[0, 1])
    print("values correlation:", jnp.corrcoef(sl_values.flatten(), true_outputs.flatten())[0, 1])

    do_cluster_then_attention = attn_algo.cluster_then_attention
    keys, values, queries = get_data(index, D, N)
    slc_logmass, slc_values = do_cluster_then_attention(queries, keys, values)
    slc_logmass = slc_logmass[qindices]
    slc_values = slc_values[qindices]
    slc_values = slc_values - true_output_mean

    do_exact_attention = attn_algo.exact_attention
    exact_logmass, exact_values = do_exact_attention(queries, keys, values)
    exact_logmass = exact_logmass[qindices]
    exact_values = exact_values[qindices]
    exact_values = exact_values - true_output_mean
    print("----- Exact Attention Cross-Check -----")
    print("logmass correlation:", jnp.corrcoef(exact_logmass.flatten(), true_logmass.flatten())[0, 1])
    print("values correlation:", jnp.corrcoef(exact_values.flatten(), true_outputs.flatten())[0, 1])



    print("----- Cluster then Attention Statistics -----")
    print("logmass correlation:", jnp.corrcoef(slc_logmass.flatten(), true_logmass.flatten())[0, 1])
    print("values correlation:", jnp.corrcoef(slc_values.flatten(), true_outputs.flatten())[0, 1])


    if not FLAGS.show:
        return


    # Plot the attention matrix
    plt.figure(figsize=(10, 8))
    plt.imshow(scores_to_plot, cmap='viridis', aspect='auto', interpolation='nearest', vmin=-6, vmax=6)
    plt.colorbar(label='Attention Score')
    plt.title(f'Attention Matrix for Head {index}')
    plt.xlabel('Keys')
    plt.ylabel('Queries')
    plt.show()

    # Plot the attention matrix
    plt.figure(figsize=(10, 8))
    plt.imshow(res_scores, cmap='viridis', aspect='auto', interpolation='nearest', vmin=-6, vmax=6)
    plt.colorbar(label='Attention Score')
    plt.title(f'Attention Matrix for Head {index}')
    plt.xlabel('Keys')
    plt.ylabel('Queries')
    plt.show()


if __name__ == "__main__":
    app.run(main)
