import ast
import logging
import numpy as np
import PIL.Image
import PIL.ImageDraw
import PIL.ImageFont
import re
import os
import io
import requests
import json

from PIL import Image

from benchmark.webarenasafe.page_understanding.smart_parse import parse_tree_to_json, json_to_ax_tree_string, filter_tree_json
from benchmark.browsergym.core.src.browsergym.utils.obs import _process_bid, _remove_redundant_static_text
from .som_utils import analyze_image, DashedImageDraw

logger = logging.getLogger(__name__)

SERVER_HOST = os.getenv("SERVER_HOST")
SERVER_PORT = os.getenv("SERVER_PORT")
SERVER_API_URL = f'http://{SERVER_HOST}:{SERVER_PORT}'

IGNORED_AXTREE_ROLES = ["LineBreak"]

IGNORED_AXTREE_PROPERTIES = (
    "editable",
    "readonly",
    "level",
    "settable",
    "multiline",
    "invalid",
    "focusable",
)

CLICKABLE_ELEMENTS = ['link', 'button', 'input', 'select_toggle', 'checkbox', 'combobox', 'gridcell', 'listbox', 'menu', 'menuitem', 'menuitemcheckbox', 
                          'menuitemradio', 'option', 'radio', 'scrollbar', 'searchbox', 'slider', 'spinbutton', 'switch', 'tab', 'textbox', 'img', ]

def flatten_axtree_to_str(
    AX_tree,
    extra_properties: dict = None,
    with_visible: bool = False,
    with_clickable: bool = False,
    with_center_coords: bool = False,
    with_bounding_box_coords: bool = False,
    with_som: bool = False,
    skip_generic: bool = True,
    filter_visible_only: bool = False,
    filter_with_bid_only: bool = False,
    filter_som_only: bool = False,
    coord_decimals: int = 0,
    ignored_roles=IGNORED_AXTREE_ROLES,
    ignored_properties=IGNORED_AXTREE_PROPERTIES,
    remove_redundant_static_text: bool = True,
    hide_bid_if_invisible: bool = False,
    hide_all_children: bool = False,
    filter_by_type: bool = False,
    keep_parents: bool = True,
    indent_base: str = "   ", 
) -> str:
    """Formats the accessibility tree into a string text"""
    node_id_to_idx = {}
    for idx, node in enumerate(AX_tree["nodes"]):
        node_id_to_idx[node["nodeId"]] = idx

    def dfs(node_idx: int, depth: int, parent_node_filtered: bool, parent_node_name: str) -> str:
        tree_str = ""
        node = AX_tree["nodes"][node_idx]
        indent = indent_base * depth # "\t" * depth
        skip_node = False
        filter_node = False
        node_role = node["role"]["value"]
        node_name = ""

        keep_a_parent_node = bool(len(node.get("childIds",[])) > 0 and keep_parents)
        soft_skip = False

        if node_role in ignored_roles:
            skip_node = True
            pass
        elif "name" not in node:
            skip_node = True
            pass
        else:
            node_name = node["name"]["value"]
            if "value" in node and "value" in node["value"]:
                node_value = node["value"]["value"]
            else:
                node_value = None

            # extract bid
            bid = node.get("browsergym_id", None)

            # extract node attributes
            attributes = []
            for property in node.get("properties", []):
                if not "value" in property:
                    continue
                if not "value" in property["value"]:
                    continue

                prop_name = property["name"]
                prop_value = property["value"]["value"]

                if prop_name in ignored_properties:
                    continue
                elif prop_name in ("required", "focused", "atomic"):
                    if prop_value:
                        attributes.append(prop_name)
                else:
                    attributes.append(f"{prop_name}={repr(prop_value)}")

            if (node_role not in CLICKABLE_ELEMENTS) and filter_by_type:
                soft_skip = True

            if skip_generic and node_role == "generic" and not attributes:
                skip_node = True

            if node_role == "StaticText":
                if parent_node_filtered:
                    skip_node = True
                elif remove_redundant_static_text and node_name in parent_node_name:
                    skip_node = True
                if not keep_a_parent_node:
                    skip_node = True
            else:
                filter_node, extra_attributes_to_print = _process_bid(
                    bid,
                    extra_properties=extra_properties,
                    with_visible=with_visible,
                    with_clickable=with_clickable,
                    with_center_coords=with_center_coords,
                    with_bounding_box_coords=with_bounding_box_coords,
                    with_som=with_som,
                    filter_visible_only=filter_visible_only,
                    filter_with_bid_only=filter_with_bid_only,
                    filter_som_only=filter_som_only,
                    coord_decimals=coord_decimals,
                )

                # if either is True, skip the node
                skip_node = skip_node or (hide_all_children and parent_node_filtered)
                soft_skip = soft_skip or filter_node or skip_node

                # insert extra attributes before regular attributes
                attributes = extra_attributes_to_print + attributes

            # Don't skip the node if it has children, even if it's flagged for skipping
            if ((not soft_skip) or keep_a_parent_node) and (not skip_node):
                
                node_str = f"{node_role} {repr(node_name.strip())}"

                if not (
                    bid is None
                    or (
                        hide_bid_if_invisible
                        and extra_properties.get(bid, {}).get("visibility", 0) < 0.5
                    )
                    or 
                    soft_skip
                ):
                    bid_str = bid
                else:
                    bid_str = ''

                node_str = f"[{bid_str}] " + node_str

                if node_value is not None:
                    node_str += f' value={repr(node["value"]["value"])}'

                if attributes:
                    node_str += ", ".join([""] + attributes)

                tree_str += f"{indent}{node_str}"

        for child_node_id in node["childIds"]:
            if child_node_id not in node_id_to_idx or child_node_id == node["nodeId"]:
                continue
            # mark this to save some tokens
            child_depth = depth if skip_node else (depth + 1)
            child_str = dfs(
                node_id_to_idx[child_node_id],
                child_depth,
                parent_node_filtered=filter_node,
                parent_node_name=node_name,
            )
            if child_str:
                if tree_str:
                    tree_str += "\n"
                tree_str += child_str

        return tree_str

    tree_str = dfs(0, 0, False, "")
    if remove_redundant_static_text:
        tree_str = _remove_redundant_static_text(tree_str)
    return tree_str

def postprocess_axtree_str(
        llm_provider,
        axtree_txt,
        page_understanding,
        extra_element_properties, 
        screenshot,
        filter_by_visible = True,
        filter_by_clickable = True, 
        filter_by_bbox = False,
        utterance = "", 
        current_step_plan = "",
    ):

    clickable_elements = ['link', 'button', 'input', 'select_toggle', 'checkbox', 'combobox', 'gridcell', 'listbox', 'menu', 'menuitem', 'menuitemcheckbox', 
                          'menuitemradio', 'option', 'radio', 'scrollbar', 'searchbox', 'slider', 'spinbutton', 'switch', 'tab', 'textbox', 'img', ]

    # Update each PU element with 'clickable' value
    for element in page_understanding:
        if element['tagName'] in clickable_elements:
            element['clickable'] = True
        else:
            element['clickable'] = False
    id2pu = {item['id']: item for item in page_understanding} # id to PU element

    # Update the 'clickable' value in extra_element_properties based on the PU dictionary's 'clickable' value
    for key in extra_element_properties:
        if id2pu.get(key):
            extra_element_properties[key]['type'] = id2pu[key]['tagName']
        if (key in id2pu and id2pu[key]['clickable']):
            extra_element_properties[key]['clickable'] = True
        else:
            extra_element_properties[key]['clickable'] = False

    filtered_extra_element_properties = {
        key: element for key, element in extra_element_properties.items()
        if (
            (not filter_by_visible or element.get('visibility', 0) > 0.99) and
            (not filter_by_bbox or element.get('bbox')) and
            (not filter_by_clickable or element.get('clickable'))
        )                   
    }
    
    json_tree = parse_tree_to_json(axtree_txt)  # aXtree to json
    filtered_axtree_json = filter_tree_json(json_tree, [k for k in filtered_extra_element_properties.keys()])   # filter json tree
    filtered_axtree_str =  json_to_ax_tree_string(filtered_axtree_json)   # filtered json tree to aXtree string

    ## Enrich filtered aXtree using vision ##

    screenshot = Image.fromarray(screenshot)    # ndarray to PIL Image

    # Convert the PIL image to a bytes object in memory
    image_bytes = io.BytesIO()
    screenshot.save(image_bytes, format='JPEG')
    image_bytes.seek(0)  # Reset the pointer to the start of the file

    url = SERVER_API_URL+'/enrich_axtree'  # Define the endpoint URL

    if llm_provider:
        # Prepare the payload
        files = {
            "image": ("image.jpg", image_bytes, "image/jpeg"),
        }
        data = {
            "llm_provider": llm_provider,  # Replace with actual LLM provider value
            "filtered_axtree_str": filtered_axtree_str, 
            "bid_elements_str": json.dumps(filtered_extra_element_properties, indent = 2),
            "utterance": utterance, 
            "current_step_plan": current_step_plan,
        }

        # Send the POST request
        response = requests.post(url, files=files, data=data)

        output = json.loads(response.text)['summary']

        # Split by the delimiter ---
        sections = output.split('---')

        # Page Summary is the first section
        page_summary = sections[1].strip()

        # Enriched aXtree String is the second section
        enriched_axtree_str = sections[3].strip()

        # Debug the response if it's not successful
        if response.status_code != 200:
            print(f"Error response: {response.status_code} - {response.text}")
    else:
        page_summary = ""
        enriched_axtree_str = filtered_axtree_str

    return page_summary, enriched_axtree_str

def extract_pu_attributes(pu):
    result = pu
    items_raw = []
    for k, v in result.map.items():
        bid = v.html.attributes.get('bid', None)
        if bid:
            bbox = v.html.boundingRect
            bbox.x = int(bbox.x)
            bbox.y = int(bbox.y)
            bbox.width = int(bbox.width)
            bbox.height = int(bbox.height)
            items_raw.append({
                'element': 'ref <Node>',
                'tagName': v.match.rule.type,
                'id': bid,
                'include': True,
                'area': bbox.width * bbox.height,
                'rects': [{
                    'left': bbox.x,
                    'top': bbox.y,
                    'right': bbox.x + bbox.width,
                    'bottom': bbox.y + bbox.height,
                    'width': bbox.width,
                    'height': bbox.height
                }],
                'text': v.text.value
            })
    return items_raw

def bbox_intersection(bbox1, bbox2):
    # Bounding boxes are defined as (x, y, width, height)
    x1_1, y1_1, w1, h1 = bbox1
    x1_2, y1_2, w2, h2 = bbox2
    
    # Convert (x, y, width, height) to (x1, y1, x2, y2)
    x2_1, y2_1 = x1_1 + w1, y1_1 + h1
    x2_2, y2_2 = x1_2 + w2, y1_2 + h2

    # Calculate intersection
    x1_inter = max(x1_1, x1_2)
    y1_inter = max(y1_1, y1_2)
    x2_inter = min(x2_1, x2_2)
    y2_inter = min(y2_1, y2_2)

    # Check if there's an intersection
    if x1_inter < x2_inter and y1_inter < y2_inter:
        return True
    return False

def bbox_annotator(img, elements, negative=False, thickness=4, dash=(6,3), fontsize=14, dynamic=False, keep_ids=True):

    W, H = img.size
    annotation_positions = []
    img1 = DashedImageDraw(img)
    for ele in elements.values():
        try:
            boundingRect = ele['bbox']
            x0, y0, w0, h0 = boundingRect[0], boundingRect[1], boundingRect[2], boundingRect[3]
            annotation_positions.append((x0, y0, w0, h0))

            rectangle_color = 'yellow' if negative else 'black'
            img1.dashed_rectangle([(x0, y0),(x0+w0, y0+h0)],
                    dash = dash, outline  = rectangle_color, width = thickness)
        except:
            continue
    
    idx = -1
    for key, ele in elements.items():
        
        idx += 1
        try:
            boundingRect = ele['bbox']
            x0, y0, w0, h0 = boundingRect[0], boundingRect[1], boundingRect[2], boundingRect[3]
        except:
            continue

        # Determine text size for background rectangle
        try:
            font = PIL.ImageFont.truetype('/Users/benwiesel/Downloads/tt_hoves_pro/TT Hoves Pro Trial Regular.ttf', size=fontsize)  # specify path and size
        except IOError:
            font = PIL.ImageFont.load_default(size=fontsize)  # Load default font if custom font fails

        text = f"{str(key if keep_ids else idx).zfill(2)}"
        text_width, text_height = img1.textbbox((0, 0), text, font=font)[2:]
        text_width, text_height = text_width, text_height

        # Coordinates for text (usually near the bounding box)
        text_position = (max(x0-1.0*text_width,0), max(y0-1.0*text_height,0))  # TOP LEFT
        # text_position = (max(x0-1.25*text_width,0), max(y0+0.1*text_height,0))  # CENTER LEFT
        text_bbox = [text_position[0] - 0.13*text_width, text_position[1]- 0.13*text_height, 1.3*text_width, 1.5*text_height]
        if dynamic:
            if any(bbox_intersection(bbox, text_bbox) for bbox in annotation_positions):
                text_position = (min(max(x0 + w0 + 0.2*text_width, 0), W), min(max(y0+0.1*text_height, 0), H))  # CENTER RIGHT
                text_bbox = [text_position[0] - 0.2*text_width, text_position[1]- 0.2*text_height, 1.4*text_width, 1.6*text_height]
       
        # Background rectangle coordinates
        background_bbox = [text_bbox[0], text_bbox[1], text_bbox[0]+text_bbox[2], text_bbox[1]+text_bbox[3]]

        # append curr. bbox only after label assignment
        annotation_positions.append((text_bbox[0], text_bbox[1], text_bbox[2], text_bbox[3]))

        rectangle_color = 'yellow' if negative else 'black'
        text_color = 'black' if negative else 'white'

        # Draw background rectangle for text
        img1.rectangle(background_bbox, fill=rectangle_color)

        # Draw the text in white
        img1.text(text_position, text, fill=text_color, font=font)
        
    
    return img


def overlay_som(
    screenshot: np.typing.ArrayLike,
    extra_properties: dict,
    ids_to_annotate,
    dynamic:bool = False,
    keep_ids:bool = False,
    fontsize: int = 12,
    linewidth: int = 2,
    dash = (6,3),
):
    img = PIL.Image.fromarray(screenshot).copy()  # make a copy
    img = img.convert(mode="RGBA")

    negative = analyze_image(img)

    filtered_extra_prop = {key: extra_properties[key] for key in ids_to_annotate if key in extra_properties}

    anotated_image = bbox_annotator(img.copy(), filtered_extra_prop, negative, linewidth, dash, fontsize, dynamic, keep_ids)

    # convert to RGB (3 channels)
    img = anotated_image.convert(mode="RGB")
    # convert to a numpy array
    img = np.array(img)

    return img
