#!/usr/bin/env python3
"""
VisualizeAugment: use an LLM to generate QA pairs about basic geometric
elements (points, lines, circles) and their relationships.
"""

import json
import os
import re
import time
from typing import Any, Dict, Optional

from openai import OpenAI
from ...utils.model_urls import get_base_url_for_model


class VisualizeAugment:
    """Use an LLM to generate QA pairs about basic geometric elements."""
    
    def __init__(
        self,
        api_key: Optional[str] = None,
        model: Optional[str] = None,
        base_url: Optional[str] = None,
        max_retries: int = 3,
    ):
        """
        Initialize the visual QA augmenter.
        
        Args:
            api_key: API key (if None, will be read from the environment).
            model: model name (if None, will use a default value).
            base_url: API base URL (if None, will be inferred from the model).
            max_retries: maximum number of retries on API failure.
        """
        # Set defaults
        if model is None:
            model = os.getenv("OPENAI_MODEL") or "gemini-2.5-flash"
        
        # Get base_url automatically according to model name
        if base_url is None:
            base_url = get_base_url_for_model(model)
            if base_url is None:
                base_url = os.getenv("OPENAI_BASE_URL") or "https://generativelanguage.googleapis.com/v1beta/openai/"
        
        self.api_key = api_key or os.getenv("OPENAI_API_KEY")
        
        if not self.api_key:
            raise ValueError(
                "API key is required (provide via argument or environment variable)."
            )
        
        self.model = model
        self.base_url = base_url
        self.max_retries = max_retries
        client_kwargs = {"api_key": self.api_key, "timeout": 600.0}
        if self.base_url:
            client_kwargs["base_url"] = self.base_url
        self.client = OpenAI(**client_kwargs)
        self.last_token_usage = None  # Store token usage from the most recent call
        
    def format_visualize_prompt(
        self,
        plotting_code: Dict[str, Any],
        caption: str,
        num_versions: int = 2,
    ) -> str:
        """Format the prompt for visual QA generation (VL alignment)."""
        versions = ", ".join([f'"V{i}"' for i in range(1, num_versions + 1)])
        
        prompt = f"""You are an expert in geometry education and visual understanding.
Based on the given geometric plotting code and description, generate
{num_versions} sets of QA pairs about **basic geometric elements** for
vision-language alignment training.

Geometric plotting code (`plotting_code`) includes:
- `points`: point names and their coordinates
- `segments-chains`: chains of segments that indicate which points lie on
  the same segment
- `circles`: circles with the following formats (the first argument is the
  circle ID):
    ["C1", "O", 5]                # circle C1, center O, radius 5
    ["C2", "O", "P"]              # circle C2, center O, OP is the radius
    ["C3", "A", "B", "diameter"]  # circle C3, AB is a diameter
    ["C4", "A", "B", "C"]         # circle C4 passes through A, B, C

Figure description (`caption`):
{caption}

### Task
Generate questions and answers based on the given geometric plotting code
and description. All questions must be **answerable purely by observing
the figure**, and should not involve complex calculations or proofs.

### Content requirements
1. Each QA pair must focus on **basic geometric elements and their
   relationships**, including but not limited to:
   - Points: positions, relative locations, whether a point lies on a
     segment or a circle
   - Lines: which points are connected, whether segments intersect,
     whether they are perpendicular or parallel, etc.
   - Circles: center, circles with known radius, which points lie on the
     circle, and basic relationships between circles and lines
   - Basic relations: whether a point lies on a segment or circle,
     whether segments intersect, and other directly observable properties

2. Requirements for each QA pair:
   - **Question**:
       - Use natural English; you can start with phrases such as
         "In this figure..." or "From the diagram we can see..."
       - Ask only one specific fact; do not combine multiple sub-questions
       - Strictly use the names that appear in the figure (points A, B,
         C; segment AB; circle C1)
       - Do not require numerical calculations (e.g., area, perimeter,
         angle computation); only ask classification/judgment questions

   - **Answer**:
       - Be concise and explicit, giving the conclusion directly in one
         sentence with a brief explanation
       - Answers should be basic judgments such as
         yes/no, belongs/does not belong, intersect/do not intersect,
         above/below, etc.

3. Diversity requirements:
   - The {num_versions} QA sets should cover different elements:
       - Some focusing on relationships among points
       - Some focusing on relationships among segments
       - Some focusing on circles and their related relationships
       - Keep a balance between positive and negative examples

4. Constraints:
   - Do NOT invent any points, lines, or circles that do not appear in
     `plotting_code`
   - Do NOT fabricate specific lengths, coordinates, or angles
   - Each question should focus on a single concept
   - You may refer to the `caption`, but you should rely primarily on
     the structured data in `plotting_code`

5. Output format:
   - Output a single JSON object only; do not add any extra text or code
     fences
   - Use the following format, where {versions} will be replaced with
     "V1", "V2", ..., "V{num_versions}":

   {{
   {versions}: {{
       "question": "question text",
       "answer": "answer text"
   }}
   }}

Geometric plotting code:
{json.dumps(plotting_code, ensure_ascii=False)}
"""
        return prompt


    def call_llm(self, prompt: str) -> Dict[str, Any]:
        """Call the LLM API."""
        for attempt in range(self.max_retries):
            try:
                completion = self.client.chat.completions.create(
                    model=self.model,
                    messages=[
                        {
                            "role": "user",
                            "content": prompt,
                        },
                    ],
                    temperature=0.7,
                    max_tokens=32767,
                )
                choice = completion.choices[0]
                response = choice.message
                finish_reason = getattr(choice, "finish_reason", None)
                if finish_reason == "length":
                    raise Exception(
                        "LLM response was truncated (finish_reason=length). "
                        "Current max_tokens=32767; increase the global "
                        "max_tokens or shorten the prompt."
                    )
                return {
                    "content": response.content if response else "",
                    "usage": completion.usage,
                    "finish_reason": finish_reason,
                }
                
            except Exception as e:
                if attempt < self.max_retries - 1:
                    wait_time = 2 ** attempt
                    print(
                        f"Attempt {attempt + 1} failed, retrying in "
                        f"{wait_time} seconds... ({e})"
                    )
                    time.sleep(wait_time)
                    continue
                raise Exception(
                    f"API call failed (retried {self.max_retries} times): {e}"
                )
        
        raise Exception("Reached the maximum number of retries")
    
    @staticmethod
    def _serialize_usage(usage_obj) -> Optional[Dict[str, int]]:
        """
        Convert a usage object into a JSON-serializable dictionary.
        
        Args:
            usage_obj: usage object (either a CompletionUsage-like object
                       or a dictionary).
            
        Returns:
            A dictionary form of usage; returns None if it cannot be
            converted.
        """
        if usage_obj is None:
            return None
        
        # If it's already a dictionary, normalize and return
        if isinstance(usage_obj, dict):
            return {
                "prompt_tokens": usage_obj.get("prompt_tokens", 0) or 0,
                "completion_tokens": usage_obj.get("completion_tokens", 0) or 0,
                "total_tokens": usage_obj.get("total_tokens", 0) or 0,
            }
        
        # If it's an object, extract attributes
        try:
            prompt_tokens = getattr(usage_obj, 'prompt_tokens', None)
            completion_tokens = getattr(usage_obj, 'completion_tokens', None)
            total_tokens = getattr(usage_obj, 'total_tokens', None)
            
            # If attributes exist, convert them to a dictionary
            if prompt_tokens is not None or completion_tokens is not None or total_tokens is not None:
                return {
                    "prompt_tokens": int(prompt_tokens) if prompt_tokens is not None else 0,
                    "completion_tokens": int(completion_tokens) if completion_tokens is not None else 0,
                    "total_tokens": int(total_tokens) if total_tokens is not None else 0,
                }
        except Exception:
            pass
        
        return None
    
    @staticmethod
    def extract_json_from_response(response_text: str) -> Optional[Dict[str, Any]]:
        """Extract a JSON object from raw response text."""
        if not response_text:
            return None

        stripped = response_text.strip()

        # 1) Try parsing the whole string directly
        try:
            return json.loads(stripped)
        except json.JSONDecodeError:
            pass

        # 2) Try parsing a ```json ... ``` code block
        code_block_match = re.search(r"```(?:json)?\s*(\{)", response_text, re.DOTALL)
        if code_block_match:
            start_pos = code_block_match.start(1)  # JSON object start position
            json_text = response_text[start_pos:]
            
            # Find the closing fence (if any)
            if "```" in json_text[1:]:  # search from 2nd char to avoid the opening
                end_marker = json_text.find("```", 1)
                if end_marker > 0:
                    json_text = json_text[:end_marker].rstrip()
            
            # Try decoding from the beginning of the JSON text
            decoder = json.JSONDecoder()
            try:
                obj, end_pos = decoder.raw_decode(json_text)
                if isinstance(obj, dict):
                    return obj
            except json.JSONDecodeError:
                pass

        # 3) Fallback: scan for the first decodable `{ ... }` substring
        decoder = json.JSONDecoder()
        for idx, ch in enumerate(response_text):
            if ch != "{":
                continue
            try:
                obj, _ = decoder.raw_decode(response_text[idx:])
                if isinstance(obj, dict):
                    return obj
            
            except json.JSONDecodeError:
                    continue

        return None
    
    def augment(
        self,
        plotting_code: Dict[str, Any],
        caption: str,
        index: Optional[int] = None,
        num_versions: int = 2,
    ) -> Dict[str, Dict[str, str]]:
        """
        Generate visual QA pairs.
        
        Args:
            plotting_code: plotting code dict, including points, circles,
                           annotations, segment_chains, etc.
            caption: textual description of the figure (required).
            index: optional sample index (for logging).
            num_versions: number of QA versions to generate.
        
        Returns:
            A dictionary containing multiple QA versions in the form
            {"V1": {"question": "...", "answer": "..."}, "V2": {...}, ...}.
        """
        if not plotting_code:
            return {"V1": {"question": "", "answer": ""}}
        try:
            if index is not None:
                print(
                    f"[Sample {index}] Generating visual QA pairs "
                    f"({num_versions} versions)..."
                )
            
            prompt = self.format_visualize_prompt(plotting_code, caption, num_versions)
            response = self.call_llm(prompt)
            response_text = response.get("content", "").strip()
            usage = response.get("usage")
            
            # Aggregate and store token usage
            usage_dict = self._serialize_usage(usage)
            self.last_token_usage = usage_dict
            if usage_dict and index is not None:
                print(
                    f"[Sample {index}] Token usage: "
                    f"prompt={usage_dict['prompt_tokens']}, "
                    f"completion={usage_dict['completion_tokens']}, "
                    f"total={usage_dict['total_tokens']}"
                )
            elif usage_dict:
                print(
                    f"Token usage: "
                    f"prompt={usage_dict['prompt_tokens']}, "
                    f"completion={usage_dict['completion_tokens']}, "
                    f"total={usage_dict['total_tokens']}"
                )
            
            # Extract JSON from the raw response text
            result_json = self.extract_json_from_response(response_text)
            
            if result_json:
                # Extract all versions V1, V2, ...
                versions = {}
                for i in range(1, num_versions + 1):
                    key = f"V{i}"
                    if key in result_json:
                        value = result_json[key]
                        # Ensure the format is correct
                        if isinstance(value, dict) and "question" in value and "answer" in value:
                            versions[key] = {
                                "question": str(value["question"]).strip(),
                                "answer": str(value["answer"]).strip(),
                            }
                        elif isinstance(value, str):
                            # If a string is returned, try to parse it as JSON
                            try:
                                parsed = json.loads(value)
                                if isinstance(parsed, dict) and "question" in parsed and "answer" in parsed:
                                    versions[key] = {
                                        "question": str(parsed["question"]).strip(),
                                        "answer": str(parsed["answer"]).strip(),
                                    }
                            except:
                                pass
                
                if versions:
                    return versions
            
            # If JSON parsing failed, return an empty QA pair
            if index is not None:
                print(
                    f"[Sample {index}] Failed to generate visual QA pairs; "
                    "using an empty QA pair."
                )
            self.last_token_usage = None  # Clear token stats on failure
            return {"V1": {"question": "", "answer": ""}}
        except Exception as e:
            if index is not None:
                print(
                    f"[Sample {index}] Error while generating visual QA pairs: "
                    f"{e}; using an empty QA pair."
                )
            self.last_token_usage = None  # Clear token stats on failure
            return {"V1": {"question": "", "answer": ""}}

