import glob
import hashlib
import json
import os
import shutil
import subprocess
from pathlib import Path
from typing import Any, Callable, List

import numpy as np  # pip install numpy if you don’t already have it
from matplotlib import pyplot as plt

from utils.debug_utils import set_env_variables

set_env_variables()

import re
import tempfile
import time

from PIL import Image

from browser_env import Action, ActionTypes, Trajectory
from browser_env.actions import is_equivalent
from browser_env.auto_login import get_site_comb_from_filepath, is_expired_for_sites
from utils.image_utils import get_image_from_url


# ===============================================================================
# Visualization aux functions
# ===============================================================================
def show_obs_img(trajectory: Trajectory):
    # Get last observation image from trajectory and show it
    plt.figure(figsize=(12, 12))
    plt.imshow(trajectory[-1]["observation"]["image"])
    plt.axis("off")
    plt.show(block=False)


# ===============================================================================
# Data loaders
# ===============================================================================
def load_task_config(config_file: str):
    with open(config_file) as f:
        task_config = json.load(f)

    intent = task_config["intent"]
    task_id = task_config["task_id"]
    image_paths = task_config.get("image", None)
    intent_images: List[Image.Image] = []

    if image_paths is not None:
        if isinstance(image_paths, str):
            image_paths = [image_paths]
        for image_path in image_paths:
            if image_path.startswith("http"):
                input_image = get_image_from_url(image_path)
            else:
                input_image = Image.open(image_path)
            intent_images.append(input_image)

    return task_config, intent, task_id, intent_images


def get_tasks_with_trajectory(trajectory_html_path: str, test_config_base_dir: str, test_file_list: list[str] = []):
    html_files = glob.glob(os.path.join(trajectory_html_path, "*.html"))
    matches = [re.search(r"(\d+)\.html$", html_file) for html_file in html_files]
    task_ids = sorted([int(match.group(1)) for match in matches if match])
    test_files_with_trajectory = [os.path.join(test_config_base_dir, f"{i}.json") for i in task_ids]
    if test_file_list:
        test_file_list = [t for t in test_files_with_trajectory if t in test_file_list]
    else:
        test_file_list = test_files_with_trajectory
    return test_file_list


# ===============================================================================
# Helpers for evaluation loop
# ===============================================================================
def check_fuzzy_match(config=None):
    eval_dict = config["eval"]
    if "reference_answers" in eval_dict and eval_dict["reference_answers"] is not None:
        if "fuzzy_match" in config["eval"]["reference_answers"]:
            return True
    return False


def early_stop(trajectory: Trajectory, max_steps: int, thresholds: dict[str, int]) -> tuple[bool, str]:
    """Check whether need to stop early"""

    # reach the max step
    num_steps = (len(trajectory) - 1) / 2
    if num_steps >= max_steps:
        return True, f"Reach max steps {max_steps}"

    last_k_actions: list[Action]
    action_seq: list[Action]

    # Case: parsing failure for k times
    k = thresholds["parsing_failure"]
    last_k_actions = trajectory[1::2][-k:]  # type: ignore[assignment]

    if len(last_k_actions) >= k:
        if all([action["action_type"] == ActionTypes.NONE for action in last_k_actions]):
            return True, f"Failed to parse actions for {k} times"

    # Case: same action for k times
    k = thresholds["repeating_action"]
    last_k_actions = trajectory[1::2][-k:]  # type: ignore[assignment]
    action_seq = trajectory[1::2]  # type: ignore[assignment]

    if len(action_seq) == 0:
        return False, ""

    last_action: Action = action_seq[-1]

    if last_action["action_type"] != ActionTypes.TYPE:
        if len(last_k_actions) >= k:
            if all([is_equivalent(action, last_action) for action in last_k_actions]):
                return True, f"Same action for {k} times"
    else:
        # check the action sequence
        if sum([is_equivalent(action, last_action) for action in action_seq]) >= k:
            return True, f"Same typing action for {k} times"
    return False, ""


def check_na_ref_answer(config_file):
    reference_answers = config_file["eval"]["reference_answers"]
    if reference_answers is not None:
        ref_answers = [item for item in reference_answers.values()]
        return any(["N/A" in ans for ans in ref_answers])
    return False


# ===============================================================================
# Helpers for env setup, start, reset
# ===============================================================================
def get_expired_sites(sites_list: list[str] | str, auth_folder: str = "./.auth", exc_comb: bool = False):
    if isinstance(sites_list, str):
        sites_list = [sites_list]
    expired_sites, expired_cookies = is_expired_for_sites(sites_list, auth_folder, exc_comb)
    return expired_sites, expired_cookies


def auto_login(config_dict, config_file, auth_folder: str = "./.auth", out_dir: str = "", renew=True):
    # Automatically login
    if config_dict["storage_state"]:
        cookie_file_name = os.path.basename(config_dict["storage_state"])
        comb = get_site_comb_from_filepath(cookie_file_name)
        if not out_dir:
            out_dir = tempfile.mkdtemp()

        # Copy contents from auth_folder to temp_dir
        for file in os.listdir(auth_folder):
            shutil.copy(os.path.join(auth_folder, file), os.path.join(out_dir, file))

        if renew:
            # Check which sites in `comb` whose cookies are expired
            expired_sites, expired_cookies = is_expired_for_sites(sites=comb, auth_folder=out_dir)
            if expired_sites:
                # subprocess to renew the cookie
                subprocess.run(
                    [
                        "python",
                        "-m",
                        "browser_env.auto_login",
                        "--auth_folder",
                        out_dir,
                        "--site_list",
                        *expired_sites,
                    ]
                )

        config_dict["storage_state"] = f"{out_dir}/{cookie_file_name}"
        assert os.path.exists(config_dict["storage_state"])

        # Update the config file
        config_file = f"{out_dir}/{os.path.basename(config_file)}"
        with open(config_file, "w") as f:
            json.dump(config_dict, f)
        return config_file


def create_autologin_cookies():
    print("[INFO] Creating autologin cookies...")
    cmd = ["./scripts/environments/autologin_cookies.sh", "local_vwebarena"]
    subprocess.run(cmd, check=True)
    print("[INFO] Cookies created.")


# ===============================================================================
# Other helpers
# ===============================================================================
def get_domain_from_test_config_dir(test_config_dir: str) -> str:
    return Path(test_config_dir).name.replace("test_", "")


# ===============================================================================
# Internet helpers
# ===============================================================================


def wait_for_spinners(
    page: Any, selector: str, max_timeout_ms: float = 2 * 1000, wait_to_appear_ms: int = 0, state="detached"
) -> None:
    if page.query_selector(selector):
        page.wait_for_selector(selector, state=state, timeout=max_timeout_ms)

    if not wait_to_appear_ms:
        return

    time.sleep(wait_to_appear_ms / 1000)
    if page.query_selector(selector):
        page.wait_for_selector(selector, state=state, timeout=max_timeout_ms)
    else:
        return


def get_dom_hash(page: Any) -> str:
    """Compute a hash of the page DOM to check for stability."""
    dom = page.content()
    return hashlib.md5(dom.encode("utf-8")).hexdigest()


def wait_for_dom_hash_stability(page: Any, interval_ms: float = 500, checks: int = 3) -> None:
    """Waits until the DOM is stable for `checks` consecutive intervals."""
    previous_hash = get_dom_hash(page)
    stable_count = 0

    for _ in range(checks):
        time.sleep(interval_ms / 1000)
        current_hash = get_dom_hash(page)
        if current_hash == previous_hash:
            stable_count += 1
        else:
            stable_count = 0
            previous_hash = current_hash

    if stable_count == checks:
        return
    else:
        raise Exception(f"DOM hash is not stable after {checks} checks")


def wait_with_timeouts(wait_func: Callable, timeouts: list[int], logger=None, description="operation") -> bool:
    """
    Attempts to perform a waiting operation using a series of progressive timeouts.

    Args:
        wait_func: A callable that receives a timeout (in ms) and performs the waiting operation.
        timeouts: A list of timeout values (in ms) to try sequentially.
        logger: Optional logger for reporting status messages.
        description: A textual description of the operation for logging purposes.

    Returns:
        bool: True if one attempt was successful, False if all attempts failed.
    """
    logger = None
    for t in timeouts:
        try:
            wait_func(t)
            if logger:
                logger.debug(f"{description} succeeded with timeout {t}ms")
            return True
        except Exception as e:
            if logger:
                logger.debug(f"{description} failed with timeout {t}ms: {e}")
    if logger:
        logger.warning(f"All attempts for {description} failed. Tried timeouts: {timeouts}")
    return False


def wait_for_page_to_stabilize(
    page: Any,
    max_timeout_ms: float = 2 * 1000,
    logger: Any | None = None,
    min_num_trues: int = 3,
    return_early: bool = False,
    return_after: int | None = None,
    hard_sleep: float = 0.0,
) -> bool:
    """
    Wait for a page to fully stabilize using multiple waiting mechanisms.
    Stops early if min_num_trues checks pass.

    Args:
        page: Playwright page object (sync API)
        max_timeout_ms: Maximum timeout per wait function
        logger: Optional logger for reporting
        min_num_trues: Minimum number of successful checks to consider the page stable
        return_early: If True, return after min_num_trues checks pass.
        return_after: If not None, return after `return_after` successful checks; overrides return_early.
    Returns:
        bool: True if at least min_num_trues checks passed; False otherwise.
    """
    if hard_sleep > 0.0:
        time.sleep(hard_sleep)
        return True
    successful_checks = 0

    checks = [
        # Check 1: DOMContentLoaded
        # --------------------------
        # Wait for the DOMContentLoaded event which fires after the initial HTML
        # is loaded and parsed. This indicates that the core DOM structure is available.
        (
            lambda t: page.wait_for_load_state("domcontentloaded", timeout=min(t, max_timeout_ms)),
            [500, 500, 500],  # Wait for a max of 1500ms
            "DOMContentLoaded",
        ),
        # Check 2: Document Ready State
        # -----------------------------
        # Wait until document.readyState is 'complete', ensuring that all sub-resources,
        # such as images and stylesheets, have been fully loaded.
        (
            lambda t: page.wait_for_function("document.readyState === 'complete'", timeout=min(t, max_timeout_ms)),
            [500, 500, 500],
            "document ready state",
        ),
        # Check 3: Network Idle
        # ----------------------
        # Wait for the network to become idle. This check ensures that asynchronous
        # background tasks (e.g., API calls or lazy-loaded resources) have mostly finished.
        (
            lambda t: page.wait_for_load_state("networkidle", timeout=min(t, max_timeout_ms)),
            [500, 500, 500, 500, 500],  # Wait for a total of 2500ms
            "network idle",
        ),
        # Check 4: DOM Stability
        # -----------------------
        # Utilizes a MutationObserver to monitor changes in the DOM.
        # If no DOM mutations occur for 300ms, it is assumed that the page has stabilized.
        (
            lambda t: page.wait_for_function(
                """() => {
                    return new Promise(resolve => {
                        let timeout;
                        const observer = new MutationObserver(() => {
                            clearTimeout(timeout);
                            timeout = setTimeout(resolve, 300);
                        });
                        observer.observe(document.body, {
                            childList: true,
                            subtree: true,
                            attributes: true,
                            characterData: true
                        });
                        timeout = setTimeout(resolve, 300);
                    });
                }""",
                timeout=min(t, max_timeout_ms),
            ),
            [500, 500, 500],  # Wait for a total of 1500ms
            "DOM stability",
        ),
        (
            lambda t: wait_for_dom_hash_stability(page, interval_ms=t, checks=3),
            [50, 100],
            "DOM stability",
        ),
        # Check 5: Animation Frames
        # --------------------------
        # Uses two consecutive requestAnimationFrame calls to ensure that all rendering updates,
        # such as animations and transitions, have been flushed.
        (
            lambda t: page.wait_for_function(
                """() => new Promise(resolve => {
                    requestAnimationFrame(() => requestAnimationFrame(resolve));
                })""",
                timeout=min(t, max_timeout_ms),
            ),
            [200, 200, 100],  # Wait for a max of 500ms
            "animation frames",
        ),
        (
            lambda t: wait_for_spinners(
                page, selector="#checkout-loader", max_timeout_ms=min(t, max_timeout_ms), wait_to_appear_ms=5
            ),
            [500, 500, 500],  # Wait for a max of 2000ms
            "spinner disappearance",
        ),
    ]

    # Regularize params
    min_num_trues = min(min_num_trues, len(checks))

    if not return_after:
        # If return_after is not set, use return_early or default to len(checks)
        if return_early:
            return_after = min_num_trues
        else:
            return_after = len(checks)
    else:
        # If return_after is set, use it and ensure it's not greater than len(checks)
        return_after = min(return_after, len(checks))

    # emulate_slow_network(page)

    try:
        for wait_func, timeouts, description in checks:
            try:
                if wait_with_timeouts(wait_func, timeouts, logger=logger, description=description):
                    successful_checks += 1

                # Stop early if the minimum number of successful checks has been reached.
                if successful_checks >= return_after:
                    if logger:
                        logger.debug(f"Page stabilized early after {successful_checks} successful checks.")
                    return True
            except Exception as e:
                if logger:
                    logger.warning(f"Error in timeout check {description}: {e}")
                continue

    except Exception as e:
        if logger:
            logger.warning(f"Error in wait_for_page_to_stabilize: {e}")

    if logger:
        logger.debug(f"Page stabilized with {successful_checks} successful checks.")
    return successful_checks >= min_num_trues


# useful for debugging purposes
def emulate_slow_network(page):
    # Create a new CDP session for this page.
    client = page.context.new_cdp_session(page)
    # Enable the Network domain.
    client.send("Network.enable")
    # Emulate network conditions.
    client.send(
        "Network.emulateNetworkConditions",
        {
            "offline": False,
            "latency": 500,  # 500 ms of additional latency.
            "downloadThroughput": 50 * 1024 / 8,  # ~50 kb/s download.
            "uploadThroughput": 50 * 1024 / 8,  # ~50 kb/s upload.
        },
    )
    # Now, when you navigate, the page will load under these throttled conditions.
