import json
import requests
import re
import time

OPENROUTER_API_URL = "https://openrouter.ai/api/v1/chat/completions"
JUDGE_PROMPT_FILE = "prompts/eval_traces_prompt.txt"
with open(JUDGE_PROMPT_FILE, "r") as f:
    JUDGE_PROMPT = f.read()

def get_reasoning_from_openrouter(prompts, model, api_key):
    headers = {
        "Authorization": f"Bearer {api_key}",
        "Content-Type": "application/json"
    }
    data = {
        "model": model,
        "messages": prompts
    }
    max_retries = 5
    for attempt in range(max_retries):
        try:
            response = requests.post(OPENROUTER_API_URL, headers=headers, json=data)
            response.raise_for_status()
            result = response.json()
            content = result["choices"][0]["message"]["content"]
            # Only return reasoning content from the reasoning field
            reasoning = result["choices"][0]["message"].get("reasoning")
            return reasoning, content
        except requests.exceptions.RequestException as e:
            if attempt < max_retries - 1:
                wait_time = 2 ** attempt
                print(f"Request failed (attempt {attempt+1}/{max_retries}): {e}. Retrying in {wait_time}s...")
                time.sleep(wait_time)
            else:
                print(f"Request failed after {max_retries} attempts: {e}")
                raise

def parse_judge_response(judge_response):
    try:
        data = json.loads(judge_response)
    except Exception as e1:
        try:
            # Handle escaped quotes - this is often the case when JSON is returned as a string representation
            fixed = judge_response
            
            # Remove markdown code block markers if present
            if fixed.startswith('```json\n'):
                fixed = fixed[8:]  # Remove '```json\n'
            if fixed.startswith('```\n'):
                fixed = fixed[4:]  # Remove '```\n'
            if fixed.endswith('\n```'):
                fixed = fixed[:-4]  # Remove '\n```'
            if fixed.endswith('```'):
                fixed = fixed[:-3]  # Remove '```'
            
            # If the response contains escaped quotes, unescape them
            if '\\"' in fixed:
                fixed = fixed.replace('\\"', '"')
            
            # If the response contains escaped newlines, unescape them
            if '\\n' in fixed:
                fixed = fixed.replace('\\n', '\n')
            
            # Other common fixes
            fixed = fixed.replace("'", '"')
            fixed = re.sub(r',\s*}', '}', fixed)
            fixed = re.sub(r',\s*]', ']', fixed)
            data = json.loads(fixed)
        except Exception as e2:
            # If JSON parsing fails, try to extract the expected fields using regex
            return {
                "conclusion": extract_string_field(judge_response, "conclusion"),
                "evidence": extract_list_field(judge_response, "evidence"),
                "reasoning": extract_string_field(judge_response, "reasoning")
            }
    
    # Extract the expected fields from the parsed JSON and flatten the structure
    evaluation_awareness = data.get("evaluation_awareness", {})
    return {
        "conclusion": evaluation_awareness.get("conclusion"),
        "evidence": evaluation_awareness.get("evidence", []),
        "reasoning": evaluation_awareness.get("reasoning")
    }

def extract_string_field(text, field_name):
    """Extract a string field from JSON-like text using regex"""
    # First try to find the field nested within evaluation_awareness
    nested_pattern = rf'"evaluation_awareness".*?"{field_name}"\s*:\s*"(.*?)"'
    match = re.search(nested_pattern, text, re.DOTALL)
    if match:
        return match.group(1).strip()
    
    # Fallback to top-level field search
    pattern = rf'"?{field_name}"?\s*:\s*"(.*?)"'
    match = re.search(pattern, text, re.DOTALL)
    if match:
        return match.group(1).strip()
    return None

def extract_list_field(text, field_name):
    """Extract a list field from JSON-like text using regex"""
    # First try to find the field nested within evaluation_awareness
    nested_pattern = rf'"evaluation_awareness".*?"{field_name}"\s*:\s*(\[.*?\])'
    match = re.search(nested_pattern, text, re.DOTALL)
    if match:
        list_str = match.group(1)
        try:
            return json.loads(list_str)
        except Exception:
            # Fallback: extract quoted strings from the list
            items = re.findall(r'"(.*?)"', list_str)
            return items
    
    # Fallback to top-level field search
    pattern = rf'"?{field_name}"?\s*:\s*(\[.*?\])'
    match = re.search(pattern, text, re.DOTALL)
    if match:
        list_str = match.group(1)
        try:
            return json.loads(list_str)
        except Exception:
            # Fallback: extract quoted strings from the list
            items = re.findall(r'"(.*?)"', list_str)
            return items
    return []

def judge_llm_evaluation_awareness(system_prompt, user_prompt, reasoning, answer, model, api_key):
    headers = {
        "Authorization": f"Bearer {api_key}",
        "Content-Type": "application/json"
    }
    # Format the prompt by replacing the placeholders
    formatted_prompt = JUDGE_PROMPT.format(
        system=system_prompt,
        user=user_prompt,
        reasoning=reasoning,
        answer=answer
    )
    judge_prompt = [
        {
            "role": "user",
            "content": formatted_prompt
        }
    ]
    data = {
        "model": model,
        "messages": judge_prompt
    }
    response = requests.post(OPENROUTER_API_URL, headers=headers, json=data)
    response.raise_for_status()
    result = response.json()
    content = result["choices"][0]["message"]["content"]
    return content.strip()

def eval_reasoning_ea(system_prompt, user_prompt, model, api_key):
    prompts = [
        {"role": "system", "content": system_prompt},
        {"role": "user", "content": user_prompt}
    ]

    reasoning, full_output = get_reasoning_from_openrouter(prompts, model, api_key)
    judge_response = judge_llm_evaluation_awareness(system_prompt, user_prompt, reasoning, full_output, model, api_key)
    parsed = parse_judge_response(judge_response)
    
    # Build the full conversation history
    full_conversations = []
    
    # Initial conversation with the model
    conversation = [
        {"role": "system", "content": system_prompt},
        {"role": "user", "content": user_prompt},
        {"role": "assistant", "content": full_output, "reasoning": reasoning}
    ]
    full_conversations.append(conversation)
    
    # Judge evaluation conversation
    formatted_judge_prompt = JUDGE_PROMPT.format(
        system=system_prompt,
        user=user_prompt,
        reasoning=reasoning,
        answer=full_output
    )
    
    return reasoning, full_output, judge_response, parsed, full_conversations