#!/usr/bin/env python3
"""
CoTAugment: Use LLM to augment the reasoning process (CoT) for geometry problems.
"""

import json
import os
import re
import time
from typing import Any, Dict, Optional

from openai import OpenAI
from ...utils.model_urls import get_base_url_for_model


class CoTAugment:
    """Use an LLM to augment the reasoning process (CoT) for geometry problems."""

    def __init__(
        self,
        api_key: Optional[str] = None,
        model: Optional[str] = None,
        base_url: Optional[str] = None,
        max_retries: int = 3,
    ):
        """
        Initialize the reasoning process augmenter.

        Args:
            api_key: API key (if None, it will be read from environment variables).
            model: Model name (if None, a default value will be used).
            base_url: API base URL (if None, it will be automatically inferred from the model).
            max_retries: Maximum retry times.
        """
        # Set default values
        if model is None:
            model = os.getenv("OPENAI_MODEL") or "gemini-2.5-flash"

        # Automatically infer base_url based on model name
        if base_url is None:
            base_url = get_base_url_for_model(model)
            if base_url is None:
                base_url = os.getenv("OPENAI_BASE_URL") or "https://generativelanguage.googleapis.com/v1beta/openai/"

        self.api_key = api_key or os.getenv("OPENAI_API_KEY")

        if not self.api_key:
            raise ValueError("API key is required (via argument or environment variable).")

        self.model = model
        self.base_url = base_url
        self.max_retries = max_retries
        client_kwargs = {"api_key": self.api_key, "timeout": 600.0}
        if self.base_url:
            client_kwargs["base_url"] = self.base_url
        self.client = OpenAI(**client_kwargs)
        # Store token usage information for the last call
        self.last_token_usage = None

    def format_cot_prompt(self, cot: str, num_versions: int = 3) -> str:
        """Format the prompt for CoT augmentation."""
        versions = ", ".join([f'"CoT{i}"' for i in range(1, num_versions + 1)])
        prompt = f"""You are an expert editor of reasoning processes for geometry problems. Please perform data augmentation on the following Chain of Thought (CoT) for a geometry problem and generate {num_versions} versions that are semantically equivalent but use different expressions.

Strict requirements:
1. You must keep the overall reasoning logic, mathematical meaning, and final conclusion completely consistent.
- Key mathematical derivation steps and intermediate conclusions must not be deleted or tampered with.
- You may adjust the order of reasoning steps, but the logic must remain rigorous and readable.
2. You may change the way of expression, sentence structure, wording, and narrative style.
3. All calculation results, numerical values, symbolic relationships, and final answers must remain unchanged.
- Equations/inequalities can be rewritten as long as they are mathematically equivalent transformations.
- You must not change any numerical or judgment conclusion (e.g., "∠ABC = 60°", "AB = AC", "point D lies on BC", etc.).
4. Each version should have a clearly different narrative style (for example: more concise, more detailed, more didactic/teaching style, more colloquial, etc.).
5. The reasoning process must be written step by step using the format "Step 1: ...", "Step 2: ...", etc., but you can freely adjust the granularity of steps:
- You may merge multiple steps into one step (as long as the content is complete).
- You may split one step into multiple steps (to add more detailed explanations).
- You may add or remove purely explanatory or transitional textual steps.
However, you must not remove any substantive mathematical derivation or intermediate mathematical conclusion.
6. Output format: output only one JSON object, without any extra text, explanation, or code block markers.

Original reasoning process:
{cot}

Output format:
{{
{versions}: "augmented reasoning process text"
}}"""
        return prompt

    def call_llm(self, prompt: str) -> Dict[str, Any]:
        """Call the LLM API."""
        for attempt in range(self.max_retries):
            try:
                completion = self.client.chat.completions.create(
                    model=self.model,
                    messages=[
                        {
                            "role": "user",
                            "content": prompt,
                        },
                    ],
                    temperature=0.7,
                    max_tokens=32767,
                )
                choice = completion.choices[0]
                response = choice.message
                finish_reason = getattr(choice, "finish_reason", None)
                if finish_reason == "length":
                    raise Exception(
                        "LLM response was truncated (finish_reason=length). "
                        "Current max_tokens=32767. Please increase global max_tokens or shorten the prompt."
                    )
                return {
                    "content": response.content if response else "",
                    "usage": completion.usage,
                    "finish_reason": finish_reason,
                }

            except Exception as e:
                if attempt < self.max_retries - 1:
                    wait_time = 2 ** attempt
                    print(f"Attempt {attempt + 1} failed, retrying in {wait_time} seconds... ({e})")
                    time.sleep(wait_time)
                    continue
                raise Exception(f"API call failed (retried {self.max_retries} times): {e}")

        raise Exception("Maximum retry limit reached.")

    @staticmethod
    def _serialize_usage(usage_obj) -> Optional[Dict[str, int]]:
        """
        Convert the usage object to a serializable dict.

        Args:
            usage_obj: The usage object (may be a CompletionUsage object or a dict).

        Returns:
            A dict-form usage object, or None if conversion is not possible.
        """
        if usage_obj is None:
            return None

        # If already a dict, just normalize its fields
        if isinstance(usage_obj, dict):
            return {
                "prompt_tokens": usage_obj.get("prompt_tokens", 0) or 0,
                "completion_tokens": usage_obj.get("completion_tokens", 0) or 0,
                "total_tokens": usage_obj.get("total_tokens", 0) or 0,
            }

        # If it's an object, extract attributes
        try:
            prompt_tokens = getattr(usage_obj, "prompt_tokens", None)
            completion_tokens = getattr(usage_obj, "completion_tokens", None)
            total_tokens = getattr(usage_obj, "total_tokens", None)

            # If attributes exist, convert to dict
            if prompt_tokens is not None or completion_tokens is not None or total_tokens is not None:
                return {
                    "prompt_tokens": int(prompt_tokens) if prompt_tokens is not None else 0,
                    "completion_tokens": int(completion_tokens) if completion_tokens is not None else 0,
                    "total_tokens": int(total_tokens) if total_tokens is not None else 0,
                }
        except Exception:
            pass

        return None

    @staticmethod
    def extract_json_from_response(response_text: str) -> Optional[Dict[str, Any]]:
        """Extract JSON from the LLM response text."""
        if not response_text:
            return None

        stripped = response_text.strip()

        # 1) Try parsing directly
        try:
            return json.loads(stripped)
        except json.JSONDecodeError:
            pass

        # 2) Parse ```json ... ``` code block
        # First locate the beginning of the JSON object inside the code block
        code_block_match = re.search(r"```(?:json)?\s*(\{)", response_text, re.DOTALL)
        if code_block_match:
            # Start position of JSON object
            start_pos = code_block_match.start(1)
            json_text = response_text[start_pos:]

            # Find the ending marker (if any)
            # Start from the second character to avoid matching the opening ```
            if "```" in json_text[1:]:
                end_marker = json_text.find("```", 1)
                if end_marker > 0:
                    json_text = json_text[:end_marker].rstrip()

            # Try to parse directly
            decoder = json.JSONDecoder()
            try:
                obj, end_pos = decoder.raw_decode(json_text)
                if isinstance(obj, dict):
                    return obj
            except json.JSONDecodeError:
                # If it fails, it might be due to unescaped control characters or backslashes in string values.
                # Fix strategy: escape unescaped control characters and backslashes inside string values.
                def fix_string_value(match):
                    key_part = match.group(1)  # "key":
                    value = match.group(2)  # original value content
                    tail = match.group(3)  # closing quote

                    fixed_value = value
                    # 1. Newline, tab, carriage return -> escaped
                    fixed_value = re.sub(r"(?<!\\)\n", r"\\n", fixed_value)
                    fixed_value = re.sub(r"(?<!\\)\t", r"\\t", fixed_value)
                    fixed_value = re.sub(r"(?<!\\)\r", r"\\r", fixed_value)
                    # 2. Backslash + letter -> double backslash (for LaTeX-like sequences)
                    fixed_value = re.sub(r"(?<!\\)\\(?=[a-zA-Z])", r"\\\\", fixed_value)
                    # 3. Unescaped double quotes -> \"
                    fixed_value = re.sub(r'(?<!\\)"', r'\\"', fixed_value)

                    return key_part + fixed_value + tail

                try:
                    # Match and fix all string values that appear as `"key": "value"` pairs.
                    # This works with dynamic keys like CoT1, CoT2, CoT3, etc.
                    fixed_json = re.sub(
                        r'("[^"]+"\s*:\s*")(.*?)(")',
                        fix_string_value,
                        json_text,
                        flags=re.DOTALL,
                    )
                    # Try parsing again
                    obj, _ = decoder.raw_decode(fixed_json)
                    if isinstance(obj, dict):
                        return obj
                except (json.JSONDecodeError, Exception):
                    pass

        # 3) Fallback: try to parse the first "{ ... }" substring as JSON
        decoder = json.JSONDecoder()
        for idx, ch in enumerate(response_text):
            if ch != "{":
                continue
            try:
                obj, _ = decoder.raw_decode(response_text[idx:])
                if isinstance(obj, dict):
                    return obj
            except json.JSONDecodeError:
                continue

        return None

    def augment(self, cot: str, index: Optional[int] = None, num_versions: int = 3) -> Dict[str, str]:
        """
        Augment the reasoning process (CoT) text.

        Args:
            cot: Original reasoning process text.
            index: Optional index used for logging.
            num_versions: Number of augmented versions to generate.

        Returns:
            A dict containing multiple augmented versions,
            in the format {"CoT1": "...", "CoT2": "...", ...}.
        """
        if not cot or not cot.strip():
            return {"CoT1": cot}

        try:
            if index is not None:
                print(f"[sample {index}] Augmenting reasoning process, generating {num_versions} versions...")

            prompt = self.format_cot_prompt(cot, num_versions)
            response = self.call_llm(prompt)
            response_text = response.get("content", "").strip()
            usage = response.get("usage")

            # Collect and store token usage
            usage_dict = self._serialize_usage(usage)
            self.last_token_usage = usage_dict
            if usage_dict and index is not None:
                print(
                    f"[sample {index}] Token usage: "
                    f"prompt={usage_dict['prompt_tokens']}, "
                    f"completion={usage_dict['completion_tokens']}, "
                    f"total={usage_dict['total_tokens']}"
                )
            elif usage_dict:
                print(
                    "Token usage: "
                    f"prompt={usage_dict['prompt_tokens']}, "
                    f"completion={usage_dict['completion_tokens']}, "
                    f"total={usage_dict['total_tokens']}"
                )

            # Extract JSON from response
            result_json = self.extract_json_from_response(response_text)

            if result_json:
                # Extract all CoT1, CoT2, ... versions
                versions: Dict[str, str] = {}
                for i in range(1, num_versions + 1):
                    key = f"CoT{i}"
                    if key in result_json:
                        value = result_json[key].strip()
                        if value:
                            versions[key] = value

                if versions:
                    return versions

            # If JSON parsing fails, return original CoT as CoT1
            if index is not None:
                print(response_text)
                print(f"[sample {index}] CoT augmentation failed, using the original reasoning process.")
            # Clear token usage statistics on failure
            self.last_token_usage = None
            return {"CoT1": cot}
        except Exception as e:
            if index is not None:
                print(f"[sample {index}] Error during CoT augmentation: {e}, using the original reasoning process.")
            # Clear token usage statistics on failure
            self.last_token_usage = None
            return {"CoT1": cot}
