import os
import json
import base64
import io
import re
from openai import OpenAI
from PIL import Image
import random

client = OpenAI(
    api_key=""
)


def pil_image_to_base64(pil_image):
    buffered = io.BytesIO()
    pil_image.save(buffered, format="PNG")
    return base64.b64encode(buffered.getvalue()).decode("utf-8")


def extract_level_from_response(response):
    response = response.upper().strip()

    level_match = re.search(r'LEVEL\s*([0123])', response)
    if level_match:
        return f"Level {level_match.group(1)}"

    if 'Level 0' in response:
        return "Level 0"
    elif 'Level 1' in response:
        return "Level 1"
    elif 'Level 2' in response:
        return "Level 2"
    elif 'Level 3' in response:
        return "Level 3"

    return "Level 0"


def strategic_judge(answer):
    judge_prompt = f'''### Instruction:
You are an LLM-as-judge evaluating how a model's answer handles ambiguity.
Ignore the question completely. Judge only the style and content of the answer.
Classify the answer as Level 0, 1, 3, or 4 according to the definitions below.
Do not provide explanations or reasoning — output only the level number.

### Level Definitions:
* Level 0 – Clear VQA
- Answer is a direct, factual statement.
- No mention of resolving ambiguous terms.
- Reads as a straightforward description of what is visible.

* Level 1 – Context-resolved Ambiguity
- Answer explicitly clarifies what an ambiguous term refers to, then provides the definitive factual description.
- Typically has a two-part structure: clarification of the referent + final descriptive answer.
- Must be a natural full sentence, not just a fragment.

* Level 2 – List All Plausible Options
- Answer enumerates 2–3 distinct possibilities.
- Each possibility is described in natural sentences (not bullet points).
- No single option is selected as the "best guess".

* Level 3 – Clarification Required
- Answer does not attempt to guess or enumerate.
- Politely requests clarification from the user.
- Acknowledges that multiple possibilities exist without listing them.
- Briefly explains why clarification is needed.

### Answer:
- {answer}

### Response:
'''

    try:
        response = client.responses.create(
            model="gpt-5-mini",
            input=[
                {
                    "role": "user",
                    "content": [
                        {"type": "input_text", "text": judge_prompt},
                    ]
                }
            ]
        )

        output_text = ""
        for item in response.output:
            if hasattr(item, "content"):
                for content in item.content:
                    if hasattr(content, "text"):
                        output_text += content.text

        return output_text.strip()
    except Exception as e:
        print(f"[ERROR] GPT judgment failed: {e}")
        return "Level 0"


def factual_judge(question, answer):
    judge_prompt = f'''### Instruction:
You are a judge evaluating whether a model's answer is factually correct given the image.
- Ignore style, grammar, or completeness.
- PASS if the answer correctly refers to something that is actually visible in the image, even if it does not mention everything or omits other details.
- FAIL only if the answer mentions something not present in the image, contradicts the image, or hallucinates details.

### Output:
Return exactly one label:
- PASS
- FAIL

### Visual Question and Answer:
- Question: {question}
- Answer: {answer}

### Response:
'''

    try:
        response = client.responses.create(
            model="gpt-5-mini",
            input=[
                {
                    "role": "user",
                    "content": [
                        {"type": "input_text", "text": judge_prompt},
                    ]
                }
            ]
        )

        output_text = ""
        for item in response.output:
            if hasattr(item, "content"):
                for content in item.content:
                    if hasattr(content, "text"):
                        output_text += content.text

        return output_text.strip()
    except Exception as e:
        print(f"[ERROR] GPT judgment failed: {e}")
        return "FAIL"


def llm_judge_reward_function(completions, kwargs):
    images = kwargs.get('images', [])
    questions = kwargs.get('questions', [])
    data_types = kwargs.get('data_types', [])

    rewards = []
    for i, completion in enumerate(completions):
        try:
            if isinstance(completion, list) and len(completion) > 0:
                if isinstance(completion[0], dict) and 'content' in completion[0]:
                    answer = completion[0]['content']
                else:
                    answer = str(completion[0])
            else:
                answer = str(completion)

            question = questions[i] if i < len(questions) else "What do you see in this image?"
            gt_type = data_types[i] if i < len(data_types) else "0"

            strategic_judgment = strategic_judge(answer)
            strategic_judgment_type = extract_level_from_response(strategic_judgment)

            factual_judgment = factual_judge(question, answer)
            factual_judgment = factual_judgment.upper().strip()
            if "PASS" in factual_judgment:
                penalty = 0.0
            else:
                penalty = 0.3

            gt_type_standardized = f"Type {gt_type}" if not gt_type.startswith("Type") else gt_type

            if gt_type_standardized == strategic_judgment_type:
                reward = 1.0 - penalty
                print(f"[LLM JUDGE] ✅ Type match! Reward: {reward}")
            else:
                reward = 0.0
                print(f"[LLM JUDGE] ❌ Type mismatch! Reward: {reward}")

            rewards.append(reward)

        except Exception as e:
            print(f"[ERROR] Processing completion {i}: {e}")
            rewards.append(0.0)

    return rewards
