import os
import sys
import torch
from argparse import ArgumentParser

sys.path.append(os.path.join(os.path.dirname(__file__), ".."))

from model import OnlineZZModel
from utils import setup_seed, real_size_of_codebook
from plot import visualize_attention_grid


def main():
    """Main function to inspect hypertoken embeddings and visualize attention weights."""
    setup_seed()

    # Argument parsing
    parser = ArgumentParser(
        description="Inspect hypertoken embeddings and visualize attention maps."
    )
    parser.add_argument(
        "--adapter", type=str, required=True, help="Path to model adapter."
    )
    parser.add_argument(
        "--hub-adapter", type=str, required=False, help="Optional hub adapter path."
    )
    parser.add_argument("--prompt", type=str, required=False, help="Input prompt text.")
    parser.add_argument(
        "--extra-vocab-size", type=int, default=None, help="Size of extra vocabulary."
    )
    args = parser.parse_args()

    # Load model
    model = OnlineZZModel.load_pretrained(
        args.adapter, args.hub_adapter, args.extra_vocab_size
    )

    # Enable debug mode for attention weights
    model.config.embedding_encoder.unsafe_config["attn_implementation"] = "debug"

    # Print vocabulary size
    print(f"Base Vocabulary Size: {model.config.initial_vocab_size}")
    print(f"Extra Vocabulary Size: {model.config.extra_vocab_size}")

    # Define input text
    input_text = "squeeze the juice squeeze the juice"
    input_ids = model.tokenizer([input_text])["input_ids"]
    print("Input IDs:", input_ids)

    # Compress input using LZW
    lzw_input_ids, codebook_tensor = model.lzw_compress(input_ids)

    # Print LZW token statistics
    real_size = real_size_of_codebook(codebook_tensor).item()
    print(f"LZW Input Shape: {lzw_input_ids.shape}")  # (B, L)
    print(f"Codebook Tensor Shape: {codebook_tensor.shape}")  # (B, V_E, M)
    print(f"Number of Hyper Tokens: {real_size}")

    # Compute embeddings
    base_token_embeddings = model.compute_codebook_embeddings(
        codebook_tensor
    )  # (B, V_E, M, D)
    all_hypertoken_embeddings, metadata = model.compute_all_hypertoken_embeddings(
        codebook_tensor
    )  # (B, V_E, D)

    print(
        "First Hyper Token Embedding:", all_hypertoken_embeddings[0, 0, :]
    )  # shape (D,)

    # Extract non-empty attention weights and mask
    non_empty_attn_tensor = metadata["attn_weight"][:real_size, ...]  # (B, S, S)
    non_empty_attn_mask = metadata["attn_mask"][:real_size, ...]  # (B, S, S)

    # Visualize attention
    visualize_attention_grid(
        attn_tensor=non_empty_attn_tensor, attn_mask=non_empty_attn_mask
    )


if __name__ == "__main__":
    main()
