import os
#from jax import numpy as jnp
import jax
import jax.numpy as jnp
import numpy as np
from collections import defaultdict

import matplotlib.pyplot as plt
from geometric_mean_transform import compute_transforms

def transform(ks, qs):
    #assert ks.shape == (16384, 256), "Keys should have shape (16384, 256)"
    #assert qs.shape == (16384, 256), "Queries should have shape (16384, 256)"
    Ck = jnp.einsum('jk,jl->kl', ks, ks) / ks.shape[-2]
    Cq = jnp.einsum('jk,jl->kl', qs, qs) / qs.shape[-2]
    Tk, Tq = compute_transforms(Ck, Cq)
    return ks @ Tk.T, qs @ Tq.T

D = 256  # Dimension to truncate to
def truncate_spectrum(xs, D=D):
    Cx = jnp.einsum('jk,jl->kl', xs, xs) / xs.shape[-2]
    u, s, vh = jnp.linalg.svd(Cx)
    return xs @ u[:, :D]


def main():
    # Directory containing the .npz files
    data_dir = "data"

    # Get all .npz files in the directory
    book_files = [f for f in os.listdir(data_dir) if f.endswith('.npz') and not f.endswith('data.npz')]
    book_files.sort()  # Sort to ensure consistent ordering

    # Extract book names (removing .npz extension)
    book_names = [os.path.splitext(f)[0] for f in book_files]

    # Dictionary to hold arrays grouped by key
    arrays_by_key = defaultdict(list)

    # Process each book file
    for book_name, book_file in zip(book_names, book_files):
        if book_name == "data": continue
        file_path = os.path.join(data_dir, book_file)
        with jnp.load(file_path) as data:
            # Get all keys in this file
            keys = list(data.keys())
            
            # Add arrays to their respective key groups
            for key in keys:
                arrays_by_key[key].append(jnp.array(data[key]))

    for key in arrays_by_key:
        # Convert lists to numpy arrays for easier handling
        print(arrays_by_key[key][0].shape)
        arrays_by_key[key] = jnp.concatenate(arrays_by_key[key], axis=0)

    # Print book names
    print("Book names:")
    for i, name in enumerate(book_names):
        print(f"{i}: {name}")
    print()

    # Print keys and array shapes for each book
    for key in arrays_by_key:
        print(f"Key: {key}")
        print("Shapes:")
        array = arrays_by_key[key]
        print(array.shape)
        print(array.dtype)
        #for i, arr in enumerate(arrays_by_key[key]):
        #    print(f"  {book_names[i]}: {arr.shape}")
        print()

    keys = arrays_by_key["keys"]
    assert keys.shape[-1] == 256, "Keys should have last dimension of size 256"
    assert keys.shape[-2] == 16384, "Keys should have second to last dimension of size 16384"
    keys = keys.reshape(-1, 16384, 256)
    #keys = keys[:, :8192, :]
    kmean = keys.mean(axis=-2, keepdims=True)
    queries = arrays_by_key["queries"]
    queries = queries.reshape(-1, 16384, 256)
    #queries = queries[:, 8192:, :]
    qmean = queries.mean(axis=-2, keepdims=True)
    values = arrays_by_key["values"]
    values = values.reshape(-1, 16384, 256)
    #values = values[:, 8192:, :]
    #values = values - values.mean(axis=-2, keepdims=True)
    jnp.savez(f"{data_dir}/raw_data.npz", keys=keys, queries=queries, values=values)

    keys = keys - kmean
    queries = queries - qmean
    keys, queries = jax.vmap(transform, in_axes=(0, 0))(keys, queries)
    keys = keys + kmean
    queries = queries + qmean

    # Reduce dimension for testing
    keys = jax.vmap(truncate_spectrum, in_axes=0)(keys)
    queries = jax.vmap(truncate_spectrum, in_axes=0)(queries)
    values = jax.vmap(truncate_spectrum, in_axes=0)(values)

    print("Keys shape:", keys.shape)
    print("Queries shape:", queries.shape)
    print("Values shape:", values.shape)

    jnp.savez(f"{data_dir}/data.npz", keys=keys, queries=queries, values=values)

    NH = 4
    Ck = jnp.einsum('ijk,ijl->ikl', keys[:NH], keys[:NH]) / keys.shape[-2]
    Qk = jnp.einsum('ijk,ijl->ikl', queries[:NH], queries[:NH]) / queries.shape[-2]
    #Qk_chol = jnp.linalg.cholesky(Qk + 1e-6 * jnp.eye(Qk.shape[-1]))

    #print("Ck shape:", Ck.shape)

    #QK = jnp.einsum('ikj,ikl,ilm->ijm', Qk_chol, Ck, Qk_chol)
    eig_k = jnp.linalg.eigvalsh(Ck)
    ln_eig_k = -jnp.log(eig_k).clip(min=-10, max=1e8)

    for lek in ln_eig_k.reshape(-1,1*D):
        #plt.hist(lek, bins=100, density=False, histtype='step')
        plt.hist(np.array(lek), bins=100, cumulative=True, density=False, histtype='step')

    plt.yscale('log')
    plt.show()
    


if __name__ == "__main__":
    main()
