#!/usr/bin/env python3
"""
Fetch Gemini batch job results specified in the tracking file, parse JSONL responses,
and merge the LLM annotations back into the original input JSONL.

Example:
python bin/annotation_llm/fetch_and_merge_batch_results.py \\
  --batch-tracking-file work/batch_tracking.json \\
  --input-jsonl data/qa_items.jsonl \\
  --merged-output-jsonl work/qa_items_with_llm_annotations.jsonl \\
  --project-id your-gcp-project \\
  --location us-central1 \\
  --google-application-credentials ./google_api.json
"""
import argparse
import json
import logging
import os
import re
import sys
import time
from pathlib import Path
from typing import Any, Dict, List, Optional

from google.cloud import storage as gcs

logging.basicConfig(level=logging.INFO)

CODE_FENCE_RE = re.compile(r"^```(?:json)?\s*|\s*```$", re.IGNORECASE | re.MULTILINE)


def parse_args() -> argparse.Namespace:
    p = argparse.ArgumentParser(
        description="Fetch batch results and merge with input JSONL."
    )
    p.add_argument("--batch-tracking-file", required=True)
    p.add_argument(
        "--input-jsonl",
        required=True,
        help="Original input JSONL used to create batches.",
    )
    p.add_argument(
        "--merged-output-jsonl",
        required=True,
        help="Output JSONL with annotations merged.",
    )
    p.add_argument("--project-id", required=True)
    p.add_argument("--location", default="us-central1")
    p.add_argument("--google-application-credentials", default=None)
    p.add_argument(
        "--poll-seconds",
        type=int,
        default=20,
        help="Polling interval while waiting for jobs.",
    )
    p.add_argument(
        "--max-wait-seconds",
        type=int,
        default=24 * 3600,
        help="Maximum total wait per job.",
    )
    p.add_argument(
        "--save-output-dir",
        required=True,
        help="Local directory to save raw JSONL files downloaded from GCS (e.g., your --out-dir).",
    )
    p.add_argument(
        "--latest-file-only",
        action="store_true",
        help="If set, downloads only the newest .jsonl file (instead of all .jsonl in the newest subfolder).",
    )
    return p.parse_args()


def clean_json_text(txt: str) -> str:
    # Remove common formatting like fenced code blocks
    return CODE_FENCE_RE.sub("", txt).strip()


def extract_response_json_from_generate_content_response(
    obj: Dict[str, Any]
) -> Optional[Dict[str, Any]]:
    """
    Batch response line is a GenerateContentResponse.
    We take the first candidate's text and parse JSON from it.
    """
    try:
        cands = obj["response"].get("candidates") or []
        if not cands:
            return None
        parts = cands[0].get("content", {}).get("parts", [])
        # logging.info(f"Response JSON {parts}")
        for p in parts:
            if "text" in p and isinstance(p["text"], str):
                text = clean_json_text(p["text"])
                return json.loads(text)

    except Exception:
        logging.exception("Failed to parse response JSON {parts}")
        return None
    return None


def main() -> None:
    args = parse_args()

    if args.google_application_credentials:
        os.environ[
            "GOOGLE_APPLICATION_CREDENTIALS"
        ] = args.google_application_credentials

    # Load tracking
    with open(args.batch_tracking_file, "r", encoding="utf-8") as f:
        tracking = json.load(f)

    # Load original input items
    items: List[Dict[str, Any]] = []
    with open(args.input_jsonl, "r", encoding="utf-8") as f:
        for line in f:
            if line.strip():
                items.append(json.loads(line))

    # Auth + client
    from google import genai

    client = genai.Client(
        vertexai=True, project=args.project_id, location=args.location
    )
    storage_client = gcs.Client(project=args.project_id)

    LATEST_FILE_ONLY = args.latest_file_only
    # For each batch job, wait for completion, download results, and merge
    for meta in tracking["batches"]:
        name = meta["batch_job_name"]
        start_idx, end_idx = meta["start_index"], meta["end_index"]

        # Wait for job to complete (bounded)
        waited = 0
        while True:
            job = client.batches.get(name=name)
            state = getattr(job, "state", None)
            state_name = getattr(state, "name", str(state))
            logging.info(
                f"Waiting for {name}, state={state_name}, waited={waited} seconds"
            )
            if state_name in (
                "JOB_STATE_SUCCEEDED",
                "JOB_STATE_FAILED",
                "JOB_STATE_CANCELLED",
                "JOB_STATE_EXPIRED",
            ):
                break
            if waited >= args.max_wait_seconds:
                print(
                    f"[WARN] Timed out waiting for {name}, state={state_name}",
                    file=sys.stderr,
                )
                break
            time.sleep(args.poll_seconds)
            waited += args.poll_seconds

        if state_name != "JOB_STATE_SUCCEEDED":
            # Annotate items with error
            for i in range(start_idx, end_idx + 1):
                items[i]["llm_annotation_error"] = {
                    "batch_job_name": name,
                    "state": state_name,
                }
            continue

        # Download results file (file_name) or inline responses
        # logging.info(f"Fetching results for {job}\n state={state_name}")
        # dest = getattr(job, "dest", None)
        dest = (
            getattr(job, "dest", None)
            or getattr(job, "output", None)
            or getattr(job, "config", None)
        )
        out_prefix = None

        if dest and getattr(dest, "gcs_uri", None):
            out_prefix = dest.gcs_uri
        if not out_prefix:
            out_prefix = meta.get(
                "output_gcs_prefix"
            )  # recorded when submitting the batch

        if not out_prefix or not out_prefix.startswith("gs://"):
            # mark this shard as failed and continue
            for i in range(start_idx, end_idx + 1):
                items[i]["llm_annotation_error"] = {
                    "batch_job_name": name,
                    "error": f"missing_or_invalid_output_prefix: {out_prefix!r}",
                }
            continue

        _, _, rest = out_prefix.partition("gs://")
        bucket_name, _, prefix = rest.partition("/")
        bucket = storage_client.bucket(bucket_name)

        # Vertex often nests results under a timestamped subfolder; just read every .jsonl under the prefix
        all_blobs = list(storage_client.list_blobs(bucket, prefix=prefix))

        # List everything under the prefix (files may be nested under a timestamped dir)
        all_blobs = list(storage_client.list_blobs(bucket, prefix=prefix))

        # Derive top-level subdirectories directly under 'prefix' (e.g., 2025-08-22-09-38-59/)
        def first_subdir(blob_name: str, base_prefix: str) -> str | None:
            suffix = blob_name[len(base_prefix) :]
            if "/" in suffix:
                return suffix.split("/", 1)[0]
            return None

        subdirs = {first_subdir(b.name, prefix) for b in all_blobs}
        subdirs = {s for s in subdirs if s}

        # Pick the newest subfolder (lexicographically last is typically the latest timestamp)
        if subdirs:
            newest_subdir = sorted(subdirs)[-1]
            newest_prefix = f"{prefix}{newest_subdir}/"
            blobs = list(storage_client.list_blobs(bucket, prefix=newest_prefix))
            local_shard_dir = Path(args.save_output_dir) / newest_subdir
        else:
            # No subfolder created; write into a folder named after the shard itself
            blobs = all_blobs
            shard_name = out_prefix.rstrip("/").split("/")[
                -1
            ]  # e.g., requests_shard_00000
            local_shard_dir = Path(args.save_output_dir) / shard_name

        local_shard_dir.mkdir(parents=True, exist_ok=True)

        # Keep only JSONL files
        jsonl_blobs = [b for b in blobs if b.name.endswith(".jsonl")]
        if not jsonl_blobs:
            # Fallback: if extension missing, read all files
            jsonl_blobs = blobs

        # Single latest file

        jsonl_blobs.sort(
            key=lambda b: (
                getattr(b, "updated", None) or getattr(b, "time_created", None),
                b.name,
            )
        )
        if LATEST_FILE_ONLY and jsonl_blobs:
            jsonl_blobs = [jsonl_blobs[-1]]

        # Download files to local disk, then read lines
        lines = []
        for b in jsonl_blobs:
            try:
                # Preserve subpath under newest_prefix for clarity
                rel_name = (
                    b.name.split("/")[-1]
                    if subdirs
                    else b.name[len(prefix) :].lstrip("/")
                )
                local_path = local_shard_dir / rel_name
                b.download_to_filename(str(local_path))  # save to disk
                # Now read from disk into your merge buffer
                with open(local_path, "r", encoding="utf-8") as f:
                    for ln in f:
                        if ln.strip():
                            lines.append(ln)
            except Exception as e:
                print(f"[WARN] Failed to download {b.name}: {e}", file=sys.stderr)

        # Map by order; parse and merge per item
        cursor = start_idx
        for line in lines:
            if cursor > end_idx:
                break
            if not line.strip():
                items[cursor]["llm_annotation_error"] = {
                    "batch_job_name": name,
                    "error": "empty_line",
                }
                cursor += 1
                continue

            try:
                resp_obj = json.loads(line)
            except Exception as e:
                items[cursor]["llm_annotation_error"] = {
                    "batch_job_name": name,
                    "error": f"json_decode_failed: {e}",
                }
                cursor += 1
                continue

            if "error" in resp_obj and not resp_obj.get("candidates"):
                items[cursor]["llm_annotation_error"] = {
                    "batch_job_name": name,
                    "error": resp_obj["error"],
                }
                cursor += 1
                continue

            parsed = extract_response_json_from_generate_content_response(resp_obj)
            if parsed is None:
                items[cursor]["llm_annotation_error"] = {
                    "batch_job_name": name,
                    "error": "no_text_json_found",
                }
                cursor += 1
                continue

            try:
                if "location_relevancy" in parsed:
                    items[cursor]["location_relevancy"] = parsed["location_relevancy"]
                questions = parsed.get("questions", {})
                true_false_index = 0
                for question in questions:
                    if question["type"] == "open_ended":
                        items[cursor]["QA_meta"]["open-ended"][0][
                            "annotation"
                        ] = question.get("annotation", None)
                    elif question["type"] == "multiple_choice":
                        items[cursor]["QA_meta"]["multiple-choice"][0][
                            "annotation"
                        ] = question.get("annotation", None)
                    elif question["type"] == "true_false":
                        items[cursor]["QA_meta"]["true_false"][true_false_index][
                            "annotation"
                        ] = question.get("annotation", None)
                        true_false_index += 1
                # print(f"Merged annotations for item {items[cursor]} from batch {name}")
            except Exception as e:
                items[cursor]["llm_annotation_error"] = {
                    "batch_job_name": name,
                    "error": f"merge_failed: {e}",
                }

            cursor += 1

    # Write merged JSONL
    Path(args.merged_output_jsonl).parent.mkdir(parents=True, exist_ok=True)
    with open(args.merged_output_jsonl, "w", encoding="utf-8") as f:
        for obj in items:
            f.write(json.dumps(obj, ensure_ascii=False) + "\n")

    print(
        json.dumps(
            {"status": "ok", "merged_output_jsonl": args.merged_output_jsonl}, indent=2
        )
    )


if __name__ == "__main__":
    main()
