#!/usr/bin/env python3
"""
CaptionAugment: use LLM to augment descriptions (captions) of geometry problems.
"""

import json
import os
import re
import time
from typing import Any, Dict, Optional

from openai import OpenAI
from ...utils.model_urls import get_base_url_for_model


class CaptionAugment:
    """Use an LLM to augment geometric problem descriptions (captions)."""

    def __init__(
        self,
        api_key: Optional[str] = None,
        model: Optional[str] = None,
        base_url: Optional[str] = None,
        max_retries: int = 3,
    ):
        """
        Initialize the caption augmenter.

        Args:
            api_key: API key (if None, will be read from environment variables).
            model: Model name (if None, a default will be used).
            base_url: API base URL (if None, it will be inferred from the model).
            max_retries: Maximum number of retries for API calls.
        """
        # Set default values
        if model is None:
            model = os.getenv("OPENAI_MODEL") or "gemini-2.5-flash"

        # Automatically infer base_url from model name
        if base_url is None:
            base_url = get_base_url_for_model(model)
            if base_url is None:
                base_url = (
                    os.getenv("OPENAI_BASE_URL")
                    or "https://generativelanguage.googleapis.com/v1beta/openai/"
                )

        self.api_key = api_key or os.getenv("OPENAI_API_KEY")

        if not self.api_key:
            raise ValueError(
                "An API key is required (provide via argument or environment variable)."
            )

        self.model = model
        self.base_url = base_url
        self.max_retries = max_retries
        client_kwargs = {"api_key": self.api_key, "timeout": 600.0}
        if self.base_url:
            client_kwargs["base_url"] = self.base_url
        self.client = OpenAI(**client_kwargs)
        # Store token usage for the most recent call
        self.last_token_usage = None

    def format_caption_prompt(self, caption: str, num_versions: int = 2) -> str:
        """Format the data-augmentation prompt for a geometric image description."""
        versions = ", ".join([f'"C{i}"' for i in range(1, num_versions + 1)])
        prompt = f"""You are a professional editor of geometric figure descriptions. Please perform data augmentation on the following description of a geometric figure and generate {num_versions} versions that have the same semantics but different wording.

Strict requirements:
1. You must keep all geometric elements, structural relationships, positional relationships, and mathematical properties in the figure exactly the same.
- Do not add any new elements (new points, lines, angles, circles, etc.).
- Do not remove any elements.
- Do not change any relationships (such as "perpendicular", "parallel", "midpoint", "angle marking", "intersecting", "on the circle", etc.).
2. You may change the manner of expression, sentence structure, order of description, and choice of wording; more colloquial, more formal, or more narrative styles are all allowed.
3. All names, symbols, point labels, angle labels, and numeric values must remain strictly unchanged.
4. Each version must have a clearly different writing style (for example: concise, instructional, detailed, natural-language style, etc.).
5. Output format: you must output only a single JSON object, and MUST NOT include any extra text, explanations, or code block markers.

Original description:
{caption}

Output format:
{{
{versions}: "augmented description text"
}}"""
        return prompt

    def call_llm(self, prompt: str) -> Dict[str, Any]:
        """Call the LLM API."""
        for attempt in range(self.max_retries):
            try:
                completion = self.client.chat.completions.create(
                    model=self.model,
                    messages=[
                        {
                            "role": "user",
                            "content": prompt,
                        },
                    ],
                    temperature=0.7,
                    max_tokens=32767,
                )
                choice = completion.choices[0]
                response = choice.message
                finish_reason = getattr(choice, "finish_reason", None)
                if finish_reason == "length":
                    raise Exception(
                        "LLM response was truncated (finish_reason=length). "
                        "Current max_tokens=32767. Please increase the global "
                        "max_tokens setting or shorten the prompt."
                    )
                return {
                    "content": response.content if response else "",
                    "usage": completion.usage,
                    "finish_reason": finish_reason,
                }

            except Exception as e:
                if attempt < self.max_retries - 1:
                    wait_time = 2**attempt
                    print(
                        f"Attempt {attempt + 1} failed, retrying in {wait_time} seconds... ({e})"
                    )
                    time.sleep(wait_time)
                    continue
                raise Exception(
                    f"API call failed (retried {self.max_retries} times): {e}"
                )

        raise Exception("Reached maximum number of retries.")

    @staticmethod
    def _serialize_usage(usage_obj) -> Optional[Dict[str, int]]:
        """
        Convert the usage object to a serializable dict.

        Args:
            usage_obj: Usage object (may be a CompletionUsage object or a dict).

        Returns:
            A dict containing usage stats, or None if it cannot be converted.
        """
        if usage_obj is None:
            return None

        # If it's already a dict, normalize and return
        if isinstance(usage_obj, dict):
            return {
                "prompt_tokens": usage_obj.get("prompt_tokens", 0) or 0,
                "completion_tokens": usage_obj.get("completion_tokens", 0) or 0,
                "total_tokens": usage_obj.get("total_tokens", 0) or 0,
            }

        # Otherwise, try to extract attributes from an object
        try:
            prompt_tokens = getattr(usage_obj, "prompt_tokens", None)
            completion_tokens = getattr(usage_obj, "completion_tokens", None)
            total_tokens = getattr(usage_obj, "total_tokens", None)

            # If any attribute exists, convert to dict
            if (
                prompt_tokens is not None
                or completion_tokens is not None
                or total_tokens is not None
            ):
                return {
                    "prompt_tokens": int(prompt_tokens) if prompt_tokens is not None else 0,
                    "completion_tokens": int(completion_tokens)
                    if completion_tokens is not None
                    else 0,
                    "total_tokens": int(total_tokens) if total_tokens is not None else 0,
                }
        except Exception:
            pass

        return None

    @staticmethod
    def extract_json_from_response(response_text: str) -> Optional[Dict[str, Any]]:
        """Extract a JSON object from the LLM response text."""
        if not response_text:
            return None

        stripped = response_text.strip()

        # 1) Try direct JSON parsing
        try:
            return json.loads(stripped)
        except json.JSONDecodeError:
            pass

        # 2) Try to parse a ```json ... ``` code block
        code_block_match = re.search(r"```(?:json)?\s*(\{)", response_text, re.DOTALL)
        if code_block_match:
            start_pos = code_block_match.start(1)  # start position of the JSON object
            json_text = response_text[start_pos:]

            # Find the end marker, if any
            if "```" in json_text[1:]:  # search from the second character to avoid the first ```
                end_marker = json_text.find("```", 1)
                if end_marker > 0:
                    json_text = json_text[:end_marker].rstrip()

            # First, try parsing directly
            decoder = json.JSONDecoder()
            try:
                obj, _ = decoder.raw_decode(json_text)
                if isinstance(obj, dict):
                    return obj
            except json.JSONDecodeError:
                # If that fails, we may have unescaped control characters or backslashes
                # inside string values. Try to fix them heuristically.
                def fix_string_value(match):
                    key_part = match.group(1)  # "key":
                    value = match.group(2)  # original value content
                    tail = match.group(3)  # closing quote

                    fixed_value = value
                    # 1. Newlines, tabs, carriage returns → escaped
                    fixed_value = re.sub(r"(?<!\\)\n", r"\\n", fixed_value)
                    fixed_value = re.sub(r"(?<!\\)\t", r"\\t", fixed_value)
                    fixed_value = re.sub(r"(?<!\\)\r", r"\\r", fixed_value)
                    # 2. Backslash + letter (e.g. LaTeX commands) → double backslash
                    fixed_value = re.sub(
                        r"(?<!\\)\\(?=[a-zA-Z])", r"\\\\", fixed_value
                    )
                    # 3. Unescaped double quotes inside the string → \"
                    fixed_value = re.sub(r'(?<!\\)"', r'\\"', fixed_value)

                    return key_part + fixed_value + tail

                try:
                    # Fix all string values under the "augmented" key, if present
                    fixed_json = re.sub(
                        r'("augmented"\s*:\s*")(.*?)(")',
                        fix_string_value,
                        json_text,
                        flags=re.DOTALL,
                    )
                    # Try parsing again
                    obj, _ = decoder.raw_decode(fixed_json)
                    if isinstance(obj, dict):
                        return obj
                except (json.JSONDecodeError, Exception):
                    pass

        # 3) As a last resort, scan for the first JSON object substring
        decoder = json.JSONDecoder()
        for idx, ch in enumerate(response_text):
            if ch != "{":
                continue
            try:
                obj, _ = decoder.raw_decode(response_text[idx:])
                if isinstance(obj, dict):
                    return obj
            except json.JSONDecodeError:
                continue

        return None

    def augment(
        self, caption: str, index: Optional[int] = None, num_versions: int = 2
    ) -> Dict[str, str]:
        """
        Augment a description text.

        Args:
            caption: Original description text.
            index: Optional index (for logging).
            num_versions: Number of augmented versions to generate.

        Returns:
            A dictionary containing multiple augmented versions,
            in the form {"C1": "...", "C2": "...", ...}.
        """
        if not caption or not caption.strip():
            return {"C1": caption}

        try:
            if index is not None:
                print(
                    f"[Sample {index}] Augmenting caption, generating {num_versions} versions..."
                )

            prompt = self.format_caption_prompt(caption, num_versions)
            response = self.call_llm(prompt)
            response_text = response.get("content", "").strip()
            usage = response.get("usage")

            # Record and store token usage
            usage_dict = self._serialize_usage(usage)
            self.last_token_usage = usage_dict
            if usage_dict and index is not None:
                print(
                    f"[Sample {index}] Token usage: "
                    f"prompt={usage_dict['prompt_tokens']}, "
                    f"completion={usage_dict['completion_tokens']}, "
                    f"total={usage_dict['total_tokens']}"
                )
            elif usage_dict:
                print(
                    "Token usage: "
                    f"prompt={usage_dict['prompt_tokens']}, "
                    f"completion={usage_dict['completion_tokens']}, "
                    f"total={usage_dict['total_tokens']}"
                )

            # Extract JSON from the response
            result_json = self.extract_json_from_response(response_text)

            if result_json:
                # Collect all C1, C2, ... versions
                versions: Dict[str, str] = {}
                for i in range(1, num_versions + 1):
                    key = f"C{i}"
                    if key in result_json:
                        value = result_json[key].strip()
                        if value:
                            versions[key] = value

                if versions:
                    return versions

            # If JSON parsing fails, fall back to the original caption as C1
            if index is not None:
                print(
                    f"[Sample {index}] Caption augmentation failed, using original caption."
                )
            # Clear token statistics on failure
            self.last_token_usage = None
            return {"C1": caption}
        except Exception as e:
            if index is not None:
                print(
                    f"[Sample {index}] Error during caption augmentation: {e}, "
                    "using original caption."
                )
            # Clear token statistics on failure
            self.last_token_usage = None
            return {"C1": caption}
