from PIL import Image
import io
import base64
import re
from typing import Dict, Any, Optional
from openai import AzureOpenAI

class GPT4oReward:

    def __init__(
        self,
        api_key: str,
        model: str,
        api_base: Optional[str] = None,
        region: str = "",
    ):
        API_BASE = api_base or ""
        ENDPOINT = f"{API_BASE}/{region}"

        self.client = AzureOpenAI(
            api_key=api_key,
            api_version="",
            azure_endpoint=ENDPOINT,
        )
        self.model = model

    def _encode_image(self, image: Image.Image) -> str:
        with io.BytesIO() as buffer:
            image.save(buffer, format="PNG")
            png_bytes = buffer.getvalue()
        return base64.b64encode(png_bytes).decode()

    def _build_messages(
        self,
        prompt_data: Dict[str, Any],
        img64: str
    ) -> list:
        return [
            {
                "role": "system",
                "content": [
                    {
                        "type": "text",
                        "text": (
                            "You are a rigorous image–text alignment judge. "
                            "Decide ONLY whether the image correctly satisfies the prompt."
                        )
                    }
                ]
            },
            {
                "role": "user",
                "content": [
                    {
                        "type": "text",
                        "text": f"""Evaluate the image against the PROMPT strictly from three angles:
1) Consistency (faithfulness to every required element in the prompt; no contradictions or omissions),
2) Realism (physical plausibility and visual coherence; avoid obvious artifacts),
3) Aesthetic Quality (composition, color harmony, clarity).

Decision rule (EXTREMELY STRICT):
- Output **0** if AND ONLY IF the image clearly and fully satisfies the prompt with strong realism and acceptable aesthetics. Minor doubts → downgrade to -1.
- Otherwise, output **-1**.

IMPORTANT:
- Your entire response MUST be a single token: `0` or `-1`.
- Do NOT add any words, punctuation, spaces, or explanations.

PROMPT: "{prompt_data['prompt']}"
"""
                    },
                    {
                        "type": "image_url",
                        "image_url": {
                            "url": f"data:image/png;base64,{img64}"
                        }
                    }
                ]
            }
        ]

    def _extract_label(self, text: str) -> Optional[int]:
        m = re.search(r'(?<!\d)-(?:1)\b|\b0\b', text.strip())
        if m:
            token = m.group(0)
            if token == '0':
                return 0
            if token == '-1':
                return -1
        m2 = re.search(r'-1|0', text)
        if m2:
            return int(m2.group(0))
        return None

    def get_reward(
        self,
        image: Image.Image,
        prompt_data: Dict[str, Any]
    ) -> Optional[int]:
        try:
            img64 = self._encode_image(image)
            messages = self._build_messages(prompt_data, img64)

            response = self.client.chat.completions.create(
                model=self.model,
                messages=messages,
                temperature=0.0,   
                max_tokens=4       
            )
            content = response.choices[0].message.content or ""
            print(f"[LLM raw] {content!r}")

            label = self._extract_label(content)
            if label in (0, -1):
                return label
            else:
                print("[WARN] Could not parse a valid label (0 or -1).")
                return None

        except Exception as e:
            print(f"[ERR]: {e}")
            return None

    def judge_answer(self, image: Image.Image, data: Dict[str, Any]) -> int:
        label = self.get_reward(image, data)
        return label if label in (0, -1) else -1