"""
verify_prismatic.py

Given an HF-exported Prismatic model, attempt to load via AutoClasses, and verify forward() and generate().
"""

import time

import requests
import torch
from PIL import Image
from transformers import AutoModelForVision2Seq, AutoProcessor

# === Verification Arguments ===
MODEL_PATH = "TRI-ML/prismatic-siglip-224px-7b"
DEFAULT_IMAGE_URL = (
    "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/beignets-task-guide.png"
)

if "-prism-" in MODEL_PATH:
    SAMPLE_PROMPTS_FOR_GENERATION = [
        "In: What is sitting in the coffee?\nOut:",
        "In: What's the name of the food on the plate?\nOut:",
        "In: caption.\nOut:",
        "In: how many beinets..?\nOut:",
        "In: Can you give me a lyrical description of the scene\nOut:",
    ]
else:
    SYSTEM_PROMPT = (
        "A chat between a curious user and an artificial intelligence assistant. "
        "The assistant gives helpful, detailed, and polite answers to the user's questions."
    )
    SAMPLE_PROMPTS_FOR_GENERATION = [
        f"{SYSTEM_PROMPT} USER: What is sitting in the coffee? ASSISTANT:",
        f"{SYSTEM_PROMPT} USER: What's the name of the food on the plate? ASSISTANT:",
        f"{SYSTEM_PROMPT} USER: caption. ASSISTANT:",
        f"{SYSTEM_PROMPT} USER: how many beinets..? ASSISTANT:",
        f"{SYSTEM_PROMPT} USER: Can you give me a lyrical description of the scene ASSISTANT:",
    ]


@torch.inference_mode()
def verify_prismatic() -> None:
    print(f"[*] Verifying PrismaticForConditionalGeneration using Model `{MODEL_PATH}`")
    device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

    # Load Processor & VLM
    print("[*] Instantiating Processor and Pretrained VLM")
    processor = AutoProcessor.from_pretrained(MODEL_PATH, trust_remote_code=True)

    # === AUTOCAST MODE ===
    # print("[*] Loading in BF16 Autocast Mode")
    # vlm = AutoModelForVision2Seq.from_pretrained(MODEL_PATH, low_cpu_mem_usage=True, trust_remote_code=True).to(
    #     device, dtype=torch.bfloat16
    # )

    # === NATIVE BFLOAT16 MODE ===
    # print("[*] Loading in BF16")
    # vlm = AutoModelForVision2Seq.from_pretrained(
    #     MODEL_PATH, torch_dtype=torch.bfloat16, low_cpu_mem_usage=True, trust_remote_code=True
    # ).to(device)

    # === BFLOAT16 + FLASH-ATTN MODE :: [~14GB of VRAM Passive || 18GB of VRAM Active] ===
    print("[*] Loading in BF16 with Flash-Attention Enabled")
    vlm = AutoModelForVision2Seq.from_pretrained(
        MODEL_PATH,
        attn_implementation="flash_attention_2",
        torch_dtype=torch.bfloat16,
        low_cpu_mem_usage=True,
        trust_remote_code=True,
    ).to(device)

    # === 8-BIT QUANTIZATION MODE (`pip install bitsandbytes`) :: [~9GB of VRAM Passive || 10GB of VRAM Active] ===
    # print("[*] Loading in 8-Bit Quantization Mode")
    # vlm = AutoModelForVision2Seq.from_pretrained(
    #     MODEL_PATH,
    #     attn_implementation="flash_attention_2",
    #     torch_dtype=torch.float16,
    #     quantization_config=BitsAndBytesConfig(load_in_8bit=True),
    #     low_cpu_mem_usage=True,
    #     trust_remote_code=True,
    # )

    # === 4-BIT QUANTIZATION MODE (`pip install bitsandbytes`) :: [~6GB of VRAM Passive || 7GB of VRAM Active] ===
    # print("[*] Loading in 4-Bit Quantization Mode")
    # vlm = AutoModelForVision2Seq.from_pretrained(
    #     MODEL_PATH,
    #     attn_implementation="flash_attention_2",
    #     torch_dtype=torch.float16,
    #     quantization_config=BitsAndBytesConfig(load_in_4bit=True),
    #     low_cpu_mem_usage=True,
    #     trust_remote_code=True,
    # )

    # Iterate over Sample Prompts =>> Generate
    image = Image.open(requests.get(DEFAULT_IMAGE_URL, stream=True).raw).convert("RGB")
    num_tokens, total_time = 0, 0.0

    print("[*] Iterating over Sample Prompts\n===\n")
    for idx, prompt in enumerate(SAMPLE_PROMPTS_FOR_GENERATION):
        # === AUTOCAST MODE (Reproduces Prismatic `scripts/generate.py`) ===
        # inputs = processor(prompt, image).to(device)
        #
        # # Using "autocast" to evaluate bit-wise equivalence to `scripts/generate.py`
        # #   =>> Running in native BF16 is also fine (but leads to slightly different generations)
        # with torch.autocast("cuda", dtype=torch.bfloat16, enabled=True):
        #     gen_ids = vlm.generate(**inputs, do_sample=False, min_length=1, max_length=512)

        # === BFLOAT16 MODE ===
        inputs = processor(prompt, image).to(device, dtype=torch.bfloat16)

        # === 8-BIT/4-BIT QUANTIZATION MODE ===
        # inputs = processor(prompt, image).to(device, dtype=torch.float16)

        # Run Inference
        gen_ids = None
        for _ in range(5):
            start_time = time.time()
            gen_ids = vlm.generate(**inputs, do_sample=False, min_length=1, max_length=512)
            total_time += time.time() - start_time

            gen_ids = gen_ids[0, inputs.input_ids.shape[1] :]
            num_tokens += len(gen_ids)

        # ===
        gen_text = processor.decode(gen_ids, skip_special_tokens=True).strip()
        print(f"[{idx + 1}] Input Prompt => {prompt}\n    Generated    => {gen_text}\n")

    # Compute Tokens / Second
    print(f"[*] Generated Tokens per Second = {num_tokens / total_time} w/ {num_tokens = } and {total_time = }")


if __name__ == "__main__":
    verify_prismatic()
