#!/usr/bin/env python3
"""
Incremental inference utility for Gemini Flash model on image captioning + tagging **via Vertex AI**.
"""

import argparse
import json
import os
from pathlib import Path
from typing import Any, Dict, List

from datasets import load_dataset
from tqdm import tqdm
from google import genai

from utils import load_jsonl_mapping, append_jsonl, run_annotation

# ────────────────────────────────────────────────────────────────────────────────
# Argument parsing
# ────────────────────────────────────────────────────────────────────────────────

def parse_args():
    parser = argparse.ArgumentParser(
        description="Run multi-modal inference with Gemini Flash on a HF dataset (Vertex AI version)",
        formatter_class=argparse.ArgumentDefaultsHelpFormatter,
    )

    parser.add_argument(
        "dataset",
        help="HuggingFace dataset path or identifier, e.g. 'imagenet-1k' or './my_dataset'",
    )
    parser.add_argument(
        "--split", default="train", help="Dataset split to process (train/val/test)"
    )
    parser.add_argument(
        "--output",
        default="./results",
        help="Destination directory for the model outputs (one JSONL file per model).",
    )
    parser.add_argument(
        "--model",
        default="gemini-2.5-flash-lite-preview-06-17",
        help="Gemini model name deployed on Vertex AI. For example 'gemini-2.5-flash'.",
    )
    parser.add_argument(
        "--temperature", type=float, default=0.4, help="Sampling temperature for the model."
    )
    parser.add_argument(
        "--max-output-tokens",
        type=int,
        default=2048,
        help="Maximum tokens to generate per request.",
    )
    return parser.parse_args()

# ────────────────────────────────────────────────────────────────────────────────
# Main entrypoint
# ────────────────────────────────────────────────────────────────────────────────

def main():
    args = parse_args()

    # Resolve output path early so we can determine already processed IDs.
    output_path = (
        Path(args.output)
        / Path(args.dataset).name
        / f"annotations_{args.model}.jsonl"
    )
    output_path.parent.mkdir(parents=True, exist_ok=True)

    # Detect previously processed samples for resume capability.
    processed_ids: set[int] = set()
    if output_path.exists():
        mapping = load_jsonl_mapping(output_path, key="id", value="label")
        processed_ids = {int(k) for k in mapping.keys()}
        if processed_ids:
            print(f"🔄 Resuming run – {len(processed_ids)} samples already processed.")

    # Init Gemini client.
    print("Initializing Vertex AI client …")
    client = genai.Client()

    # Load dataset – works for both hosted and local datasets.
    print(f"Loading dataset '{args.dataset}' ({args.split}) …")
    ds = load_dataset(args.dataset, split=args.split, streaming=True)

    task_info = load_jsonl_mapping("task_metadata.jsonl").get(args.dataset) if Path("task_metadata.jsonl").exists() else None
    if task_info:
        print(f"ℹ️ Found task scope: {task_info}")
    else:
        print("ℹ️ No task info found for this dataset, using default prompt.")

    # Process samples – incremental save + resume.
    with tqdm(total=None, desc="Processing samples", unit="sample") as pbar:
        for idx, sample in enumerate(ds):
            if idx in processed_ids:
                pbar.update()
                continue  # Skip already processed items

            output = run_annotation(client, sample, task_info=task_info, model=args.model)
            output.update({"id": idx, "label": sample.get("label")})

            # Append immediately so we can recover if interrupted.
            append_jsonl(output_path, output)

            pbar.update()

    print(f"✅ Finished. All predictions are stored at: {output_path}\n")


if __name__ == "__main__":
    main()
