#!/usr/bin/env python3
"""
Inference utility for Gemini Flash model on image captioning + tagging **via Vertex AI**.

Changes vs. the original version
────────────────────────────────
1. New CLI flag `--batch-size` (default 8) – limits the number of concurrent
   API calls.
2. The request/response path is now fully asynchronous.  We use
   `model.generate_content_async(…)` and an `asyncio.Semaphore` so that at
   most `batch_size` requests are in-flight at any time.
3. `main()` is now an `async def` executed via `asyncio.run(main())`.

Dependencies remain the same; just make sure you have a recent enough
`google-cloud-aiplatform / google-generativeai` that exposes
`generate_content_async`.
"""

import os
import argparse
import asyncio
import json
from pathlib import Path
from typing import Any, Dict, List

from datasets import load_dataset
from PIL import Image           # noqa: F401  (kept for completeness)
from tqdm.asyncio import tqdm   # async-aware progress bar
from pydantic import BaseModel, Field

# ───────────────────────────────────────────────────────────────────────────────
# Vertex AI / Gemini
# ───────────────────────────────────────────────────────────────────────────────
import google.generativeai as genai
genai.configure(api_key=os.getenv("GOOGLE_API_KEY"))

from utils import save_jsonl, load_jsonl_mapping, strip_unwanted_keys

DEFAULT_PROMPT = Path("prompts/default_prompt.txt").read_text()
TASK_AWARE_PROMPT = Path("prompts/task_aware_prompt.txt").read_text()


class CaptioningOutput(BaseModel):
    caption: str = Field(..., description="Descriptive caption produced by Gemini")
    class_tags: List[str] = Field(
        ..., description="Tags helpful to identify the class"
    )
    other_tags: List[str] = Field(
        ..., description="Tags that set it apart from others in the same class"
    )
    is_image_class_explicit: bool = Field(
        ..., description="Whether the image class is understandable by the model"
    )


# ───────────────────────────────────────────────────────────────────────────────
# Args
# ───────────────────────────────────────────────────────────────────────────────
def parse_args() -> argparse.Namespace:
    parser = argparse.ArgumentParser(
        description="Run multi-modal inference with Gemini Flash on a HF dataset (Vertex AI version)",
        formatter_class=argparse.ArgumentDefaultsHelpFormatter,
    )

    parser.add_argument("dataset", help="HF path or identifier, e.g. 'imagenet-1k' or './my_dataset'")
    parser.add_argument("--split", default="train", help="Dataset split to process (train/val/test)")
    parser.add_argument(
        "--output",
        default="./results",
        help="Destination directory; a JSONL file is created inside.",
    )
    parser.add_argument(
        "--model",
        default="gemini-2.5-flash-lite-preview-06-17",
        help="Gemini model name deployed on Vertex AI.",
    )
    parser.add_argument("--temperature", type=float, default=0.4, help="Sampling temperature")
    parser.add_argument("--max-output-tokens", type=int, default=2048, help="Max tokens to generate")
    parser.add_argument("--batch-size", type=int, default=8, help="Max number of concurrent requests")
    return parser.parse_args()


# ───────────────────────────────────────────────────────────────────────────────
# Async helpers
# ───────────────────────────────────────────────────────────────────────────────
async def run_inference_async(
    model: genai.GenerativeModel,
    semaphore: asyncio.Semaphore,
    sample: Dict[str, Any],
    task_info: str | None,
) -> Dict[str, Any]:
    """Single async request guarded by the semaphore."""
    async with semaphore:
        if task_info:
            prompt = TASK_AWARE_PROMPT.format(task_info=task_info, label=sample["label"])
        else:
            prompt = DEFAULT_PROMPT.format(label=sample["label"])

        # Generate
        response = await model.generate_content_async(
            [prompt, sample["image"]],
            generation_config={
                "response_mime_type": "application/json",
                "response_schema": strip_unwanted_keys(CaptioningOutput.model_json_schema()),
            },
        )

        # Vertex AI returns raw text
        try:
            text: str = response.text.strip()
            if text.startswith("```"):
                text = "\n".join(line for line in text.splitlines() if not line.startswith("```"))
        except AttributeError:
            text = "None"

        try:
            output: Dict[str, Any] = json.loads(text)
        except json.JSONDecodeError as err:
            output = {"_error": f"JSON parse failed: {err}", "_raw": text}
        return output


# ───────────────────────────────────────────────────────────────────────────────
# Main
# ───────────────────────────────────────────────────────────────────────────────
async def main() -> None:
    args = parse_args()

    # Configure the model
    model = genai.GenerativeModel(args.model)

    # Load dataset (streaming keeps RAM low)
    print(f"Loading dataset '{args.dataset}' ({args.split}) …")
    ds = load_dataset(args.dataset, split=args.split, streaming=True).take(16)  # TODO: REMOVE AFTER TESTING

    task_info = load_jsonl_mapping("task_metadata.jsonl").get(args.dataset)
    if task_info:
        print(f"ℹ️ Found task scope: {task_info}")
    else:
        print("ℹ️ No task info found, using default prompt.")

    semaphore = asyncio.Semaphore(args.batch_size)
    tasks: List[asyncio.Task] = []

    print("Scheduling requests …")
    for idx, sample in enumerate(ds):
        coro = run_inference_async(
            model=model,
            semaphore=semaphore,
            sample=sample,
            task_info=task_info,
        )
        task = asyncio.create_task(coro)
        task.add_done_callback(lambda t, i=idx, s=sample.get("label"): t.result().update({"id": i, "label": s}))
        tasks.append(task)

    results = []
    print("Waiting for responses …")
    for fut in tqdm(asyncio.as_completed(tasks), total=len(tasks), desc="Processing samples", unit="sample"):
        results.append(await fut)

    output_path = (
        Path(args.output) / Path(args.dataset).name / f"annotations_{args.model}.jsonl"
    )
    output_path.parent.mkdir(parents=True, exist_ok=True)
    save_jsonl(output_path, results)
    print(f"✓ Saved all predictions to {output_path}")


if __name__ == "__main__":
    asyncio.run(main())
