#!/usr/bin/env python3
"""
Create Gemini batch request JSONL shards (<=200 MB each), upload shards to GCS,
submit Vertex batch jobs, and write a tracking file to fetch+merge later.

Example:
python bin/annotation_llm/run_batch_llm_annotation.py \
  --input-jsonl data/qa_items.jsonl \
  --out-dir work/batches_local \
  --batch-tracking-file work/batch_tracking.json \
  --project-id YOUR_GCP_PROJECT \
  --location us-central1 \
  --gcs-bucket your-bucket \
  --gcs-prefix llm-judge-$(date +%Y%m%d-%H%M%S) \
  --google-application-credentials ./google_api.json \
  --model gemini-2.5-pro \
  --temperature 0.0
"""
import argparse
import base64
import json
import logging
import mimetypes
import os
from datetime import datetime, timezone
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple

# Configure logging
logging.basicConfig(
    level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
)

# -------------------------
# Prompts (unchanged)
# -------------------------
SYSTEM_PROMPT = r"""
You are an expert annotation judge for an image-based MCQA dataset. You will evaluate four questions per image:
- 1 open-ended question (with its answer + rationale),
- 1 multiple-choice question (with options, the selected answer + rationale),
- 2 true/false statements (with the selected label + rationale).

Your ONLY sources of truth are: (a) the image and (b) the image description provided. Avoid any speculation beyond what is inferable from these. If a question cannot be answered without outside facts, mark external_knowledge="requires".

Follow these rubrics strictly:

1) Location relevancy (image-level):
   - Decide if the image/questions are related to the specified [LOCATION].
   - Labels: "relevant", "not_relevant", "not_sure".
   - Use "not_sure" sparingly (only when genuinely ambiguous).

2) Question quality (per question, 1–5):
   - Clarity, unambiguity, and relevance to the image.
   - 1 = Poor; 5 = Excellent.
   - If <4, set question_revision_reasons (one or more) from:
     ["unclear_or_ambiguous","not_relevant_to_image","hard_to_understand"].
   - If ≥4, set question_revision_reasons = [].

3) Answer quality (per question):
   - OPEN/MCQ: 1–5 (correctness, completeness, image support).
   - TRUE/FALSE (selected label correctness): 1–3 (correctness, image support).
   - If <4 for OPEN/MCQ or <2 for TRUE/FALSE, set answer_revision_reasons (one or more):
     ["incorrect_or_unsupported","incomplete_or_missing_info","speculative_or_assumptive","options_overlap","irrelevant_or_implausible_options","vague_or_confusing"].
     Notes:
       • "options_overlap","irrelevant_or_implausible_options","vague_or_confusing" are MCQ-specific (use only when applicable).
       • For T/F, do NOT use the MCQ-only reasons.
   - If the threshold is met, set answer_revision_reasons = [].

4) Rationale quality (per question; evaluate the provided rationale text, not your own):
   - rational_clarity_info: 1–5 (clarity & informativeness).
   - rational_plausibility_faithfulness: 1–5 (plausible, faithful, grounded in the image).
   - If either dimension <4, that does NOT automatically imply answer_revision_reasons or question_revision_reasons; score independently.

5) External knowledge flag (per question):
   - external_knowledge: "requires" or "does_not_require".
   - "requires" if the necessary information is not visible/inferable from the image/description.

Important constraints:
- Be strict. Prefer lower scores if evidence is weak or absent.
- Never add facts not grounded in the image/description.
- For T/F, "question_quality" still uses 1–5 but "answer_quality" uses 1–3.
- Return EXACTLY four items in "questions", preserving the original index order of [QUESTIONS_JSON].
- Do NOT reorder, merge, or drop items. The two "true_false" items must remain two separate entries.
- For every i, set output questions[i].type to EXACTLY the input questions[i].type.
- If the input questions include an "id" field, echo it back in the corresponding output question as "id".
- Use exactly the field names and JSON shape below.
- Respond with VALID JSON ONLY (no explanations, no markdown).

EXPECTED OUTPUT (strict JSON schema):

{
  "location_relevancy": "relevant | not_relevant | not_sure",
  "questions": [
    {
      "type": "open_ended | multiple_choice | true_false",
      "question_text": "string",
      "answer_text": "string or null",
      "options": ["only for MCQ, else []"],
      "selected_answer": "for MCQ/T-F; else null",
      "rationale_text": "string or null",

      "annotation": {
        "question_quality": 1-5,
        "answer_quality": (1-5 for open/mcq; 1-3 for true_false),
        "rational_clarity_info": 1-5,
        "rational_plausibility_faithfulness": 1-5,
        "external_knowledge": "requires | does_not_require",
        "question_revision_reasons": ["unclear_or_ambiguous","not_relevant_to_image","hard_to_understand"] or [],
        "answer_revision_reasons": ["incorrect_or_unsupported","incomplete_or_missing_info","speculative_or_assumptive","options_overlap","irrelevant_or_implausible_options","vague_or_confusing"] or []
      }
    },
    "... three more question objects ..."
  ]
}
"""

USER_PROMPT_TEMPLATE = r"""
You are given an image and its textual description. Evaluate FOUR questions about this image according to the scoring policy.

[LOCATION]: {location}

[IMAGE_URL]: Attached image (base64-encoded).

[IMAGE_DESCRIPTION]:
{image_description}

[QUESTIONS_JSON]:
{questions_json}

Return a "questions" array of exactly four items that aligns INDEX-WISE with [QUESTIONS_JSON].
Do not reorder or collapse items. The two true_false questions remain two separate entries.

Notes:
- Each question dict MUST include:
  - "type": "open_ended" | "multiple_choice" | "true_false"
  - "question_text": string
  - For open_ended:
      "answer_text": string
  - For multiple_choice:
      "options": [list of strings],
      "selected_answer": the chosen option string,
      "rationale_text": string (if provided, else empty string)
  - For true_false:
      "selected_answer": "True" or "False",
      "rationale_text": string (if provided, else empty string)

Return ONLY the JSON specified in the system message under EXPECTED OUTPUT.
"""


# -------------------------
# Helpers
# -------------------------
def _now_iso() -> str:
    return datetime.now(timezone.utc).isoformat()


def guess_mime(path_or_name: str, default: str = "image/jpeg") -> str:
    mt, _ = mimetypes.guess_type(path_or_name)
    if mt and mt.startswith("image/"):
        return mt
    return default


def upload_image_to_gcs(
    storage_client,
    bucket_name: str,
    prefix: str,
    *,
    local_path: Optional[str] = None,
    image_b64: Optional[str] = None,
    object_name_hint: str = "image",
) -> str:
    """
    Upload image to GCS and return gs:// URI.
    Supports either local file path or base64 string (no data: header).
    """
    from google.cloud import storage as gcs

    bucket = storage_client.bucket(bucket_name)
    # Choose extension
    ext = ".jpg"
    if local_path:
        ext = Path(local_path).suffix or ".jpg"
    blob_name = f"{prefix}/images/{object_name_hint}{ext}"
    blob = bucket.blob(blob_name)

    if local_path:
        blob.upload_from_filename(local_path)
        mime = guess_mime(local_path)
    else:
        # image_b64 provided
        raw = image_b64
        if raw.startswith("data:"):
            # data URL -> split off header
            header, _, b64 = raw.partition(",")
            raw = b64 or raw
            # try to pull mime from header
            mime = header.split(";")[0].replace("data:", "") or "image/jpeg"
        else:
            mime = "image/jpeg"
        blob.upload_from_string(base64.b64decode(raw))
        blob.content_type = mime
        blob.patch()

    return f"gs://{bucket_name}/{blob_name}"


def image_data_url_from_item(item: Dict[str, Any], mime: str) -> str:
    """Return a data: URL using image_base64 or image_path (fallback to image_url if already a data URL)."""
    path = item.get("image_path")
    if path:
        with open(path, "rb") as f:
            enc = base64.b64encode(f.read()).decode("utf-8")
        return f"data:{mime};base64,{enc}"
    else:
        logging.warning("No image_base64 or image_path found in item; for path {path}.")
    return item.get("image_url", "")


def build_user_prompt(item: Dict[str, Any]) -> str:
    location = item.get("country", "")
    image_desc = item.get("image_desc_meta", "").get("en_description", "")
    questions_obj = item.get("QA_meta")

    # Remove specified keys from each question in questions_obj
    for qtype in ["open_ended", "multiple_choice", "true_false"]:
        for question in questions_obj.get(qtype, []):
            for key in [
                "semantic_focus",
                "cognitive_focus",
                "msa_question",
                "msa_answer",
                "msa_rationale",
            ]:
                question.pop(key, None)

    return USER_PROMPT_TEMPLATE.format(
        location=location,
        image_description=image_desc,
        questions_json=json.dumps(questions_obj, ensure_ascii=False),
    )


def default_safety_settings() -> List[Dict[str, str]]:
    # Optional – can be disabled with CLI flag
    return [
        {"category": "HARM_CATEGORY_HATE_SPEECH", "threshold": "BLOCK_ONLY_HIGH"},
        {"category": "HARM_CATEGORY_DANGEROUS_CONTENT", "threshold": "BLOCK_ONLY_HIGH"},
        {"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", "threshold": "BLOCK_ONLY_HIGH"},
        {"category": "HARM_CATEGORY_HARASSMENT", "threshold": "BLOCK_ONLY_HIGH"},
    ]


def get_mime_type(image_path: str) -> str:
    """Get MIME type for an image file."""
    if image_path.lower().endswith((".jpg", ".jpeg")):
        return "image/jpeg"
    elif image_path.lower().endswith(".png"):
        return "image/png"
    elif image_path.lower().endswith(".webp"):
        return "image/webp"
    else:
        return "image/jpeg"  # Default


def inline_part_from_item(b64):
    if b64.startswith("data:"):
        header, _, payload = b64.partition(",")
        if "base64" in header:
            # mime from header if present
            try:
                mime = header.split(";")[0].split(":", 1)[1] or mime
            except Exception:
                pass
            b64 = payload  # strip the prefix, keep only base64 bytes

    return b64, mime


def make_request_object(
    user_prompt_text: str,
    system_prompt: str,
    image_data: base64,
    mimetype: str,
    temperature: float,
    use_json: bool = True,
    safety_settings: Optional[List[Dict[str, str]]] = None,
    max_output_tokens: Optional[int] = None,
) -> Dict[str, Any]:
    """Vertex Batch request (camelCase), text-only; image is embedded as data: URL inside user_prompt_text."""
    b64, mimetype = inline_part_from_item(image_data)
    req: Dict[str, Any] = {
        "systemInstruction": {"parts": [{"text": system_prompt}]},
        "contents": [
            {
                "role": "user",
                "parts": [
                    {"text": user_prompt_text},
                    {
                        "inlineData": {
                            "mimeType": mimetype or "image/jpeg",
                            "data": b64,  # base64-encoded image data
                        }
                    },
                ],
            }
        ],
        "generationConfig": {"temperature": temperature},
    }
    if max_output_tokens is not None:
        # Vertex expects camelCase
        req["generationConfig"]["maxOutputTokens"] = int(max_output_tokens)
    if use_json:
        req["generationConfig"]["responseMimeType"] = "application/json"
    if safety_settings:
        req["safetySettings"] = safety_settings
    return req


def shard_writer_vertex_batch(
    out_dir: Path,
    max_bytes: int,
    items: List[Dict[str, Any]],
    storage_client,
    bucket_name: str,
    gcs_prefix: str,
    gcp_dir: str,
    temperature: float,
    include_safety: bool,
    max_output_tokens: Optional[int],
) -> Tuple[List[Path], List[Tuple[int, int]], List[str], List[str]]:
    """
    Create local JSONL shards (<=max_bytes), upload each to GCS input folder,
    and return (local_shard_paths, index_ranges, gcs_input_uris, gcs_output_prefixes).
    """
    from google.cloud import storage as gcs

    bucket = storage_client.bucket(bucket_name)

    out_dir.mkdir(parents=True, exist_ok=True)
    shard_paths: List[Path] = []
    shard_ranges: List[Tuple[int, int]] = []
    gcs_input_uris: List[str] = []
    gcs_output_prefixes: List[str] = []

    shard_idx = 0
    cur_bytes = 0
    cur_path: Optional[Path] = None
    cur_start = 0
    cur_count = 0

    def open_new_shard() -> None:
        nonlocal shard_idx, cur_bytes, cur_path, cur_start, cur_count
        if cur_path is not None:
            shard_ranges.append((cur_start, cur_start + cur_count - 1))
        cur_bytes = 0
        cur_count = 0
        cur_start = sum((r[1] - r[0] + 1) for r in shard_ranges)
        cur_path = out_dir / f"requests_shard_{shard_idx:05d}.jsonl"
        shard_paths.append(cur_path)
        shard_idx += 1
        with open(cur_path, "w", encoding="utf-8") as f:
            pass

    open_new_shard()

    safety = default_safety_settings() if include_safety else None

    # Pre-upload images to GCS and emit requests
    for i, item in enumerate(items):
        user_prompt = build_user_prompt(item)
        image_path = item.get("image_path")
        custom_id = item.get("image_id")
        mime_type = get_mime_type(image_path) if image_path else "image/jpeg"
        image_data = image_data_url_from_item(item, mime_type)

        request = make_request_object(
            user_prompt_text=user_prompt,
            system_prompt=SYSTEM_PROMPT,
            image_data=image_data,
            mimetype=mime_type,
            temperature=temperature,
            use_json=True,
            safety_settings=safety,
            max_output_tokens=max_output_tokens,
        )
        # Official format: one object per line with a "request" field
        line_obj = {"key": custom_id, "request": request}
        line = json.dumps(line_obj, ensure_ascii=False)
        line_len = len(line.encode("utf-8")) + 1  # newline

        if cur_bytes > 0 and cur_bytes + line_len > max_bytes:
            open_new_shard()

        with open(cur_path, "a", encoding="utf-8") as f:
            f.write(line + "\n")
        cur_bytes += line_len
        cur_count += 1

    # finalize last shard's range
    if cur_path is not None:
        shard_ranges.append((cur_start, cur_start + cur_count - 1))

    # Upload shards to GCS input and prepare output prefixes
    for idx, shard_path in enumerate(shard_paths):
        input_blob_name = f"{gcs_prefix}/{gcp_dir}/input/{shard_path.name}"
        output_prefix = (
            f"gs://{bucket_name}/{gcs_prefix}/{gcp_dir}/output/{shard_path.stem}/"
        )
        blob = bucket.blob(input_blob_name)
        blob.upload_from_filename(str(shard_path))
        gcs_input_uris.append(f"gs://{bucket_name}/{input_blob_name}")
        gcs_output_prefixes.append(output_prefix)

    return shard_paths, shard_ranges, gcs_input_uris, gcs_output_prefixes


# -------------------------
# CLI
# -------------------------
def parse_args() -> argparse.Namespace:
    p = argparse.ArgumentParser(
        description="Create & submit Gemini Vertex batch jobs for LLM-as-a-judge annotations."
    )
    p.add_argument(
        "--input-jsonl",
        required=True,
        help="Input JSONL with items (image+description+questions...).",
    )
    p.add_argument(
        "--out-dir",
        required=True,
        help="Local dir to write JSONL shards before uploading to GCS.",
    )
    p.add_argument(
        "--batch-tracking-file",
        required=True,
        help="Where to write tracking info JSON.",
    )

    p.add_argument(
        "--project-id", required=True, help="GCP Project ID (Vertex AI backend)."
    )
    p.add_argument(
        "--location", default="us-central1", help="GCP region (e.g., us-central1)."
    )
    p.add_argument(
        "--google-application-credentials",
        default=None,
        help="Path to service account JSON.",
    )

    p.add_argument(
        "--gcs-bucket",
        required=True,
        help="Existing GCS bucket to store inputs/outputs.",
    )
    p.add_argument(
        "--gcs-prefix",
        default=None,
        help="Path prefix within bucket; default llm-judge-<timestamp>.",
    )

    p.add_argument(
        "--model", default="gemini-2.5-pro", help="Gemini model (e.g., gemini-2.5-pro)."
    )
    p.add_argument("--temperature", type=float, default=0.0)
    p.add_argument(
        "--max-output-tokens",
        type=int,
        default=1024,
        help="If set, uses this value for generationConfig.maxOutputTokens.",
    )
    p.add_argument(
        "--max-shard-bytes",
        type=int,
        default=200 * 1024 * 1024,
        help="Max size per JSONL shard (bytes).",
    )

    p.add_argument(
        "--include-safety-settings",
        action="store_true",
        help="Include default safetySettings in each request.",
    )
    return p.parse_args()


def main() -> None:
    args = parse_args()

    # Auth (ADC). Respect explicit service account path.
    if args.google_application_credentials:
        os.environ[
            "GOOGLE_APPLICATION_CREDENTIALS"
        ] = args.google_application_credentials

    # Load items
    items: List[Dict[str, Any]] = []
    with open(args.input_jsonl, "r", encoding="utf-8") as f:
        for line in f:
            if line.strip():
                items.append(json.loads(line))

    # Clients
    from google import genai
    from google.cloud import storage as gcs
    from google.genai.types import CreateBatchJobConfig

    storage_client = gcs.Client(project=args.project_id)
    client = genai.Client(
        vertexai=True, project=args.project_id, location=args.location
    )

    # GCS prefix
    gcs_prefix = (
        args.gcs_prefix or f"llm-annotation-{datetime.now().strftime('%Y%m%d-%H%M%S')}"
    )

    # Create shards locally and upload them to GCS input/
    dir_name = os.path.dirname(os.path.abspath(args.input_jsonl))
    timestamp = datetime.now().strftime("%Y%m%d-%H%M%S")
    gcp_dir = os.path.join(dir_name, f"{timestamp}")

    out_dir = Path(args.out_dir) / timestamp
    out_dir.mkdir(parents=True, exist_ok=True)
    shard_paths, shard_ranges, input_uris, output_prefixes = shard_writer_vertex_batch(
        out_dir=out_dir,
        max_bytes=args.max_shard_bytes,
        items=items,
        storage_client=storage_client,
        bucket_name=args.gcs_bucket,
        gcs_prefix=gcs_prefix,
        gcp_dir=gcp_dir,
        temperature=args.temperature,
        include_safety=args.include_safety_settings,
        max_output_tokens=args.max_output_tokens,
    )

    # Submit a batch job per shard (official pattern)
    batches_meta: List[Dict[str, Any]] = []
    for shard_idx, (input_uri, out_prefix, (start, end)) in enumerate(
        zip(input_uris, output_prefixes, shard_ranges)
    ):
        batch_job = client.batches.create(
            model=args.model,
            src=input_uri,
            config=CreateBatchJobConfig(dest=out_prefix),
        )
        batches_meta.append(
            {
                "shard_index": shard_idx,
                "input_shard_path": str(shard_paths[shard_idx]),
                "input_gcs_uri": input_uri,
                "output_gcs_prefix": out_prefix,
                "batch_job_name": batch_job.name,  # e.g., batches/123456789
                "start_index": start,
                "end_index": end,
                "num_requests": end - start + 1,
            }
        )

    # Tracking JSON
    tracking = {
        "created_at": _now_iso(),
        "project_id": args.project_id,
        "location": args.location,
        "model": args.model,
        "gcs_bucket": args.gcs_bucket,
        "gcs_prefix": gcs_prefix,
        "input_jsonl": os.path.abspath(args.input_jsonl),
        "batches": batches_meta,
    }
    Path(args.batch_tracking_file).parent.mkdir(parents=True, exist_ok=True)
    with open(args.batch_tracking_file, "w", encoding="utf-8") as f:
        json.dump(tracking, f, ensure_ascii=False, indent=2)

    print(
        json.dumps(
            {
                "status": "ok",
                "batches_created": len(batches_meta),
                "batch_tracking_file": args.batch_tracking_file,
                "gcs_bucket": args.gcs_bucket,
                "gcs_prefix": gcs_prefix,
            },
            indent=2,
        )
    )


if __name__ == "__main__":
    main()
