#!/usr/bin/env python
"""
quick_hf_forward.py

Usage:
    python quick_hf_forward.py gpt2
    python quick_hf_forward.py meta-llama/Meta-Llama-3-8B --device cuda
"""

import argparse
import torch
from transformers import AutoTokenizer, AutoModel, AutoModelForMaskedLM
from transformers.image_utils import load_image
from huggingface_hub import login  # Optional; only needed for gated/private repos.

POOL_EXAMPLES = [
    "The [MASK] barks loudly.",
    "The capital of France is [MASK].",
    "The [MASK] is the largest mammal on Earth.",
    "The cat is [MASK] on the mat.",
]

def main(model_id: str, device: str):
    # 1️⃣  Download & load tokenizer + model
    tokenizer = AutoTokenizer.from_pretrained(model_id)
    model = AutoModelForMaskedLM.from_pretrained(
        model_id,
        torch_dtype=torch.float32,
        trust_remote_code=True,   # lets you load community models with custom code
    ).to(device)
    print(f"Model loaded: {model_id}\n")

    # 2️⃣  Prepare a minimal prompt
    # Load images
    text = POOL_EXAMPLES[1].replace("[MASK]", tokenizer.mask_token)
    inputs = tokenizer(text, return_tensors="pt")
    inputs = inputs.to(device)

    print("=" * 20, "Tokenized sequence", "=" * 20)
    print(f"\n{tokenizer.decode(inputs.input_ids[0], skip_special_tokens=False)}\n")
    print("=" * 60)

    # 3️⃣  Single forward pass (no gradient computation needed)
    with torch.no_grad():
        outputs = model(**inputs)

    # 4️⃣  Verify everything worked
    k = 10
    print("\n✅ Forward pass succeeded!")
    print(f"Logits tensor shape: {tuple(outputs.logits.shape)}")
    masked_index = inputs.input_ids[0].tolist().index(tokenizer.mask_token_id)
    print(f"Masked token index: {masked_index} (token: {tokenizer.decode(inputs.input_ids[0][masked_index])})")
    top_k_tokens = outputs.logits[:, masked_index, :].topk(k=k).indices[0]
    print(f"\nTop {k} embeddding:")
    for i, token in enumerate(top_k_tokens):
        token_str = tokenizer.decode(token)
        print(f"{i + 1}. {token_str} (logit: {outputs.logits[0, masked_index, token].item():.4f})")
    print("\nDone!")

if __name__ == "__main__":
    parser = argparse.ArgumentParser(
        description="Quick test: load an HF model and do one forward pass."
    )
    parser.add_argument(
        "model_id",
        help="Model identifier on the Hub (e.g. 'gpt2', "
             "'EleutherAI/gpt-j-6B', 'meta-llama/Meta-Llama-3-8B')."
    )
    parser.add_argument(
        "--device",
        default="cuda" if torch.cuda.is_available() else "cpu",
        help="Device to run on: 'cpu', 'cuda', or a specific GPU id (e.g. 'cuda:1')."
    )
    parser.add_argument(
        "--hf_token",
        help="Optional HF access token for gated/private models."
    )
    args = parser.parse_args()

    if args.hf_token:
        login(token=args.hf_token)

    main(args.model_id, args.device)
