#!/usr/bin/env python3
"""
LLMJudge: use an LLM as a judge to evaluate the quality of generated problems.

Primarily used to filter and clean generated problems.
"""

import json
import os
import re
import time
from typing import Any, Dict, Optional

from openai import OpenAI


class LLMJudge:
    """Use an LLM as a judge to evaluate generated problem quality."""
    
    def __init__(
        self,
        api_key: Optional[str] = None,
        model: Optional[str] = None,
        base_url: Optional[str] = None,
        max_retries: int = 3,
    ):
        """
        Initialize the judge.

        Args:
            api_key: API key (if None, will be read from the environment).
            model: Model name (if None, a default is used).
            base_url: API base URL (if None, a default is used).
            max_retries: Maximum number of retries.
        """
        # Default configuration
        if model is None:
            model = os.getenv("OPENAI_MODEL") or "gpt-4o-mini"
        if base_url is None:
            base_url = os.getenv("OPENAI_BASE_URL") or "https://api.openai.com/v1"
        
        self.api_key = api_key or os.getenv("OPENAI_API_KEY")

        if not self.api_key:
            raise ValueError("API key must be provided via argument or environment variable.")
        
        self.model = model
        self.base_url = base_url
        self.max_retries = max_retries
        client_kwargs = {"api_key": self.api_key, "timeout": 600.0}
        if self.base_url:
            client_kwargs["base_url"] = self.base_url
        self.client = OpenAI(**client_kwargs)
    
    @staticmethod
    def format_judge_prompt(question: str, cot: str = "", answer: str = "") -> str:
        """Format the judging prompt for the LLM."""
        prompt = f"""
You are an expert in verifying the mathematical correctness of geometry problems.

Your task: decide whether the problem is **mathematically self-consistent**, has **sufficient conditions**,
and is **solvable**, and whether the provided reasoning and answer are consistent with the problem.

Input:

Problem: {question}

Reasoning (cot): {cot}
Answer: {answer}

[Evaluation criteria]

Judge only based on:
1. **Correctness**: Is the problem description self-consistent, with no contradictions?
2. **Solvability**: Are the conditions sufficient to deduce the answer? Are the CoT and answer
   consistent with the problem statement?

[Output]

Return a strict JSON object only, with no extra text:

{{
  "passed": true/false,
  "reason": "1–2 sentences explaining the core reason"
}}
"""
        return prompt
    
    def call_llm(self, prompt: str) -> Dict[str, Any]:
        """Call the LLM chat completion API."""
        for attempt in range(self.max_retries):
            try:
                completion = self.client.chat.completions.create(
                    model=self.model,
                    messages=[
                        {
                            "role": "user",
                            "content": prompt,
                        },
                    ],
                    temperature=0.3,
                    max_tokens=32767 * 2,
                )
                choice = completion.choices[0]
                response = choice.message
                finish_reason = getattr(choice, "finish_reason", None)
                
                # Get response content
                content = response.content if response and response.content else ""
                
                # If finish_reason=length but content is empty, max_tokens may be too small.
                if finish_reason == "length" and not content:
                    print(
                        "Warning: finish_reason=length but response content is empty; "
                        "max_tokens may be too small."
                    )
                
                return {
                    "content": content,
                    "usage": completion.usage,
                    "finish_reason": finish_reason,
                }
                
            except Exception as e:
                if attempt < self.max_retries - 1:
                    wait_time = 2 ** attempt
                    print(f"Attempt {attempt + 1} failed, retrying in {wait_time} seconds... ({e})")
                    time.sleep(wait_time)
                    continue
                raise Exception(f"API call failed after {self.max_retries} retries: {e}")
        
        raise Exception("Reached maximum number of retries.")
    
    @staticmethod
    def extract_json_from_response(response_text: str) -> Optional[Dict[str, Any]]:
        """Extract a JSON object from a free-form LLM text response."""
        if not response_text:
            return None
        
        # Detect obviously truncated responses (e.g. only ```json with no content).
        stripped = response_text.strip()
        if stripped in ["```json", "```", "```json\n", "```\n"] or len(stripped) < 10:
            return None
        
        # 1) Try direct JSON parse
        try:
            return json.loads(stripped)
        except json.JSONDecodeError:
            pass
        
        # 2) Try to parse JSON inside a ```json ... ``` code block.
        #    First, find the start of the object in the block.
        code_block_match = re.search(r"```(?:json)?\s*(\{)", response_text, re.DOTALL)
        if code_block_match:
            start_pos = code_block_match.start(1)
            json_text = response_text[start_pos:]
            
            # Check for a closing fence
            if "```" in json_text:
                # Find the position of the closing fence
                end_marker = json_text.find("```")
                json_text = json_text[:end_marker].rstrip()
            
            decoder = json.JSONDecoder()
            try:
                obj, end_pos = decoder.raw_decode(json_text)
                if isinstance(obj, dict):
                    return obj
            except json.JSONDecodeError:
                # If this fails, strings may contain unescaped backslashes.
                # Fix by escaping backslashes that precede letters.
                def fix_string_value(match):
                    key_part = match.group(1)
                    value = match.group(2)
                    # Escape unescaped backslashes that precede letters.
                    fixed_value = re.sub(r'(?<!\\)\\(?=[a-zA-Z])', r'\\\\', value)
                    return key_part + fixed_value + match.group(3)
                
                try:
                    fixed_json = re.sub(
                        r'("(?:passed|reason|score)"\s*:\s*")(.*?)(")',
                        fix_string_value,
                        json_text,
                        flags=re.DOTALL
                    )
                    obj, _ = decoder.raw_decode(fixed_json)
                    if isinstance(obj, dict):
                        return obj
                except (json.JSONDecodeError, Exception):
                    pass
        
        # 3) Fallback: attempt to parse the first { ... } block in the text.
        json_match = re.search(r'\{.*\}', response_text, re.DOTALL)
        if json_match:
            try:
                return json.loads(json_match.group(0))
            except json.JSONDecodeError:
                pass
        
        return None
    
    @staticmethod
    def _serialize_usage(usage_obj) -> Optional[Dict[str, int]]:
        """
        Convert a usage object into a JSON-serializable dict.
        
        Args:
            usage_obj: Usage object (possibly a CompletionUsage instance or a dict).
            
        Returns:
            A dict with token counts, or None if conversion is not possible.
        """
        if usage_obj is None:
            return None
        
        # If it's already a dict, normalize and return.
        if isinstance(usage_obj, dict):
            return {
                "prompt_tokens": usage_obj.get("prompt_tokens", 0) or 0,
                "completion_tokens": usage_obj.get("completion_tokens", 0) or 0,
                "total_tokens": usage_obj.get("total_tokens", 0) or 0,
            }
        
        # Otherwise, try to extract attributes from an object-like usage.
        try:
            prompt_tokens = getattr(usage_obj, 'prompt_tokens', None)
            completion_tokens = getattr(usage_obj, 'completion_tokens', None)
            total_tokens = getattr(usage_obj, 'total_tokens', None)
            
            # If attributes exist, convert to dict.
            if prompt_tokens is not None or completion_tokens is not None or total_tokens is not None:
                return {
                    "prompt_tokens": int(prompt_tokens) if prompt_tokens is not None else 0,
                    "completion_tokens": int(completion_tokens) if completion_tokens is not None else 0,
                    "total_tokens": int(total_tokens) if total_tokens is not None else 0,
                }
        except Exception:
            pass
        
        return None
    
    def judge(
        self,
        question: str,
        cot: str = "",
        answer: str = "",
        index: Optional[int] = None,
    ) -> Dict[str, Any]:
        """
        Judge a single problem.
        
        Args:
            question: Geometry problem text.
            cot: Chain-of-thought / reasoning text (optional).
            answer: Final answer (optional).
            index: Optional index for logging.
        
        Returns:
            A result dict with:
            - success: bool (whether API call succeeded)
            - passed: bool (whether the problem passed the judge)
            - reason: str (judge's explanation)
            - score: int (0–100)
            - llm_response: Raw LLM response
            - error: Error message (if failed)
        """
        # Use the same retry count as call_llm
        for attempt in range(self.max_retries):
            try:
                # Format the judging prompt
                prompt = self.format_judge_prompt(question, cot, answer)
                
                if index is not None:
                    print(f"[sample {index}] Calling LLM to judge problem...")
                
                # Call LLM
                import time as time_module
                call_start_time = time_module.time()
                api_response = self.call_llm(prompt)
                call_elapsed_time = time_module.time() - call_start_time
                
                llm_response = api_response.get("content", "")
                usage = api_response.get("usage")
                finish_reason = api_response.get("finish_reason")
                
                # Record call details
                call_details = {
                    "timestamp": time_module.time(),
                    "elapsed_time": call_elapsed_time,
                    "response_length": len(llm_response) if llm_response else 0,
                    "finish_reason": finish_reason,
                    "usage": self._serialize_usage(usage) if usage else None,
                }
                
                # Detect truncation issues.
                # If finish_reason=length but response is empty, max_tokens is probably too small.
                if finish_reason == "length":
                    if not llm_response or len(llm_response.strip()) < 10:
                        error_msg = (
                            "Response was truncated and content is empty "
                            f"(finish_reason={finish_reason}, length={len(llm_response) if llm_response else 0}); "
                            "max_tokens may be too small or the prompt too long."
                        )
                    else:
                        error_msg = (
                            "Response was truncated "
                            f"(finish_reason={finish_reason}, length={len(llm_response)}); "
                            "consider increasing max_tokens."
                        )
                    if index is not None:
                        print(f"[sample {index}] {error_msg}")
                        if usage:
                            usage_dict = self._serialize_usage(usage)
                            if usage_dict:
                                print(
                                    "   Token usage: "
                                    f"prompt={usage_dict.get('prompt_tokens', 0)}, "
                                    f"completion={usage_dict.get('completion_tokens', 0)}, "
                                    f"total={usage_dict.get('total_tokens', 0)}"
                                )
                                print(
                                    "   Hint: if prompt_tokens approaches or exceeds max_tokens, "
                                    "the response will be truncated."
                                )
                    # If this is not the last attempt, retry the whole flow.
                    if attempt < self.max_retries - 1:
                        wait_time = 2 ** attempt
                        print(
                            f"[sample {index}] Attempt {attempt + 1}/{self.max_retries} failed, "
                            f"retrying in {wait_time} seconds..."
                        )
                        time.sleep(wait_time)
                        continue
                    return {
                        "success": False,
                        "error": error_msg,
                        "llm_response": llm_response,
                        "usage": usage,
                        "finish_reason": finish_reason,
                        "call_details": call_details,
                    }
                
                # Check for suspiciously short responses.
                if llm_response and len(llm_response.strip()) < 10:
                    error_msg = f"Response is too short (length={len(llm_response)}), likely invalid."
                    if index is not None:
                        print(f"[sample {index}] {error_msg}")
                    # If this is not the last attempt, retry the whole flow.
                    if attempt < self.max_retries - 1:
                        wait_time = 2 ** attempt
                        print(
                            f"[sample {index}] Attempt {attempt + 1}/{self.max_retries} failed, "
                            f"retrying in {wait_time} seconds..."
                        )
                        time.sleep(wait_time)
                        continue
                    return {
                        "success": False,
                        "error": error_msg,
                        "llm_response": llm_response,
                        "usage": usage,
                        "finish_reason": finish_reason,
                        "call_details": call_details,
                    }
                
                # Extract JSON
                result_json = self.extract_json_from_response(llm_response)
                
                if result_json:
                    # Validate required fields
                    if "passed" in result_json:
                        passed = bool(result_json.get("passed", False))
                        reason = result_json.get("reason", "")
                        score = result_json.get("score", 0)
                        
                        if index is not None:
                            status = "PASSED" if passed else "FAILED"
                            print(f"{status} [sample {index}] score: {score}, reason: {reason[:50]}...")
                        
                        return {
                            "success": True,
                            "passed": passed,
                            "reason": reason,
                            "score": score,
                            "llm_response": llm_response,
                            "usage": usage,
                            "finish_reason": finish_reason,
                            "call_details": call_details,
                        }
                    else:
                        # JSON missing 'passed' field
                        if attempt < self.max_retries - 1:
                            wait_time = 2 ** attempt
                            print(
                                f"[sample {index}] Attempt {attempt + 1}/{self.max_retries} failed: "
                                "JSON from judge is missing 'passed' field, "
                                f"retrying in {wait_time} seconds..."
                            )
                            time.sleep(wait_time)
                            continue
                        return {
                            "success": False,
                            "error": "JSON from judge is missing 'passed' field",
                            "llm_response": llm_response,
                            "usage": usage,
                            "finish_reason": finish_reason,
                            "call_details": call_details,
                        }
                else:
                    error_msg = "Failed to extract JSON from judge response."
                    if llm_response:
                        error_msg += f" (length={len(llm_response)}, first 100 chars: {llm_response[:100]})"
                    if index is not None:
                        print(f"[sample {index}] {error_msg}")
                        print(f"   Raw response preview: {llm_response[:500]}...")
                    # If this is not the last attempt, retry the whole flow
                    if attempt < self.max_retries - 1:
                        wait_time = 2 ** attempt
                        print(
                            f"[sample {index}] Attempt {attempt + 1}/{self.max_retries} failed, "
                            f"retrying in {wait_time} seconds..."
                        )
                        time.sleep(wait_time)
                        continue
                    return {
                        "success": False,
                        "error": error_msg,
                        "llm_response": llm_response,
                        "usage": usage,
                        "finish_reason": finish_reason,
                        "call_details": call_details,
                    }
            
            except Exception as e:
                if index is not None:
                    print(f"[sample {index}] Judging failed: {e}")
                # If this is not the last attempt, retry the whole flow
                if attempt < self.max_retries - 1:
                    wait_time = 2 ** attempt
                    print(
                        f"[sample {index}] Attempt {attempt + 1}/{self.max_retries} failed, "
                        f"retrying in {wait_time} seconds..."
                    )
                    time.sleep(wait_time)
                    continue
                return {
                    "success": False,
                    "error": str(e),
                    "llm_response": "",
                    "usage": None,
                    "finish_reason": None,
                    "call_details": {},
                }
    
    @staticmethod
    def format_judge_plotting_prompt(question: str, plotting_code: Dict[str, Any]) -> str:
        """Format the prompt for judging plotting code (a structured representation)."""
        dsl = plotting_code.get("dsl", "")
        segments = plotting_code.get("segments", [])
        circles = plotting_code.get("circles", [])
        
        prompt = f"""You are an expert judge of geometric plotting code quality.

Your task is to evaluate the plotting code for:
- DSL correctness
- completeness
- consistency
- whether it is drawable
- overall geometric reasonableness.

Problem:
{question}

Plotting code:
{{
  "dsl": "{dsl}",
  "segments": {json.dumps(segments, ensure_ascii=False)},
  "circles": {json.dumps(circles, ensure_ascii=False)}
}}

Return **only** a JSON object with no extra text:
{{
  "passed": true/false,
  "reason": "Brief explanation in 1–2 sentences",
  "score": 0-100
}}"""
        return prompt
    
    def judge_plotting_code(
        self,
        question: str,
        plotting_code: Dict[str, Any],
        index: Optional[int] = None,
    ) -> Dict[str, Any]:
        """
        Judge plotting (geometry drawing) code.
        
        Args:
            question: Geometry problem text.
            plotting_code: Dict containing plotting code (dsl, segments, circles).
            index: Optional index for logging.
        
        Returns:
            A result dict with:
            - success: bool
            - passed: bool
            - reason: str
            - score: int (0–100)
            - llm_response: raw LLM response
            - error: error message if failed
        """
        try:
            # Format prompt
            prompt = self.format_judge_plotting_prompt(question, plotting_code)
            
            if index is not None:
                print(f"[sample {index}] Calling LLM to judge plotting code...")
            
            # Call LLM
            api_response = self.call_llm(prompt)
            llm_response = api_response.get("content", "")
            
            # Extract JSON
            result_json = self.extract_json_from_response(llm_response)
            
            if result_json:
                # Validate required fields
                if "passed" in result_json:
                    passed = bool(result_json.get("passed", False))
                    reason = result_json.get("reason", "")
                    score = result_json.get("score", 0)
                    
                    if index is not None:
                        status = "PASSED" if passed else "FAILED"
                        print(f"{status} [sample {index}], reason: {reason[:50]}...")
                    
                    return {
                        "success": True,
                        "passed": passed,
                        "reason": reason,
                        "score": score,
                        "llm_response": llm_response,
                    }
                else:
                    return {
                        "success": False,
                        "error": "JSON from plotting-code judge is missing 'passed' field",
                        "llm_response": llm_response,
                    }
            else:
                return {
                    "success": False,
                    "error": "Failed to extract JSON from plotting-code judge response",
                    "llm_response": llm_response,
                }
        
        except Exception as e:
            if index is not None:
                print(f"[sample {index}] Plotting-code judging failed: {e}")
            return {
                "success": False,
                "error": str(e),
            }
