"""base class for evaluation"""
import collections
import html
import time
import urllib
import urllib.parse
from Agent_E.test.test_utils import clean_answer
from Agent_E.test.test_utils import evaluate_exact_match
from Agent_E.test.test_utils import evaluate_fuzzy_match
from Agent_E.test.test_utils import evaluate_must_include
from Agent_E.test.test_utils import evaluate_ua_match
from typing import Any

from Agent_E.ae.utils.logger import logger
from playwright.sync_api import CDPSession
from playwright.sync_api import Page
from termcolor import colored


class Evaluator:
    """Base class for evaluation strategies.

    Attributes:
        eval_tag (str): A tag to identify or categorize the evaluator.
    """

    def __init__(self, eval_tag: str = "") -> None:
        """Initialize the evaluator with an optional evaluation tag."""
        self.eval_tag = eval_tag

    async def __call__(self, task_config: dict[str, Any], page: Page, client: CDPSession, answer: str) -> dict[str, float|str]:
        """Abstract method to be implemented by subclasses for evaluation.

        Raises:
            NotImplementedError: This method should be overridden by subclasses.
        """
        raise NotImplementedError("This method should be overridden by subclasses.")


class StringEvaluator(Evaluator):
    """Evaluates string-based answers using various matching criteria.

    Supports exact matches, some matches, fuzzy matching using LLM, and unachievable task matching.
    """

    async def __call__(
        self,
        task_config: dict[str, Any],
        page: Page | None = None,
        client: CDPSession | None = None,
        answer: str | None = None,

    ) -> dict[str, float|str]:
        last_action = answer or ""
        pred = clean_answer(last_action)

        score = 1.0
        for approach, value in task_config["eval"]["reference_answers"].items():

            match approach:
                case "exact_match":
                    logger.info(f"Evaluating exact_match for answer: Predicted: {pred} , Reference: {value}")
                    score *= evaluate_exact_match(ref=value, pred=pred)

                case "must_include":
                    logger.info(f"Evaluating must_include for answer: \"{answer}\" to see if it includes the expeced values: \"{value}\"\n")
                    assert isinstance(value, list)
                    for must_value in value: # type: ignore
                        score *= evaluate_must_include(
                            ref=must_value, # type: ignore
                            pred=pred,
                            tokenize=(len(value) == 1), # type: ignore
                        )
                case "some_matches":
                    min_required_matches = value.get("min_required", 1)
                    matches = sum(evaluate_must_include(ref=phrase, pred=pred, tokenize=False) for phrase in value["phrases"])
                    score *= float(matches >= min_required_matches)
                case "fuzzy_match":
                    logger.info(f"Evaluating fuzzy_match for answer: {answer}")
                    intent = task_config["intent"]
                    if value == "N/A":
                        # if the instruction only asks the model to generate N/A when encountering an unachievable task
                        # without more concrete reasons
                        score *= evaluate_exact_match(ref=value, pred=pred)
                        # if the instruction also asks the model to generate the reason why the task is unachievable
                        # this should be the default as it will prevent false positive N/A`
                        if score != 1:
                            score = 1.0 * evaluate_ua_match(
                                intent=task_config["intent"],
                                ref=task_config["eval"]["string_note"],
                                pred=pred,
                            )
                    else:
                        logger.info(f"Evaluating generic for answer: {answer}")
                        assert isinstance(value, list)
                        for reference in value: # type: ignore
                            score *= evaluate_fuzzy_match(
                                ref=reference, pred=pred, intent=intent # type: ignore
                            )
                case _:
                    logger.info(f"Unknown approach value received: {approach}")
        return {"score": score}


class URLEvaluator(Evaluator):
    """Evaluates if the given URL matches the expected URL criteria defined in the configuration.

    This includes checking if the base path of the URL and its query parameters match those specified in the reference URLs.
    """

    async def __call__(
        self,
        task_config: dict[str, Any],
        page: Page,
        client: CDPSession | None = None,
        answer: str | None = None
    ) -> dict[str, float|str]:
        """Evaluates the current page URL against reference URLs specified in the config file.

        Parameters:
            task_config (dict[str, Any]): The task configuration containing evaluation criteria.
            page (Page): The Playwright page object for the current webpage.
            client (CDPSession | None, optional): The Chrome DevTools Protocol session object. Not used in this evaluator.
            answer (str | None, optional): Not used in this evaluator.

        Returns:
            dict[str, float|str]: "score" 1.0 if the page URL matches any of the reference URLs, considering the matching rule; otherwise 0.0.

        Raises:
            ValueError: If an unknown matching rule is specified in the config file.
        """

        def clean_url(url: str) -> str:
            url = str(url)
            url = url.rstrip("/")
            url = url.lower()
            return url

        def parse_url(url: str) -> tuple[str, dict[str, list[str]]]:
            """Parse a URL into its base, path, and query components."""
            parsed_url = urllib.parse.urlparse(url)
            base_path = parsed_url.netloc + parsed_url.path
            query = urllib.parse.parse_qs(parsed_url.query)
            return base_path, query

        def parse_urls(
            urls: list[str],
        ) -> tuple[list[str], dict[str, set[str]]]:
            """Parse a list of URLs."""
            base_paths: list[str] = []
            queries: dict[str, set[str]] = collections.defaultdict(set)
            for url in urls:
                base_path, query = parse_url(url)
                base_paths.append(base_path)
                for k, v in query.items():
                    queries[k].update(v)
            return base_paths, queries

        pred = clean_url(page.url)
        ref_urls = task_config["eval"]["reference_url"].split(" |OR| ")
        ref_urls = [clean_url(url) for url in ref_urls]
        matching_rule = task_config["eval"].get("url_note", "GOLD in PRED")
        if matching_rule == "GOLD in PRED":
            for ref_url in ref_urls:
                ref_base_path, ref_query = parse_url(ref_url)
                pred_base_paths, pred_query = parse_url(pred)
                # Calculate base score for each ref_url
                base_score = float(ref_base_path in pred_base_paths)
                # Calculate query score for each ref_url
                query_score = 1.0
                for k, possible_values in ref_query.items(): # type: ignore
                    if k in pred_query:
                        query_score *= float(
                            any(
                                possible_ref_value in pred_query.get(k, []) # type: ignore
                                for possible_ref_value in possible_values # type: ignore
                            )
                        )
                    else:
                        # If the key is not in pred_query, check if the reference URL has no query parameters
                        if not possible_values:
                            query_score *= 1.0  # No query parameters to match, so consider it a match
                        else:
                            query_score *= 0.0  # Reference URL has query parameters but predicted URL does not
                # Calculate final score for each ref_url
                score = base_score * query_score
                # Return immediately if any score is 1
                if score == 1.0:
                    return {"score": score}

        else:
            raise ValueError(f"Unknown matching rule: {matching_rule}")

        return {"score": 0.0}


class HTMLContentEvaluator(Evaluator):
    """Evaluates if specified HTML content or elements appear on the webpage.

    This involves navigating to URLs specified in the configuration and checking for the presence of HTML elements or content using various strategies.
    """

    async def __call__(
        self,
        task_config: dict[str, Any],
        page: Page,
        client: CDPSession | None = None,
        answer: str | None = None
    ) -> dict[str, float|str]:
        """Evaluates the presence of specified HTML content on the webpage.

        Parameters:
            task_config (dict[str, Any]): The task configuration containing evaluation criteria.
            page (Page): The Playwright page object for the current webpage.
            client (CDPSession | None, optional): The Chrome DevTools Protocol session object. Not used in this evaluator.
            answer (str | None, optional): Not used in this evaluator.

        Returns:
            dict[str, float|str]: "score" A score between 0.0 and 1.0 representing the presence of required HTML content on the webpage.

        Raises:
            ValueError: If an unknown locator strategy is specified in the config file.
        """
        targets = task_config["eval"]["program_html"]

        score = 1.0
        for target in targets:
            target_url: str = target["url"]  # which url to check
            if target_url.startswith("func"):
                func = target_url.split("func:")[1]
                func = func.replace("__last_url__", page.url)
                target_url = eval(func)

            locator: str = target["locator"]  # js element locator

            # navigate to that url
            if target_url != "last":
                page.goto(target_url)
                time.sleep(3)

            # empty, use the full page
            if not locator.strip():
                selected_element = page.content()
            # use JS to select the element
            elif locator.startswith("document.") or locator.startswith("[...document.") or locator.startswith("jsblock:"):
                if "prep_actions" in target:
                    try:
                        for prep_action in target["prep_actions"]:
                            page.evaluate(f"() => {prep_action}")
                    except Exception:
                        pass
                try:
                    if locator.startswith("jsblock:"):
                        locator = locator.split("jsblock:")[1]

                    selected_element = str(await page.evaluate(f"() => {locator}"))
                    if not selected_element:
                        selected_element = ""
                except Exception:
                    # the page is wrong, return empty
                    selected_element = ""
            # run program to call API
            elif locator.startswith("func:"):  # a helper function
                func = locator.split("func:")[1]
                func = func.replace("__page__", "page")
                selected_element = eval(func)
            else:
                raise ValueError(f"Unknown locator: {locator}")

            selected_element = html.unescape(selected_element)

            if "exact_match" in target["required_contents"]:
                required_contents = target["required_contents"]["exact_match"]
                cur_score = evaluate_exact_match(
                    ref=required_contents, pred=selected_element
                )
                score *= float(cur_score)
                # logger.info(f"[exact match] {cur_score}, selected element: {selected_element}, required contents: {required_contents}")
            elif "must_include" in target["required_contents"]:
                required_contents = target["required_contents"]["must_include"]
                assert isinstance(required_contents, list)
                for content in required_contents: # type: ignore
                    content_or = content.split(" |OR| ") # type: ignore
                    cur_score = any(
                        [
                            evaluate_must_include(
                                ref=content, # type: ignore
                                pred=selected_element,
                                tokenize=False,
                            )
                            for content in content_or # type: ignore
                        ]
                    )
                    score *= float(cur_score)
                    # logger.info(f"[must include] {cur_score}, selected element: {selected_element}, required contents: {content_or}")
            else:
                raise ValueError(
                    f"Unknown required_contents: {target['required_contents'].keys()}"
                )
        return {"score": score}

class ManualContentEvaluator(Evaluator):
    """Evaluation Route for Manual Evaluation."""
    async def __call__(
        self,
        task_config: dict[str, Any],
        page: Page,
        client: CDPSession | None = None,
        answer: str | None = None
    ) -> dict[str, float|str]:
        """Pauses Execution to get manual evaluation score from user.

        Parameters:
            task_config (dict[str, Any]): The task configuration containing evaluation criteria.
            page (Page): The Playwright page object for the current webpage.
            client (CDPSession | None, optional): The Chrome DevTools Protocol session object. Not used in this evaluator.
            answer (str | None, optional): Not used in this evaluator.

        Returns:
            dict[str, float|str]: A score representig the status 1 = pass, 0 = fail and -0.1 is a skip. Additionaly, a reason can be provided for the score (mainly for fail/skip).
        """
        task = task_config["intent"]
        reference_answer = task_config["eval"]["reference_answers"]["manual_check"]["answer"]
        answer_type = task_config["eval"]["reference_answers"]["manual_check"]["type"]
        id = str(task_config["task_id"])
        index = str(task_config["task_index"])

        print(colored("\n\n***************************\n", "green", attrs=["bold"]))
        print(colored("Task ID: ", "blue", attrs=["bold"]) + id + "\n")
        print(colored("Task Index: ", "blue", attrs=["bold"]) + index + "\n")
        print(colored("Task: ", "blue", attrs=["bold"]) + task + "\n")
        print(colored("Agent answer: ", "blue", attrs=["bold"]) + str(answer or "") + "\n")

        if answer_type.strip().lower() == "possible":
            print(colored("Possible answer (reference): ", "yellow") + f"~~~{reference_answer}~~~")
        elif answer_type.strip().lower() == "golden":
            print(colored("Golden answer (reference): ", "yellow") + reference_answer)

        user_response = input(colored("Annotate the task as Pass, Fail or Skip (please use Skip sparingly)? ", "magenta", attrs=["bold"]))
        eval_response: dict[str, float|str] = {}
        if(user_response.lower()=="pass"):
            eval_response["score"] = 1.0
        elif user_response.lower()=="fail":
            eval_response["score"] = 0.0
        elif user_response.lower()=="skip":
            eval_response["score"] = -0.1
        else:
            print(colored(f"Received response: {user_response}", "red"))
            raise ValueError("Invalid user response. Please enter 'Pass', 'Fail' or 'Skip'.")
        reason: str|None = None

        if eval_response["score"] <= 0:
            reason = input("Reason for rating: ")
            eval_response["reason"] = reason

        return eval_response

class EvaluatorComb(Evaluator):
    """Combines multiple evaluators to perform a comprehensive evaluation based on different criteria.

    Attributes:
        evaluators (list[Evaluator]): A list of evaluator instances to be used for evaluation.
    """

    def __init__(self, evaluators: list[Evaluator]) -> None:
        """Initializes the composite evaluator with a list of individual evaluators.

        Parameters:
            evaluators (list[Evaluator]): The list of evaluators to include in the composite evaluation.
        """
        self.evaluators = evaluators


    async def __call__(
        self,
        task_config: dict[str, Any],
        page: Page,
        client: CDPSession,
        answer: str,
    ) -> dict[str, float|str]:
        """Performs the evaluation using all included evaluators and aggregates their scores.

        Parameters:
            task_config (dict[str, Any]): The task configuration containing evaluation criteria.
            page (Page): The Playwright page object for the current webpage.
            client (CDPSession): The Chrome DevTools Protocol session object.
            answer (str): The answer or content to be evaluated.

        Returns:
            dict[str, float|str]: "score" - The aggregated score from all evaluators, representing the overall evaluation result. "reason" - The reason for the evaluation score, if applicable.
        """
        score: float = 1.0
        reason: str | None = None
        for evaluator in self.evaluators:
            eval_result = await evaluator(task_config, page, client, answer)
            score: float = score * eval_result["score"] # type: ignore
            if "reason" in eval_result:
                if reason is None:
                    reason = eval_result["reason"] # type: ignore
                else:
                    reason += f"\n{eval_result['reason']}"
        return {"score": score, "reason": reason} # type: ignore


def evaluator_router(task_config: dict[str, Any]) -> EvaluatorComb:
    """Creates and configures a composite evaluator based on the evaluation types specified in the configuration file.

    Parameters:
        task_config dict[str, Any]: configuration specifying the evaluation types to use.

    Returns:
        EvaluatorComb: A composite evaluator configured with the specified types of individual evaluators.

    Raises:
        ValueError: If an unsupported evaluation type is specified in the configuration file.
    """

    eval_types = task_config["eval"]["eval_types"]
    evaluators: list[Evaluator] = []
    for eval_type in eval_types:
        match eval_type:
            case "string_match":
                logger.info("Adding string evaluator")
                evaluators.append(StringEvaluator())
            case "url_match":
                logger.info("Adding URL evaluator")
                evaluators.append(URLEvaluator())
            case "program_html":
                logger.info("Adding HTML evaluator")
                evaluators.append(HTMLContentEvaluator())
            case "manual":
                logger.info("Adding manual evaluator")
                evaluators.append(ManualContentEvaluator())
            case _:
                raise ValueError(f"eval_type {eval_type} is not supported")

    return EvaluatorComb(evaluators)
