import os
#from jax import numpy as jnp
import numpy as jnp
from collections import defaultdict

import matplotlib.pyplot as plt

def main():
    # Directory containing the .npz files
    data_dir = "data"

    # Get all .npz files in the directory
    book_files = [f for f in os.listdir(data_dir) if f.endswith('.npz')]
    book_files.sort()  # Sort to ensure consistent ordering

    # Extract book names (removing .npz extension)
    book_names = [os.path.splitext(f)[0] for f in book_files]

    # Dictionary to hold arrays grouped by key
    arrays_by_key = defaultdict(list)

    # Process each book file
    for book_file in book_files:
        file_path = os.path.join(data_dir, book_file)
        with jnp.load(file_path) as data:
            # Get all keys in this file
            keys = list(data.keys())
            
            # Add arrays to their respective key groups
            for key in keys:
                arrays_by_key[key].append(jnp.array(data[key]))

    for key in arrays_by_key:
        # Convert lists to numpy arrays for easier handling
        print(arrays_by_key[key][0].shape)
        arrays_by_key[key] = jnp.concatenate(arrays_by_key[key], axis=0)

    # Print book names
    print("Book names:")
    for i, name in enumerate(book_names):
        print(f"{i}: {name}")
    print()

    # Print keys and array shapes for each book
    for key in arrays_by_key:
        print(f"Key: {key}")
        print("Shapes:")
        array = arrays_by_key[key]
        print(array.shape)
        print(array.dtype)
        #for i, arr in enumerate(arrays_by_key[key]):
        #    print(f"  {book_names[i]}: {arr.shape}")
        print()

    keys = arrays_by_key["keys"]
    assert keys.shape[-1] == 256, "Keys should have last dimension of size 256"
    assert keys.shape[-2] == 16384, "Keys should have second to last dimension of size 16384"
    keys = keys.reshape(-1, 16384, 256)
    keys = keys[:, :8192, :]
    keys = keys - keys.mean(axis=-2, keepdims=True)
    queries = arrays_by_key["queries"]
    queries = queries.reshape(-1, 16384, 256)
    queries = queries[:, 8192:, :]
    queries = queries - queries.mean(axis=-2, keepdims=True)
    values = arrays_by_key["values"]
    values = values.reshape(-1, 16384, 256)
    values = values[:, 8192:, :]
    values = values - values.mean(axis=-2, keepdims=True)

    print("Keys shape:", keys.shape)
    print("Queries shape:", queries.shape)
    print("Values shape:", values.shape)

    NH = 4
    Ck = jnp.einsum('ijk,ijl->ikl', keys[:NH], keys[:NH]) / keys.shape[-2]
    Qk = jnp.einsum('ijk,ijl->ikl', queries[:NH], queries[:NH]) / queries.shape[-2]
    Qk_chol = jnp.linalg.cholesky(Qk + 1e-6 * jnp.eye(Qk.shape[-1]))

    print("Ck shape:", Ck.shape)

    QK = jnp.einsum('ikj,ikl,ilm->ijm', Qk_chol, Ck, Qk_chol)
    eig_k = jnp.linalg.eigvalsh(QK)
    ln_eig_k = -jnp.log(eig_k).clip(min=-10, max=1e8)

    for lek in ln_eig_k.reshape(-1,1*256):
        #plt.hist(lek, bins=100, density=False, histtype='step')
        plt.hist(lek, bins=100, cumulative=True, density=False, histtype='step')

    plt.yscale('log')
    plt.show()
    


if __name__ == "__main__":
    main()
