#!/usr/bin/env python3
"""
VisualizeQA:

- Given a geometry problem, its chain-of-thought (CoT), and an image caption,
  rewrite the problem and CoT into a vision-QA friendly form.
- Optionally performs a **two-step text sanitization** using annotations:
  1) Pure text simplification.
  2) Use annotations to further remove conditions that are purely visual
     and already expressed in the diagram.
"""

import json
import os
import re
import sys
import time
from typing import Any, Dict, Optional

from openai import OpenAI

# Import the annotations translator (used to convert structured annotations
# into short natural-language descriptions).
from ..utils.annotation_translator import translate_annotations


class VisualizeQA:
    """
    Caption-aware question and CoT rewriter.

    - Optionally sanitizes the problem statement to remove conditions that are
      redundant given the diagram.
    - Rewrites the CoT so that it explicitly distinguishes between information
      coming from the text and information that is visually evident in the
      caption/image.
    """
    
    def __init__(
        self,
        api_key: Optional[str] = None,
        model: Optional[str] = None,
        base_url: Optional[str] = None,
        max_retries: int = 3,
    ):
        """
        Initialize VisualizeQA.

        Args:
            api_key: API key (if None, will be read from the environment).
            model: Model name (if None, a default is used).
            base_url: API base URL (if None, a default is used).
            max_retries: Maximum number of retries for API calls.
        """
        # Default configuration
        if model is None:
            model = os.getenv("OPENAI_MODEL") or "models/gpt-oss-120b/snapshots/8b193b0ef83bd41b40eb71fee8f1432315e02a3e/"
        if base_url is None:
            base_url = os.getenv("OPENAI_BASE_URL") or "http://localhost:8001/v1"
        
        self.api_key = api_key or os.getenv("OPENAI_API_KEY") or "sk-proj-1234567890"
        
        if not self.api_key:
            raise ValueError("API key must be provided via argument or environment variable.")
        
        self.model = model
        self.base_url = base_url
        self.max_retries = max_retries
        client_kwargs = {"api_key": self.api_key, "timeout": 600.0}
        if self.base_url:
            client_kwargs["base_url"] = self.base_url
        self.client = OpenAI(**client_kwargs)
    
    @staticmethod
    def _clean_actual_or_plotting_data(data: Dict[str, Any]) -> Dict[str, Any]:
        """
        Clean actual_data / plotting_code into a structural representation.

        - Keep topological structure (points, circles, annotation types, etc.).
        - Optionally keep points (coordinate information).
        - Drop segment list and annotation_summary.
        - Strip numeric radius information from circles so it cannot leak
          answer-related values into the LLM.
        """
        cleaned = json.loads(json.dumps(data))  # Deep copy to avoid mutating the original

        if "segments" in cleaned:
            del cleaned["segments"]
        if "annotation_summary" in cleaned:
            del cleaned["annotation_summary"]

        if "circles" in cleaned and isinstance(cleaned["circles"], list):
            new_circles = []
            for circle in cleaned["circles"]:
                if isinstance(circle, list) and len(circle) == 3:
                    if (
                        isinstance(circle[2], (int, float))
                        or (isinstance(circle[2], str) and circle[2].isdigit())
                    ):
                        new_circles.append(circle[:2])
                    else:
                        new_circles.append(circle)
                else:
                    new_circles.append(circle)
            cleaned["circles"] = new_circles

        return cleaned

    @staticmethod
    def _extract_annotations(
        plotting_code: Optional[Dict[str, Any]],
        actual_data: Optional[Dict[str, Any]],
    ) -> str:
        """
        Extract annotations from plotting_code or actual_data and translate
        them into a natural-language string.
        """
        annotations = None
        if actual_data and isinstance(actual_data, dict):
            annotations = actual_data.get("annotations")
        
        if annotations is None and plotting_code and isinstance(plotting_code, dict):
            annotations = plotting_code.get("annotations")
        
        if annotations is None:
            return ""
        
        try:
            translated = translate_annotations(annotations)
            return translated if translated else ""
        except Exception as e:
            # If translation fails, fall back to raw JSON text.
            print(f"[WARN] Failed to translate annotations, falling back to JSON: {e}", file=sys.stderr)
            try:
                return json.dumps(annotations, ensure_ascii=False)
            except TypeError:
                return str(annotations)

    def _build_prompt_step1(self, index: int, question: str) -> str:
        """
        Step 1: text-only simplification of the original problem statement
        (ignore annotations).
        """
        prompt = f"""
You are a professional mathematics editor. Your task is to rewrite a geometry
problem statement into a **concise, natural, and fluent** version in LaTeX.

Core assumption (must follow):
There is a perfect geometric diagram attached to the problem. All point
existence, positions, and topological structure (who intersects whom, who lies
on which segment, collinearity, etc.) are already perfectly visible in the
figure. Therefore, any textual description whose only purpose is to define
where a point lies is redundant and **must be removed**, even if that makes
the point look "undefined" from text alone.

1. Remove purely visual/topological definitions:
   - Intersection definitions such as "point G is the intersection of lines
     AB and CD", "they intersect at P" -> delete.
   - Positional descriptions such as "point D lies on segment BC",
     "A, B, C are collinear", "as shown in the figure", "in the plane" -> delete.
   - Construction commands such as "draw segment AD", "connect BD" -> delete.

2. Keep true geometric constraints / metrics:
   - Metrics: numeric lengths, angles, areas, ratios.
   - Relations: parallel (``\\parallel``), perpendicular (``\\perp``), equal (``=``).
   - High-level geometric notions that implicitly encode metrics or relations:
     centroid, incenter, orthocenter, midpoint (1:1), angle bisector, tangent,
     regular polygon, diameter, etc.

3. Rewrite into a natural-flow statement:
   - Use connective words such as "Given", "Suppose", "satisfies", "where",
     "and" to join the remaining conditions.
   - Keep sentences fluent; do not output a list of broken keywords.
   - **Preserve all LaTeX math formatting**.

Few-shot examples:

Example 1 (delete intersection definition)
Input:
Given $AB=3$, $AC=6$, $\\angle BAC=90^\\circ$, point $D$ is the midpoint of $AB$,
and point $G$ is the intersection of $CD$ and $AE$. Find the length of $GD$.
Output:
Given $AB=3$, $AC=6$, $\\angle BAC=90^\\circ$, and $D$ is the midpoint of $AB$.
Find the length of $GD$.

Example 2 (delete construction and collinearity)
Input:
In the plane, points $A$, $B$, $C$, $D$ satisfy $AB \\parallel CD$ and
$AB = CD = 4$. Segments $AC$ and $BD$ intersect at $O$. If $A$, $O$, $C$ are
collinear, find $\\frac{AO}{OC}$.
Output:
Given $AB \\parallel CD$ and $AB = CD = 4$. Find $\\frac{AO}{OC}$.

Example 3 (keep circles and tangents)
Input:
As shown in the figure, $PA$ and $PB$ are tangents to circle $O$ at $A$ and $B$
respectively. Line $PO$ meets circle $O$ again at $C$ and $D$. If $PA=4$, find
the length of $PB$.
Output:
Segments $PA$ and $PB$ are tangents to circle $O$, and $PA=4$. Find the length
of $PB$.

================= Problem statement to sanitize (original) =================
{question}

================= Output format =================
Return a single valid JSON object:
{{
    "question_sanitized": "Simplified problem statement (must keep LaTeX formatting)"
}}
"""
        return prompt

    def _build_prompt_step2(
        self,
        index: int,
        question_simplified: str,
        annotations: str,
    ) -> str:
        """
        Step 2: use annotations to remove conditions that are already
        expressed by visual labels in the diagram (right angles, lengths,
        angle measures).
        """
        prompt = f"""
You are a professional mathematics editor, performing a **second filtering
step** on a geometry problem. The output must use LaTeX.

Important:
- You are given the Step 1 simplified text, called Q1.
- You are also given the problem's **annotations** (visual labels on the
  figure) already translated into natural language.
- Annotations only include three types: right-angle labels (perpendicular
  relations), length labels, and angle labels.
- Your task is to **delete** from Q1 every condition that is explicitly covered
  by annotations, while preserving LaTeX formatting for what remains.

Core rules:
1. Deletion rule:
   - If Q1 explicitly states a condition that appears in annotations (right
     angle, specific length or angle measure), delete that condition from Q1.
2. Equivalence recognition (important):
   You must recognize equivalent formulations, even if the expressions differ.
   - Perpendicular vs. right angle:
     * ``AD \\perp CD`` is equivalent to ``\\angle ADC = 90^\\circ``,
       ``\\angle DAC = 90^\\circ``, or ``\\angle ACD = 90^\\circ`` depending on
       the vertex.
     * ``AB \\perp BC`` is equivalent to ``\\angle ABC = 90^\\circ``.
     * If annotations include "\\angle DAC = 90^\\circ", then ``AD \\perp CD``
       in Q1 should also be removed.
   - Equal lengths: ``AB = CD`` and ``CD = AB`` are equivalent.
   - Equal angles: if two forms clearly refer to the same geometric angle
     (e.g. due to symmetry), treat them as equivalent.
3. Retention rule:
   - If an annotation does not fully cover the semantics of a phrase in Q1
     (e.g. Q1 says "ABCD is a rectangle", but annotations only have a single
     right-angle label), keep the phrase in Q1.
4. Numeric retention:
   - If annotations only indicate label *type* (e.g. "right-angle at ∠BAC")
     but Q1 contains additional numeric or structural information (e.g.
     "AB=3"), keep Q1's text unless the explicit value is already in
     annotations.

Examples:

Example 1 (delete conditions fully covered by annotations)
Input Q1:
Given $AB=3$, $AC=6$, $\\angle BAC=90^\\circ$, and $D$ is the midpoint of $AB$.
Find the length of $GD$.
Annotations:
Right-angle: ∠BAC = 90°; length labels: AB = 3, AC = 6.
Output:
Given $D$ is the midpoint of $AB$. Find the length of $GD$.

Example 2 (equivalence: perpendicular vs. right angle)
Input Q1:
Given $AD \\perp CD$, and $AD=5$, $CD=12$. Find the length of $AC$.
Annotations:
Right-angle label: ∠DAC = 90°.
Output:
Given $AD=5$, $CD=12$. Find the length of $AC$.

Example 3 (equivalence: different angle notations)
Input Q1:
Given $\\angle ABC = 45^\\circ$ and $AB=5$. Find the length of $BC$.
Annotations:
Angle label: ∠CBA = 45°.
Output:
Given $AB=5$. Find the length of $BC$.

Example 4 (keep descriptors that cannot be fully removed by annotations)
Input Q1:
Given rectangle $ABCD$ with $AB=5$ and $BC=12$. Find the length of diagonal
$AC$.
Annotations:
Right-angle label: ∠ABC = 90°.
Output:
Given rectangle $ABCD$ with $AB=5$ and $BC=12$. Find the length of diagonal
$AC$.

================= Input: simplified statement Q1 (from Step 1) =================
{question_simplified}

================= Input: corresponding annotations (visual labels) =================
{annotations}

================= Output format =================
Return a single valid JSON object:
{{
    "question_sanitized": "Q1 with all conditions covered by annotations removed (must keep LaTeX formatting)"
}}
"""
        return prompt

    def _sanitize_question_with_two_steps(
        self,
        question: str,
        plotting_code: Optional[Dict[str, Any]] = None,
        actual_data: Optional[Dict[str, Any]] = None,
        index: Optional[int] = None,
    ) -> str:
        """
        Perform a two-step sanitization of the problem statement:
        1) Pure text simplification (ignoring annotations).
        2) Use annotations to further remove purely visual / topological conditions.
        Returns the final question text.

        If the LLM response cannot be parsed as JSON, an exception is raised
        and the outer retry mechanism will be triggered.
        """
        if not question:
            return ""

        # Step 1: text-only simplification based on the problem statement
        prompt_step1 = self._build_prompt_step1(index=index or 0, question=question)
        api_resp1 = self._call_llm_for_sanitization(prompt_step1)
        content1 = api_resp1.get("content", "") or ""

        parsed1 = self.extract_json_from_response(content1)
        if not parsed1:
            raise ValueError(
                f"index={index} STEP1 LLM response cannot be parsed as JSON. "
                f"Raw content: {content1[:200]}..."
            )

        q_step1 = parsed1.get("question_sanitized")
        if not q_step1:
            raise ValueError(
                f"index={index} STEP1 LLM JSON is missing 'question_sanitized' field. "
                f"Raw content: {content1[:200]}..."
            )

        # Debug output: result of step 1
        if index is not None:
            print(f"[STEP1 index={index}]")
            print("  original:", question)
            print("  step1   :", q_step1)

        # Step 2: if annotations are available, further remove purely visual conditions
        annotations = self._extract_annotations(plotting_code, actual_data)
        if annotations:
            prompt_step2 = self._build_prompt_step2(
                index=index or 0,
                question_simplified=q_step1,
                annotations=annotations,
            )
            api_resp2 = self._call_llm_for_sanitization(prompt_step2)
            content2 = api_resp2.get("content", "") or ""

            parsed2 = self.extract_json_from_response(content2)
            if not parsed2:
                raise ValueError(
                    f"index={index} STEP2 LLM response cannot be parsed as JSON. "
                    f"Raw content: {content2[:200]}..."
                )

            q_step2 = parsed2.get("question_sanitized")
            if not q_step2:
                raise ValueError(
                    f"index={index} STEP2 LLM JSON is missing 'question_sanitized' field. "
                    f"Raw content: {content2[:200]}..."
                )

            # Debug output: step 2 result + annotations
            if index is not None:
                print(f"[STEP2 index={index}]")
                print("  annotations:", annotations)
                print("  step1      :", q_step1)
                print("  step2      :", q_step2)
            return str(q_step2)

        # If there are no annotations, directly return the step 1 result
        return str(q_step1)

    def _call_llm_for_sanitization(self, prompt: str) -> Dict[str, Any]:
        """Low-level API call for question sanitization with simple retries."""
        last_err: Optional[Exception] = None
        for attempt in range(self.max_retries):
            try:
                start = time.time()
                # Try using reasoning_effort if the model supports it.
                try:
                    completion = self.client.chat.completions.create(
                        model=self.model,
                        messages=[{"role": "user", "content": prompt}],
                        temperature=0.1,
                        max_tokens=10240,
                        reasoning_effort="high",
                    )
                except Exception as e:
                    # If the model does not support reasoning_effort, retry without it.
                    if "reasoning_effort" in str(e).lower() or "unexpected keyword" in str(e).lower():
                        completion = self.client.chat.completions.create(
                            model=self.model,
                            messages=[{"role": "user", "content": prompt}],
                            temperature=0.1,
                            max_tokens=10240,
                        )
                    else:
                        raise
                elapsed = time.time() - start
                choice = completion.choices[0]
                message = choice.message
                content = message.content if message and message.content else ""
                return {
                    "content": content,
                    "finish_reason": getattr(choice, "finish_reason", None),
                    "elapsed_time": elapsed,
                    "usage": getattr(completion, "usage", None),
                }
            except Exception as e:
                last_err = e
                wait = 2 ** attempt
                print(
                    f"[WARN] LLM call failed, attempt {attempt + 1}/{self.max_retries}; "
                    f"retrying in {wait}s: {e}",
                    file=sys.stderr,
                )
                time.sleep(wait)
        raise RuntimeError(f"LLM call failed after {self.max_retries} retries: {last_err}")

    @staticmethod
    def format_prompt(question: str, cot: str, caption: str) -> str:
        """
        Format the prompt text to send to the LLM.

        Args:
            question: Original problem statement.
            cot: Original chain-of-thought solution.
            caption: Image caption / description.

        Returns:
            A formatted prompt string.
        """
        prompt = f"""
You are a **professional geometry VQA rewriting expert**.

Your task is:
Given the **caption (visually observable information)**, rewrite the original
**question** and **CoT** into a form suitable for visual question answering (VQA),
producing a new question and a reasoning chain.
====================================================
Input data
====================================================
Original problem (question):
{question}
Original reasoning process (cot):
{cot}
Image caption (caption, only describing what is visually present in the figure,
not additional problem conditions):
{caption}
====================================================
[Rewriting rules: QUESTION]
====================================================
Professionally rewrite the question so that it is suitable for VQA, following:
1. **Remove all information that is directly observable in the image.**
   - Any geometric structure, connectivity of points, or visible shapes that are
     already described by the caption must not be repeated in the text.
   - All visible numeric values, angles, lengths, and parallel/perpendicular
     markings that are evident in the diagram must be removed from the text.
2. **Keep implicit geometric conditions that are not directly visible in the figure.**
   For example:
   - Midpoints
   - Angle bisectors
   - Perpendicular / parallel relations (if not explicitly indicated in the figure)
   - Internal / external division of segments
   - Ratios, similarity relations (if they are not directly obvious from the figure)
3. **Keep the problem solvable.**
   The final computational task (e.g., find an angle, length, or area) must remain.
4. **Do not introduce any new geometric relations, point names, or numeric values**
   that were not present in the original text.
5. **Use a textbook-like mathematical style**: start from the basic configuration
   and then introduce implicit conditions step by step in natural language.
====================================================
[Rewriting rules: CoT] (vision-grounded reasoning)
====================================================
Based on the original CoT and the caption, rewrite the reasoning process:
1. **You may treat any "visually observable facts" in the caption as known
   conditions and explicitly state them with phrases such as "From the figure we
   can see that ..."** in order to simplify the reasoning.
2. The reasoning must use the standard format:
   `Step 1: ...`
   `Step 2: ...`
   Each step should perform exactly one logical action.
3. **Do not hallucinate visual information that is not present in the caption.**
   Any structure, label, or relation not mentioned in the caption must not be
   assumed as a visual fact.
4. Use only the following sources of information:
   - What is explicitly described as visible in the caption.
   - The implicit conditions that remain in the rewritten question.
   - General Euclidean geometry knowledge (e.g., sum of angles in a triangle,
     definitions of similarity, etc.).
5. **Remove redundant steps to keep the reasoning compact but correct.**
6. The final numerical or symbolic answer must remain consistent with the
   original solution.

====================================================
[Output format] (must be strictly followed)
====================================================

Return a standard JSON object:
```
{{
  "question": "...",
  "cot": "Step 1: ...\\nStep 2: ..."
}}
```
"""
        return prompt
    
    def call_llm(
        self,
        prompt: str,
        index: Optional[int] = None,
    ) -> Dict[str, Any]:
        """
        Call the LLM API in plain-text mode.

        Args:
            prompt: The textual prompt.
            index: Optional index for logging.

        Returns:
            A dictionary containing the API response.
        """
        for attempt in range(self.max_retries):
            try:
                if index is not None:
                    print(f"[sample {index}] Calling LLM API for VisualizeQA...")
                
                # Call API
                start_time = time.time()
                completion = self.client.chat.completions.create(
                    model=self.model,
                    messages=[
                        {
                            "role": "user",
                            "content": prompt,
                        },
                    ],
                    temperature=0.3,
                    max_tokens=32767 * 2,
                )
                elapsed_time = time.time() - start_time
                
                if not completion or not completion.choices:
                    raise Exception("API returned invalid response: completion.choices is empty")
                
                choice = completion.choices[0]
                if not choice or not choice.message:
                    raise Exception("API returned invalid response: message is empty")
                
                response = choice.message
                finish_reason = choice.finish_reason
                
                if finish_reason == "length":
                    print("Warning: response truncated due to insufficient max_tokens (current max_tokens=4000)")
                
                content = response.content
                if not content or not content.strip():
                    raise Exception(f"LLM returned empty content (finish_reason={finish_reason})")
                
                return {
                    "content": content,
                    "role": response.role,
                    "finish_reason": finish_reason,
                    "elapsed_time": elapsed_time,
                    "usage": completion.usage,
                }
            
            except Exception as e:
                error_msg = str(e)
                if attempt < self.max_retries - 1:
                    wait_time = 2 ** attempt
                    print(
                        f"Attempt {attempt + 1}/{self.max_retries} failed: "
                        f"{error_msg}, retrying in {wait_time} seconds..."
                    )
                    time.sleep(wait_time)
                    continue
                raise Exception(f"LLM API call failed after {self.max_retries} attempts: {error_msg}")
        
        raise Exception("LLM API call failed")
    
    @staticmethod
    def extract_json_from_response(response_text: str) -> Optional[Dict[str, Any]]:
        """Extract a JSON object from an LLM response (robust parser)."""
        if not response_text:
            return None

        stripped = response_text.strip()

        # 1) Try direct parsing
        try:
            return json.loads(stripped)
        except json.JSONDecodeError:
            pass

        # 2) Parse a ```json ... ``` code block, if present.
        #    First locate the JSON object start position inside the code block.
        code_block_match = re.search(r"```(?:json)?\s*(\{)", response_text, re.DOTALL)
        if code_block_match:
            start_pos = code_block_match.start(1)  # JSON object starting position
            json_text = response_text[start_pos:]
            
            # Look for code block end marker (if any)
            if "```" in json_text[1:]:  # search from second character to skip opening ```
                end_marker = json_text.find("```", 1)
                if end_marker > 0:
                    json_text = json_text[:end_marker].rstrip()
            
            # First attempt: direct JSON decode
            decoder = json.JSONDecoder()
            try:
                obj, end_pos = decoder.raw_decode(json_text)
                if isinstance(obj, dict):
                    return obj
            except json.JSONDecodeError:
                # If this fails, it might be due to unescaped control characters
                # or backslashes inside string values. We try to repair them.
                def fix_string_value(match):
                    key_part = match.group(1)  # "key":
                    value = match.group(2)     # raw value content
                    tail = match.group(3)      # trailing quote

                    fixed_value = value
                    # 1. Escape newlines, tabs, and carriage returns
                    fixed_value = re.sub(r'(?<!\\)\n', r'\\n', fixed_value)
                    fixed_value = re.sub(r'(?<!\\)\t', r'\\t', fixed_value)
                    fixed_value = re.sub(r'(?<!\\)\r', r'\\r', fixed_value)
                    # 2. Escape single backslashes before letters (for LaTeX-like sequences)
                    fixed_value = re.sub(r'(?<!\\)\\(?=[a-zA-Z])', r'\\\\', fixed_value)
                    # 3. Escape unescaped double quotes
                    fixed_value = re.sub(r'(?<!\\)"', r'\\"', fixed_value)

                    return key_part + fixed_value + tail

                try:
                    # Match and fix all string values (only for question and cot fields)
                    fixed_json = re.sub(
                        r'("(?:question|cot)"\s*:\s*")(.*?)(")',
                        fix_string_value,
                        json_text,
                        flags=re.DOTALL
                    )
                    # Attempt to decode again
                    obj, _ = decoder.raw_decode(fixed_json)
                    if isinstance(obj, dict):
                        return obj
                except (json.JSONDecodeError, Exception):
                    pass

        # 3) Fallback: scan for the first "{...}" substring that decodes as JSON
        decoder = json.JSONDecoder()
        for idx, ch in enumerate(response_text):
            if ch != "{":
                continue
            try:
                obj, _ = decoder.raw_decode(response_text[idx:])
                if isinstance(obj, dict):
                    return obj
            
            except json.JSONDecodeError:
                    continue

        return None
    
    def visualize_qa(
        self,
        question: str,
        cot: str,
        caption: str,
        index: Optional[int] = None,
        plotting_code: Optional[Dict[str, Any]] = None,
        actual_data: Optional[Dict[str, Any]] = None,
    ) -> Dict[str, Any]:
        """
        Rewrite the problem and CoT based on the caption (text-only interface).

        If plotting_code or actual_data is provided, use the two-step
        de-visualization process; otherwise, use the original single-step
        caption-based VQA rewriting process.

        Args:
            question: Original question text.
            cot: Original chain-of-thought solution.
            caption: Image caption.
            index: Optional index for logging.
            plotting_code: Optional plotting code (used to extract annotations).
            actual_data: Optional actual_data (preferred source for annotations).

        Returns:
            A result dictionary containing:
            - success: bool (whether the API call succeeded)
            - question: str (rewritten question)
            - cot: str (rewritten chain-of-thought)
            - usage: usage object from the API (if available)
            - error: str (error message if failed)
        """
        # If plotting_code / actual_data is provided, use the two-step pipeline
        if plotting_code or actual_data:
            return self._visualize_qa_with_two_steps(
                question=question,
                cot=cot,
                caption=caption,
                index=index,
                plotting_code=plotting_code,
                actual_data=actual_data,
            )
        else:
            # Use the original single-step processing
            return self._visualize_qa_legacy(
                question=question,
                cot=cot,
                caption=caption,
                index=index,
            )

    def _visualize_qa_with_two_steps(
        self,
        question: str,
        cot: str,
        caption: str,
        index: Optional[int] = None,
        plotting_code: Optional[Dict[str, Any]] = None,
        actual_data: Optional[Dict[str, Any]] = None,
    ) -> Dict[str, Any]:
        """
        Two-step pipeline:
        1) De-visualize and sanitize the problem statement.
        2) Rewrite the CoT.
        """
        usage_records = []
        
        # Use the same retry count as call_llm
        for attempt in range(self.max_retries):
            try:
                # Step 1: sanitize the question (two-step de-visualization)
                sanitized_question = self._sanitize_question_with_two_steps(
                    question=question,
                    plotting_code=plotting_code,
                    actual_data=actual_data,
                    index=index,
                )
                
                # We could collect usage for the sanitization calls here if needed.
                # Currently, only the final rewrite call's usage is tracked.
                
                # Step 2: CoT rewriting (using format_prompt but with sanitized_question)
                prompt = self.format_prompt(sanitized_question, cot, caption)
                if index is not None:
                    print(
                        f"[sample {index}] VisualizeQA two-step mode: "
                        f"question de-visualized, now rewriting CoT"
                    )
                
                # Call LLM API to rewrite CoT
                api_response = self.call_llm(prompt, index=index)
                usage_records.append(api_response.get("usage"))
                
                # Parse response
                content = api_response.get("content", "")
                usage = api_response.get("usage")
                
                # Extract JSON
                result = self.extract_json_from_response(content)
                if not result:
                    error_msg = f"Failed to extract JSON from response: {content[:200]}"
                    # If this is not the last attempt, retry the whole pipeline
                    if attempt < self.max_retries - 1:
                        wait_time = 2 ** attempt
                        print(
                            f"[sample {index}] Attempt {attempt + 1}/{self.max_retries} "
                            f"failed: {error_msg}, retrying in {wait_time} seconds..."
                        )
                        time.sleep(wait_time)
                        continue
                    raise Exception(error_msg)
                
                # Validate required fields
                visualized_cot = result.get("cot", "")
                
                if not visualized_cot:
                    error_msg = "Rewritten chain-of-thought is empty"
                    # If this is not the last attempt, retry the whole pipeline
                    if attempt < self.max_retries - 1:
                        wait_time = 2 ** attempt
                        print(
                            f"[sample {index}] Attempt {attempt + 1}/{self.max_retries} "
                            f"failed: {error_msg}, retrying in {wait_time} seconds..."
                        )
                        time.sleep(wait_time)
                        continue
                    raise Exception(error_msg)
                
                # Use sanitized_question as the final question (we ignore result["question"])
                return {
                    "success": True,
                    "question": sanitized_question,
                    "cot": visualized_cot,
                    "usage": usage,
                }
            
            except Exception as e:
                error_msg = str(e)
                print(f"VisualizeQA failed: {error_msg}")
                # If this is not the last attempt, retry the whole pipeline
                if attempt < self.max_retries - 1:
                    wait_time = 2 ** attempt
                    print(
                        f"[sample {index}] Attempt {attempt + 1}/{self.max_retries} "
                        f"failed, retrying in {wait_time} seconds..."
                    )
                    time.sleep(wait_time)
                    continue
                return {
                    "success": False,
                    "question": "",
                    "cot": "",
                    "usage": None,
                    "error": error_msg,
                }
        
        # Should not reach here, but keep a safe fallback
        return {
            "success": False,
            "question": "",
            "cot": "",
            "usage": None,
            "error": "Reached maximum retry count",
        }

    def _visualize_qa_legacy(
        self,
        question: str,
        cot: str,
        caption: str,
        index: Optional[int] = None,
    ) -> Dict[str, Any]:
        """
        Original single-step logic (caption-based VQA rewriting).
        """
        # Use the same retry count as call_llm
        for attempt in range(self.max_retries):
            try:
                # Format prompt
                prompt = self.format_prompt(question, cot, caption)
                if index is not None:
                    print(f"[sample {index}] VisualizeQA in single-step text mode")
                # Call LLM API
                api_response = self.call_llm(prompt, index=index)
                
                # Parse response
                content = api_response.get("content", "")
                usage = api_response.get("usage")
                
                # Extract JSON
                result = self.extract_json_from_response(content)
                if not result:
                    error_msg = f"Failed to extract JSON from response: {content[:200]}"
                    # If this is not the last attempt, retry the whole pipeline
                    if attempt < self.max_retries - 1:
                        wait_time = 2 ** attempt
                        print(
                            f"[sample {index}] Attempt {attempt + 1}/{self.max_retries} "
                            f"failed: {error_msg}, retrying in {wait_time} seconds..."
                        )
                        time.sleep(wait_time)
                        continue
                    raise Exception(error_msg)
                
                # Validate required fields
                visualized_question = result.get("question", "")
                visualized_cot = result.get("cot", "")
                
                if not visualized_question:
                    error_msg = "Rewritten question is empty"
                    # If this is not the last attempt, retry the whole pipeline
                    if attempt < self.max_retries - 1:
                        wait_time = 2 ** attempt
                        print(
                            f"[sample {index}] Attempt {attempt + 1}/{self.max_retries} "
                            f"failed: {error_msg}, retrying in {wait_time} seconds..."
                        )
                        time.sleep(wait_time)
                        continue
                    raise Exception(error_msg)
                if not visualized_cot:
                    error_msg = "Rewritten chain-of-thought is empty"
                    # If this is not the last attempt, retry the whole pipeline
                    if attempt < self.max_retries - 1:
                        wait_time = 2 ** attempt
                        print(
                            f"[sample {index}] Attempt {attempt + 1}/{self.max_retries} "
                            f"failed: {error_msg}, retrying in {wait_time} seconds..."
                        )
                        time.sleep(wait_time)
                        continue
                    raise Exception(error_msg)
                
                return {
                    "success": True,
                    "question": visualized_question,
                    "cot": visualized_cot,
                    "usage": usage,
                }
            
            except Exception as e:
                error_msg = str(e)
                print(f"VisualizeQA failed: {error_msg}")
                # If this is not the last attempt, retry the whole pipeline
                if attempt < self.max_retries - 1:
                    wait_time = 2 ** attempt
                    print(
                        f"[sample {index}] Attempt {attempt + 1}/{self.max_retries} "
                        f"failed, retrying in {wait_time} seconds..."
                    )
                    time.sleep(wait_time)
                    continue
                return {
                    "success": False,
                    "question": "",
                    "cot": "",
                    "usage": None,
                    "error": error_msg,
                }
        
        # Should not reach here, but keep a safe fallback
        return {
            "success": False,
            "question": "",
            "cot": "",
            "usage": None,
            "error": "Reached maximum retry count",
        }

