#!/usr/bin/env python3
"""
QuestionAugment: Use LLM to perform text augmentation for geometry problems
"""

import json
import os
import re
import time
from typing import Any, Dict, Optional

from openai import OpenAI
from ...utils.model_urls import get_base_url_for_model


class QuestionAugment:
    """Use an LLM to perform text augmentation for geometry problems"""
    
    def __init__(
        self,
        api_key: Optional[str] = None,
        model: Optional[str] = None,
        base_url: Optional[str] = None,
        max_retries: int = 3,
        ):
        """
        Initialize the question augmenter.
        
        Args:
            api_key: API Key (if None, it will be read from environment variables)
            model: Model name (if None, a default value will be used)
            base_url: API base URL (if None, it will be inferred from the model)
            max_retries: Maximum number of retries
        """
        # Set default values
        if model is None:
            model = os.getenv("OPENAI_MODEL") or "gemini-2.5-flash"
        
        # Automatically get base_url based on the model name
        if base_url is None:
            base_url = get_base_url_for_model(model)
            if base_url is None:
                base_url = os.getenv("OPENAI_BASE_URL") or "https://generativelanguage.googleapis.com/v1beta/openai/"
        
        self.api_key = api_key or os.getenv("OPENAI_API_KEY")
        
        if not self.api_key:
            raise ValueError("API Key is required (via parameter or environment variable)")
        
        self.model = model
        self.base_url = base_url
        self.max_retries = max_retries
        client_kwargs = {"api_key": self.api_key, "timeout": 600.0}
        if self.base_url:
            client_kwargs["base_url"] = self.base_url
        self.client = OpenAI(**client_kwargs)
        self.last_token_usage = None  # Store token usage of the last call
    
    def format_question_prompt(self, question: str, num_versions: int = 3) -> str:
        """Format the prompt for geometry problem text augmentation (enhanced version)"""
        versions = ", ".join([f'"Q{i}"' for i in range(1, num_versions + 1)])
        prompt = f"""You are a professional geometry problem editor. Please perform data augmentation on the following geometry problem and generate {num_versions} versions that are semantically equivalent but expressed differently.

    Strict requirements:
    1. The mathematical meaning, all geometric relationships, implicit conditions, and the solving objective must remain completely consistent.
    - It is forbidden to add any new geometric conditions.
    - It is forbidden to remove or weaken any existing conditions.
    - It is forbidden to change known relationships (such as "midpoint", "perpendicular", "parallel", "cyclic quadrilateral", etc.).
    2. The expression can change: adjust sentence patterns, word order, wording, or narrative style.
    3. All mathematical symbols, point names, segment names, angle labels, and numerical values must remain exactly the same.
    4. Each version must have clearly distinguishable differences in expression style.
    5. Output format: output only a single JSON object, without any additional text, explanations, or code block markers.

    Original problem:
    {question}

    Output format:
    {{
    {versions}: "augmented problem text"
    }}"""
        return prompt

    
    def call_llm(self, prompt: str) -> Dict[str, Any]:
        """Call the LLM API"""
        for attempt in range(self.max_retries):
            try:
                completion = self.client.chat.completions.create(
                    model=self.model,
                    messages=[
                        {
                            "role": "user",
                            "content": prompt,
                        },
                    ],
                    temperature=0.7,
                    max_tokens=32767,
                )
                choice = completion.choices[0]
                response = choice.message
                finish_reason = getattr(choice, "finish_reason", None)
                if finish_reason == "length":
                    raise Exception(
                        "LLM response was truncated (finish_reason=length), current max_tokens=32767. "
                        "Please increase the global max_tokens or shorten the prompt."
                    )
                return {
                    "content": response.content if response else "",
                    "usage": completion.usage,
                    "finish_reason": finish_reason,
                }
                
            except Exception as e:
                if attempt < self.max_retries - 1:
                    wait_time = 2 ** attempt
                    print(f"Attempt {attempt + 1} failed, retrying after {wait_time} seconds... ({e})")
                    time.sleep(wait_time)
                    continue
                raise Exception(f"API call failed (retried {self.max_retries} times): {e}")
        
        raise Exception("Maximum number of retries reached")
    
    @staticmethod
    def _serialize_usage(usage_obj) -> Optional[Dict[str, int]]:
        """
        Convert the usage object into a serializable dictionary format.
        
        Args:
            usage_obj: usage object (may be a CompletionUsage object or a dict)
            
        Returns:
            Usage in dictionary format, or None if conversion is not possible.
        """
        if usage_obj is None:
            return None
        
        # If it is already a dictionary, return directly
        if isinstance(usage_obj, dict):
            return {
                "prompt_tokens": usage_obj.get("prompt_tokens", 0) or 0,
                "completion_tokens": usage_obj.get("completion_tokens", 0) or 0,
                "total_tokens": usage_obj.get("total_tokens", 0) or 0,
            }
        
        # If it is an object, extract attributes
        try:
            prompt_tokens = getattr(usage_obj, 'prompt_tokens', None)
            completion_tokens = getattr(usage_obj, 'completion_tokens', None)
            total_tokens = getattr(usage_obj, 'total_tokens', None)
            
            # If attributes exist, convert to a dictionary
            if prompt_tokens is not None or completion_tokens is not None or total_tokens is not None:
                return {
                    "prompt_tokens": int(prompt_tokens) if prompt_tokens is not None else 0,
                    "completion_tokens": int(completion_tokens) if completion_tokens is not None else 0,
                    "total_tokens": int(total_tokens) if total_tokens is not None else 0,
                }
        except Exception:
            pass
        
        return None
    
    @staticmethod
    def extract_json_from_response(response_text: str) -> Optional[Dict[str, Any]]:
        """Extract JSON from the response text"""
        if not response_text:
            return None

        stripped = response_text.strip()

        # 1) Try parsing directly
        try:
            return json.loads(stripped)
        except json.JSONDecodeError:
            pass

        # 2) Parse ```json ... ``` code blocks
        code_block_match = re.search(r"```(?:json)?\s*(\{)", response_text, re.DOTALL)
        if code_block_match:
            start_pos = code_block_match.start(1)  # Start position of the JSON object
            json_text = response_text[start_pos:]
            
            # Find the ending marker (if any)
            if "```" in json_text[1:]:  # Start from the second character to avoid the opening marker
                end_marker = json_text.find("```", 1)
                if end_marker > 0:
                    json_text = json_text[:end_marker].rstrip()
            
            # Try parsing directly first
            decoder = json.JSONDecoder()
            try:
                obj, end_pos = decoder.raw_decode(json_text)
                if isinstance(obj, dict):
                    return obj
            except json.JSONDecodeError:
                # If it fails, string values may contain unescaped control characters or backslashes.
                # Fix strategy: escape unescaped control characters and backslashes inside string values.
                def fix_string_value(match):
                    key_part = match.group(1)  # "key":
                    value = match.group(2)     # Original value content
                    tail = match.group(3)      # Ending quote

                    fixed_value = value
                    # 1. Newline, tab, carriage return → escape them
                    fixed_value = re.sub(r'(?<!\\)\n', r'\\n', fixed_value)
                    fixed_value = re.sub(r'(?<!\\)\t', r'\\t', fixed_value)
                    fixed_value = re.sub(r'(?<!\\)\r', r'\\r', fixed_value)
                    # 2. Backslash + letter → double backslash (for LaTeX-like sequences)
                    fixed_value = re.sub(r'(?<!\\)\\(?=[a-zA-Z])', r'\\\\', fixed_value)
                    # 3. Unescaped double quotes → \"
                    fixed_value = re.sub(r'(?<!\\)"', r'\\"', fixed_value)

                    return key_part + fixed_value + tail

                try:
                    # Match all string values and repair them
                    fixed_json = re.sub(
                        r'("augmented"\s*:\s*")(.*?)(")',
                        fix_string_value,
                        json_text,
                        flags=re.DOTALL
                    )
                    # Try parsing again
                    obj, _ = decoder.raw_decode(fixed_json)
                    if isinstance(obj, dict):
                        return obj
                except (json.JSONDecodeError, Exception):
                    pass

        # 3) Parse the first `{ ... }` substring (try to match a valid JSON object)
        decoder = json.JSONDecoder()
        for idx, ch in enumerate(response_text):
            if ch != "{":
                continue
            try:
                obj, _ = decoder.raw_decode(response_text[idx:])
                if isinstance(obj, dict):
                    return obj
            
            except json.JSONDecodeError:
                    continue

        return None
    
    def augment(self, question: str, index: Optional[int] = None, num_versions: int = 3) -> Dict[str, str]:
        """
        Augment the question text.
        
        Args:
            question: Original question text
            index: Optional index (for logging)
            num_versions: Number of versions to generate
        
        Returns:
            A dictionary containing multiple augmented versions, in the format {"Q1": "...", "Q2": "...", ...}
        """
        if not question or not question.strip():
            return {"Q1": question}
        
        try:
            if index is not None:
                print(f"[Sample {index}] Augmenting question, generating {num_versions} versions...")
            
            prompt = self.format_question_prompt(question, num_versions)
            response = self.call_llm(prompt)
            response_text = response.get("content", "").strip()
            usage = response.get("usage")
            
            # Aggregate and store token usage
            usage_dict = self._serialize_usage(usage)
            self.last_token_usage = usage_dict
            if usage_dict and index is not None:
                print(
                    f"[Sample {index}] Token usage: "
                    f"prompt={usage_dict['prompt_tokens']}, "
                    f"completion={usage_dict['completion_tokens']}, "
                    f"total={usage_dict['total_tokens']}"
                )
            elif usage_dict:
                print(
                    f"Token usage: "
                    f"prompt={usage_dict['prompt_tokens']}, "
                    f"completion={usage_dict['completion_tokens']}, "
                    f"total={usage_dict['total_tokens']}"
                )
            
            # Extract JSON from the response
            result_json = self.extract_json_from_response(response_text)
            
            if result_json:
                # Extract all versions Q1, Q2, Q3, etc.
                versions = {}
                for i in range(1, num_versions + 1):
                    key = f"Q{i}"
                    if key in result_json:
                        value = result_json[key].strip()
                        if value:
                            versions[key] = value
                
                if versions:
                    return versions
            
            # If JSON parsing fails, return the original question as Q1
            if index is not None:
                print(f"[Sample {index}] Question augmentation failed, using the original question")
            self.last_token_usage = None  # Clear token statistics on failure
            return {"Q1": question}
        except Exception as e:
            if index is not None:
                print(f"[Sample {index}] Error during question augmentation: {e}, using the original question")
            self.last_token_usage = None  # Clear token statistics on failure
            return {"Q1": question}

