"""Simplified evaluation system for GUI Agent"""
import json
import os
import time
from pathlib import Path
from typing import Any, Dict, List, Union
import base64
import re
from types import SimpleNamespace
from PIL import Image

import requests
from beartype import beartype
from playwright.sync_api import CDPSession, Page
from openai import OpenAI

from browser_env import Action, Trajectory, StateInfo
from .helper_functions import (
    clean_answer,
    clean_url,
)
from .helper_functions import encode_image as _encode_image_util

class Evaluator:
    """Base class for evaluation"""
    
    def __init__(self, eval_tag: str = "") -> None:
        self.eval_tag = eval_tag

    def __call__(
        self,
        trajectory: Trajectory,
        config_file: Path | str,
        page: Page,
        client: CDPSession,
    ) -> float:
        raise NotImplementedError

    @staticmethod
    def get_last_action(trajectory: Trajectory) -> Action:
        """Get the last action from trajectory"""
        if not trajectory or not isinstance(trajectory[-1], dict):
            raise ValueError("The last element of trajectory should be an action")
        return trajectory[-1]

    @staticmethod
    def get_last_state(trajectory: Trajectory) -> StateInfo:
        """Get the last state from trajectory"""
        if len(trajectory) < 2 or not isinstance(trajectory[-2], dict):
            raise ValueError("The second last element of trajectory should be a state")
        return trajectory[-2]
    

class LLMJudgeEvaluator(Evaluator):
    """
    LLM-based evaluator:
    1) Extract Key Points from the intent
    2) Score candidate screenshots, select high-scoring screenshots and their reasoning
    3) Summarize action history + Key Points + important screenshots, determine success/failure
    """

    def __init__(self, vllm_client=None, score_threshold: int = 4, max_images: int = 50):
        super().__init__()
        self.vllm_client = vllm_client
        self.score_threshold = int(score_threshold)
        self.max_images = int(max_images)

    # ------------------------ helpers ------------------------
    def _llm_chat_text(self, messages: List[Dict[str, Any]]) -> str:
        """Uniformly extract .content text, compatible with dictionaries/objects"""
        resp = self.vllm_client.chat(messages=messages, stream=False)
        if isinstance(resp, dict):
            return resp.get("content", "") or ""
        return getattr(resp, "content", "") or ""

    def _encode_image_from_path(self, path: str) -> str:
        """path -> base64(jpeg/png). Prioritize utils.encode_image(PIL.Image); fall back to file bytes base64."""
        path = str(path)
        if _encode_image_util and Image is not None:
            try:
                return _encode_image_util(Image.open(path))
            except Exception:
                pass
        # fallback
        with open(path, "rb") as f:
            return base64.b64encode(f.read()).decode("utf-8")

    def _encode_image_any(self, obj: Any) -> str | None:
        """Convert path/bytes/base64/PIL.Image to base64 string (without data: prefix)"""
        try:
            if obj is None:
                return None
            if isinstance(obj, str):
                # If already data:image/...;base64,xxx
                if obj.startswith("data:image/"):
                    return obj.split("base64,", 1)[-1]
                # Might be pure base64
                try:
                    base64.b64decode(obj, validate=True)
                    return obj
                except Exception:
                    # Treat as path
                    if os.path.isfile(obj):
                        return self._encode_image_from_path(obj)
                    return None
            if isinstance(obj, (bytes, bytearray)):
                return base64.b64encode(bytes(obj)).decode("utf-8")
            if Image is not None and isinstance(obj, Image.Image):
                import io as _io
                buf = _io.BytesIO()
                obj.save(buf, format="PNG")
                return base64.b64encode(buf.getvalue()).decode("utf-8")
        except Exception:
            return None
        return None

    def _extract_actions_text(self, trajectory: Trajectory) -> List[str]:
        """Convert action history to human-readable multi-line text"""
        actions: List[str] = []
        for idx, item in enumerate(trajectory):
            # Experience: odd/last items are usually actions; but to be safe, check fields
            if isinstance(item, dict) and any(k in item for k in ("name", "op", "action")):
                name = item.get("name") or item.get("op") or item.get("action") or "action"
                args = item.get("arguments") or item.get("args") or {}
                # Extract a few key fields
                brief = []
                for k in ("text", "description", "field_description", "answer"):
                    if k in args and args[k]:
                        v = str(args[k])
                        if len(v) > 80:
                            v = v[:77] + "..."
                        brief.append(f'{k}="{v}"')
                if not brief and isinstance(args, dict):
                    # Show at least one key
                    for k, v in args.items():
                        if v:
                            vv = str(v)
                            if len(vv) > 80:
                                vv = vv[:77] + "..."
                            brief = [f'{k}="{vv}"']
                            break
                actions.append(f"{name}: " + (", ".join(brief) if brief else "(no-args)"))
        return actions

    def _gather_images_base64(self, configs: Dict[str, Any], trajectory: Trajectory, page: Page | None) -> List[str]:
        """Priority: use configs['eval']['image_paths']; otherwise infer from trajectory states; if still not possible, capture current page."""
        b64s: List[str] = []

        # 1) Paths explicitly provided in config
        try:
            image_paths = configs.get("eval", {}).get("image_paths") or []
            for p in image_paths:
                b = self._encode_image_any(p) or self._encode_image_from_path(p)
                if b:
                    b64s.append(b)
        except Exception:
            pass

        # 2) Extract from states in trajectory
        for item in trajectory:
            if not isinstance(item, dict) or "url" in item:  # Many states have url, but actions might too; just heuristic
                # Still try fields
                pass
            # Common field name guesses
            for k in ("screenshot_path", "image_path", "screenshot_b64", "screenshot", "image"):
                if k in item and item[k]:
                    b = self._encode_image_any(item[k])
                    if b:
                        b64s.append(b)

        # 3) Fallback: capture current page
        if not b64s and page is not None:
            try:
                png_bytes = page.screenshot(type="png", full_page=False)
                b64s.append(base64.b64encode(png_bytes).decode("utf-8"))
            except Exception:
                pass

        # Deduplicate & limit quantity
        seen = set()
        uniq = []
        for b in b64s:
            if b not in seen:
                uniq.append(b)
                seen.add(b)
        return uniq[: max(1, self.max_images)]

    # ------------------ key steps with LLM -------------------
    def _identify_key_points(self, task: str) -> str:
        system_msg = (
            "You are an expert tasked with analyzing a given task to identify the key points explicitly stated in the task description.\n\n"
            "Objective: extract critical elements explicitly mentioned in the task.\n"
            "Important: do not infer unstated info. If there are comparatives like 'cheapest/highest/latest/...', the key point should indicate using sort/filter."
        )
        user_text = f"Task: {task}"
        messages = [
            {"role": "system", "content": system_msg},
            {"role": "user", "content": [{"type": "text", "text": user_text}]},
        ]
        reply = self._llm_chat_text(messages)
        # Try to keep only the part after "Key Points:" (if exists)
        m = re.split(r"\*\*Key Points\*\*:\s*|Key Points:\s*", reply, maxsplit=1, flags=re.IGNORECASE)
        key_points = m[-1].strip() if len(m) > 1 else reply.strip()
        return key_points

    def _score_image(self, task: str, key_points: str, image_b64: str) -> tuple[int, str]:
        """
        Returns (score[1..5] or 0, reasoning_text)
        """
        system_msg = (
            "You are an expert evaluator tasked with determining whether an image contains information about the necessary steps to complete a task.\n"
            "Respond with two lines:\n"
            "Reasoning: <your reasoning>\n"
            "Score: [1-5]"
        )
        prompt = (
            f"**Task**: {task}\n\n"
            f"**Key Points for Task Completion**:\n{key_points}\n\n"
            f"The snapshot of the web page is shown in the image."
        )
        messages = [
            {"role": "system", "content": system_msg},
            {
                "role": "user",
                "content": [
                    {"type": "text", "text": prompt},
                    {"type": "image_url", "image_url": {"url": f"data:image/png;base64,{image_b64}", "detail": "high"}},
                ],
            },
        ]
        reply = self._llm_chat_text(messages)
        # Parse score and reasoning
        reasoning = ""
        score = 0
        # Reasoning section
        m_r = re.search(r"Reasoning\s*:\s*(.+?)(?:\n+|$)", reply, flags=re.IGNORECASE | re.DOTALL)
        if m_r:
            reasoning = m_r.group(1).strip()
        # Score section
        m_s = re.search(r"Score\s*:\s*([1-5])", reply)
        if m_s:
            score = int(m_s.group(1))
        return score, reasoning

    def _final_judgement(self, task: str, key_points: str, last_actions: List[str],
                         selected_images: List[str], thoughts: List[str]) -> bool:
        """
        Returns True for success, False for failure
        """
        system_msg = (
            "You are an expert in evaluating a web navigation agent.\n"
            "Strictly output two lines:\n"
            'Thoughts: <your reasoning>\n'
            'Status: "success" or "failure"'
        )
        if selected_images:
            prompt = (
                f"User Task: {task}\n\n"
                f"Key Points:\n{key_points}\n\n"
                "Action History:\n" + "\n".join(f"{i+1}. {a}" for i, a in enumerate(last_actions)) + "\n\n"
                "The potentially important snapshots of the webpage in the agent's trajectory and their reasons:\n" +
                "\n".join(f"{i+1}. {t}" for i, t in enumerate(thoughts))
            )
        else:
            prompt = (
                f"User Task: {task}\n\n"
                f"Key Points:\n{key_points}\n\n"
                "Action History:\n" + "\n".join(f"{i+1}. {a}" for i, a in enumerate(last_actions))
            )

        content = [{"type": "text", "text": prompt}]
        for b in selected_images:
            content.append({"type": "image_url", "image_url": {"url": f"data:image/png;base64,{b}", "detail": "high"}})

        messages = [
            {"role": "system", "content": system_msg},
            {"role": "user", "content": content},
        ]
        reply = self._llm_chat_text(messages)

        m = re.search(r'Status\s*:\s*"?\s*(success|failure)\s*"?', reply, flags=re.IGNORECASE)
        if not m:
            # If parsing fails, judge as failure
            return False
        return m.group(1).lower() == "success"

    # ------------------------- main --------------------------
    def __call__(
        self,
        trajectory: Trajectory,
        config_file: Path | str,
        page: Page | None = None,
        client: CDPSession | None = None,
    ) -> float:
        if self.vllm_client is None:
            raise ValueError("LLMJudgeEvaluator requires `vllm_client` (an object with .chat(messages, stream=False)).")

        with open(config_file, "r") as f:
            configs = json.load(f)

        task_intent = configs.get("intent", "") or ""
        # 1) Key Points
        key_points = self._identify_key_points(task_intent)

        # 2) Action history text
        last_actions = self._extract_actions_text(trajectory)

        # 3) Collect images and score them
        all_b64 = self._gather_images_base64(configs, trajectory, page)
        selected_images: List[str] = []
        thoughts: List[str] = []
        for b in all_b64:
            score, reasoning = self._score_image(task_intent, key_points, b)
            if score >= self.score_threshold:
                selected_images.append(b)
                if reasoning:
                    thoughts.append(reasoning)
            if len(selected_images) >= self.max_images:
                break

        # 4) Final determination
        is_success = self._final_judgement(task_intent, key_points, last_actions, selected_images, thoughts)
        return 1.0 if is_success else 0.0



class StringEvaluator(Evaluator):
    """Check whether the answer is correct with exact match, must include, and fuzzy match"""
    
    def __init__(self, vllm_client=None):
        super().__init__()
        self.vllm_client = vllm_client

    def __call__(
        self,
        trajectory: Trajectory,
        config_file: Path | str,
        page: Page | None = None,
        client: CDPSession | None = None,
    ) -> float:
        with open(config_file, "r") as f:
            configs = json.load(f)

        # Use helper function for cleaning answers

        last_action = self.get_last_action(trajectory)
        pred = clean_answer(last_action.get("answer", ""))
        
        score = 1.0
        for approach, value in configs["eval"]["reference_answers"].items():
            match approach:
                case "exact_match":
                    assert isinstance(value, str)
                    ref_answer = clean_answer(value)
                    score = score * (pred == ref_answer)
                case "must_include":
                    url = page.url
                    pred += str(url)
                    assert isinstance(value, list)
                    for must_value in value:
                        must_value = clean_answer(must_value)
                        score = score * (must_value in pred)
                case "fuzzy_match":
                    intent = configs.get("intent", "")
                    assert isinstance(value, list)
                    for reference in value:
                        fuzzy_score = self._llm_fuzzy_match(pred, reference, intent)
                        score = score * fuzzy_score
        return score

    def _llm_fuzzy_match(self, pred: str, reference: str, question: str) -> float:
        """Use vLLM with Qwen2.5-VL-Instruct for binary yes/no matching"""
        try:
            
            # Create the prompt for binary matching
            prompt = f"""You are an evaluator that determines if a predicted answer is correct for a given question.

Question: {question}
Reference Answer: {reference}
Predicted Answer: {pred}

Please evaluate if the predicted answer is correct. Consider:
1. Semantic similarity to the reference answer
2. Key information overlap
3. Factual accuracy
Note: If the predicted answer is like "I'm sorry, I can't answer that question.", you should return "no".

Respond with only "yes" or "no":
- "yes": if the predicted answer is correct or equivalent to the reference answer
- "no": if the predicted answer is incorrect, incomplete, or irrelevant"""

            # Call the model
            response = self.vllm_client.chat(
                messages=[
                    {"role": "system", "content": "You are a helpful evaluator that provides binary yes/no responses."},
                    {"role": "user", "content": prompt}
                ],
                temperature=0.8,
                max_tokens=5
            )
            
            # Extract the response
            answer_text = response.content.strip().lower()
            
            # Parse the yes/no response
            if 'yes' in answer_text:
                return 1.0
            elif 'no' in answer_text:
                return 0.0
            else:
                print(f"Could not parse yes/no from response: '{answer_text}'")
                # Fallback: return 0.0 if parsing fails
                return 0.0
                
        except Exception as e:
            print(f"Error in fuzzy matching: {e}")
            # Fallback: return 0.0 if there's an error
            return 0.0
        


class URLExactEvaluator(Evaluator):
    """Check whether the URL is exactly the same as the reference URLs"""
    
    def __call__(
        self,
        trajectory: Trajectory,
        config_file: Path | str,
        page: Page,
        client: CDPSession | None = None,
    ) -> float:
        with open(config_file, "r") as f:
            configs = json.load(f)

        # Use helper function for cleaning URLs

        pred = clean_url(page.url)
        ref_urls = configs["eval"]["reference_url"].split(" |OR| ")
        ref_urls = [clean_url(url) for url in ref_urls]
        matching_rule = configs["eval"].get("url_note", "EXACT")
        
        if matching_rule == "EXACT":
            if pred in ref_urls:
                return 1.0
            else:
                return 0.0
        elif matching_rule == "GOLD in PRED":
            if any([ref in pred for ref in ref_urls]):
                return 1.0
            else:
                return 0.0
        else:
            raise ValueError(f"Unknown matching rule: {matching_rule}")


class HTMLContentEvaluator(Evaluator):
    """Check whether the contents appear in the page"""
    
    def __call__(
        self,
        trajectory: Trajectory,
        config_file: Path | str,
        page: Page,
        client: CDPSession | None = None,
    ) -> float:
        def clean(text: str) -> str:
            text = str(text)
            return text.strip().lower()

        with open(config_file, "r") as f:
            configs = json.load(f)

        targets = configs["eval"]["program_html"]
        score = 1.0
        
        for target in targets:
            target_url: str = target["url"]
            if target_url.startswith("func"):
                func = target_url.split("func:")[1]
                func = func.replace("__last_url__", page.url)
                target_url = eval(func)

            required_contents: str = target["required_contents"]
            locator: str = target["locator"]

            # Navigate to that URL
            if target_url != "last":
                page.goto(target_url)
                time.sleep(2)  # Wait for page to load

            # Get the element content
            if not locator.strip():
                selected_element = page.content()
            elif locator.startswith("document."):
                try:
                    selected_element = page.evaluate(f"() => {locator}")
                    if not selected_element:
                        selected_element = ""
                    selected_element = str(selected_element)
                except Exception:
                    selected_element = ""
            else:
                raise ValueError(f"Unknown locator: {locator}")

            required_contents_or = [
                clean(x) for x in required_contents.split(" |OR| ")
            ]
            selected_element = clean(selected_element)
            score *= any([
                content in selected_element
                for content in required_contents_or
            ])

        return score


class EvaluatorComb:
    """Combination of multiple evaluators"""
    
    def __init__(self, evaluators: List[Evaluator]) -> None:
        self.evaluators = evaluators

    def __call__(
        self,
        trajectory: Trajectory,
        config_file: Path | str,
        page: Page,
        client: CDPSession,
    ) -> float:
        score = 1.0
        for evaluator in self.evaluators:
            cur_score = evaluator(trajectory, config_file, page, client)
            score *= cur_score
        return score


def evaluator_router(config_file: Path | str, vllm_client=None) -> EvaluatorComb:
    """Router to get the evaluator class based on config file"""
    
    with open(config_file, "r") as f:
        configs = json.load(f)
    try:
        eval_types = configs["eval"]["eval_types"]
    except:
        eval_types = ["LLM_eval"]
    evaluators: List[Evaluator] = []
    
    for eval_type in eval_types:
        match eval_type:
            case "string_match":
                evaluators.append(StringEvaluator(vllm_client))
            case "url_match":
                evaluators.append(URLExactEvaluator())
            case "program_html":
                evaluators.append(HTMLContentEvaluator())
            case "LLM_eval":
                evaluators.append(LLMJudgeEvaluator(vllm_client=vllm_client, score_threshold=4, max_images=50))
            case _:
                raise ValueError(f"eval_type {eval_type} is not supported")

    return EvaluatorComb(evaluators) 