#!/usr/bin/env python
"""
quick_hf_forward.py

Usage:
    python quick_hf_forward.py gpt2
    python quick_hf_forward.py meta-llama/Meta-Llama-3-8B --device cuda
"""

import argparse
import torch
import sys
from transformers import AutoProcessor, AutoModelForMaskedLM, AutoModelForVision2Seq, AutoModelForCausalLM
from transformers.image_utils import load_image
from huggingface_hub import login  # Optional; only needed for gated/private repos.

# ------------------- CONFIG ---------------------- #
# Image path
IMAGE_PATH = "./assets/images/rococo.jpg"  # Path to the image you want to use.

# Prompt
PROMPT = "Describe this image in one word."  # The prompt to use with the image.
# PROMPT = "What is the word in the image. Answer only the word"  # The prompt to use with the image.
# PROMPT = "Complete the sentence as OCR system."  # The prompt to use with the image.

# -------------------------------------------------- #

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

def main(model_id: str):
    # 1️⃣  Download & load tokenizer + model
    print("🔄  Loading model …", file=sys.stderr)
    model = AutoModelForMaskedLM.from_pretrained(
            model_id,
            device_map="auto",
            trust_remote_code=True,
        )

    processor = AutoProcessor.from_pretrained(model_id)
    model.to(DEVICE).to(torch.float32)
    model.eval()
    print(f"✅  Model loaded: {model_id}", file=sys.stderr)
    print(f"⚙️  Device: {model.device}")

    # 2️⃣  Prepare a minimal prompt
    # Load images
    image = load_image(IMAGE_PATH)

    # Create input messages
    messages = [
        {
            "role": "user",
            "content": [
                {"type": "image"},
                {"type": "text", "text": PROMPT}
            ]
        },
    ]

    # Prepare inputs
    prompt = processor.apply_chat_template(messages, add_generation_prompt=True)

    prompt += " This is "
    if processor.tokenizer.mask_token is not None:
        prompt += f"{processor.tokenizer.mask_token}."
    print("=" * 50, "Prompt", "=" * 50)
    print(f"\n{prompt}\n")
    print("=" * 120)

    inputs = processor(text=prompt, images=[image], return_tensors="pt")
    inputs = inputs.to(DEVICE)

    # print the number of <image> tokens in the input
    image_tokid = processor.tokenizer.convert_tokens_to_ids("<image>")
    num_image_tokens = inputs.input_ids.eq(image_tokid).sum().item()
    print(f"Image size: {image.size}")
    print(f"Number of <image> tokens in the input: {num_image_tokens}")

    # 3️⃣  Single forward pass (no gradient computation needed)
    with torch.no_grad():
        outputs = model(**inputs)

    # 4️⃣  Verify everything worked
    k = 10
    pooling = "last"
    if processor.tokenizer.mask_token is not None:
        pooling = "mask"
    print("\n✅ Forward pass succeeded!")
    print(f"Logits tensor shape: {tuple(outputs.logits.shape)}")
    if pooling == "mask":
        masked_index = inputs.input_ids[0].tolist().index(processor.tokenizer.mask_token_id)
        top_k_next_token = outputs.logits[:, masked_index, :].topk(k=k).indices[0]
        top_k_logits = outputs.logits[0, masked_index, top_k_next_token]
    elif pooling == "last":
        top_k_next_token = outputs.logits[0, -1, :].topk(k=k).indices
        top_k_logits = outputs.logits[0, -1, top_k_next_token]
    elif pooling == "cls":
        top_k_next_token = outputs.logits[0, 0, :].topk(k=k).indices
        top_k_logits = outputs.logits[0, 0, top_k_next_token]
    elif pooling == "mean":
        mean_logits = (outputs.logits * inputs.attention_mask.unsqueeze(-1)).sum(dim=1) / inputs.attention_mask.sum(dim=1, keepdim=True)
        # mean_logits = outputs.logits.mean(dim=1)
        top_k_next_token = mean_logits.topk(k=k).indices[0]
        top_k_logits = mean_logits[top_k_next_token]
    else:
        raise ValueError(f"Unknown pooling method: {pooling}")
    print(f"\nTop {k} embeddding (pooling={pooling}):")
    for i, token in enumerate(top_k_next_token):
        token_str = processor.decode([token], skip_special_tokens=True)
        print(f"{(i + 1):2d}. {token_str} (logit: {top_k_logits[i].item():.4f})")
    print("\nDone! 🎉")

if __name__ == "__main__":
    parser = argparse.ArgumentParser(
        description="Quick test: load an HF model and do one forward pass."
    )
    parser.add_argument(
        "model_id",
        help="Model identifier on the Hub (e.g. 'gpt2', "
             "'EleutherAI/gpt-j-6B', 'meta-llama/Meta-Llama-3-8B')."
    )
    parser.add_argument(
        "--device",
        default="cuda" if torch.cuda.is_available() else "cpu",
        help="Device to run on: 'cpu', 'cuda', or a specific GPU id (e.g. 'cuda:1')."
    )
    parser.add_argument(
        "--hf_token",
        help="Optional HF access token for gated/private models."
    )
    args = parser.parse_args()

    if args.hf_token:
        login(token=args.hf_token)

    main(args.model_id)
