import re
import logging
import requests
from openai import OpenAI
import httpx

logger = logging.getLogger(__name__)


SYSTEM_PROMPT = """You are an expert evaluator. Your task is to determine if a model's answer is semantically equivalent to a provided standard answer, given a specific question.
Your evaluation must be strict. The model's answer is only correct if it fully matches the meaning of the standard answer.
You must provide your final judgement as a single word: either "CORRECT" or "INCORRECT". Do not provide any explanation or other text."""


USER_PROMPT_TEMPLATE = """I will provide a question, a standard answer, and a model's answer. You must evaluate if the model's answer is correct.

---
**Example 1:**
[Question]: Is the countertop tan or blue?
[Standard Answer]: The countertop is tan.
[Model's Answer]: tan
[Your Judgement]: CORRECT
---
**Example 2:**
[Question]: Is the man phone both blue and closed?
[Standard Answer]: Yes, the man phone is both blue and closed.
[Model's Answer]: No.
[Your Judgement]: INCORRECT
---
**Example 3:**
[Question]: Who wrote the novel "Pride and Prejudice"?
[Standard Answer]: Jane Austen authored Pride and Prejudice.
[Model's Answer]: The novel "Pride and Prejudice" was written by Jane Austen.
[Your Judgement]: CORRECT
---
**Example 4:**
[Question]: In which U.S. state is Silicon Valley located?
[Standard Answer]: California.
[Model's Answer]: San Jose.
[Your Judgement]: INCORRECT
---
**Example 5:**
[Question]: What is the population of the United States as of the 2020 Census?
[Standard Answer]: 331,449,281.
[Model's Answer]: 331449281
[Your Judgement]: CORRECT
---
**Example 6:**
[Question]: What is 25 multiplied by 4?
[Standard Answer]: 100
[Model's Answer]: 99
[Your Judgement]: INCORRECT
---
**Task:**
[Question]: {question}
[Standard Answer]: {ground_truth}
[Model's Answer]: {prediction}
[Your Judgement]: 
"""

class LLMJudgeClient:
    def __init__(self, base_url, model_name, temperature, api_key="EMPTY"):
        self.base_url = base_url
        self.enabled = bool(base_url)
        
        if not base_url:
            logger.warning("LLMJudgeClient initialized without base_url, LLM Judge is disabled.")
            self.client = None
            self.model_name = None
            self.temperature = temperature
            return
            
        self.client = OpenAI(
            api_key=api_key,
            base_url=base_url,
            timeout=100,
            
        )

        if model_name:
            self.model_name = model_name
        else:
            # Attempt to retrieve the model name
            sess = requests.Session()
            sess.trust_env = False
            resp = sess.get(f"{base_url}/models")
            resp = resp.json()
            assert len(resp["data"]) == 1, f"Multiple models are deployed on the current base_url, please specify model_name."
            self.model_name = resp["data"][0]["id"]
        self.temperature = temperature

        logger.info(f"LLMJudgeClient initialized successfully, base_url={base_url}, model_name={self.model_name}, temperature={temperature}")
    
    def verify(self, prediction, ground_truth, question):
        # If LLM Judge is not enabled, return 0.0
        if not self.enabled or self.client is None:
            print(f"LLM Judge is not enabled (no base_url configured), returning 0.0")
            logger.warning("LLM Judge is not enabled (no base_url configured), returning 0.0")
            return 0.0
        
        # Log request details
        import socket
        hostname = socket.gethostname()
        
        print(f"\n{'='*80}")
        print(f"🔍 [LLMJudge] === 开始验证 on host={hostname} ===")
        print(f"{'='*80}")
        print(f"🌐 [LLMJudge] base_url: {self.base_url}")
        print(f"🤖 [LLMJudge] model_name: {self.model_name}")
        print(f"🌡️  [LLMJudge] temperature: {self.temperature}")
        print(f"{'-'*80}")
        print(f"📝 [LLMJudge] 输入参数:")
        print(f"   ├─ prediction 类型: {type(prediction)}")
        print(f"   ├─ prediction 长度: {len(str(prediction))}")
        print(f"   ├─ prediction 前300字符: {str(prediction)[:300]}")
        print(f"   ├─ ground_truth 类型: {type(ground_truth)}")
        print(f"   ├─ ground_truth 值: {ground_truth}")
        print(f"   ├─ question 类型: {type(question)}")
        print(f"   ├─ question 长度: {len(str(question))}")
        print(f"   └─ question 前300字符: {str(question)[:300]}")
        print(f"{'-'*80}")
        
        logger.info(f"[LLMJudge] === Starting verification on host={hostname} ===")
        logger.info(f"[LLMJudge] base_url={self.base_url}")
        logger.info(f"[LLMJudge] model_name={self.model_name}")
        logger.info(f"[LLMJudge] question={question[:150]}..." if len(question) > 150 else f"[LLMJudge] question={question}")
        logger.info(f"[LLMJudge] ground_truth={ground_truth}")
        logger.info(f"[LLMJudge] prediction={prediction}")
            
        system_prompt = SYSTEM_PROMPT
        user_prompt = USER_PROMPT_TEMPLATE.format(question=question, ground_truth=ground_truth, prediction=prediction)
        
        print(f"📋 [LLMJudge] 准备好的 prompts:")
        print(f"   ├─ system_prompt 长度: {len(system_prompt)}")
        print(f"   ├─ user_prompt 长度: {len(user_prompt)}")
        print(f"   └─ user_prompt 前500字符:\n{user_prompt[:500]}")
        print(f"{'-'*80}")
        
        logger.info(f"[LLMJudge] Prepared prompts, system_prompt length={len(system_prompt)}, user_prompt length={len(user_prompt)}")
        
        # 构建请求参数
        request_params = {
            "model": self.model_name,
            "messages": [
                {"role": "system", "content": system_prompt},
                {"role": "user", "content": user_prompt},
            ],
            "temperature": self.temperature,
            "max_tokens": 512,
            "timeout": 60,
        }
        
        print(f"📤 [LLMJudge] 发送请求参数:")
        print(f"   ├─ model: {request_params['model']}")
        print(f"   ├─ temperature: {request_params['temperature']}")
        print(f"   ├─ max_tokens: {request_params['max_tokens']}")
        print(f"   ├─ timeout: {request_params['timeout']}")
        print(f"   └─ messages: {len(request_params['messages'])} 条消息")
        print(f"{'-'*80}")

        try:
            print(f"🚀 [LLMJudge] 正在发送请求到 {self.base_url}...")
            logger.info(f"[LLMJudge] Sending chat.completions.create request to {self.base_url}...")
            
            chat_response = self.client.chat.completions.create(**request_params)
            
            response = chat_response.choices[0].message.content.strip()
            print(f"✅ [LLMJudge] 请求成功! response: {response}")
            print(f"{'='*80}\n")
            logger.info(f"[LLMJudge] ✅ Request SUCCESS, response={response}")
            
        except Exception as e:
            # Detailed error logging
            print(f"\n{'!'*80}")
            print(f"❌❌❌ [LLMJudge] 请求失败!")
            print(f"{'!'*80}")
            print(f"❌ [LLMJudge ERROR] 异常类型: {type(e).__name__}")
            print(f"❌ [LLMJudge ERROR] 异常信息: {str(e)}")
            print(f"❌ [LLMJudge ERROR] Host: {hostname}")
            print(f"❌ [LLMJudge ERROR] base_url: {self.base_url}")
            
            logger.error(f"[LLMJudge] ❌ Request FAILED: {type(e).__name__}: {e}")
            
            # Try to extract more details from the exception
            if hasattr(e, 'response'):
                print(f"❌ [LLMJudge ERROR] Exception 有 response 属性")
                logger.error(f"[LLMJudge] Exception has response attribute")
                try:
                    print(f"❌ [LLMJudge ERROR] Response status_code: {e.response.status_code}")
                    print(f"❌ [LLMJudge ERROR] Response headers: {dict(e.response.headers)}")
                    print(f"❌ [LLMJudge ERROR] Response text (前2000字符):")
                    print(f"{e.response.text[:2000]}")
                    print(f"{'!'*80}")
                    
                    logger.error(f"[LLMJudge] Response status_code: {e.response.status_code}")
                    logger.error(f"[LLMJudge] Response headers: {dict(e.response.headers)}")
                    logger.error(f"[LLMJudge] Response text: {e.response.text}")
                except Exception as detail_err:
                    print(f"❌ [LLMJudge ERROR] 无法提取 response 详情: {detail_err}")
                    logger.error(f"[LLMJudge] Could not extract response details: {detail_err}")
            
            if hasattr(e, 'body'):
                print(f"❌ [LLMJudge ERROR] Exception body: {e.body}")
                logger.error(f"[LLMJudge] Exception body: {e.body}")
            
            if hasattr(e, 'message'):
                print(f"❌ [LLMJudge ERROR] Exception message: {e.message}")
                logger.error(f"[LLMJudge] Exception message: {e.message}")
            
            # Print full exception details
            import traceback
            print(f"❌ [LLMJudge ERROR] 完整堆栈追踪:")
            traceback.print_exc()
            print(f"{'!'*80}\n")
            
            logger.error(f"[LLMJudge] Full exception traceback:\n{traceback.format_exc()}")
            
            # Also print to stdout for immediate visibility in Ray logs
            print(f"[LLMJudge ERROR] Host={hostname}, base_url={self.base_url}")
            print(f"[LLMJudge ERROR] Exception: {type(e).__name__}: {e}")
            if hasattr(e, 'response'):
                try:
                    print(f"[LLMJudge ERROR] Response status: {e.response.status_code}")
                    print(f"[LLMJudge ERROR] Response body: {e.response.text[:500]}")
                except:
                    pass
            
            return 0.0

        # Parse LLM judge response
        if re.search(r"\bCORRECT\b", response, re.IGNORECASE):
            acc_reward = 1.0
        elif re.search(r"\bINCORRECT\b", response, re.IGNORECASE):
            acc_reward = 0.0
        else:
            logger.warning(
                "Judgement format error. Expected 'CORRECT' or 'INCORRECT'. "
                "response=%r, prediction=%r, ground_truth=%r",
                response, prediction, ground_truth,
            )
            acc_reward = 0.0

        # Penalize for model trying to predict longer answer to hack llm-as-judge
        # if len(prediction) >= 1000:
        #     acc_reward = 0.0
        #     logger.warning("Prediction length >= 1000, reward set to 0.0.")

        return acc_reward


if __name__ == "__main__":
    client = LLMJudgeClient(
        base_url=os.getenv("LLM_JUDGE_BASE_URL", "http://YOUR_VLLM_HOST:18901/v1"),
                            model_name="judge-72b",
                            temperature=0.1)
    answer = "The Answer is 19."
    ground_truth = "19"
    question = "Krysta counted the number of words per page in her new book. How many pages have at least 60 words?"
    print(client.verify(answer, ground_truth, question))
