import json
import lxml
import re
from collections import defaultdict
from typing import Any, TypedDict, Union

import numpy as np
import numpy.typing as npt
from gymnasium import spaces
from playwright.sync_api import CDPSession, Page, ViewportSize

from .browser_constants import (
    ASCII_CHARSET,
    FREQ_UNICODE_CHARSET,
    IGNORED_ACTREE_PROPERTIES,
    UTTERANCE_MAX_LENGTH,
)

from .utils import (
    AccessibilityTree,
    AccessibilityTreeNode,
    BrowserConfig,
    BrowserInfo,
    DOMNode,
    DOMTree,
    Observation,
    png_bytes_to_numpy,
)

from .html_tools import HtmlParser, basic_attrs, print_html_object

IN_VIEWPORT_RATIO_THRESHOLD = 0.6

def merge_consecutive_static_text_nodes(accessibility_tree):
    """
    在同一个父节点下，将连续的 StaticText 合并到第一个出现的节点里，
    并删除后续 StaticText 节点（相当于缩短 childIds 列表）。
    """
    # 为了能根据 nodeId 快速找到节点和索引
    node_id_to_idx = {n["nodeId"]: i for i, n in enumerate(accessibility_tree)}

    for node in accessibility_tree:
        children = node.get("childIds", [])
        if not children:
            continue

        new_children = []
        i = 0
        while i < len(children):
            child_id = children[i]
            child_node = accessibility_tree[node_id_to_idx[child_id]]
            role = child_node["role"]["value"]

            if role == "StaticText":
                # 把第一个 staticText 的文本取出来
                merged_text = child_node["name"]["value"]
                merged_id = child_node["nodeId"]

                # 往后看是否有连续的 staticText
                j = i + 1
                while j < len(children):
                    next_child_id = children[j]
                    next_child_node = accessibility_tree[node_id_to_idx[next_child_id]]
                    if next_child_node["role"]["value"] == "StaticText":
                        # 把文本并入到第一个节点上
                        merged_text += next_child_node["name"]["value"]
                        # 这里就不用单独删除 next_child_node，
                        # 只要不加到 new_children 即可
                        j += 1
                    else:
                        break

                # 把合并之后的文本更新到第一个 staticText 节点
                child_node["name"]["value"] = merged_text
                # 保留这个合并节点
                new_children.append(child_id)

                # 跳过中间已经合并的 staticText
                i = j
            else:
                # 如果不是 staticText，则直接保留这个 child
                new_children.append(child_id)
                i += 1

        # 更新当前节点的 childIds
        node["childIds"] = new_children


class ObservationProcessor:
    def process(self, page: Page, client: CDPSession) -> Observation:
        raise NotImplementedError


class ObservationMetadata(TypedDict):
    obs_nodes_info: dict[str, Any]


def create_empty_metadata() -> ObservationMetadata:
    return {
        "obs_nodes_info": {},
    }


class TextObervationProcessor(ObservationProcessor):
    def __init__(
            self,
            observation_type: str,
            current_viewport_only: bool,
            viewport_size: ViewportSize,
    ):
        self.observation_type = observation_type
        self.current_viewport_only = current_viewport_only
        self.viewport_size = viewport_size
        self.observation_tag = "text"
        self.meta_data = (
            create_empty_metadata()
        )  # use the store meta data of this observation type

    def fetch_browser_info(
            self,
            page: Page,
            client: CDPSession,
    ) -> BrowserInfo:
        # extract domtree
        tree = client.send(
            "DOMSnapshot.captureSnapshot",
            {
                "computedStyles": [],
                "includeDOMRects": True,
                "includePaintOrder": True,
            },
        )

        # calibrate the bounds, in some cases, the bounds are scaled somehow
        bounds = tree["documents"][0]["layout"]["bounds"]
        b = bounds[0]
        n = b[2] / self.viewport_size["width"]
        bounds = [[x / n for x in bound] for bound in bounds]
        tree["documents"][0]["layout"]["bounds"] = bounds

        # extract browser info
        win_top_bound = page.evaluate("window.pageYOffset")
        win_left_bound = page.evaluate("window.pageXOffset")
        win_width = page.evaluate("window.screen.width")
        win_height = page.evaluate("window.screen.height")
        win_right_bound = win_left_bound + win_width
        win_lower_bound = win_top_bound + win_height
        device_pixel_ratio = page.evaluate("window.devicePixelRatio")
        assert device_pixel_ratio == 1.0, "devicePixelRatio is not 1.0"

        config: BrowserConfig = {
            "win_top_bound": win_top_bound,
            "win_left_bound": win_left_bound,
            "win_width": win_width,
            "win_height": win_height,
            "win_right_bound": win_right_bound,
            "win_lower_bound": win_lower_bound,
            "device_pixel_ratio": device_pixel_ratio,
        }

        # assert len(tree['documents']) == 1, "More than one document in the DOM tree"
        info: BrowserInfo = {"DOMTree": tree, "config": config}
        # with open('output/browser_info.json', 'w') as f:
        #     f.write(json.dumps(tree, ensure_ascii=False))
        return info

    @staticmethod
    def get_bounding_client_rect(
            client: CDPSession, backend_node_id: str
    ) -> dict[str, Any]:
        try:
            remote_object = client.send(
                "DOM.resolveNode", {"backendNodeId": int(backend_node_id)}
            )
            remote_object_id = remote_object["object"]["objectId"]
            response = client.send(
                "Runtime.callFunctionOn",
                {
                    "objectId": remote_object_id,
                    "functionDeclaration": """
                        function() {
                            if (this.nodeType == 3) {
                                var range = document.createRange();
                                range.selectNode(this);
                                var rect = range.getBoundingClientRect().toJSON();
                                range.detach();
                                return rect;
                            } else {
                                return this.getBoundingClientRect().toJSON();
                            }
                        }
                    """,
                    "returnByValue": True,
                },
            )
            return response
        except Exception as e:
            return {"result": {"subtype": "error"}}

    @staticmethod
    def get_element_in_viewport_ratio(
            elem_left_bound: float,
            elem_top_bound: float,
            width: float,
            height: float,
            config: BrowserConfig,
    ) -> float:
        elem_right_bound = elem_left_bound + width
        elem_lower_bound = elem_top_bound + height

        win_left_bound = 0
        win_right_bound = config["win_width"]
        win_top_bound = 0
        win_lower_bound = config["win_height"]

        # Compute the overlap in x and y axes
        overlap_width = max(
            0,
            min(elem_right_bound, win_right_bound)
            - max(elem_left_bound, win_left_bound),
        )
        overlap_height = max(
            0,
            min(elem_lower_bound, win_lower_bound)
            - max(elem_top_bound, win_top_bound),
        )

        # Compute the overlap area
        ratio = overlap_width * overlap_height / width * height
        return ratio

    def fetch_page_html(
            self,
            info: BrowserInfo,
            page: Page,
            client: CDPSession,
            current_viewport_only: bool,
    ) -> DOMTree:
        # adopted from [natbot](https://github.com/nat/natbot)
        tree = info["DOMTree"]
        config = info["config"]
        strings = tree["strings"]
        document = tree["documents"][0]
        nodes = document["nodes"]
        layout = document["layout"]

        # print(len(nodes["nodeName"]), len(layout["nodeIndex"]), len(layout["bounds"]))

        import time
        stt = time.time()
        # make a dom tree that is easier to navigate
        dom_tree: DOMTree = []
        graph = defaultdict(list)
        print(nodes.keys())
        for node_idx in range(len(nodes["nodeName"])):
            cur_node: DOMNode = {
                "nodeId": "",
                "nodeType": "",
                "nodeName": "",
                "nodeValue": "",
                "attributes": "",
                "backendNodeId": "",
                "parentId": "",
                "childIds": [],
                "cursor": 0,
                "union_bound": None,
            }

            node_type_idx = nodes["nodeType"][node_idx]
            node_type = "generic"
            if node_type_idx >= 0 and node_type_idx < len(strings):
                node_type = strings[node_type_idx]

            node_name = strings[nodes["nodeName"][node_idx]]

            node_value_idx = nodes["nodeValue"][node_idx]
            node_value = ""
            if node_value_idx >= 0 and node_value_idx < len(strings):
                node_value = " ".join(strings[node_value_idx].split())

            node_attributes = [
                strings[i] for i in nodes["attributes"][node_idx]
            ]
            node_attributes_str = ""
            for i in range(0, len(node_attributes), 2):
                a = node_attributes[i]
                b = node_attributes[i + 1]
                # b = " ".join(b.split())
                import re
                b = re.sub(r"{\s*opacity:\s*.*;*\s*}", " ", b)
                b = [b_item for b_item in b.split() if b_item.count('vimium') == 0]
                b = " ".join(b)
                node_attributes_str += f'{a}="{b}" '

            node_attributes_str = node_attributes_str.strip()

            cur_node["nodeId"] = str(node_idx)
            cur_node["nodeType"] = node_type
            cur_node["nodeName"] = node_name
            cur_node["nodeValue"] = node_value
            cur_node["attributes"] = node_attributes_str
            cur_node["backendNodeId"] = str(nodes["backendNodeId"][node_idx])
            cur_node["parentId"] = str(nodes["parentIndex"][node_idx])

            if cur_node["parentId"] != "-1":
                graph[cur_node["parentId"]].append(str(cur_node["nodeId"]))

            # get the bound
            if cur_node["parentId"] == "-1":
                cur_node["union_bound"] = [0.0, 0.0, 10.0, 10.0]
            else:
                # method 1
                # response = self.get_bounding_client_rect(
                #     client, cur_node["backendNodeId"]
                # )

                # if response.get("result", {}).get("subtype", "") == "error":
                #     cur_node["union_bound"] = None
                # else:
                #     x = response["result"]["value"]["x"]
                #     y = response["result"]["value"]["y"]
                #     width = response["result"]["value"]["width"]
                #     height = response["result"]["value"]["height"]
                #     cur_node["union_bound"] = [x, y, width, height]

                # method 2
                bound = [0.0, 0.0, 0.0, 0.0]
                if node_idx in layout["nodeIndex"]:
                    bound = layout["bounds"][layout["nodeIndex"].index(node_idx)]
                    bound[0] -= config["win_left_bound"]
                    bound[1] -= config["win_top_bound"]

                cur_node["union_bound"] = bound

            dom_tree.append(cur_node)
        print('[build]', time.time() - stt)

        stt = time.time()
        # add parent children index to the node
        for parent_id, child_ids in graph.items():
            dom_tree[int(parent_id)]["childIds"] = child_ids
        print('[graph]', time.time() - stt)

        # with open('output/dom_tree.json', 'w') as f:
        #     f.write(json.dumps(dom_tree, ensure_ascii=False))

        stt = time.time()
        # remove the nodes that are not in the current viewport
        if current_viewport_only:

            def remove_node_in_graph(node: DOMNode) -> None:
                # update the node information in the accessibility tree
                node_id = node["nodeId"]
                parent_id = node["parentId"]
                child_ids = node["childIds"]

                # update the children of the parent node
                assert dom_tree[int(parent_id)]["parentId"] != "[REMOVED]"
                # remove the nodeid from parent
                index = dom_tree[int(parent_id)]["childIds"].index(node_id)
                dom_tree[int(parent_id)]["childIds"].pop(index)

                # Insert children_nodeids in the same location
                for child_id in child_ids:
                    dom_tree[int(parent_id)]["childIds"].insert(
                        index, child_id
                    )
                    index += 1

                # update children node's parent
                for child_id in child_ids:
                    dom_tree[int(child_id)]["parentId"] = parent_id
                # mark as removed
                dom_tree[int(node_id)]["parentId"] = "[REMOVED]"

            config = info["config"]
            for cursor, node in enumerate(dom_tree):
                if not node["union_bound"]:
                    remove_node_in_graph(node)
                    continue

                [x, y, width, height] = node["union_bound"]

                # invisible node
                if width == 0.0 or height == 0.0:
                    parent_id = node["parentId"]
                    if node["nodeName"] not in ['OPTION'] or dom_tree[int(parent_id)]["nodeName"] not in ["SELECT"]:
                        remove_node_in_graph(node)
                    continue

                in_viewport_ratio = self.get_element_in_viewport_ratio(
                    elem_left_bound=float(x),
                    elem_top_bound=float(y),
                    width=float(width),
                    height=float(height),
                    config=config,
                )

                if in_viewport_ratio < IN_VIEWPORT_RATIO_THRESHOLD:
                    remove_node_in_graph(node)

            dom_tree = [
                node
                for node in dom_tree
                if node.get("parentId", "-1") != "[REMOVED]"
            ]

        print('[filter]', time.time() - stt)
        return dom_tree

    @staticmethod
    def parse_my_html(dom_tree: DOMTree) -> tuple[str, str, dict[str, Any], Any]:
        """Parse the html tree into a string text"""

        obs_nodes_info = {}
        nodeid_to_cursor = {
            node["nodeId"]: idx for idx, node in enumerate(dom_tree)
        }

        def dfs(node_cursor: int, depth: int) -> tuple[str, list[str]]:
            tree_str, labeled_elems = '', []
            node = dom_tree[node_cursor]
            valid_node = True
            pure_text = False
            try:
                if node['nodeName'] == '#text':
                    node['nodeName'] = 'text'

                node_str = f"<{node['nodeName']}"
                if node["attributes"]:
                    node_str += f" {node['attributes']}"
                node_str += f" backend-id=\"bid-{node['backendNodeId']}\"> {node['nodeValue']}"

                # if node['nodeName'] == '#text':
                #     pure_text = True
                #     node_str = node['nodeValue']

                valid_node = bool(node["attributes"] or node["nodeValue"] or pure_text)

                if valid_node:
                    node_html = lxml.html.fromstring(node_str)
                    label = node_html.attrib.get('data-testid', '')
                    if len(label) > 0:
                        labeled_elems.append(node["backendNodeId"])
                    obs_nodes_info[str(node_cursor)] = {
                        "backend_id": node["backendNodeId"],
                        "union_bound": node["union_bound"],
                        "text": node['nodeValue'],
                        "label": label,
                    }
                    tree_str += f"{node_str}"

            except Exception as e:
                valid_node = False

            for child_ids in node["childIds"]:
                child_cursor = nodeid_to_cursor[child_ids]
                child_depth = depth + 1 if valid_node else depth
                child_str, elems = dfs(child_cursor, child_depth)
                tree_str += child_str
                labeled_elems.extend(elems)

            if valid_node and not pure_text:
                tree_str += f"</{node['nodeName']}>"

            return tree_str, labeled_elems

        html, labeled_elems = dfs(0, 0)

        # with open('output/raw.html', 'w') as f:
        #     f.write(html)
        print(labeled_elems)

        args = {
            'use_position': False,
            'id_attr': 'backend-id',
            'label_generator': 'order',
            'label_attr': 'data-testid',
            'attr_list': basic_attrs,
            'prompt': 'refine',
        }

        hp = HtmlParser(html, args)
        packet = hp.parse_tree()
        page_html = packet['html']

        print(print_html_object(page_html))

        it, pt = packet.get('init_time', 0), packet.get('parse_time', 0)
        print(f'[Time] {it:.3f} {pt:.3f}')

        return html, page_html, obs_nodes_info, hp

    @staticmethod
    def parse_html(dom_tree: DOMTree) -> tuple[str, dict[str, Any]]:
        """Parse the html tree into a string text"""

        obs_nodes_info = {}
        nodeid_to_cursor = {
            node["nodeId"]: idx for idx, node in enumerate(dom_tree)
        }

        def dfs(node_cursor: int, depth: int) -> str:
            tree_str = ""
            node = dom_tree[node_cursor]
            indent = "\t" * depth
            valid_node = True
            try:
                node_str = f"[{node_cursor}] <{node['nodeName']}"
                if node["attributes"]:
                    node_str += f" {node['attributes']}"
                node_str += f"> {node['nodeValue']}"
                valid_node = bool(node["attributes"] or node["nodeValue"])

                if valid_node:
                    obs_nodes_info[str(node_cursor)] = {
                        "backend_id": node["backendNodeId"],
                        "union_bound": node["union_bound"],
                        "text": node_str,
                    }
                    tree_str += f"{indent}{node_str}\n"

            except Exception as e:
                valid_node = False

            for child_ids in node["childIds"]:
                child_cursor = nodeid_to_cursor[child_ids]
                child_depth = depth + 1 if valid_node else depth
                child_str = dfs(child_cursor, child_depth)
                tree_str += child_str

            return tree_str

        html = dfs(0, 0)
        return html, obs_nodes_info

    def fetch_page_accessibility_tree(
            self,
            info: BrowserInfo,
            client: CDPSession,
            current_viewport_only: bool,
    ) -> AccessibilityTree:
        accessibility_tree: AccessibilityTree = client.send(
            "Accessibility.getFullAXTree", {}
        )["nodes"]

        # a few nodes are repeated in the accessibility tree
        seen_ids = set()
        _accessibility_tree = []
        for node in accessibility_tree:
            if node["nodeId"] not in seen_ids:
                _accessibility_tree.append(node)
                seen_ids.add(node["nodeId"])
        accessibility_tree = _accessibility_tree

        nodeid_to_cursor = {}
        for cursor, node in enumerate(accessibility_tree):
            nodeid_to_cursor[node["nodeId"]] = cursor
            # usually because the node is not visible etc
            if "backendDOMNodeId" not in node:
                node["union_bound"] = None
                continue
            backend_node_id = str(node["backendDOMNodeId"])
            if node["role"]["value"] == "RootWebArea":
                # always inside the viewport
                node["union_bound"] = [0.0, 0.0, 10.0, 10.0]
            else:
                response = self.get_bounding_client_rect(
                    client, backend_node_id
                )
                if response.get("result", {}).get("subtype", "") == "error":
                    node["union_bound"] = None
                else:
                    x = response["result"]["value"]["x"]
                    y = response["result"]["value"]["y"]
                    width = response["result"]["value"]["width"]
                    height = response["result"]["value"]["height"]
                    node["union_bound"] = [x, y, width, height]

        # filter nodes that are not in the current viewport
        if current_viewport_only:

            def remove_node_in_graph(node: AccessibilityTreeNode) -> None:
                # update the node information in the accessibility tree
                nodeid = node["nodeId"]
                node_cursor = nodeid_to_cursor[nodeid]
                parent_nodeid = node["parentId"]
                children_nodeids = node["childIds"]
                parent_cursor = nodeid_to_cursor[parent_nodeid]
                # update the children of the parent node
                assert (
                        accessibility_tree[parent_cursor].get("parentId", "Root")
                        is not None
                )
                # remove the nodeid from parent's childIds
                index = accessibility_tree[parent_cursor]["childIds"].index(
                    nodeid
                )
                accessibility_tree[parent_cursor]["childIds"].pop(index)
                # Insert children_nodeids in the same location
                for child_nodeid in children_nodeids:
                    accessibility_tree[parent_cursor]["childIds"].insert(
                        index, child_nodeid
                    )
                    index += 1
                # update children node's parent
                for child_nodeid in children_nodeids:
                    child_cursor = nodeid_to_cursor[child_nodeid]
                    accessibility_tree[child_cursor][
                        "parentId"
                    ] = parent_nodeid
                # mark as removed
                accessibility_tree[node_cursor]["parentId"] = "[REMOVED]"

            config = info["config"]
            for node in accessibility_tree:
                if not node["union_bound"]:
                    remove_node_in_graph(node)
                    continue

                [x, y, width, height] = node["union_bound"]

                # invisible node
                if width == 0 or height == 0:
                    remove_node_in_graph(node)
                    continue

                in_viewport_ratio = self.get_element_in_viewport_ratio(
                    elem_left_bound=float(x),
                    elem_top_bound=float(y),
                    width=float(width),
                    height=float(height),
                    config=config,
                )

                if in_viewport_ratio < IN_VIEWPORT_RATIO_THRESHOLD:
                    remove_node_in_graph(node)

            accessibility_tree = [
                node
                for node in accessibility_tree
                if node.get("parentId", "Root") != "[REMOVED]"
            ]

        return accessibility_tree

    @staticmethod
    def parse_accessibility_tree(
            accessibility_tree: AccessibilityTree,
    ) -> tuple[str, dict[str, Any]]:
        """Parse the accessibility tree into a string text"""
        node_id_to_idx = {}
        for idx, node in enumerate(accessibility_tree):
            node_id_to_idx[node["nodeId"]] = idx

        obs_nodes_info = {}

        def dfs(idx: int, obs_node_id: str, depth: int) -> str:
            tree_str = ""
            node = accessibility_tree[idx]
            indent = "\t" * depth
            valid_node = True
            try:
                role = node["role"]["value"]
                name = node["name"]["value"]
                node_str = f"[{obs_node_id}] {role} {repr(name)}"
                properties = []
                for property in node.get("properties", []):
                    try:
                        if property["name"] in IGNORED_ACTREE_PROPERTIES:
                            continue
                        properties.append(
                            f'{property["name"]}: {property["value"]["value"]}'
                        )
                    except KeyError:
                        pass

                if properties:
                    node_str += " " + " ".join(properties)

                # check valid
                if not node_str.strip():
                    valid_node = False

                if role == "StaticText" and len(name) < 5 and len(node["childIds"]) == 0:
                    valid_node = False

                # empty generic node
                if not name.strip():
                    if not properties:
                        if role in [
                            "generic",
                            "img",
                            "list",
                            "strong",
                            "paragraph",
                            "banner",
                            "navigation",
                            "Section",
                            "LabelText",
                            "Legend",
                            "listitem",
                        ]:
                            valid_node = False
                    elif role in ["listitem"]:
                        valid_node = False

                if valid_node:
                    tree_str += f"{indent}{node_str}"
                    obs_nodes_info[obs_node_id] = {
                        "backend_id": node["backendDOMNodeId"],
                        "union_bound": node["union_bound"],
                        "text": node_str,
                    }

            except Exception as e:
                valid_node = False

            for _, child_node_id in enumerate(node["childIds"]):
                if child_node_id not in node_id_to_idx:
                    continue
                # mark this to save some tokens
                child_depth = depth + 1 if valid_node else depth
                child_str = dfs(
                    node_id_to_idx[child_node_id], child_node_id, child_depth
                )
                if child_str.strip():
                    if tree_str.strip():
                        tree_str += "\n"
                    tree_str += child_str

            return tree_str

        tree_str = dfs(0, accessibility_tree[0]["nodeId"], 0)
        return tree_str, obs_nodes_info

    @staticmethod
    def clean_accesibility_tree(tree_str: str) -> str:
        """further clean accesibility tree"""
        clean_lines: list[str] = []
        for line in tree_str.split("\n"):
            if "statictext" in line.lower():
                prev_lines = clean_lines[-3:]
                pattern = r"\[\d+\] StaticText '([^']+)'"

                match = re.search(pattern, line)
                if match:
                    static_text = match.group(1)
                    if all(
                            static_text not in prev_line
                            for prev_line in prev_lines
                    ):
                        clean_lines.append(line)
            else:
                clean_lines.append(line)

        return "\n".join(clean_lines)

    def process(self, page: Page, client: CDPSession, context: str) -> str:
        # get the tab info
        open_tabs = page.context.pages
        # try:
        #     tab_titles = [tab.title() for tab in open_tabs]
        #     current_tab_idx = open_tabs.index(page)
        #     for idx in range(len(open_tabs)):
        #         if idx == current_tab_idx:
        #             tab_titles[
        #                 idx
        #             ] = f"Tab {idx} (current): {open_tabs[idx].title()}"
        #         else:
        #             tab_titles[idx] = f"Tab {idx}: {open_tabs[idx].title()}"
        #     tab_title_str = " | ".join(tab_titles)
        # except Exception:
        #     tab_title_str = " | ".join(
        #         ["Tab {idx}" for idx in range(len(open_tabs))]
        #     )

        try:
            tab_titles = [tab.title() for tab in open_tabs]
            current_tab_idx = open_tabs.index(page)
            for idx in range(len(open_tabs)):
                if idx == current_tab_idx:
                    tab_titles[
                        idx
                    ] = f"{idx + 1}. {open_tabs[idx].title()} <-- current tab"
                else:
                    tab_titles[idx] = f"{idx + 1}. {open_tabs[idx].title()}"
            tab_title_str = "\n".join(tab_titles)
        except Exception:
            tab_title_str = "\n".join(
                [f"{idx + 1}. Default" for idx in range(len(open_tabs))]
            )

        try:
            browser_info = self.fetch_browser_info(page, client)
        except Exception:
            page.wait_for_load_state("load", timeout=500)
            browser_info = self.fetch_browser_info(page, client)

        if self.observation_type == "html":
            import time
            stt = time.time()
            dom_tree = self.fetch_page_html(
                browser_info,
                page,
                client,
                current_viewport_only=self.current_viewport_only,
            )

            print('[fetch]', time.time() - stt)

            stt = time.time()
            raw_html, content, obs_nodes_info, hp = self.parse_my_html(dom_tree)
            print('[parse]', time.time() - stt)

            window_height = page.evaluate("window.innerHeight")
            page_height = page.evaluate('document.documentElement.scrollHeight') / window_height
            position = page.evaluate("window.scrollY") / window_height

            self.obs_nodes_info = obs_nodes_info
            self.meta_data["obs_nodes_info"] = obs_nodes_info
            self.meta_data["position_info"] = {
                "page_height": page_height,
                "position": position,
            }
            self.meta_data["dom_info"] = {
                "raw_html": raw_html,
                "dom_tree": dom_tree,
            }
            self.meta_data["html_parser"] = hp
            self.meta_data["tab_title"] = tab_title_str

        elif self.observation_type == "accessibility_tree":
            # 1. 获取无过滤的可访问性树
            accessibility_tree = self.fetch_page_accessibility_tree(
                browser_info,
                client,
                current_viewport_only=self.current_viewport_only,
            )

            # 2. 做一次“合并连续 staticText”预处理
            merge_consecutive_static_text_nodes(accessibility_tree)

            # 3. 再做 parse_accessibility_tree
            content, obs_nodes_info = self.parse_accessibility_tree(
                accessibility_tree
            )
            content = self.clean_accesibility_tree(content)
            self.obs_nodes_info = obs_nodes_info
            self.meta_data["obs_nodes_info"] = obs_nodes_info

        else:
            raise ValueError(
                f"Invalid observatrion type: {self.observation_type}"
            )

        self.browser_config = browser_info["config"]
        # content = f"{tab_title_str}\n\n{content}"
        return content

    def get_element_center(self, element_id: str) -> tuple[float, float]:
        node_info = self.obs_nodes_info[element_id]
        node_bound = node_info["union_bound"]
        x, y, width, height = node_bound
        center_x = x + width / 2
        center_y = y + height / 2
        return (
            center_x / self.viewport_size["width"],
            center_y / self.viewport_size["height"],
        )


class ImageObservationProcessor(ObservationProcessor):
    def __init__(self, observation_type: str):
        self.observation_type = observation_type
        self.observation_tag = "image"
        self.meta_data = create_empty_metadata()

    def process(self, page: Page, client: CDPSession, context: str) -> npt.NDArray[np.uint8]:
        try:
            screenshot = png_bytes_to_numpy(page.screenshot())
        except:
            page.wait_for_event("load")
            screenshot = png_bytes_to_numpy(page.screenshot())
        return screenshot


class ObservationHandler:
    """Main entry point to access all observation processor"""

    def __init__(
            self,
            main_observation_type: str,
            text_observation_type: str,
            image_observation_type: str,
            current_viewport_only: bool,
            viewport_size: ViewportSize,
            simple_mode: bool = False
    ) -> None:
        self.simple_mode = simple_mode
        self.main_observation_type = main_observation_type
        self.text_processor = TextObervationProcessor(
            text_observation_type, current_viewport_only, viewport_size
        )
        self.image_processor = ImageObservationProcessor(
            image_observation_type
        )
        self.viewport_size = viewport_size

    def get_observation_space(self) -> spaces.Dict:
        text_space = spaces.Text(
            min_length=0,
            max_length=UTTERANCE_MAX_LENGTH,
            charset=ASCII_CHARSET + FREQ_UNICODE_CHARSET,
        )

        if self.simple_mode:
            return {"text": text_space, "image": None}

        image_space = spaces.Box(
            # Each position stores the RGB values. Note the swapped axes (height first).
            np.zeros(
                (self.viewport_size["height"], self.viewport_size["width"], 3),
                dtype=np.uint8,
            ),
            np.ones(
                (self.viewport_size["height"], self.viewport_size["width"], 3),
                dtype=np.uint8,
            )
            * 255.0,
            dtype=np.uint8,
        )

        return spaces.Dict({"text": text_space, "image": image_space})

    def get_observation(
            self, page: Page, client: CDPSession, context: str = '',
    ) -> dict[str, Observation]:
        text_obs = self.text_processor.process(page, client, context)
        if self.simple_mode:
            return {"text": text_obs, "image": None}
        image_obs = self.image_processor.process(page, client, context)
        return {"text": text_obs, "image": image_obs}

    def get_observation_metadata(self) -> dict[str, ObservationMetadata]:
        return {
            "text": self.text_processor.meta_data,
            "image": self.image_processor.meta_data,
        }

    @property
    def action_processor(self) -> ObservationProcessor:
        """Return the main processor that is associated with the action space"""
        if self.main_observation_type == "text":
            return self.text_processor
        elif self.main_observation_type == "image":
            return self.image_processor
        else:
            raise ValueError("Invalid main observation type")
