#!/usr/bin/env python3
"""
VLImageQuality: image quality checking and caption generation with a vision-language model.
"""

import base64
import json
import os
import re
import time
from pathlib import Path
from typing import Any, Dict, Optional

from openai import OpenAI


class VLImageQuality:
    """Use a vision-language model to check image quality and generate captions."""
    
    def __init__(
        self,
        api_key: Optional[str] = None,
        model: Optional[str] = None,
        base_url: Optional[str] = None,
        max_retries: int = 3,
    ):
        """
        Initialize the image quality checker.
        
        Args:
            api_key: API key (if None, will be read from the environment).
            model: Model name (if None, a default vision-capable model such as gpt-4o or gpt-4o-mini is used).
            base_url: API base URL (if None, a default is used).
            max_retries: Maximum number of retries for API calls.
        """
        # Default configuration
        if model is None:
            model = os.getenv("OPENAI_MODEL") or "gpt-4o-mini"
        if base_url is None:
            base_url = os.getenv("OPENAI_BASE_URL") or "https://api.openai.com/v1"
        
        self.api_key = api_key or os.getenv("OPENAI_API_KEY")
        
        if not self.api_key:
            raise ValueError("API key must be provided via argument or environment variable.")
        
        self.model = model
        self.base_url = base_url
        self.max_retries = max_retries
        client_kwargs = {"api_key": self.api_key, "timeout": 60.0}
        if self.base_url:
            client_kwargs["base_url"] = self.base_url
        self.client = OpenAI(**client_kwargs)
    
    @staticmethod
    def remove_points_from_plotting_code(plotting_code: Dict[str, Any]) -> Dict[str, Any]:
        """
        Remove point coordinates from plotting_code while keeping other information.
        Also strip numeric parameters from circle descriptions (such as radius values)
        to prevent leaking answer-related information into the quality model.
        
        Args:
            plotting_code: Original plotting code dictionary.
            
        Returns:
            A cleaned plotting_code dictionary without points and numeric circle data.
        """
        # Deep copy to avoid mutating the original data
        cleaned_data = json.loads(json.dumps(plotting_code))
        
        # Remove points field
        if "points" in cleaned_data:
            del cleaned_data["points"]

        if "segments" in cleaned_data:
            del cleaned_data["segments"]
        
        if "annotation_summary" in cleaned_data:
            del cleaned_data["annotation_summary"]
        
        # Remove numeric information from circle descriptions
        if "circles" in cleaned_data and isinstance(cleaned_data["circles"], list):
            cleaned_circles = []
            for circle in cleaned_data["circles"]:
                if isinstance(circle, list) and len(circle) == 3:
                    # If length is 3 and the third element is numeric (radius), drop it.
                    if isinstance(circle[2], (int, float)) or (isinstance(circle[2], str) and circle[2].isdigit()):
                        cleaned_circles.append(circle[:2])
                    else:
                        cleaned_circles.append(circle)
                else:
                    cleaned_circles.append(circle)
            cleaned_data["circles"] = cleaned_circles
        
        return cleaned_data
    
    @staticmethod
    def encode_image(image_path: Path) -> str:
        """
        Encode an image as a base64 string.
        
        Args:
            image_path: Path to the image file.
            
        Returns:
            Base64-encoded image string.
        """
        with open(image_path, "rb") as image_file:
            return base64.b64encode(image_file.read()).decode('utf-8')

    def format_quality_check_prompt(self) -> str:
        """
        Format the prompt for the image quality check (step 1).
        
        Returns:
            The formatted prompt string.
        """
        prompt = """
You are an expert in evaluating the visual quality of geometric diagrams.
Decide whether the given image is suitable to be used as an illustration for a geometry problem,
based **only on the visual content of the image**.

====================================================
Image quality evaluation criteria
====================================================

Judge strictly based on the visual content of the image:

- Are the lines clear, not broken, and not overly blurry?
- Are point labels, segments, or annotations heavily occluded?
- Are annotations (labels, marks, text) readable?
- Is the overall layout too crowded or visually chaotic?

Decision rules:
- If there is any issue that *seriously impairs understanding* of the diagram, output `"passed": false`.
- Otherwise output `"passed": true`.
- Provide a very short reason; do not perform mathematical reasoning.

====================================================
Output format (strict, JSON only)
====================================================

Return **only** a JSON object:

{
  "passed": true/false,
  "reason": "Short explanation"
}

"""
        return prompt
    
    def format_caption_generation_prompt(
        self, 
        plotting_code_without_points: Dict[str, Any]
    ) -> str:
        """
        Format the caption-generation prompt (step 2).
        
        Args:
            plotting_code_without_points: Plotting code with point coordinates removed.
            
        Returns:
            The formatted prompt string.
        """
        plotting_data_str = json.dumps(plotting_code_without_points, ensure_ascii=False)
        prompt = f"""
You are an expert at describing geometric diagrams.
Using only the **visually observable content** of the image (together with the structural
information in plotting_code), produce an **objective, complete English description** of the diagram,
without adding any extra reasoning or inferred properties.

====================================================
Plotting code (structure reference only)
====================================================
{plotting_data_str}

====================================================
Task requirements
====================================================

Write an English description that:
- Mentions only points, lines, circles, and annotations that are actually visible in the image.
- Does not rely on elements that may appear in plotting_code but are not clearly visible.
- Does **not** infer equalities, parallelism, perpendicularity, or other relations that are not explicitly drawn.
- Does **not** fill in geometric information that is not present in the picture.
- Does **not** use story-like or word-problem context; focus only on the diagram itself.
- Avoids hallucinations; stay strictly grounded in visible content.

====================================================
Output format (strict, JSON only)
====================================================

Return **only** a JSON object:

{{
  "caption": "Image description"
}}

"""
        return prompt
    
    def format_quality_and_caption_prompt(self, plotting_code_without_points: Dict[str, Any]) -> str:
        """
        Format a combined prompt for image quality checking and caption generation.
        
        Args:
            plotting_code_without_points: Plotting code with point coordinates removed.
            
        Returns:
            The formatted prompt string.
        """
        plotting_data_str = json.dumps(plotting_code_without_points, ensure_ascii=False)
        prompt = f"""
You are an expert in geometric diagram quality assessment and diagram description.
Follow the instructions below to complete **both tasks** and return a single JSON object.

Plotting code (for structural reference only; do **not** describe elements that are not visible):
{plotting_data_str}

====================================================
Task 1: Image quality check
====================================================

Decide, based **only on the visual content** of the image, whether the diagram passes quality control.

Evaluate:
- Clarity of lines (no heavy blur or broken lines).
- Whether point labels, segments, or annotations are heavily occluded.
- Readability of annotations.
- Whether the overall layout is too crowded or visually confusing.

Decision rules:
- If there is any issue that seriously impairs understanding of the diagram, set `"passed": false`.
- Otherwise set `"passed": true`.

Do not perform mathematical reasoning; this is purely a visual quality decision.

====================================================
Task 2: Diagram caption (English)
====================================================

Generate an **objective, complete English description** of what is visible in the diagram.

Requirements:
- Describe only points, lines, circles, and annotations that are clearly visible.
- Do not rely on non-visible elements that might be present in plotting_code.
- Do not infer equalities, parallelism, perpendicularity, or other geometric relations
  unless they are explicitly and clearly drawn.
- Do not invent extra geometric objects or story context.

====================================================
Output format (strict, JSON only)
====================================================

Return **only** a JSON object:

{{
  "passed": true/false,
  "caption": "Image description"
}}

"""
        #print(prompt)
        return prompt
    
    def call_vision_llm(
        self,
        image_path: Path,
        prompt: str,
        index: Optional[int] = None,
    ) -> Dict[str, Any]:
        """
        Call the vision-language model API.
        
        Args:
            image_path: Path to the image file.
            prompt: Text prompt describing the task.
            index: Optional sample index used for logging.
            
        Returns:
            API response payload as a dict.
        """
        for attempt in range(self.max_retries):
            try:
                if index is not None:
                    print(f"[sample {index}] Calling vision API...")
                
                # Encode image
                base64_image = self.encode_image(image_path)
                
                # Call API
                start_time = time.time()
                completion = self.client.chat.completions.create(
                    model=self.model,
                    messages=[
                        {
                            "role": "user",
                            "content": [
                                {
                                    "type": "text",
                                    "text": prompt,
                                },
                                {
                                    "type": "image_url",
                                    "image_url": {
                                        "url": f"data:image/png;base64,{base64_image}",
                                    },
                                },
                            ],
                        },
                    ],
                    temperature=0.3,
                    max_tokens=4096,  # larger limit because we may return both quality and caption
                )
                elapsed_time = time.time() - start_time
                
                if not completion or not completion.choices:
                    raise Exception("Vision API returned an invalid response: completion.choices is empty.")
                
                choice = completion.choices[0]
                if not choice or not choice.message:
                    raise Exception("Vision API returned an invalid response: message is empty.")
                
                response = choice.message
                finish_reason = choice.finish_reason
                
                if finish_reason == "length":
                    print("Warning: vision response was truncated (max_tokens may be too small, currently 4000).")
                
                content = response.content
                if not content or not content.strip():
                    raise Exception(f"Vision LLM returned empty content (finish_reason={finish_reason}).")
                
                return {
                    "content": content,
                    "role": response.role,
                    "finish_reason": finish_reason,
                    "elapsed_time": elapsed_time,
                    "usage": completion.usage,
                }
            
            except Exception as e:
                error_msg = str(e)
                if attempt < self.max_retries - 1:
                    wait_time = 2 ** attempt
                    print(
                        f"Attempt {attempt + 1}/{self.max_retries} failed: {error_msg}. "
                        f"Retrying in {wait_time} seconds..."
                    )
                    time.sleep(wait_time)
                    continue
                raise Exception(f"Vision API call failed after {self.max_retries} retries: {error_msg}")
        
        raise Exception("Vision API call failed.")
    
    @staticmethod
    def extract_json_from_response(response_text: str) -> Optional[Dict[str, Any]]:
        """
        Extract a JSON object from a free-form LLM text response
        (using the same robust parsing logic as in LLMJudge).
        
        Args:
            response_text: Raw LLM response text.
            
        Returns:
            Parsed JSON dict, or None if extraction fails.
        """
        if not response_text:
            return None

        stripped = response_text.strip()

        # Detect obviously truncated responses (e.g. only ```json with no content).
        if stripped in ["```json", "```", "```json\n", "```\n"] or len(stripped) < 10:
            return None

        # 1) Direct parse
        try:
            return json.loads(stripped)
        except json.JSONDecodeError:
            pass

        # 2) Parse JSON inside a ```json ... ``` code block
        #    First, find the start of the JSON object in the block.
        code_block_match = re.search(r"```(?:json)?\s*(\{)", response_text, re.DOTALL)
        if code_block_match:
            start_pos = code_block_match.start(1)  # start of JSON object
            json_text = response_text[start_pos:]
            
            # Look for the closing fence if present
            if "```" in json_text[1:]:  # skip the first character to avoid the opening fence
                end_marker = json_text.find("```", 1)
                if end_marker > 0:
                    json_text = json_text[:end_marker].rstrip()
            
            # Try to decode as JSON
            decoder = json.JSONDecoder()
            try:
                obj, end_pos = decoder.raw_decode(json_text)
                if isinstance(obj, dict):
                    return obj
            except json.JSONDecodeError:
                # If it fails, string values may contain unescaped control chars or backslashes.
                # We repair them by escaping newlines, tabs, carriage returns, backslashes and quotes.
                def fix_string_value(match):
                    key_part = match.group(1)  # "key":
                    value = match.group(2)     # original value
                    tail = match.group(3)      # trailing quote

                    fixed_value = value
                    # 1. Escape newlines, tabs, carriage returns.
                    fixed_value = re.sub(r'(?<!\\)\n', r'\\n', fixed_value)
                    fixed_value = re.sub(r'(?<!\\)\t', r'\\t', fixed_value)
                    fixed_value = re.sub(r'(?<!\\)\r', r'\\r', fixed_value)
                    # 2. Backslash + letter → double backslash (for LaTeX-style content).
                    fixed_value = re.sub(r'(?<!\\)\\(?=[a-zA-Z])', r'\\\\', fixed_value)
                    # 3. Unescaped double quotes → \".
                    fixed_value = re.sub(r'(?<!\\)"', r'\\"', fixed_value)

                    return key_part + fixed_value + tail

                try:
                    # Repair all string values for passed / reason / caption / filtered_question.
                    fixed_json = re.sub(
                        r'("(?:passed|reason|caption|filtered_question)"\s*:\s*")(.*?)(")',
                        fix_string_value,
                        json_text,
                        flags=re.DOTALL
                    )
                    # Try parsing again
                    obj, _ = decoder.raw_decode(fixed_json)
                    if isinstance(obj, dict):
                        return obj
                except (json.JSONDecodeError, Exception):
                    pass

        # 3) Fallback: scan for the first substring that looks like a JSON object.
        decoder = json.JSONDecoder()
        for idx, ch in enumerate(response_text):
            if ch != "{":
                continue
            try:
                obj, _ = decoder.raw_decode(response_text[idx:])
                if isinstance(obj, dict):
                    return obj
            
            except json.JSONDecodeError:
                continue

        return None
    
    def check_quality_and_generate_caption(
        self,
        image_path: Path,
        plotting_code: Dict[str, Any],
        index: Optional[int] = None,
    ) -> Dict[str, Any]:
        """
        In a single API call, check image quality and generate a caption.
        
        Args:
            image_path: Path to the image file.
            plotting_code: Plotting code (points will be removed automatically).
            index: Optional index for logging.
            
        Returns:
            A combined result dict containing:
            - success: bool (whether the API call succeeded)
            - passed: bool (whether the image passed quality check)
            - reason: str (evaluation reason, if provided)
            - caption: str (generated caption, if provided)
            - usage: usage object
            - error: error message if failed
        """
        # Use the same retry count as call_vision_llm
        for attempt in range(self.max_retries):
            try:
                # Remove point coordinates
                plotting_code_cleaned = self.remove_points_from_plotting_code(plotting_code)
                
                # Format combined prompt
                prompt = self.format_quality_and_caption_prompt(plotting_code_cleaned)
                # Call Vision API
                api_response = self.call_vision_llm(image_path, prompt, index=index)
                
                # Parse response
                content = api_response.get("content", "")
                usage = api_response.get("usage")
                
                # Extract JSON
                result = self.extract_json_from_response(content)
                if not result:
                    error_msg = f"Failed to extract JSON from response: {content[:200]}"
                    # If this is not the last attempt, retry the whole flow.
                    if attempt < self.max_retries - 1:
                        wait_time = 2 ** attempt
                        print(
                            f"[sample {index}] Attempt {attempt + 1}/{self.max_retries} failed: "
                            f"{error_msg} Retrying full flow in {wait_time} seconds..."
                        )
                        time.sleep(wait_time)
                        continue
                    raise Exception(error_msg)
                
                # Validate required fields
                passed = bool(result.get("passed", False))
                reason = result.get("reason", "")
                caption = result.get("caption", "")
                
                return {
                    "success": True,
                    "passed": passed,
                    "reason": reason,
                    "caption": caption,
                    "usage": usage,
                }
            
            except Exception as e:
                error_msg = str(e)
                print(f"Image quality check and caption generation failed: {error_msg}")
                # If this is not the last attempt, retry the whole flow.
                if attempt < self.max_retries - 1:
                    wait_time = 2 ** attempt
                    print(
                        f"[sample {index}] Attempt {attempt + 1}/{self.max_retries} failed; "
                        f"retrying full flow in {wait_time} seconds..."
                    )
                    time.sleep(wait_time)
                    continue
                return {
                    "success": False,
                    "passed": False,
                    "reason": "",
                    "caption": "",
                    "usage": None,
                    "error": error_msg,
                }
        
        # Should not normally reach here, but keep for safety.
        return {
            "success": False,
            "passed": False,
            "reason": "",
            "caption": "",
            "usage": None,
            "error": "Reached maximum number of retries.",
        }
    
    def check_image_quality(
        self,
        image_path: Path,
        index: Optional[int] = None,
    ) -> Dict[str, Any]:
        """
        Step 1: check image quality and decide if it is suitable as a diagram
        for a geometry problem.
        
        Args:
            image_path: Path to the image file.
            index: Optional index for logging.
            
        Returns:
            Result dict containing:
            - success: bool (whether the API call succeeded)
            - passed: bool (whether the image passed quality check)
            - usage: usage object
            - error: str (error message if failed)
        """
        for attempt in range(self.max_retries):
            try:
                # Format quality-check prompt
                prompt = self.format_quality_check_prompt()
                
                # Call Vision API
                if index is not None:
                    print(f"[sample {index}] Step 1: checking image quality...")
                api_response = self.call_vision_llm(image_path, prompt, index=index)
                
                # Parse response
                content = api_response.get("content", "")
                usage = api_response.get("usage")
                
                # Extract JSON
                result = self.extract_json_from_response(content)
                if not result:
                    error_msg = f"Failed to extract JSON from response: {content[:200]}"
                    if attempt < self.max_retries - 1:
                        wait_time = 2 ** attempt
                        print(
                            f"[sample {index}] Attempt {attempt + 1}/{self.max_retries} failed: "
                            f"{error_msg} Retrying in {wait_time} seconds..."
                        )
                        time.sleep(wait_time)
                        continue
                    raise Exception(error_msg)
                
                # Validate required fields
                passed = bool(result.get("passed", False))
                reason = result.get("reason", "")
                
                return {
                    "success": True,
                    "passed": passed,
                    "reason": reason,
                    "usage": usage,
                }
            
            except Exception as e:
                error_msg = str(e)
                print(f"Image quality check failed: {error_msg}")
                if attempt < self.max_retries - 1:
                    wait_time = 2 ** attempt
                    print(
                        f"[sample {index}] Attempt {attempt + 1}/{self.max_retries} failed; "
                        f"retrying in {wait_time} seconds..."
                    )
                    time.sleep(wait_time)
                    continue
                return {
                    "success": False,
                    "passed": False,
                    "usage": None,
                    "error": error_msg,
                }
        
        return {
            "success": False,
            "passed": False,
            "usage": None,
            "error": "Reached maximum number of retries.",
        }
    
    def generate_caption(
        self,
        image_path: Path,
        plotting_code: Dict[str, Any],
        index: Optional[int] = None,
    ) -> Dict[str, Any]:
        """
        Step 2: based on the image and plotting_code, generate a caption
        while avoiding hallucinations.
        
        Args:
            image_path: Path to the image file.
            plotting_code: Plotting code (points will be removed automatically).
            index: Optional index for logging.
            
        Returns:
            Result dict containing:
            - success: bool (whether the API call succeeded)
            - caption: str (generated caption)
            - usage: usage object
            - error: str (error message if failed)
        """
        for attempt in range(self.max_retries):
            try:
                # Remove point coordinates
                plotting_code_cleaned = self.remove_points_from_plotting_code(plotting_code)
                
                # Format caption-generation prompt
                prompt = self.format_caption_generation_prompt(plotting_code_cleaned)
                
                # Call Vision API
                if index is not None:
                    print(f"[sample {index}] Step 2: generating caption...")
                api_response = self.call_vision_llm(image_path, prompt, index=index)
                
                # Parse response
                content = api_response.get("content", "")
                usage = api_response.get("usage")
                
                # Extract JSON
                result = self.extract_json_from_response(content)
                if not result:
                    error_msg = f"Failed to extract JSON from response: {content[:200]}"
                    if attempt < self.max_retries - 1:
                        wait_time = 2 ** attempt
                        print(
                            f"[sample {index}] Attempt {attempt + 1}/{self.max_retries} failed: "
                            f"{error_msg} Retrying in {wait_time} seconds..."
                        )
                        time.sleep(wait_time)
                        continue
                    raise Exception(error_msg)
                
                # Validate required fields
                caption = result.get("caption", "")
                
                if not caption:
                    error_msg = "Generated caption is empty."
                    if attempt < self.max_retries - 1:
                        wait_time = 2 ** attempt
                        print(
                            f"[sample {index}] Attempt {attempt + 1}/{self.max_retries} failed: "
                            f"{error_msg} Retrying in {wait_time} seconds..."
                        )
                        time.sleep(wait_time)
                        continue
                    raise Exception(error_msg)
                
                return {
                    "success": True,
                    "caption": caption,
                    "usage": usage,
                }
            
            except Exception as e:
                error_msg = str(e)
                print(f"Caption generation failed: {error_msg}")
                if attempt < self.max_retries - 1:
                    wait_time = 2 ** attempt
                    print(
                        f"[sample {index}] Attempt {attempt + 1}/{self.max_retries} failed; "
                        f"retrying in {wait_time} seconds..."
                    )
                    time.sleep(wait_time)
                    continue
                return {
                    "success": False,
                    "caption": "",
                    "usage": None,
                    "error": error_msg,
                }
        
        return {
            "success": False,
            "caption": "",
            "usage": None,
            "error": "Reached maximum number of retries.",
        }

