"""Script for inspecting an npz file as used in this project.

Example invocation:

python -m fmri2music.scripts.inspect_npz \
    --file_path=data/music-emb/gtzan-emb-window10s-stride1_5s-mv101-avg.npz

"""

import os

import argparse
import numpy as np


def get_npz_content(file_path: str) -> tuple[list[str], np.ndarray]:
    """Returns the keys and vectors of an NPZ file."""
    with open(file_path, "rb") as f:
        content = np.load(f)
        keys = list(content["keys"])
        vecs = content["vecs"]
    return keys, vecs


def print_npz_preview(keys: list[str], vecs: np.ndarray) -> None:
    """Prints a summary preview of the NPZ file content."""
    print(f"Number of keys: {len(keys)}")
    print(f"Shape of vectors: {vecs.shape}")
    print(f"First key values: {keys[:5]}")
    print(f"First vector value ([:10]): {vecs[0,:10]}")


def print_full_npz_file(keys: list[str], vecs: np.ndarray) -> None:
    """Prints all vectors and keys of the NPZ file, one per line."""

    print("All vectors:")
    print(*vecs, sep="\n")

    print("All keys:")
    print(*keys, sep="\n")
    print("\n")


def main(args):
    """Main entrypoint of the script."""
    file_path = args.file_path

    if not os.path.isfile(file_path):
        raise ValueError(f"File path '{file_path}' does not exist.")

    if not file_path.endswith(".npz"):
        print(f"File path '{file_path}' does not end with .npz.")

    print(f"Reading file content from '{file_path}'...")
    keys, vecs = get_npz_content(file_path)

    print_npz_preview(keys, vecs)

    if args.print_all:
        print_full_npz_file(keys, vecs)
    else:
        print("To print all content, use --print_all=True.")


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Inspect NPZ file.")

    parser.add_argument(
        "--file_path", type=str, required=True, help="Path to the .npz file."
    )

    parser.add_argument(
        "--print_all",
        type=bool,
        default=False,
        help="Whether to print all .npz file content.",
    )

    main(parser.parse_args())
