import lxml
import re
from collections import defaultdict
from typing import Any, TypedDict

import numpy as np
import numpy.typing as npt
from gymnasium import spaces
from playwright.sync_api import CDPSession, Page, ViewportSize

from browser_env.constants import (
    ASCII_CHARSET,
    FREQ_UNICODE_CHARSET,
    IGNORED_ACTREE_PROPERTIES,
    UTTERANCE_MAX_LENGTH,
)

from .utils import (
    AccessibilityTree,
    AccessibilityTreeNode,
    BrowserConfig,
    BrowserInfo,
    DOMNode,
    DOMTree,
    Observation,
    png_bytes_to_numpy,
)

from .html_tools import HtmlParser, basic_attrs, print_html_object

IN_VIEWPORT_RATIO_THRESHOLD = 0.8

class TreeNode:
    def __init__(self, node_id, role, name, depth, **kwargs):
        self.visible = True
        self.node_id = node_id
        self.role = role
        self.name = name
        self.depth = depth
        self.properties = None
        if "properties" in kwargs.keys():
            self.properties = kwargs["properties"]

        self.children = []
        self.parent = None

    def add_child(self, child):
        child.parent = self
        self.children.append(child)

    def copy(self):
        from copy import deepcopy
        new_self = deepcopy(self)
        new_self.children = []
        new_self.parent = None
        return new_self
    
    def get_visible_node_number(self):
        visible_ids = []

        def dfs(current_node):
            if current_node.visible:
                visible_ids.append(current_node.node_id)
            for child in current_node.children:
                dfs(child)

        dfs(self)

        return len(visible_ids)
    
    def delete_tree(self):
        for child in self.children:
            child.delete_tree()
        self.children.clear()
        self.parent = None

    def has_properties(self):
        return getattr(self, "properties", {})
    
    def visible_children(self):
        return [c for c in self.children if c.visible]
    
    def visible_siblings(self):
        if not self.parent:
            return []
        return [n for n in self.parent.children if n.visible and n.node_id != self.node_id]
    
    def siblings(self):
        if not self.parent:
            return []
        return [n for n in self.parent.children if n.node_id != self.node_id]

    def search_node_by_id(self, target_id):
        if self.node_id == target_id or (self.name and f"[{target_id}]" in self.name):
            return self
        for child in self.children:
            result = child.search_node_by_id(target_id)
            if result:
                return result
        return None
    
    def all_children_invisible(self):
        if not self.children:
            return True
        for child in self.children:
            if child.visible:
                return False
        return True
    
    def has_the_same_properties_as(self, another_node):
        node_a_has_properties = getattr(self, "properties", "")
        node_b_has_properties = getattr(another_node, "properties", "")
        if not node_a_has_properties and not node_b_has_properties:
            return True
        elif (node_a_has_properties and not node_b_has_properties) or (not node_a_has_properties and node_b_has_properties):
            return False
        else:
            return self.properties == another_node.properties
        
    def is_identical_to(self, another_node):
        if another_node.children:
            return False
        return self.role == another_node.role and self.name == another_node.name and self.has_the_same_properties_as(another_node=another_node)
        
    def last_sibling(self, visible_required=False):
        if not self.parent:
            return None
        last_sibling_idx = self.parent.children.index(self) - 1
        if last_sibling_idx < 0:
            return None
        if not visible_required:
            return self.parent.children[last_sibling_idx]
        for sibling in self.parent.children[:self.parent.children.index(self):-1]:
            if sibling.visible:
                return sibling
        return None
        
    def next_sibling(self, visible_required=False):
        if not self.parent:
            return None
        next_sibling_idx = self.parent.children.index(self) + 1
        if next_sibling_idx >= len(self.parent.children):
            return None
        if not visible_required:
            return self.parent.children[next_sibling_idx]
        for sibling in self.parent.children[next_sibling_idx:]:
            if sibling.visible:
                return sibling
        return None
    
    def has_identical_siblings(self):
        if not (self.parent and self.all_children_invisible()):
            return False
        if any(sibling.role == self.role and sibling.name == self.name for sibling in self.parent.children if (sibling.node_id != self.node_id and sibling.all_children_invisible())):
            return True
        return False
    
    def has_identical_surrounding_siblings(self):
        if self.last_sibling(visible_required=False):
            if self.is_identical_to(self.last_sibling(visible_required=False)):
                return True
        if self.last_sibling(visible_required=True):
            if self.is_identical_to(self.last_sibling(visible_required=True)):
                return True
        if self.next_sibling(visible_required=False):
            if self.is_identical_to(self.next_sibling(visible_required=False)):
                return True
        if self.next_sibling(visible_required=True):
            if self.is_identical_to(self.next_sibling(visible_required=True)):
                return True
        return False
        
    def is_differentiable(self, strict=False):
        if self.parent and self.parent.role == "row":
            return True
        if not strict and self.has_identical_siblings():
            return False
        if self.has_identical_surrounding_siblings():
            return False
        return True


class ObservationProcessor:
    def process(self, page: Page, client: CDPSession) -> Observation:
        raise NotImplementedError


class ObservationMetadata(TypedDict):
    obs_nodes_info: dict[str, Any]


def create_empty_metadata() -> ObservationMetadata:
    return {
        "obs_nodes_info": {},
    }


class TextObervationProcessor(ObservationProcessor):
    def __init__(
        self,
        observation_type: str,
        current_viewport_only: bool,
        viewport_size: ViewportSize,
    ):
        self.observation_type = observation_type
        self.current_viewport_only = current_viewport_only
        self.viewport_size = viewport_size
        self.observation_tag = "text"
        self.meta_data = (
            create_empty_metadata()
        )  # use the store meta data of this observation type

    def fetch_browser_info(
        self,
        page: Page,
        client: CDPSession,
    ) -> BrowserInfo:
        # extract domtree
        tree = client.send(
            "DOMSnapshot.captureSnapshot",
            {
                "computedStyles": [],
                "includeDOMRects": True,
                "includePaintOrder": True,
            },
        )

        # calibrate the bounds, in some cases, the bounds are scaled somehow
        bounds = tree["documents"][0]["layout"]["bounds"]
        b = bounds[0]
        n = b[2] / self.viewport_size["width"]
        bounds = [[x / n for x in bound] for bound in bounds]
        tree["documents"][0]["layout"]["bounds"] = bounds

        # extract browser info
        win_top_bound = page.evaluate("window.pageYOffset")
        win_left_bound = page.evaluate("window.pageXOffset")
        win_width = page.evaluate("window.screen.width")
        win_height = page.evaluate("window.screen.height")
        win_right_bound = win_left_bound + win_width
        win_lower_bound = win_top_bound + win_height
        device_pixel_ratio = page.evaluate("window.devicePixelRatio")
        assert device_pixel_ratio == 1.0, "devicePixelRatio is not 1.0"

        config: BrowserConfig = {
            "win_top_bound": win_top_bound,
            "win_left_bound": win_left_bound,
            "win_width": win_width,
            "win_height": win_height,
            "win_right_bound": win_right_bound,
            "win_lower_bound": win_lower_bound,
            "device_pixel_ratio": device_pixel_ratio,
        }

        # assert len(tree['documents']) == 1, "More than one document in the DOM tree"
        info: BrowserInfo = {"DOMTree": tree, "config": config}
        # with open('output/browser_info.json', 'w') as f:
        #     f.write(json.dumps(tree, ensure_ascii=False))
        return info

    @staticmethod
    def get_bounding_client_rect(
        client: CDPSession, backend_node_id: str
    ) -> dict[str, Any]:
        try:
            remote_object = client.send(
                "DOM.resolveNode", {"backendNodeId": int(backend_node_id)}
            )
            remote_object_id = remote_object["object"]["objectId"]
            response = client.send(
                "Runtime.callFunctionOn",
                {
                    "objectId": remote_object_id,
                    "functionDeclaration": """
                        function() {
                            if (this.nodeType == 3) {
                                var range = document.createRange();
                                range.selectNode(this);
                                var rect = range.getBoundingClientRect().toJSON();
                                range.detach();
                                return rect;
                            } else {
                                return this.getBoundingClientRect().toJSON();
                            }
                        }
                    """,
                    "returnByValue": True,
                },
            )
            return response
        except Exception:
            return {"result": {"subtype": "error"}}

    @staticmethod
    def get_element_in_viewport_ratio(
        elem_left_bound: float,
        elem_top_bound: float,
        width: float,
        height: float,
        config: BrowserConfig,
    ) -> float:
        elem_right_bound = elem_left_bound + width
        elem_lower_bound = elem_top_bound + height

        win_left_bound = 0
        win_right_bound = config["win_width"]
        win_top_bound = 0
        win_lower_bound = config["win_height"]

        # Compute the overlap in x and y axes
        overlap_width = max(
            0,
            min(elem_right_bound, win_right_bound)
            - max(elem_left_bound, win_left_bound),
        )
        overlap_height = max(
            0,
            min(elem_lower_bound, win_lower_bound)
            - max(elem_top_bound, win_top_bound),
        )

        # Compute the overlap area
        ratio = overlap_width * overlap_height / width * height
        return ratio
    
    def element_is_visible(self, page, element_id):
        def _get_element_in_viewport_ratio(
            elem_left_bound: float,
            elem_top_bound: float,
            width: float,
            height: float,
            config: BrowserConfig,
        ) -> float:
            def calculate_overlap(start1, end1, start2, end2):
                # Calculate overlap
                overlap_start = max(start1, start2)
                overlap_end = min(end1, end2)
                
                # Check if there's overlap
                if overlap_start < overlap_end:
                    overlap = overlap_end - overlap_start
                else:
                    overlap = 0
                
                return overlap
            elem_right_bound = elem_left_bound + width
            elem_lower_bound = elem_top_bound + height

            win_left_bound = 0
            win_right_bound = config["win_width"]
            win_top_bound = 0
            win_lower_bound = config["win_height"]

            overlap_width = calculate_overlap(elem_left_bound, elem_right_bound, win_left_bound, win_right_bound)
            overlap_height = calculate_overlap(elem_top_bound, elem_lower_bound, win_top_bound, win_lower_bound)

            try:
                ratio = (overlap_width * overlap_height) / (width * height)
                return ratio
            except:
                return 1 #TODO
        try:
            browser_info = self.fetch_browser_info(page, page.client)
        except Exception:
            page.wait_for_load_state("load", timeout=500)
            browser_info = self.fetch_browser_info(page, page.client)
        
        response = self.get_bounding_client_rect(
            page.client, self.obs_nodes_info[element_id]["backend_id"]
        )

        x = response["result"]["value"]["x"]
        y = response["result"]["value"]["y"]
        width = response["result"]["value"]["width"]
        height = response["result"]["value"]["height"]


        in_viewport_ratio = _get_element_in_viewport_ratio(
            elem_left_bound=float(x),
            elem_top_bound=float(y),
            width=float(width),
            height=float(height),
            config=browser_info["config"],
        )

        if in_viewport_ratio < IN_VIEWPORT_RATIO_THRESHOLD:
            return False
        
        return True

    def fetch_page_html(
        self,
        info: BrowserInfo,
        page: Page,
        client: CDPSession,
        current_viewport_only: bool,
    ) -> DOMTree:
        # adopted from [natbot](https://github.com/nat/natbot)
        tree = info["DOMTree"]
        config = info["config"]
        strings = tree["strings"]
        document = tree["documents"][0]
        nodes = document["nodes"]
        layout = document["layout"]
        
        import time
        stt = time.time()
        # make a dom tree that is easier to navigate
        dom_tree: DOMTree = []
        graph = defaultdict(list)
        print(nodes.keys())
        for node_idx in range(len(nodes["nodeName"])):
            cur_node: DOMNode = {
                "nodeId": "",
                "nodeType": "",
                "nodeName": "",
                "nodeValue": "",
                "attributes": "",
                "backendNodeId": "",
                "parentId": "",
                "childIds": [],
                "cursor": 0,
                "union_bound": None,
            }

            node_type_idx = nodes["nodeType"][node_idx]
            node_type = "generic"
            if node_type_idx >= 0 and node_type_idx < len(strings):
                node_type = strings[node_type_idx]

            node_name = strings[nodes["nodeName"][node_idx]]

            node_value_idx = nodes["nodeValue"][node_idx]
            node_value = ""
            if node_value_idx >= 0 and node_value_idx < len(strings):
                node_value = " ".join(strings[node_value_idx].split())

            node_attributes = [
                strings[i] for i in nodes["attributes"][node_idx]
            ]
            node_attributes_str = ""
            for i in range(0, len(node_attributes), 2):
                a = node_attributes[i]
                b = node_attributes[i + 1]
                # b = " ".join(b.split())
                import re
                b = re.sub(r"{\s*opacity:\s*.*;*\s*}", " ", b)
                b = [b_item for b_item in b.split() if b_item.count('vimium') == 0]
                b = " ".join(b)
                node_attributes_str += f'{a}="{b}" '
            
            node_attributes_str = node_attributes_str.strip()

            cur_node["nodeId"] = str(node_idx)
            cur_node["nodeType"] = node_type
            cur_node["nodeName"] = node_name
            cur_node["nodeValue"] = node_value
            cur_node["attributes"] = node_attributes_str
            cur_node["backendNodeId"] = str(nodes["backendNodeId"][node_idx])
            cur_node["parentId"] = str(nodes["parentIndex"][node_idx])
            
            if cur_node["parentId"] != "-1":
                graph[cur_node["parentId"]].append(str(cur_node["nodeId"]))

            # get the bound
            if cur_node["parentId"] == "-1":
                cur_node["union_bound"] = [0.0, 0.0, 10.0, 10.0]
            else:
                # method 1
                # response = self.get_bounding_client_rect(
                #     client, cur_node["backendNodeId"]
                # )
                
                # if response.get("result", {}).get("subtype", "") == "error":
                #     cur_node["union_bound"] = None
                # else:
                #     x = response["result"]["value"]["x"]
                #     y = response["result"]["value"]["y"]
                #     width = response["result"]["value"]["width"]
                #     height = response["result"]["value"]["height"]
                #     cur_node["union_bound"] = [x, y, width, height]

                # method 2
                bound = [0.0, 0.0, 0.0, 0.0]
                if node_idx in layout["nodeIndex"]:
                    bound = layout["bounds"][layout["nodeIndex"].index(node_idx)]
                    bound[0] -= config["win_left_bound"]
                    bound[1] -= config["win_top_bound"]
                    
                cur_node["union_bound"] = bound

            dom_tree.append(cur_node)
        print('[build]', time.time() - stt)
        
        stt = time.time()
        # add parent children index to the node
        for parent_id, child_ids in graph.items():
            dom_tree[int(parent_id)]["childIds"] = child_ids
        print('[graph]', time.time() - stt)
        
        # with open('output/dom_tree.json', 'w') as f:
        #     f.write(json.dumps(dom_tree, ensure_ascii=False))
            
        stt = time.time()
        # remove the nodes that are not in the current viewport
        if current_viewport_only:

            def remove_node_in_graph(node: DOMNode) -> None:
                # update the node information in the accessibility tree
                node_id = node["nodeId"]
                parent_id = node["parentId"]
                child_ids = node["childIds"]

                # update the children of the parent node
                assert dom_tree[int(parent_id)]["parentId"] != "[REMOVED]"
                # remove the nodeid from parent
                index = dom_tree[int(parent_id)]["childIds"].index(node_id)
                dom_tree[int(parent_id)]["childIds"].pop(index)

                # Insert children_nodeids in the same location
                for child_id in child_ids:
                    dom_tree[int(parent_id)]["childIds"].insert(
                        index, child_id
                    )
                    index += 1

                # update children node's parent
                for child_id in child_ids:
                    dom_tree[int(child_id)]["parentId"] = parent_id
                # mark as removed
                dom_tree[int(node_id)]["parentId"] = "[REMOVED]"

            config = info["config"]
            for cursor, node in enumerate(dom_tree):
                if not node["union_bound"]:
                    remove_node_in_graph(node)
                    continue

                [x, y, width, height] = node["union_bound"]

                # invisible node
                if width == 0.0 or height == 0.0:
                    parent_id = node["parentId"]
                    if node["nodeName"] not in ['OPTION'] or dom_tree[int(parent_id)]["nodeName"] not in ["SELECT"]:
                        remove_node_in_graph(node)
                    continue

                in_viewport_ratio = self.get_element_in_viewport_ratio(
                    elem_left_bound=float(x),
                    elem_top_bound=float(y),
                    width=float(width),
                    height=float(height),
                    config=config,
                )

                if in_viewport_ratio < IN_VIEWPORT_RATIO_THRESHOLD:
                    remove_node_in_graph(node)

            dom_tree = [
                node
                for node in dom_tree
                if node.get("parentId", "-1") != "[REMOVED]"
            ]

        print('[filter]', time.time() - stt)
        return dom_tree

    @staticmethod
    def parse_my_html(dom_tree: DOMTree) -> tuple[str, str, dict[str, Any], Any]:
        """Parse the html tree into a string text"""

        obs_nodes_info = {}
        nodeid_to_cursor = {
            node["nodeId"]: idx for idx, node in enumerate(dom_tree)
        }

        def dfs(node_cursor: int, depth: int) -> tuple[str, list[str]]:
            tree_str, labeled_elems = '', []
            node = dom_tree[node_cursor]
            valid_node = True
            pure_text = False
            try:
                if node['nodeName'] == '#text':
                    node['nodeName'] = 'text'
                
                node_str = f"<{node['nodeName']}"
                if node["attributes"]:
                    node_str += f" {node['attributes']}"
                node_str += f" backend-id=\"bid-{node['backendNodeId']}\"> {node['nodeValue']}"
                
                # if node['nodeName'] == '#text':
                #     pure_text = True
                #     node_str = node['nodeValue']
                    
                valid_node = bool(node["attributes"] or node["nodeValue"] or pure_text)
                    
                if valid_node:                    
                    node_html = lxml.html.fromstring(node_str)
                    label = node_html.attrib.get('data-testid', '')
                    if len(label) > 0:
                        labeled_elems.append(node["backendNodeId"])
                    obs_nodes_info[str(node_cursor)] = {
                        "backend_id": node["backendNodeId"],
                        "union_bound": node["union_bound"],
                        "text": node['nodeValue'],
                        "label": label,
                    }
                    tree_str += f"{node_str}"

            except Exception:
                valid_node = False

            for child_ids in node["childIds"]:
                child_cursor = nodeid_to_cursor[child_ids]
                child_depth = depth + 1 if valid_node else depth
                child_str, elems = dfs(child_cursor, child_depth)
                tree_str += child_str
                labeled_elems.extend(elems)
            
            if valid_node and not pure_text:
                tree_str += f"</{node['nodeName']}>"

            return tree_str, labeled_elems

        html, labeled_elems = dfs(0, 0)
        
        # with open('output/raw.html', 'w') as f:
        #     f.write(html)
        print(labeled_elems)
            
        args = {
            'use_position': False,
            'id_attr': 'backend-id',
            'label_generator': 'order',
            'label_attr': 'data-testid',
            'attr_list': basic_attrs,
            'prompt': 'refine',
        }
        
        hp = HtmlParser(html, args)
        packet = hp.parse_tree()
        page_html = packet['html']
        
        print(print_html_object(page_html))
        
        it, pt = packet.get('init_time', 0), packet.get('parse_time', 0)
        print(f'[Time] {it:.3f} {pt:.3f}')
        
        return html, page_html, obs_nodes_info, hp
    
    @staticmethod
    def parse_html(dom_tree: DOMTree) -> tuple[str, dict[str, Any]]:
        """Parse the html tree into a string text"""

        obs_nodes_info = {}
        nodeid_to_cursor = {
            node["nodeId"]: idx for idx, node in enumerate(dom_tree)
        }

        def dfs(node_cursor: int, depth: int) -> str:
            tree_str = ""
            node = dom_tree[node_cursor]
            indent = "\t" * depth
            valid_node = True
            try:
                node_str = f"[{node_cursor}] <{node['nodeName']}"
                if node["attributes"]:
                    node_str += f" {node['attributes']}"
                node_str += f"> {node['nodeValue']}"
                valid_node = bool(node["attributes"] or node["nodeValue"])

                if valid_node:
                    obs_nodes_info[str(node_cursor)] = {
                        "backend_id": node["backendNodeId"],
                        "union_bound": node["union_bound"],
                        "text": node_str,
                    }
                    tree_str += f"{indent}{node_str}\n"

            except Exception:
                valid_node = False

            for child_ids in node["childIds"]:
                child_cursor = nodeid_to_cursor[child_ids]
                child_depth = depth + 1 if valid_node else depth
                child_str = dfs(child_cursor, child_depth)
                tree_str += child_str

            return tree_str

        html = dfs(0, 0)
        return html, obs_nodes_info

    def fetch_page_accessibility_tree(
        self,
        info: BrowserInfo,
        client: CDPSession,
        current_viewport_only: bool,
    ) -> AccessibilityTree:
        accessibility_tree: AccessibilityTree = client.send(
            "Accessibility.getFullAXTree", {}
        )["nodes"]

        # a few nodes are repeated in the accessibility tree
        seen_ids = set()
        _accessibility_tree = []
        for node in accessibility_tree:
            if node["nodeId"] not in seen_ids:
                _accessibility_tree.append(node)
                seen_ids.add(node["nodeId"])
        accessibility_tree = _accessibility_tree
        nodeid_to_cursor = {}
        for cursor, node in enumerate(accessibility_tree):
            nodeid_to_cursor[node["nodeId"]] = cursor
            # usually because the node is not visible etc
            if "backendDOMNodeId" not in node:
                node["union_bound"] = None
                continue
            backend_node_id = str(node["backendDOMNodeId"])
            if node["role"]["value"] == "RootWebArea":
                # always inside the viewport
                node["union_bound"] = [0.0, 0.0, 10.0, 10.0]
            else:
                response = self.get_bounding_client_rect(
                    client, backend_node_id
                )
                if response.get("result", {}).get("subtype", "") == "error":
                    node["union_bound"] = None
                else:
                    x = response["result"]["value"]["x"]
                    y = response["result"]["value"]["y"]
                    width = response["result"]["value"]["width"]
                    height = response["result"]["value"]["height"]
                    node["union_bound"] = [x, y, width, height]

        # filter nodes that are not in the current viewport
        if current_viewport_only:

            def remove_node_in_graph(node: AccessibilityTreeNode) -> None:
                # update the node information in the accessibility tree
                nodeid = node["nodeId"]
                node_cursor = nodeid_to_cursor[nodeid]
                parent_nodeid = node["parentId"]
                children_nodeids = node["childIds"]
                parent_cursor = nodeid_to_cursor[parent_nodeid]
                # update the children of the parent node
                assert (
                    accessibility_tree[parent_cursor].get("parentId", "Root")
                    is not None
                )
                # remove the nodeid from parent's childIds
                index = accessibility_tree[parent_cursor]["childIds"].index(
                    nodeid
                )
                accessibility_tree[parent_cursor]["childIds"].pop(index)
                # Insert children_nodeids in the same location
                for child_nodeid in children_nodeids:
                    accessibility_tree[parent_cursor]["childIds"].insert(
                        index, child_nodeid
                    )
                    index += 1
                # update children node's parent
                for child_nodeid in children_nodeids:
                    child_cursor = nodeid_to_cursor[child_nodeid]
                    accessibility_tree[child_cursor][
                        "parentId"
                    ] = parent_nodeid
                # mark as removed
                accessibility_tree[node_cursor]["parentId"] = "[REMOVED]"

            config = info["config"]
            for node in accessibility_tree:
                if not node["union_bound"]:
                    remove_node_in_graph(node)
                    continue

                [x, y, width, height] = node["union_bound"]

                # invisible node
                if width == 0 or height == 0:
                    remove_node_in_graph(node)
                    continue

                in_viewport_ratio = self.get_element_in_viewport_ratio(
                    elem_left_bound=float(x),
                    elem_top_bound=float(y),
                    width=float(width),
                    height=float(height),
                    config=config,
                )

                if in_viewport_ratio < IN_VIEWPORT_RATIO_THRESHOLD:
                    remove_node_in_graph(node)

            accessibility_tree = [
                node
                for node in accessibility_tree
                if node.get("parentId", "Root") != "[REMOVED]"
            ]

        return accessibility_tree

    @staticmethod
    def parse_accessibility_tree(
        accessibility_tree: AccessibilityTree,
    ) -> tuple[str, dict[str, Any], TreeNode]:
        """Parse the accessibility tree into a string text"""
        node_id_to_idx = {}
        for idx, node in enumerate(accessibility_tree):
            node_id_to_idx[node["nodeId"]] = idx

        obs_nodes_info = {}

        def dfs(idx: int, obs_node_id: str, depth: int, active_node_dict: dict) -> str:
            tree_str = ""
            node = accessibility_tree[idx]
            indent = "\t" * depth
            valid_node = True
            try:
                role = node["role"]["value"]
                name = node["name"]["value"]
                node_str = f"[{obs_node_id}] {role} {repr(name)}"
                properties = []
                structured_properties = {}
                for property in node.get("properties", []):
                    try:
                        if property["name"] in IGNORED_ACTREE_PROPERTIES:
                            continue
                        properties.append(
                            f'{property["name"]}: {property["value"]["value"]}'
                        )
                        structured_properties[property["name"]] = property["value"]["value"]
                    except KeyError:
                        pass

                if properties:
                    node_str += " " + " ".join(properties)

                # check valid
                if not node_str.strip():
                    valid_node = False

                # empty generic node
                if not name.strip():
                    if not properties:
                        if role in [
                            "generic",
                            "img",
                            "list",
                            "strong",
                            "paragraph",
                            "banner",
                            "navigation",
                            "Section",
                            "LabelText",
                            "Legend",
                            "listitem",
                        ]:
                            valid_node = False
                    elif role in ["listitem"]:
                        valid_node = False

                if valid_node:
                    tree_str += f"{indent}{node_str}"
                    obs_nodes_info[obs_node_id] = {
                        "backend_id": node["backendDOMNodeId"],
                        "union_bound": node["union_bound"],
                        "text": node_str,
                    }

            except Exception:
                valid_node = False

            structured_node = TreeNode(node_id=int(obs_node_id), role=node["role"]["value"], name=node["name"]["value"], depth=depth, properties=structured_properties) if valid_node else None
            active_node_dict[depth] = structured_node if valid_node else active_node_dict.get(depth, None)

            for _, child_node_id in enumerate(node["childIds"]):
                if child_node_id not in node_id_to_idx:
                    continue
                # mark this to save some tokens
                child_depth = depth + 1 if valid_node else depth
                child_str, child_node = dfs(
                    node_id_to_idx[child_node_id], child_node_id, child_depth, active_node_dict=active_node_dict
                )
                if child_str.strip():
                    if tree_str.strip():
                        tree_str += "\n"
                    tree_str += child_str
                if child_depth > 0 and child_node:
                    active_node_dict[child_depth - 1].add_child(child_node)

            return tree_str, structured_node

        tree_str, structured_node = dfs(0, accessibility_tree[0]["nodeId"], 0, active_node_dict={})
        return tree_str, obs_nodes_info, structured_node

    @staticmethod
    def clean_accesibility_tree(tree_str: str) -> str:
        """further clean accesibility tree"""
        clean_lines: list[str] = []
        for line in tree_str.split("\n"):
            if "statictext" in line.lower():
                prev_lines = clean_lines[-3:]
                pattern = r"\[\d+\] StaticText '([^']+)'"

                match = re.search(pattern, line)
                if match:
                    static_text = match.group(1)
                    if all(
                        static_text not in prev_line
                        for prev_line in prev_lines
                    ):
                        clean_lines.append(line)
            else:
                clean_lines.append(line)

        return "\n".join(clean_lines)

    def process(self, page: Page, client: CDPSession, context: str) -> str:
        # get the tab info
        open_tabs = page.context.pages
        # try:
        #     tab_titles = [tab.title() for tab in open_tabs]
        #     current_tab_idx = open_tabs.index(page)
        #     for idx in range(len(open_tabs)):
        #         if idx == current_tab_idx:
        #             tab_titles[
        #                 idx
        #             ] = f"Tab {idx} (current): {open_tabs[idx].title()}"
        #         else:
        #             tab_titles[idx] = f"Tab {idx}: {open_tabs[idx].title()}"
        #     tab_title_str = " | ".join(tab_titles)
        # except Exception:
        #     tab_title_str = " | ".join(
        #         ["Tab {idx}" for idx in range(len(open_tabs))]
        #     )

        try:
            tab_titles = [tab.title() for tab in open_tabs]
            current_tab_idx = open_tabs.index(page)
            for idx in range(len(open_tabs)):
                if idx == current_tab_idx:
                    tab_titles[
                        idx
                    ] = f"{idx+1}. {open_tabs[idx].title()} <-- current tab"
                else:
                    tab_titles[idx] = f"{idx+1}. {open_tabs[idx].title()}"
            tab_title_str = "\n".join(tab_titles)
        except Exception:
            tab_title_str = "\n".join(
                [f"{idx+1}. Default" for idx in range(len(open_tabs))]
            )

        
        try:
            browser_info = self.fetch_browser_info(page, client)
        except Exception:
            page.wait_for_load_state("load", timeout=500)
            browser_info = self.fetch_browser_info(page, client)
        
        if self.observation_type == "html":
            import time
            stt = time.time()
            dom_tree = self.fetch_page_html(
                browser_info,
                page,
                client,
                current_viewport_only=self.current_viewport_only,
            )
            
            print('[fetch]', time.time() - stt)
            
            stt = time.time()
            raw_html, content, obs_nodes_info, hp = self.parse_my_html(dom_tree)
            print('[parse]', time.time() - stt)
            
            window_height = page.evaluate("window.innerHeight")
            page_height = page.evaluate('document.documentElement.scrollHeight') / window_height
            position = page.evaluate("window.scrollY") / window_height
            
            self.obs_nodes_info = obs_nodes_info
            self.meta_data["obs_nodes_info"] = obs_nodes_info
            self.meta_data["position_info"] = {
                "page_height": page_height,
                "position": position,
            }
            self.meta_data["dom_info"] = {
                "raw_html": raw_html,
                "dom_tree": dom_tree,
            }
            self.meta_data["html_parser"] = hp
            self.meta_data["tab_title"] = tab_title_str

        elif self.observation_type == "accessibility_tree":
            accessibility_tree = self.fetch_page_accessibility_tree(
                browser_info,
                client,
                current_viewport_only=self.current_viewport_only,
            )
            content, obs_nodes_info, node_root = self.parse_accessibility_tree(
                accessibility_tree
            )
            content = self.clean_accesibility_tree(content)
            self.obs_nodes_info = obs_nodes_info
            page_dialog_message = getattr(page, "dialog_message", "")
            if page_dialog_message:
                import copy
                node_root.properties["page_dialog_message"] = copy.deepcopy(page_dialog_message) + " Retry."
                page.dialog_message = None
            self.node_root = node_root
            self.meta_data["obs_nodes_info"] = obs_nodes_info

        else:
            raise ValueError(
                f"Invalid observatrion type: {self.observation_type}"
            )

        self.browser_config = browser_info["config"]
        # content = f"{tab_title_str}\n\n{content}"
        return (content, node_root)

    def get_node_info_by_element_id(self, AXTreeId):
        return self.node_root.search_node_by_id(AXTreeId)

    def get_element_center(self, element_id: str, page) -> tuple[float, float]:
        node = self.obs_nodes_info[element_id]
        backend_node_id = str(node["backend_id"])
        response = self.get_bounding_client_rect(
            page.client, backend_node_id
        )
        x = response["result"]["value"]["x"]
        y = response["result"]["value"]["y"]
        width = response["result"]["value"]["width"]
        height = response["result"]["value"]["height"]
        center_x = x + width / 2
        center_y = y + height / 2
        return (
            center_x / self.viewport_size["width"],
            center_y / self.viewport_size["height"],
        )


class ImageObservationProcessor(ObservationProcessor):
    def __init__(self, observation_type: str, current_viewport_only: bool):
        self.observation_type = observation_type
        self.current_viewport_only = current_viewport_only
        self.observation_tag = "image"
        self.meta_data = create_empty_metadata()

    def process(self, page: Page, client: CDPSession, context: str) -> npt.NDArray[np.uint8]:
        try:
            screenshot = png_bytes_to_numpy(page.screenshot(full_page=(not self.current_viewport_only)))
            screenshot = screenshot[:2*screenshot.shape[1], :, :]
        except:
            page.wait_for_event("load")
            screenshot = png_bytes_to_numpy(page.screenshot(full_page=(not self.current_viewport_only)))
        return screenshot


class ObservationHandler:
    """Main entry point to access all observation processor"""

    def __init__(
        self,
        main_observation_type: str,
        text_observation_type: str,
        image_observation_type: str,
        current_viewport_only: bool,
        viewport_size: ViewportSize,
    ) -> None:
        self.main_observation_type = main_observation_type
        self.text_processor = TextObervationProcessor(
            text_observation_type, current_viewport_only, viewport_size
        )
        self.image_processor = ImageObservationProcessor(
            image_observation_type, current_viewport_only
        )
        self.viewport_size = viewport_size

    def get_observation_space(self) -> spaces.Dict:
        text_space = spaces.Text(
            min_length=0,
            max_length=UTTERANCE_MAX_LENGTH,
            charset=ASCII_CHARSET + FREQ_UNICODE_CHARSET,
        )

        image_space = spaces.Box(
            # Each position stores the RGB values. Note the swapped axes (height first).
            np.zeros(
                (self.viewport_size["height"], self.viewport_size["width"], 3),
                dtype=np.uint8,
            ),
            np.ones(
                (self.viewport_size["height"], self.viewport_size["width"], 3),
                dtype=np.uint8,
            )
            * 255.0,
            dtype=np.uint8,
        )

        return spaces.Dict({"text": text_space, "image": image_space})

    def get_observation(
        self, page: Page, client: CDPSession, context: str = '',
    ) -> dict[str, Observation]:
        text_obs = self.text_processor.process(page, client, context)
        image_obs = self.image_processor.process(page, client, context)
        return {"text": text_obs, "image": image_obs}

    def get_observation_metadata(self) -> dict[str, ObservationMetadata]:
        return {
            "text": self.text_processor.meta_data,
            "image": self.image_processor.meta_data,
        }

    @property
    def action_processor(self) -> ObservationProcessor:
        """Return the main processor that is associated with the action space"""
        if self.main_observation_type == "text":
            return self.text_processor
        elif self.main_observation_type == "image":
            return self.image_processor
        else:
            raise ValueError("Invalid main observation type")
