#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Batch submit script for Azure OpenAI Batch — TRANSLATION EDITION.

- One batch task per translatable leaf:
  * image_desc_meta: en_description, en_reason
  * QA_meta.open-ended[i]: en_question, en_answer, en_rationale
  * QA_meta.multiple-choice[i]: en_question, en_answer, en_rationale, and EVERY en_options[j]
  * QA_meta.true_false[i]: en_question, en_rationale (answer optional; see --skip_tf_answer)
- custom_id encodes JSON path so results can be merged:
    {id}::path=QA_meta.open-ended.0.en_question::lang=ajp
- Uses prompts/translate_prompt.py (make_messages) expecting {"text": "..."}
"""

from __future__ import annotations

import argparse
import importlib.util
import json
import logging
import os
from typing import Any, Callable, Dict, List, Optional, Union

from dotenv import load_dotenv
from openai import AzureOpenAI, OpenAI


def configure_logging():
    logging.basicConfig(
        level=logging.INFO,
        format="%(asctime)s - %(levelname)s - %(message)s",
        handlers=[logging.StreamHandler()],
    )


def load_env_azure(env_path: str):
    load_dotenv(dotenv_path=env_path, override=True)
    api_base = os.environ["AZURE_API_URL"].rstrip("/")
    api_key = os.environ["AZURE_API_KEY"]
    api_version = os.environ["AZURE_API_VERSION"]
    engine = os.environ.get("AZURE_ENGINE_NAME", "gpt-4.1-global-batch")
    return api_key, api_base, api_version, engine


def load_env_openai(env_path):
    load_dotenv(dotenv_path=env_path, override=True)
    openai_api_key = os.environ["OPENAI_API_KEY"]
    client = OpenAI(key=openai_api_key)
    return client


def load_prompt_module(prompt_name: str, prompt_path: Optional[str]) -> Any:
    """Expect a module exposing: make_messages(example: dict, **kwargs) -> List[dict]"""
    if prompt_path:
        prompt_path = os.path.abspath(prompt_path)
        if not os.path.exists(prompt_path):
            raise FileNotFoundError(f"Prompt file not found: {prompt_path}")
        spec = importlib.util.spec_from_file_location("user_prompt_module", prompt_path)
        mod = importlib.util.module_from_spec(spec)
        assert spec and spec.loader
        spec.loader.exec_module(mod)  # type: ignore
        return mod

    here = os.path.dirname(os.path.abspath(__file__))
    guess = os.path.join(here, "prompts", f"{prompt_name}_prompt.py")
    if not os.path.exists(guess):
        raise FileNotFoundError(
            f"No prompt file for '{prompt_name}'. Expected at: {guess}\n"
            f"Or pass --prompt_path /path/to/custom_prompt.py"
        )
    spec = importlib.util.spec_from_file_location("user_prompt_module", guess)
    mod = importlib.util.module_from_spec(spec)
    assert spec and spec.loader
    spec.loader.exec_module(mod)  # type: ignore
    return mod


class LLMBatchBuilder:
    def __init__(
        self,
        input_jsonl: str,
        output_dir: str,
        model_name: str = "gpt-4.1-global-batch",
        endpoint: str = "/chat/completions",
        max_completion_tokens: int = 4000,
        batch_file_size_limit: int = 180 * 1024 * 1024,
        id_field: str = "image_id",
        max_tasks_per_file: int = 100000,
        custom_id_format: str = "{id}::path={path}::lang={lang_prefix}",
        skip_tf_answer: bool = True,
    ):
        self.input_jsonl = input_jsonl
        self.output_dir = output_dir
        self.model_name = model_name
        self.endpoint = endpoint
        self.max_completion_tokens = max_completion_tokens
        self.batch_file_size_limit = batch_file_size_limit
        self.id_field = id_field
        self.max_tasks_per_file = max_tasks_per_file
        self.custom_id_format = custom_id_format
        self.skip_tf_answer = skip_tf_answer
        os.makedirs(self.output_dir, exist_ok=True)

    def _calc_bytes(self, obj: Union[str, dict, list]) -> int:
        text = obj if isinstance(obj, str) else json.dumps(obj, ensure_ascii=False)
        return len(text.encode("utf-8"))

    def _save_batch(self, tasks: List[dict], idx: int) -> str:
        path = os.path.join(self.output_dir, f"batch_{idx}.jsonl")
        with open(path, "w", encoding="utf-8") as w:
            for t in tasks:
                w.write(json.dumps(t, ensure_ascii=False) + "\n")
        logging.info(f"Saved batch {idx} -> {path}")
        return path

    def _render_custom_id(
        self, base_id: str, path: str, lang_prefix: Optional[str]
    ) -> str:
        ctx = {"id": base_id, "path": path, "lang_prefix": lang_prefix or ""}
        try:
            return self.custom_id_format.format_map(ctx)
        except Exception:
            return f"{base_id}::path={path}"

    @staticmethod
    def _append(
        out: List[Dict[str, Any]], base_id: Any, path: str, text: Optional[str]
    ):
        if text is None:
            return
        text = str(text).strip()
        if not text:
            return
        out.append({"base_id": base_id, "path": path, "payload": {"text": text}})

    def _iter_qa_items(self, rec: Dict[str, Any]) -> List[Dict[str, Any]]:
        base_id = rec.get(self.id_field)
        out: List[Dict[str, Any]] = []

        img = rec.get("image_desc_meta") or {}
        if isinstance(img, dict):
            for leaf in ("en_description", "en_reason"):
                if leaf in img:
                    self._append(out, base_id, f"image_desc_meta.{leaf}", img.get(leaf))

        qa = rec.get("QA_meta") or {}

        def walk_triplet(section_key: str):
            lst = qa.get(section_key) or []
            if not isinstance(lst, list):
                return
            for i, item in enumerate(lst):
                if not isinstance(item, dict):
                    continue
                for leaf in ("en_question", "en_answer", "en_rationale"):
                    if leaf in item:
                        self._append(
                            out,
                            base_id,
                            f"QA_meta.{section_key}.{i}.{leaf}",
                            item.get(leaf),
                        )

        walk_triplet("open-ended")

        mc_list = qa.get("multiple-choice") or []
        if isinstance(mc_list, list):
            for i, item in enumerate(mc_list):
                if not isinstance(item, dict):
                    continue
                for leaf in ("en_question", "en_answer", "en_rationale"):
                    if leaf in item:
                        self._append(
                            out,
                            base_id,
                            f"QA_meta.multiple-choice.{i}.{leaf}",
                            item.get(leaf),
                        )
                opts = item.get("en_options") or []
                if isinstance(opts, list):
                    for j, opt in enumerate(opts):
                        self._append(
                            out,
                            base_id,
                            f"QA_meta.multiple-choice.{i}.en_options.{j}",
                            opt,
                        )

        tf_list = qa.get("true_false") or []
        if isinstance(tf_list, list):
            for i, item in enumerate(tf_list):
                if not isinstance(item, dict):
                    continue
                # Always question
                if "en_question" in item:
                    self._append(
                        out,
                        base_id,
                        f"QA_meta.true_false.{i}.en_question",
                        item.get("en_question"),
                    )
                # Optionally include answer (default skipped per your request)
                if (not self.skip_tf_answer) and ("en_answer" in item):
                    self._append(
                        out,
                        base_id,
                        f"QA_meta.true_false.{i}.en_answer",
                        item.get("en_answer"),
                    )
                # Rationale (if present)
                if "en_rationale" in item:
                    self._append(
                        out,
                        base_id,
                        f"QA_meta.true_false.{i}.en_rationale",
                        item.get("en_rationale"),
                    )

        return out

    def create_batches(
        self,
        make_messages: Callable[[Dict[str, Any]], List[Dict[str, Any]]],
        prompt_kwargs: Optional[Dict[str, Any]] = None,
        lang_prefix: Optional[str] = None,
    ) -> List[str]:
        prompt_kwargs = dict(prompt_kwargs or {})
        if lang_prefix:
            prompt_kwargs.setdefault("lang_prefix", lang_prefix)

        batch_idx, current_size = 1, 0
        current_tasks: List[dict] = []
        created_files: List[str] = []

        with open(self.input_jsonl, "r", encoding="utf-8") as f:
            for line in f:
                if not line.strip():
                    continue
                rec = json.loads(line)

                items = self._iter_qa_items(rec)
                for it in items:
                    example_for_prompt = dict(it["payload"])  # {"text": ...}
                    messages = make_messages(example_for_prompt, **prompt_kwargs)

                    cid = self._render_custom_id(
                        base_id=str(it["base_id"]),
                        path=str(it["path"]),
                        lang_prefix=lang_prefix,
                    )
                    task = {
                        "custom_id": cid,
                        "method": "POST",
                        "url": self.endpoint,
                        "body": {
                            "model": self.model_name,
                            "messages": messages,
                            "max_completion_tokens": self.max_completion_tokens,
                            "temperature": 0.0,
                        },
                    }

                    # task_size = self._calc_bytes(task)
                    # if (
                    #     current_size
                    #     and current_size + task_size > self.batch_file_size_limit
                    # ):
                    #     created_files.append(self._save_batch(current_tasks, batch_idx))
                    #     current_tasks, current_size = [], 0
                    #     batch_idx += 1
                    task_size = self._calc_bytes(task)
                    if (
                        self.batch_file_size_limit
                        and current_size
                        and current_size + task_size > self.batch_file_size_limit
                    ) or (
                        self.max_tasks_per_file
                        and len(current_tasks) >= self.max_tasks_per_file
                    ):
                        created_files.append(self._save_batch(current_tasks, batch_idx))
                        current_tasks, current_size = [], 0
                        batch_idx += 1

                    current_tasks.append(task)
                    current_size += task_size

        if current_tasks:
            created_files.append(self._save_batch(current_tasks, batch_idx))

        logging.info(f"Created {len(created_files)} batch file(s).")
        return created_files


class BatchManager:
    def __init__(self, client, batch_file_name: str):
        self.client = client
        self.batch_file_name = batch_file_name

    def _append_tracking(self, batch_id: str, local_file: str):
        if os.path.dirname(self.batch_file_name):
            os.makedirs(os.path.dirname(self.batch_file_name), exist_ok=True)
        with open(self.batch_file_name, "a", encoding="utf-8") as w:
            w.write(f"{batch_id},{os.path.abspath(local_file)}\n")
        logging.info(f"Tracked batch {batch_id} -> {self.batch_file_name}")

    def submit_batch_jsonl(
        self,
        batch_jsonl: str,
        completion_window: str = "24h",
        endpoint: str = "/chat/completions",
    ) -> str:
        with open(batch_jsonl, "rb") as fh:
            fobj = self.client.files.create(file=fh, purpose="batch")
        batch = self.client.batches.create(
            input_file_id=fobj.id,
            endpoint=endpoint,
            completion_window=completion_window,
        )
        self._append_tracking(batch.id, batch_jsonl)
        logging.info(
            f"Submitted {batch_jsonl} -> batch_id={batch.id} status={batch.status}"
        )
        return batch.id

    def submit_all_batches_in_directory(
        self, directory: str, verbose: bool = True
    ) -> List[str]:
        import glob

        submitted = []
        for path in sorted(glob.glob(os.path.join(directory, "batch_*.jsonl"))):
            try:
                bid = self.submit_batch_jsonl(path)
                submitted.append(bid)
            except Exception as e:
                logging.error(f"Failed to submit {path}: {e}")
        if verbose:
            logging.info(f"Submitted {len(submitted)} batch file(s).")
        return submitted


def main():
    parser = argparse.ArgumentParser(
        description="Submit Azure OpenAI Batch with one task per translatable leaf."
    )
    parser.add_argument(
        "--input", required=True, help="Path to input JSONL (one example per line)"
    )
    parser.add_argument("--env_file", required=True, help="Path to .env")
    parser.add_argument(
        "--output_dir", required=True, help="Where to write batch JSONL files"
    )
    parser.add_argument(
        "--batch_file", required=True, help="Tracking file for batch IDs"
    )

    parser.add_argument(
        "--prompt", default="translate", help="Prompt name (translate|custom)"
    )
    parser.add_argument(
        "--prompt_path", default=None, help="Optional explicit path to prompt module"
    )
    parser.add_argument(
        "--prompt_kwargs",
        default=None,
        help="Optional JSON string to pass into make_messages",
    )

    parser.add_argument(
        "--model",
        default="gpt-4.1-global-batch",
        help="Model for body.model (overrides AZURE_ENGINE_NAME)",
    )
    parser.add_argument("--max_tokens", type=int, default=4000)

    parser.add_argument(
        "--id_field", default="image_id", help="Field name for base record id"
    )
    parser.add_argument(
        "--custom_id_format",
        default="{id}::path={path}::lang={lang_prefix}",
        help="Format for custom_id",
    )
    parser.add_argument(
        "--lang_prefix",
        choices=["en", "msa", "ajp", "arz"],
        required=True,
        help="Target language prefix",
    )
    parser.add_argument(
        "--skip_tf_answer",
        action="store_true",
        default=True,
        help="Skip translating true/false answers",
    )
    parser.add_argument(
        "--api",
        choices=["azure", "openai"],
        default="azure",
        help="API endpoint ",
    )

    args = parser.parse_args()
    configure_logging()

    if args.api == "azure":
        api_key, api_base, api_version, env_model = load_env_azure(args.env_file)
        client = AzureOpenAI(
            azure_endpoint=api_base, api_key=api_key, api_version=api_version
        )
    else:
        client = load_env_openai(args.env_file)
    model_name = args.model or env_model

    prompt_mod = load_prompt_module(args.prompt, args.prompt_path)
    if not hasattr(prompt_mod, "make_messages"):
        raise AttributeError(
            "Prompt module must define make_messages(example: dict, **kwargs) -> List[dict]"
        )
    prompt_kwargs = json.loads(args.prompt_kwargs) if args.prompt_kwargs else {}

    builder = LLMBatchBuilder(
        input_jsonl=args.input,
        output_dir=args.output_dir,
        model_name=model_name,
        max_completion_tokens=args.max_tokens,
        id_field=args.id_field,
        custom_id_format=args.custom_id_format,
        skip_tf_answer=args.skip_tf_answer,
    )
    builder.create_batches(
        make_messages=prompt_mod.make_messages,
        prompt_kwargs=prompt_kwargs,
        lang_prefix=args.lang_prefix,
    )

    manager = BatchManager(
        client,
        batch_file_name=args.batch_file,
    )
    # manager.submit_all_batches_in_directory(args.output_dir, verbose=True)


if __name__ == "__main__":
    main()
