import json
import random
import uuid
from typing import Any, TypedDict

from bs4 import BeautifulSoup, NavigableString

random.seed(9966)

ROLE_MAPPING = {
    "html": "RootWebArea",
    "body": "generic",
    "button": "button",
    "a": "link",
    "h1": "heading",
    "h2": "heading",
    "h3": "heading",
    "h4": "heading",
    "h5": "heading",
    "h6": "heading",
    "title": lambda: random.choice(["heading", "document"]),
    "text": "StaticText",
    "p": "StaticText",
    "label": "StaticText",
    "ul": "list",
    "ol": "list",
    "li": "listitem",
    "table": "table",
    "span": "StaticText",
    "tr": "row",
    "td": "cell",
    "th": "columnheader",
    "input": lambda: random.choice(["combobox"]),
    "img": "image",
    "image": "image",
    "nav": "navigation",
    "footer": "contentinfo",
    "header": "banner",
    "article": "article",
    "section": "region",
    "select": "combobox",
    "option": "menuitem",
    "div": "generic",
    "aside": "complementary",
    "main": "main",
    "dialog": "dialog",
    "status": "status",
    "bottom-nav-item": "navigation",
    "tooltip": "tooltip",
    "ngc-login": "button",  # Assuming it's a login form
    "text-list": "list",
    "listbox": "listbox",
    "ngc-app-navigation-links": "navigation",
    "header-app": "banner",
    "switch": "switch",
    "columnheader": "columnheader",
    "tablist": "tablist",
    "summary": "summary",
    "ngc-information": "document",  # If it's informational
    "ngc-logo": "img",  # Assuming it's an image
    "picture": "img",
    "app-home": "main",  # Assuming it's the main content
    "ngc-global-nav": "navigation",
    "tabpanel": "tabpanel",
    "contentinfo": "contentinfo",
    "ngc-search": "search",  # Assuming it's a search function
    "ngc-notification": "alert",
    "slot": "generic",
    "search": "search",
    "row": "row",
    "small": "generic",  # Text styling element
    "b": "generic",  # Text styling element
    "smt-gcovwidget": "generic",  # Unknown, default to generic
    "em": "generic",  # Text styling element
    "heading": "heading",
    "tab": "tab",
    "i": "generic",  # Text styling element
    "rowgroup": "rowgroup",
    "feed": "feed",
    "menubar": "menubar",
    "dd": "definition",  # Within a definition list
    "checkbox": "checkbox",
    "dl": "list",
    "document": "document",
    "complementary": "complementary",
    "video-player-overlay": "generic",  # Assuming it's a UI element
    "search-algolia-results": "list",
    "list": "list",
    "legend": "legend",
    "alertdialog": "alertdialog",
    "lable": "StaticText",  # Assuming it's a typo for "label"
    "tbody": "rowgroup",
    "svg": "image",  # If used for images
    "search-algolia": "search",
    "details": "group",
    "pre": "generic",  # Text formatting
    "region": "region",
    "rt-header": "banner",
    "rt-header-nav-item": "navigation",
    "tile-dynamic": "generic",
    "ngc-footer-column": "contentinfo",
    "figure": "figure",
    "mark": "generic",  # Text styling element
    "banner": "banner",
    "thead": "rowgroup",
    "ngc-social-icons": "navigation",
    "grid": "grid",
    "gridcell": "gridcell",
    "listitem": "listitem",
    "iframe": "document",  # If it contains a document
    "ngc-search-options": "group",  # Assuming it groups search options
    "menu": "menu",
    "app-root": "generic",
    "rt-header-nav": "navigation",
    "source": "generic",  # Used within media elements like <video> and <audio>
    "search-algolia-results-category": "list",
    "sup": "generic",  # Text styling element
    "video": "video",
    "ad-unit": "generic",  # Assuming it's for advertisements
    "search-algolia-controls": "group",  # Assuming it groups controls
    "menuitem": "menuitem",
    "object": "application",  # For embedded content
    "ngc-book": "document",  # Assuming it's a book or document-like content
    "cell": "cell",
    "radiogroup": "radiogroup",
    "form": "textbox",
    "ngc-language-selector": "combobox",  # Assuming it's for language selection
    "radioitemcheckbox": "checkbox",  # Assuming it's a custom checkbox
    "tiles-carousel-responsive-item": "generic",
    "editorial-spotlight": "generic",
    "progressbar": "progressbar",
    "navigation": "navigation",
    "amp-fit-text": "generic",
    "canvas": "generic",
    "slider": "slider",
    "date-selection-view": "generic",
    "group": "group",
    "optgroup": "group",  # Grouping within <select
    "toolbar": "toolbar",
    "fieldset": "group",  # Used to group related elements in a form
    "combobox": "combobox",
    "alert": "alert",
}

missed = set()


class Node(TypedDict):
    role: str
    name: str
    states: dict[str, str]
    id: str
    children: list[Any]


# Function to create a simplified accessibility tree
def create_accessibility_tree(element) -> Node:
    if element.name is None:
        return None

    # get the role
    if "role" in element.attrs:
        role = element.attrs["role"]
    else:
        role = element.name

    if role not in ROLE_MAPPING:
        missed.add(role)
    else:
        role = ROLE_MAPPING[role]
        if callable(role):
            role = role()

    # get the node name with an order
    flag = False
    for possible_attr in ["aria_label", "label", "value", "title", "alt"]:
        if possible_attr in element.attrs:
            name = element.attrs[possible_attr]
            flag = True
            break
    if not flag:
        name = "".join(
            child for child in element if isinstance(child, NavigableString)
        ).strip()

    # get states
    if "option_selected" in element.attrs or "input_checked" in element.attrs:
        states = {"focused": "true"}
    else:
        states = {}

    tree = {
        "role": role,
        "name": name,
        "states": states,
        "id": element.get("backend_node_id", uuid.uuid4().hex),
        "children": [],
    }

    # if input_value exists, add a new child of statictext
    if "input_value" in element.attrs:
        tree["children"].append(
            {
                "role": "StaticText",
                "name": element.attrs["input_value"],
                "states": {},
                "id": "",
                "children": [],
            }
        )

    for child in element.children:
        child_tree = create_accessibility_tree(child)
        if child_tree:
            tree["children"].append(child_tree)

    return tree


def prune_tree(
    tree: Node, target_id: str, parent: Node = None, replaced_role=""
) -> tuple[Node, str]:
    # remove the node if it only has one child and it is generic
    if replaced_role:
        tree["role"] = replaced_role

    if len(tree["children"]) == 1 and tree["role"] == "generic":
        if tree["id"] == target_id:  # pass the target_id to its child
            # print(f"passing id {tree['id']} to child {tree['children'][0]['id']}")
            target_id = tree["children"][0]["id"]
        return prune_tree(tree["children"][0], target_id, parent, "")

    elif (
        len(tree["children"]) == 1
        and not tree["name"]
        and tree["children"][0]["role"] == "StaticText"
    ):
        tree["name"] = tree["children"][0]["name"]
        if target_id == tree["children"][0]["id"]:
            target_id = tree["id"]
        tree["children"] = tree["children"][0]["children"]
        return prune_tree(tree, target_id, parent, "")

    else:
        new_children = []
        for i, child in enumerate(tree["children"]):
            child, target_id = prune_tree(child, target_id, parent, "")
            if child:
                new_children.append(child)
        tree["children"] = new_children

    if len(tree["children"]) == 0 and not tree["name"]:
        if tree["id"] == target_id:  # pass the id to its parent
            return tree, target_id
        else:
            return None, target_id

    return tree, target_id


def find_select_id(
    tree: Node, target_id: str, content: str, correct_subtree: bool
) -> str:
    if correct_subtree and content in tree["name"]:
        return tree["id"]

    if len(tree["children"]) == 0:
        return None

    for child in tree["children"]:
        result = find_select_id(
            child,
            target_id,
            content,
            tree["id"] == target_id or correct_subtree,
        )
        if result:
            return result
    return None


# Function to print the tree with indentation
def print_tree(tree: Node, level=0) -> str:
    indent = "\t" * level
    node_str = f'{indent}[{tree["id"]}] {tree["role"]} "{tree["name"]}"'

    # for k, v in tree["states"].items():
    # node_str += f" {k}: {v}"
    node_str += "\n"

    for child in tree["children"]:
        node_str += print_tree(child, level + 1)

    return node_str


if __name__ == "__main__":
    raw_folder = 
    save_folder = 
    tot = 0
    errors = {"no_locator": 0, "no_target_id": 0}
    for dump_idx in range(0, 10):
        with open(f"{raw_folder}/train_{dump_idx}.json") as f:
            dump = json.load(f)
        print(f"Number of trajectories: {len(dump)}")
        random.shuffle(dump)
        for traj in dump:
            for a_idx, action in enumerate(traj["actions"]):
                action_repr = traj["action_reprs"][a_idx]
                if not action["pos_candidates"]:
                    errors["no_locator"] += 1
                    continue
                selected_id = str(
                    action["pos_candidates"][0]["backend_node_id"]
                )
                soup = BeautifulSoup(action["cleaned_html"], "html.parser")
                ax_tree = create_accessibility_tree(soup.html)
                ax_tree, target_id = prune_tree(ax_tree, selected_id)

                # update select to click
                if "-> SELECT" in action_repr:
                    content = action_repr.split("-> SELECT:")[1].strip()
                    target_id = find_select_id(
                        ax_tree, target_id, content, False
                    )
                    action_repr = action_repr.replace("-> SELECT", "-> CLICK")

                ax_tree_str = print_tree(ax_tree)
                # print(ax_tree_str.replace(f"[{target_id}]", f"[!!!{target_id}] !!!{action_repr}!!!"))

                if f"[{target_id}]" not in ax_tree_str:
                    errors["no_target_id"] += 1
                    continue
                action["ax_tree"] = {
                    "tree": ax_tree,
                    "target_id": target_id,
                    "action_repr": action_repr,
                }

                action.pop("raw_html")
                tot += 1
                # if tot >= 10:
                # exit()
        with open(f"{save_folder}/train_{dump_idx}.json", "w") as f:
            json.dump(dump, f)

    print(f"Number of actions: {tot}")
    print(f"Error count: {errors}")
    print("Missed roles:====================")
    print(missed)
