#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Fetch and merge results for Azure OpenAI Batch — TRANSLATION EDITION.

- Parses custom_id "{id}::path=...::lang=..."
- Writes translations back to sibling fields using target prefix:
    en_question -> ajp_question
    en_answer   -> ajp_answer
    en_rationale-> ajp_rationale
    image_desc_meta.en_description -> ajp_description
    image_desc_meta.en_reason      -> ajp_reason
  MCQ options: en_options -> {lang}_options (aligned by index)

- Additionally, for QA_meta.true_false items we DO NOT translate answers.
  Instead, we hard-code Arabic strings:
      True  -> "صحيح"
      False -> "غير صحيح"
"""

from __future__ import annotations

import argparse
import json
import logging
import os
import re
from typing import Any, Dict, List, Optional, Tuple

from dotenv import load_dotenv

from openai import AzureOpenAI, OpenAI


def configure_logging():
    logging.basicConfig(
        level=logging.INFO,
        format="%(asctime)s - %(levelname)s - %(message)s",
        handlers=[logging.StreamHandler()],
    )


def load_env_azure(env_path: str):
    load_dotenv(dotenv_path=env_path, override=True)
    api_base = os.environ["AZURE_API_URL"].rstrip("/")
    api_key = os.environ["AZURE_API_KEY"]
    api_version = os.environ["AZURE_API_VERSION"]
    return api_key, api_base, api_version


def load_env_openai(env_path):
    load_dotenv(dotenv_path=env_path, override=True)
    openai_api_key = os.environ["OPENAI_API_KEY"]
    client = OpenAI(key=openai_api_key)
    return client


class BatchManager:
    def __init__(self, client, batch_file_name: str):
        self.client = client
        self.batch_file_name = batch_file_name

    def retrieve_all_submitted_batches(self, batch_output_dir: str):
        os.makedirs(batch_output_dir, exist_ok=True)
        outputs, errors = [], []
        if not os.path.exists(self.batch_file_name):
            logging.warning("No tracking file found; nothing to retrieve.")
            return outputs, errors
        with open(self.batch_file_name, "r", encoding="utf-8") as f:
            for line in f:
                if not line.strip():
                    continue
                try:
                    batch_id, _ = line.strip().split(",", 1)
                except ValueError:
                    continue
                try:
                    b = self.client.batches.retrieve(batch_id)
                except Exception as e:
                    logging.error(f"Failed to retrieve batch {batch_id}: {e}")
                    continue
                status = getattr(b, "status", None)
                logging.info(f"Batch {batch_id} status={status}")
                if status != "completed":
                    continue
                ofid = getattr(b, "output_file_id", None)
                if ofid:
                    try:
                        content = self.client.files.content(ofid).text
                        out_path = os.path.join(
                            batch_output_dir, f"batch_output_{batch_id}.jsonl"
                        )
                        with open(out_path, "w", encoding="utf-8") as w:
                            w.write(content)
                        outputs.append(out_path)
                        logging.info(f"Wrote output -> {out_path}")
                    except Exception as e:
                        logging.error(f"Failed to download output for {batch_id}: {e}")
                efid = getattr(b, "error_file_id", None)
                if efid:
                    try:
                        econtent = self.client.files.content(efid).text
                        err_path = os.path.join(
                            batch_output_dir, f"batch_output_{batch_id}_error.jsonl"
                        )
                        with open(err_path, "w", encoding="utf-8") as w:
                            w.write(econtent)
                        errors.append(err_path)
                        logging.info(f"Wrote error file -> {err_path}")
                    except Exception as e:
                        logging.error(
                            f"Failed to download error file for {batch_id}: {e}"
                        )
        return outputs, errors


def parse_response_file(jsonl_path: str) -> Dict[str, Dict]:
    results: Dict[str, Dict] = {}
    with open(jsonl_path, "r", encoding="utf-8") as f:
        for line in f:
            if not line.strip():
                continue
            try:
                rec = json.loads(line)
                cid = rec.get("custom_id")
                if not cid:
                    continue
                resp = rec.get("response", {})
                body = resp.get("body", {})
                model = body.get("model")
                content = None
                try:
                    content = body["choices"][0]["message"]["content"]
                except Exception:
                    pass
                if content is None:
                    continue
                results[cid] = {"response_raw": content, "model": model}
            except Exception as e:
                logging.error(f"Failed to parse line in {jsonl_path}: {e}")
    return results


def _strip_fences(text: str) -> str:
    text = (text or "").strip()
    text = re.sub(r"^```(?:json)?\\n|\\n```$", "", text)
    return text.strip()


def parse_as_string(raw: str) -> str:
    txt = _strip_fences(raw)
    try:
        obj = json.loads(txt)
        for v in obj.values():
            if isinstance(v, str):
                return v.strip()
    except Exception:
        pass
    return txt


_PATH_RE = re.compile(r"::path=(?P<path>[^:]+)")
_LANG_RE = re.compile(r"::lang=(?P<lang>[^:]+)")


def parse_custom_id(custom_id: str) -> Tuple[str, Optional[str], Optional[str]]:
    base = custom_id.split("::", 1)[0]
    m_path = _PATH_RE.search(custom_id)
    m_lang = _LANG_RE.search(custom_id)
    path = m_path.group("path") if m_path else None
    lang = m_lang.group("lang") if m_lang else None
    return base, path, lang


def _navigate(root: Dict[str, Any], tokens: List[str]) -> Any:
    cur: Any = root
    for t in tokens:
        if isinstance(cur, list):
            idx = int(t)
            cur = cur[idx]
        elif isinstance(cur, dict):
            cur = cur.get(t)
        else:
            return None
    return cur


def _navigate_parent(root: Dict[str, Any], tokens: List[str]) -> Tuple[Any, str]:
    if not tokens:
        return None, ""
    parent_tokens = tokens[:-1]
    last = tokens[-1]
    parent = _navigate(root, parent_tokens) if parent_tokens else root
    return parent, last


def write_translation(record: Dict[str, Any], path: str, lang_prefix: str, value: str):
    tokens = path.split(".")
    if len(tokens) >= 2 and tokens[-2] == "en_options":
        parent_tokens = tokens[:-2]
        parent = _navigate(record, parent_tokens)
        if not isinstance(parent, dict):
            logging.warning(f"write_translation: parent not found for path={path}")
            return
        tgt_key = f"{lang_prefix}_options"
        src_list = parent.get("en_options") or []
        tgt_list = parent.get(tgt_key)
        if not isinstance(tgt_list, list):
            tgt_list = [None] * len(src_list)
        idx = int(tokens[-1])
        if idx >= len(tgt_list):
            tgt_list.extend([None] * (idx - len(tgt_list) + 1))
        tgt_list[idx] = value
        parent[tgt_key] = tgt_list
        return

    parent, leaf = _navigate_parent(record, tokens)
    if not isinstance(parent, dict):
        logging.warning(f"write_translation: parent not found for path={path}")
        return
    if not leaf.startswith("en_"):
        logging.warning(
            f"write_translation: leaf does not start with 'en_': {leaf} (path={path})"
        )
        return
    tgt_key = leaf.replace("en_", f"{lang_prefix}_", 1)
    parent[tgt_key] = value


def set_true_false_fixed_answers(record: Dict[str, Any], lang_prefix: str):
    """
    For QA_meta.true_false items, set {lang_prefix}_answer to Arabic fixed strings:
      True  -> "صحيح"
      False -> "غير صحيح"
    """
    qa = record.get("QA_meta") or {}
    tf_list = qa.get("true_false") or []
    if not isinstance(tf_list, list):
        return
    for item in tf_list:
        if not isinstance(item, dict):
            continue
        en_ans = item.get("en_answer")
        if not isinstance(en_ans, str):
            continue
        s = en_ans.strip().lower()
        val = "صحيح" if s == "true" else ("غير صحيح" if s == "false" else None)
        if val is not None:
            item[f"{lang_prefix}_answer"] = val


def main():
    parser = argparse.ArgumentParser(
        description="Fetch batch results and merge translations into original JSONL."
    )
    parser.add_argument(
        "--batch_file",
        required=True,
        help="Tracking file with lines: <batch_id>,<local_batch_jsonl>",
    )
    parser.add_argument("--env_file", required=True, help="Path to .env")
    parser.add_argument(
        "--output_dir", required=True, help="Dir to store fetched outputs"
    )
    parser.add_argument(
        "--output_file",
        required=True,
        help="Merged results JSONL (only updated records)",
    )
    parser.add_argument(
        "--output_error_file", required=True, help="Original rows that errored (JSONL)"
    )
    parser.add_argument(
        "--retrieve",
        default="True",
        help="True/False to actively fetch from API; else process local",
    )
    parser.add_argument("--original_file", required=True, help="Original input JSONL")
    parser.add_argument("--id_field", default="image_id")
    parser.add_argument(
        "--lang_prefix",
        choices=["en", "msa", "ajp", "arz"],
        required=True,
        help="Target language prefix",
    )
    parser.add_argument(
        "--api",
        choices=["azure", "openai"],
        default="azure",
        help="API endpoint ",
    )
    args = parser.parse_args()
    configure_logging()

    retrieve_flag = str(args.retrieve).lower() == "true"
    os.makedirs(args.output_dir, exist_ok=True)

    file_list, err_file_list = [], []
    if retrieve_flag:
        if args.api == "azure":
            api_key, api_base, api_version = load_env_azure(args.env_file)
            client = AzureOpenAI(
                azure_endpoint=api_base, api_key=api_key, api_version=api_version
            )
        else:
            client = load_env_openai(args.env_file)
        manager = BatchManager(
            client,
            batch_file_name=args.batch_file,
        )
        file_list, err_file_list = manager.retrieve_all_submitted_batches(
            batch_output_dir=args.output_dir
        )
        logging.info(
            f"Retrieved {len(file_list)} result files, {len(err_file_list)} error files."
        )
    else:
        with open(args.batch_file, "r", encoding="utf-8") as f:
            for line in f:
                if not line.strip():
                    continue
                batch_id, _ = line.strip().split(",", 1)
                fpath = os.path.join(args.output_dir, f"batch_output_{batch_id}.jsonl")
                if os.path.exists(fpath):
                    file_list.append(fpath)
                err_path = os.path.join(
                    args.output_dir, f"batch_output_{batch_id}_error.jsonl"
                )
                if os.path.exists(err_path):
                    err_file_list.append(err_path)

    base_index: Dict[str, Dict[str, Any]] = {}
    with open(args.original_file, "r", encoding="utf-8") as f:
        for line in f:
            if not line.strip():
                continue
            try:
                ex = json.loads(line)
                base_id = ex.get(args.id_field)
                if base_id is not None:
                    base_index[str(base_id)] = ex
            except Exception as e:
                logging.error(f"Error reading original line: {e}")

    updated_ids = set()
    for p in file_list:
        if not os.path.exists(p):
            continue
        collected = parse_response_file(p)
        for cid, payload in collected.items():
            base_id, path, _cid_lang = parse_custom_id(cid)
            if not path:
                logging.warning(f"No path in custom_id: {cid}")
                continue
            rec = base_index.get(str(base_id))
            if not rec:
                logging.warning(
                    f"No original record for id='{base_id}' (custom_id={cid})"
                )
                continue
            text = parse_as_string(payload.get("response_raw"))
            write_translation(rec, path, args.lang_prefix, text)
            updated_ids.add(str(base_id))

    # Set fixed Arabic answers for true/false items BEFORE writing results
    for bid in updated_ids:
        set_true_false_fixed_answers(base_index[bid], args.lang_prefix)

    with open(args.output_file, "w", encoding="utf-8") as w:
        for bid in updated_ids:
            w.write(json.dumps(base_index[bid], ensure_ascii=False) + "\n")
    logging.info(f"Wrote {len(updated_ids)} merged records -> {args.output_file}")

    # Error subset
    def collect_error_custom_ids(err_file: str) -> List[str]:
        ids: List[str] = []
        if not os.path.exists(err_file):
            return ids
        with open(err_file, "r", encoding="utf-8") as f:
            for line in f:
                if not line.strip():
                    continue
                try:
                    obj = json.loads(line)
                    cid = obj.get("custom_id")
                    if isinstance(cid, list) and cid:
                        ids.append(cid[0])
                    elif isinstance(cid, str):
                        ids.append(cid)
                except Exception:
                    pass
        return ids

    error_custom_ids: List[str] = []
    for ep in err_file_list:
        error_custom_ids.extend(collect_error_custom_ids(ep))

    targets = set()
    for cid in error_custom_ids:
        bid, _path, _l = parse_custom_id(cid)
        targets.add(bid)

    with open(args.output_error_file, "w", encoding="utf-8") as w:
        for bid, ex in base_index.items():
            if bid in targets:
                w.write(json.dumps(ex, ensure_ascii=False) + "\n")
    logging.info(f"Wrote error subset -> {args.output_error_file}")


if __name__ == "__main__":
    main()
