import base64
import os
import re
import tempfile
from dataclasses import dataclass
from io import BytesIO
from typing import Any, Dict, TypedDict, Union

import numpy as np
import numpy.typing as npt
import requests
from beartype import beartype
from PIL import Image

from browser_env.env_config import LOCAL_URLS_NORM, URL_MAPPINGS


@dataclass
class DetachedPage:
    url: str
    content: str  # html


@beartype
def png_bytes_to_numpy(png: bytes) -> npt.NDArray[np.uint8]:
    """Convert png bytes to numpy array

    Example:

    >>> fig = go.Figure(go.Scatter(x=[1], y=[1]))
    >>> plt.imshow(png_bytes_to_numpy(fig.to_image('png')))
    """
    return np.array(Image.open(BytesIO(png)))


# REVIEW[mandrade]: added user header to bypass blocking of automated requests
def get_image_from_url(url: str, headers: dict = None, timeout: int = 60) -> Image:
    if not headers:
        headers = {
            "User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/58.0.3029.110 Safari/537.3"
        }
    return Image.open(requests.get(url, stream=True, headers=headers, timeout=timeout).raw)


class DOMNode(TypedDict):
    nodeId: str
    nodeType: str
    nodeName: str
    nodeValue: str
    attributes: str
    backendNodeId: str
    parentId: str
    childIds: list[str]
    cursor: int
    union_bound: list[float] | None
    center: list[float] | None


class AccessibilityTreeNode(TypedDict):
    nodeId: str
    ignored: bool
    role: dict[str, Any]
    chromeRole: dict[str, Any]
    name: dict[str, Any]
    properties: list[dict[str, Any]]
    childIds: list[str]
    parentId: str
    backendDOMNodeId: int
    frameId: str
    bound: list[float] | None
    union_bound: list[float] | None
    offsetrect_bound: list[float] | None
    center: list[float] | None


class BrowserConfig(TypedDict):
    win_upper_bound: float
    win_left_bound: float
    win_width: float
    win_height: float
    win_right_bound: float
    win_lower_bound: float
    device_pixel_ratio: float


class BrowserInfo(TypedDict):
    DOMTree: dict[str, Any]
    config: BrowserConfig


AccessibilityTree = list[AccessibilityTreeNode]
DOMTree = list[DOMNode]

Observation = str | npt.NDArray[np.uint8]


class StateInfo(TypedDict):
    observation: dict[str, Observation]
    info: Dict[str, Any]


# REVIEW[mandrade]: added more robust methods to map url to local counterparts and vice-versa
def normalize_url(url: str) -> str:
    url = re.sub(r"https://", "http://", url)
    # Check if contains `http://`. If not, add it.
    if not re.match("http://", url):
        url = "http://" + url
    return url


def map_url_to_local(url: str) -> str:
    """Map real urls to their local counterparts
    Example: https://wikipedia.org -> http://localhost:8888/wikipedia_en_all_maxi_2022-05
    """
    url = re.sub(r"https://", "http://", url, flags=re.IGNORECASE)
    url = re.sub("//en.", "//", url, flags=re.IGNORECASE)
    url = re.sub("www.", "", url, flags=re.IGNORECASE)

    for local_url, real_url in URL_MAPPINGS.items():
        # This is relevant for Wikipedia, which has an additional path for the local url
        local_url = LOCAL_URLS_NORM[local_url]

        # If real_url counterpart in `url`, map it to the `local` equivalent
        if real_url in url:
            return url.replace(real_url, local_url)

        real_url_part = real_url.strip("http://")

        # If the mapping is for a `.org` URL, also check for its `.com` variant.
        if real_url.endswith(".org"):
            alt_url = real_url[:-4] + ".com"
            if alt_url in url:
                # First, fix the URL so that it uses .org instead of .com.
                url = url.replace(alt_url, real_url)
                # Now perform the usual mapping.
                return url.replace(real_url, local_url)

        if re.search(real_url_part, url, re.IGNORECASE):
            local_url_part = local_url.strip("http://")
            mapped_url = url.replace(real_url_part, local_url_part)
            mapped_url = re.sub("www.", "", mapped_url)
            return mapped_url

    # If no local urls found, return the original url
    return url


def map_url_to_real(url: str) -> str:
    """Map local urls to their real world counterparts"""
    url = re.sub(r"https://", "http://", url, flags=re.IGNORECASE)
    url = re.sub("www.", "", url, flags=re.IGNORECASE)
    url = re.sub("//en.", "//", url, flags=re.IGNORECASE)
    for local_url, real_url in URL_MAPPINGS.items():
        local_url_norm = LOCAL_URLS_NORM[local_url]
        if local_url_norm in url:
            url = url.replace(local_url_norm, real_url)
        elif local_url in url:
            url = url.replace(local_url, real_url)
    return url


# REVIEW[mandrade]: This is needed for String match evals if hosting the benchmark on other machines.
def map_endpoint_to_local(url: str, local_endpoint: str = "127.0.0.1") -> str:
    """Map the endpoint of the url to a local counterpart.
    This is needed for String match evals if hosting the benchmark on other machines.
    Example: http://143.215.128.18:8888 -> http://127.0.0.1:8888

    Args:
        url (str): The url to map the endpoint of.
        local_endpoint (str, optional): The local endpoint to map the url to. Defaults to "127.0.0.1".

    Returns:
        str: The url with the endpoint mapped to the local counterpart.
    """
    url = re.sub(r"https://", "http://", url)
    for local_url in URL_MAPPINGS.keys():
        if local_url in url:
            base_endpoint = re.sub(r"^https?://", "", local_url)
            base_endpoint = re.sub(r":[^:]*$", "", base_endpoint)
            return url.replace(base_endpoint, local_endpoint)
    return url
