import io
import os
import logging
import numpy as np
import playwright.sync_api
import requests
import time
from PIL import Image
import json

from typing import List, Dict, Any


MARK_FRAMES_MAX_TRIES = 3


logger = logging.getLogger(__name__)

SERVER_HOST = os.getenv("SERVER_HOST")
SERVER_PORT = os.getenv("SERVER_PORT")
SERVER_API_URL = f'http://{SERVER_HOST}:{SERVER_PORT}' 

def extract_dom_elements(
    page: playwright.sync_api.Page,
    context: playwright.sync_api.BrowserContext, 
    extension: bool = True,
    fix_color: bool = True,
):
    """
    Extract DOM elements from a Playwright page using a JavaScript script, 
    with browsergym IDs used as element identifiers if available.
    
    Args:
        page: The Playwright page object from which to extract the DOM elements.
        fix_color: If True, applies a fixed color to highlight elements. If False, random colors are applied.
        include_min_area: Minimum area (in pixels) for an element to be included in the extraction.
    
    Returns:
        A list of DOM elements that meet the specified criteria, along with their bounding boxes and other information.
    """
    from pu_enhanced.main import analyze_current_page_sync
    import yaml
    from benchmark.webarenasafe.page_understanding.js_element_extraction import js_script_template

    rules = yaml.safe_load(open(os.path.join("/Users/benwiesel/Projects/General/NL2UI-Runtime/server/page_understanding/rules.yaml")))

    # JavaScript script to run on the page
    js_script = js_script_template

    if not extension:
        # Run the JavaScript on the page and return the items
        items_raw = page.evaluate(js_script)
    else:
        result = analyze_current_page_sync(context, rules=rules)
        items_raw=[]
        for k,v in result.map.items():
            bid = v.html.attributes.get('bid', None)
            if bid:
                bbox = v.html.boundingRect
                bbox.x = int(bbox.x)
                bbox.y = int(bbox.y)
                bbox.width = int(bbox.width)
                bbox.height = int(bbox.height)
                items_raw.append({
                    'element': 'ref <Node>',
                    'tagName': v.match.rule.type,
                    'id': v.html.attributes['bid'],
                    'include': True,
                    'area': bbox.width * bbox.height,
                    'rects': [{'left': bbox.x, 'top': bbox.y, 'right': bbox.x+bbox.width, 'bottom': bbox.y+bbox.height, 'width': bbox.width, 'height': bbox.height}],
                    'text': v.text.value})
    
    return items_raw



def extract_visual_hierarchy(
    items_raw: List[Dict[str, Any]],
    screenshot: np.ndarray,
    vision: bool,
    extra_properties: Dict,
):
    """

    """
    
    # ndarray to PIL Image
    screenshot = Image.fromarray(screenshot)

    # Convert the PIL image to a bytes object in memory
    image_bytes = io.BytesIO()
    screenshot.save(image_bytes, format='JPEG')
    image_bytes.seek(0)  # Reset the pointer to the start of the file

    inventory_json_str = json.dumps(items_raw)

    if vision:
        filtered = {k: v for k, v in extra_properties.items() if v['bbox'] and v['visibility']>0.99}
        # filtered = {k: v for k, v in extra_properties.items() if v['bbox'] and v['visibility']>0.99 and v['clickable']}
        bid_elements = json.dumps(filtered, indent=2)

    url = SERVER_API_URL+'/js_inventory'    # Define the endpoint URL

    # Prepare the payload
    files = {
        "image": ("image.jpg", image_bytes, "image/jpeg"),
    }
    data = {
        "llm_provider": "openai",  # Replace with actual LLM provider value
        "inventory": inventory_json_str,  # Send the inventory as JSON string
        "vision": "true",
        "verify_all_ids": "false",
        "bid_elements_str": bid_elements,
        "axtree_format": True,
    }

    # Send the POST request
    response = requests.post(url, files=files, data=data)

    output = json.loads(response.text)['summary']

    # Split by the delimiter ---
    sections = output.split('---')

    # Page Summary is the first section
    page_summary = sections[1].strip()

    # Enriched aXtree String is the second section
    hierarchy_json = sections[3].strip()

    # Debug the response if it's not successful
    if response.status_code != 200:
        print(f"Error response: {response.status_code} - {response.text}")

    return page_summary, hierarchy_json


def extract_interactable_elements(page: playwright.sync_api.Page,
                                  extra_properties,
                                  grid_step=20):
    # Get the dimensions of the page
    viewport_size = page.evaluate('''() => {
        return {
            width: window.innerWidth,
            height: window.innerHeight
        }
    }''')
    WIDTH = viewport_size['width']
    HEIGHT = viewport_size['height']
    
    i = 1
    for k, v in extra_properties.items():
        
        bbox = v.get('bbox', None)
        
        if bbox:
            x0, y0, width, height = v["bbox"]
            x0, y0, width, height = x0/2, y0/2, width/2, height/2
            i+=1
            if i > 5000:
                 break
            x, y = x0 + width/2, y0 + height/2

            if x > WIDTH or y > HEIGHT or x < 0 or y < 0:
                continue

            cursor_style=check_cursor_style_at_location(page,x, y)
            highlight_location(page, x, y)

            print(f"{cursor_style} ({x}, {y})")

            # time.sleep(0.05)

    return 1

def check_cursor_style_at_location(page, x: int, y: int) -> str:
    """
    Checks the cursor style at the specified (x, y) coordinates without moving the mouse.
    
    Args:
        page (Page): The Playwright page object.
        x (int): The x-coordinate on the page.
        y (int): The y-coordinate on the page.
    
    Returns:
        str: The cursor style at the specified location (e.g., 'pointer', 'default').
    """
    # Evaluate the cursor style at the specified location (x, y) on the page
    cursor_style = page.evaluate(f'''
        (function() {{
            let element = document.elementFromPoint({x}, {y});
            if (element) {{
                return window.getComputedStyle(element).cursor;
            }}
            return 'none'; // If no element found, return 'none'
        }})()
    ''')
    
    return cursor_style

def highlight_location(page, x, y, width=10, height=10):
    """
    Highlights a specified location on the page by adding a temporary overlay at the given coordinates.
    """
    page.evaluate(f'''
        let highlightDiv = document.createElement('div');
        highlightDiv.style.position = 'absolute';
        highlightDiv.style.left = '{x}px';
        highlightDiv.style.top = '{y}px';
        highlightDiv.style.width = '{width}px';
        highlightDiv.style.height = '{height}px';
        highlightDiv.style.backgroundColor = 'rgba(255, 0, 0, 0.5)';
        highlightDiv.style.zIndex = '10000';
        highlightDiv.style.border = '2px solid red';
        highlightDiv.id = 'highlight-box';
        document.body.appendChild(highlightDiv);
    ''') 