import ast
import base64
import json
import re
from typing import Any, Optional, Union

import cv2
import numpy
import numpy as np
from PIL import Image
from absl.flags import FLAGS
from android_world.env import representation_utils
from openai.types.completion_usage import CompletionUsage

from . import models


def _logical_to_physical(
  logical_coordinates: tuple[int, int],
  logical_screen_size: tuple[int, int],
  physical_frame_boundary: tuple[int, int, int, int],
  orientation: int,
) -> tuple[int, int]:
  """Convert logical coordinates to physical coordinates.

  Args:
    logical_coordinates: The logical coordinates for the point.
    logical_screen_size: The logical screen size.
    physical_frame_boundary: The physical coordinates in portrait orientation
      for the upper left and lower right corner for the frame.
    orientation: The current screen orientation.

  Returns:
    The physical coordinate for the point in portrait orientation.

  Raises:
    ValueError: If the orientation is not valid.
  """
  x, y = logical_coordinates
  px0, py0, px1, py1 = physical_frame_boundary
  px, py = px1 - px0, py1 - py0
  lx, ly = logical_screen_size
  if orientation == 0:
    return (int(x * px / lx) + px0, int(y * py / ly) + py0)
  if orientation == 1:
    return (px - int(y * px / ly) + px0, int(x * py / lx) + py0)
  if orientation == 2:
    return (px - int(x * px / lx) + px0, py - int(y * py / ly) + py0)
  if orientation == 3:
    return (int(y * px / ly) + px0, py - int(x * py / lx) + py0)
  print('Invalid orientation.')
  raise ValueError('Unsupported orientation.')


def _ui_element_logical_corner(
  ui_element: representation_utils.UIElement, orientation: int
) -> list[tuple[int, int]]:
  """Get logical coordinates for corners of a given UI element.

  Args:
    ui_element: The corresponding UI element.
    orientation: The current orientation.

  Returns:
    Logical coordinates for upper left and lower right corner for the UI
    element.

  Raises:
    ValueError: If bounding box is missing.
    ValueError: If orientation is not valid.
  """
  if ui_element.bbox_pixels is None:
    raise ValueError('UI element does not have bounding box.')
  if orientation == 0:
    return [
      (int(ui_element.bbox_pixels.x_min), int(ui_element.bbox_pixels.y_min)),
      (int(ui_element.bbox_pixels.x_max), int(ui_element.bbox_pixels.y_max)),
    ]
  if orientation == 1:
    return [
      (int(ui_element.bbox_pixels.x_min), int(ui_element.bbox_pixels.y_max)),
      (int(ui_element.bbox_pixels.x_max), int(ui_element.bbox_pixels.y_min)),
    ]
  if orientation == 2:
    return [
      (int(ui_element.bbox_pixels.x_max), int(ui_element.bbox_pixels.y_max)),
      (int(ui_element.bbox_pixels.x_min), int(ui_element.bbox_pixels.y_min)),
    ]
  if orientation == 3:
    return [
      (int(ui_element.bbox_pixels.x_max), int(ui_element.bbox_pixels.y_min)),
      (int(ui_element.bbox_pixels.x_min), int(ui_element.bbox_pixels.y_max)),
    ]
  raise ValueError('Unsupported orientation.')


def add_ui_element_mark(
  screenshot: np.ndarray,
  ui_element: representation_utils.UIElement,
  index: int,
  logical_screen_size: tuple[int, int],
  physical_frame_boundary: tuple[int, int, int, int],
  orientation: int,  # Landscape or portrait
):
  """Add mark (a bounding box plus index) for a UI element in the screenshot.

  Args:
    screenshot: The screenshot as a numpy ndarray.
    ui_element: The UI element to be marked.
    index: The numeric index for the UI element.
    logical_screen_size: The logical screen size.
    physical_frame_boundary: The physical coordinates in portrait orientation
      for the upper left and lower right corner for the frame.
    orientation: The current screen orientation.
  """
  ## Check if ui_element has position (bbox_pixels)
  # If so, convert logical coords (layout coords) to physical coords (screen pixels)
  if ui_element.bbox_pixels:
    upper_left_logical, lower_right_logical = _ui_element_logical_corner(
      ui_element, orientation
    )
    upper_left_physical = _logical_to_physical(
      upper_left_logical,
      logical_screen_size,
      physical_frame_boundary,
      orientation,
    )
    lower_right_physical = _logical_to_physical(
      lower_right_logical,
      logical_screen_size,
      physical_frame_boundary,
      orientation,
    )
    
    ## Draw bounding box
    cv2.rectangle(
      screenshot,
      upper_left_physical,
      lower_right_physical,
      color=(0, 255, 0),
      thickness=2,
    )
    ## Mark index inside bounding box
    # Fill a 35x25 white rectangle at top-left of box as background for index
    screenshot[
    upper_left_physical[1] + 1: upper_left_physical[1] + 25,
    upper_left_physical[0] + 1: upper_left_physical[0] + 35,
    :,
    ] = (255, 255, 255)
    # Add index text
    cv2.putText(
      screenshot,
      str(index),
      (
        upper_left_physical[0] + 1,
        upper_left_physical[1] + 20,
      ),
      cv2.FONT_HERSHEY_SIMPLEX,
      0.7,
      (0, 0, 0),
      thickness=2,
    )


def add_coordinate_mark(
  screenshot: np.ndarray,
  x: int,
  y: int,
  logical_screen_size: tuple[int, int],
  physical_frame_boundary: tuple[int, int, int, int],
  orientation: int,
):
  """Add a coordinate mark (circle) on the screenshot at the given coordinate.

  Args:
    screenshot: The screenshot as a numpy ndarray.
    x: X coordinate in logical coordinates.
    y: Y coordinate in logical coordinates.
    logical_screen_size: The logical screen size.
    physical_frame_boundary: The physical coordinates in portrait orientation
      for the upper left and lower right corner for the frame.
    orientation: The current screen orientation.
  """
  # Convert logical coordinates to physical coordinates
  logical_coord = (x, y)
  physical_coord = _logical_to_physical(
    logical_coord,
    logical_screen_size,
    physical_frame_boundary,
    orientation,
  )
  
  # Draw white outer circle for better visibility
  cv2.circle(
    screenshot,
    physical_coord,
    radius=28,
    color=(255, 255, 255),  # White outline
    thickness=4,
  )
  
  # Draw bright yellow circle (high visibility)
  cv2.circle(
    screenshot,
    physical_coord,
    radius=22,
    color=(255, 255, 0),  # Yellow (BGR)
    thickness=-1,  # Filled
  )
  
  # Draw inner black circle for contrast
  cv2.circle(
    screenshot,
    physical_coord,
    radius=10,
    color=(0, 0, 0),
    thickness=-1,  # Filled circle
  )
  
  # Draw center bright dot
  cv2.circle(
    screenshot,
    physical_coord,
    radius=4,
    color=(0, 255, 255),  # Cyan center
    thickness=-1,
  )


def add_swipe_mark(
  screenshot: np.ndarray,
  start_x: int,
  start_y: int,
  end_x: int,
  end_y: int,
  logical_screen_size: tuple[int, int],
  physical_frame_boundary: tuple[int, int, int, int],
  orientation: int,
  action_type: str = 'swipe',  # 'swipe' or 'drag_and_drop'
):
  """Add markers for swipe/drag operations showing start point, end point, and arrow.

  Args:
    screenshot: The screenshot as a numpy ndarray.
    start_x: Start X coordinate in logical coordinates.
    start_y: Start Y coordinate in logical coordinates.
    end_x: End X coordinate in logical coordinates.
    end_y: End Y coordinate in logical coordinates.
    logical_screen_size: The logical screen size.
    physical_frame_boundary: The physical coordinates in portrait orientation
      for the upper left and lower right corner for the frame.
    orientation: The current screen orientation.
    action_type: Type of action ('swipe' or 'drag_and_drop') to determine color scheme.
  """
  # Convert logical coordinates to physical coordinates
  start_logical = (start_x, start_y)
  end_logical = (end_x, end_y)
  start_physical = _logical_to_physical(
    start_logical,
    logical_screen_size,
    physical_frame_boundary,
    orientation,
  )
  end_physical = _logical_to_physical(
    end_logical,
    logical_screen_size,
    physical_frame_boundary,
    orientation,
  )
  
  # Set colors based on action type - using brighter, more visible colors
  if action_type == 'drag_and_drop':
    # Yellow to Magenta for drag_and_drop (high contrast)
    start_color = (255, 255, 0)  # Yellow (BGR)
    end_color = (255, 0, 255)    # Magenta (BGR)
    line_color = (255, 128, 128) # Light purple
    start_label = "FROM"
    end_label = "TO"
  else:
    # Yellow to Cyan for swipe (highly visible on any background)
    start_color = (255, 255, 0)  # Yellow (BGR)
    end_color = (0, 255, 255)    # Cyan (BGR)
    line_color = (128, 255, 255) # Light cyan
    start_label = "START"
    end_label = "END"
  
  # Draw thicker white outline for better visibility
  cv2.arrowedLine(
    screenshot,
    start_physical,
    end_physical,
    color=(255, 255, 255),  # White outline
    thickness=10,
    tipLength=0.15
  )
  
  # Draw colored arrow line on top
  cv2.arrowedLine(
    screenshot,
    start_physical,
    end_physical,
    color=line_color,
    thickness=6,
    tipLength=0.15
  )
  
  # Draw start point with white outline
  cv2.circle(
    screenshot,
    start_physical,
    radius=25,
    color=(255, 255, 255),  # White outline
    thickness=4,
  )
  cv2.circle(
    screenshot,
    start_physical,
    radius=20,
    color=start_color,
    thickness=-1,  # Filled
  )
  # Inner black circle for contrast
  cv2.circle(
    screenshot,
    start_physical,
    radius=8,
    color=(0, 0, 0),
    thickness=-1,
  )
  
  # Draw end point with white outline
  cv2.circle(
    screenshot,
    end_physical,
    radius=25,
    color=(255, 255, 255),  # White outline
    thickness=4,
  )
  cv2.circle(
    screenshot,
    end_physical,
    radius=20,
    color=end_color,
    thickness=-1,  # Filled
  )
  # Inner black circle for contrast
  cv2.circle(
    screenshot,
    end_physical,
    radius=8,
    color=(0, 0, 0),
    thickness=-1,
  )
  
  # Add text labels without background box
  height, width = screenshot.shape[:2]
  font_scale = 1.2
  font_thickness = 3
  
  # Start label
  text_size_start = cv2.getTextSize(start_label, cv2.FONT_HERSHEY_DUPLEX, font_scale, font_thickness)[0]
  label_x_start = start_physical[0] - text_size_start[0] // 2
  label_y_start = start_physical[1] - 40
  # Ensure label is within screenshot bounds
  label_x_start = max(8, min(label_x_start, width - text_size_start[0] - 8))
  label_y_start = max(text_size_start[1] + 12, min(label_y_start, height - 12))
  
  # Draw background box with border
  padding = 10
  # Outer black border
  cv2.rectangle(
    screenshot,
    (label_x_start - padding - 3, label_y_start - text_size_start[1] - padding - 3),
    (label_x_start + text_size_start[0] + padding + 3, label_y_start + padding + 3),
    (0, 0, 0),
    -1
  )
  # White background
  cv2.rectangle(
    screenshot,
    (label_x_start - padding, label_y_start - text_size_start[1] - padding),
    (label_x_start + text_size_start[0] + padding, label_y_start + padding),
    (255, 255, 255),
    -1
  )
  
  # Draw text (no outline)
  cv2.putText(
    screenshot,
    start_label,
    (label_x_start, label_y_start),
    cv2.FONT_HERSHEY_DUPLEX,
    font_scale,
    (0, 0, 0),
    thickness=font_thickness,
  )
  
  # End label
  text_size_end = cv2.getTextSize(end_label, cv2.FONT_HERSHEY_DUPLEX, font_scale, font_thickness)[0]
  label_x_end = end_physical[0] - text_size_end[0] // 2
  label_y_end = end_physical[1] + 40 + text_size_end[1]
  label_x_end = max(8, min(label_x_end, width - text_size_end[0] - 8))
  label_y_end = max(text_size_end[1] + 12, min(label_y_end, height - 12))
  
  # Draw background box with border
  # Outer black border
  cv2.rectangle(
    screenshot,
    (label_x_end - padding - 3, label_y_end - text_size_end[1] - padding - 3),
    (label_x_end + text_size_end[0] + padding + 3, label_y_end + padding + 3),
    (0, 0, 0),
    -1
  )
  # White background
  cv2.rectangle(
    screenshot,
    (label_x_end - padding, label_y_end - text_size_end[1] - padding),
    (label_x_end + text_size_end[0] + padding, label_y_end + padding),
    (255, 255, 255),
    -1
  )
  
  # Draw text (no outline)
  cv2.putText(
    screenshot,
    end_label,
    (label_x_end, label_y_end),
    cv2.FONT_HERSHEY_DUPLEX,
    font_scale,
    (0, 0, 0),
    thickness=font_thickness,
  )


def add_direction_arrow(
  screenshot: np.ndarray,
  direction: str,  # 'up', 'down', 'left', 'right'
  logical_screen_size: tuple[int, int],
  physical_frame_boundary: tuple[int, int, int, int],
  orientation: int,
):
  """Add a large directional arrow at the center of the screen.

  Args:
    screenshot: The screenshot as a numpy ndarray.
    direction: Direction of swipe ('up', 'down', 'left', 'right').
    logical_screen_size: The logical screen size.
    physical_frame_boundary: The physical coordinates in portrait orientation
      for the upper left and lower right corner for the frame.
    orientation: The current screen orientation.
  """
  height, width = screenshot.shape[:2]
  
  # Calculate center point in physical coordinates
  center_x = width // 2
  center_y = height // 2
  
  # Calculate arrow length (35% of screen height for better visibility)
  arrow_length = int(height * 0.35)
  
  # Determine arrow start and end points based on direction
  if direction == 'up':
    start_point = (center_x, center_y + arrow_length // 2)
    end_point = (center_x, center_y - arrow_length // 2)
    label_text = "SWIPE UP"
    label_y_offset = arrow_length // 2 + 50
  elif direction == 'down':
    start_point = (center_x, center_y - arrow_length // 2)
    end_point = (center_x, center_y + arrow_length // 2)
    label_text = "SWIPE DOWN"
    label_y_offset = arrow_length // 2 + 50
  elif direction == 'left':
    start_point = (center_x + arrow_length // 2, center_y)
    end_point = (center_x - arrow_length // 2, center_y)
    label_text = "SWIPE LEFT"
    label_y_offset = 50
  elif direction == 'right':
    start_point = (center_x - arrow_length // 2, center_y)
    end_point = (center_x + arrow_length // 2, center_y)
    label_text = "SWIPE RIGHT"
    label_y_offset = 50
  else:
    # Invalid direction, just mark center
    cv2.circle(screenshot, (center_x, center_y), 30, (0, 255, 255), -1)  # Cyan
    return
  
  # Draw white outline arrow for maximum visibility
  cv2.arrowedLine(
    screenshot,
    start_point,
    end_point,
    color=(255, 255, 255),  # White outline
    thickness=18,
    tipLength=0.25
  )
  
  # Draw bright cyan arrow on top
  cv2.arrowedLine(
    screenshot,
    start_point,
    end_point,
    color=(0, 255, 255),  # Cyan (BGR)
    thickness=12,
    tipLength=0.25
  )
  
  # Draw bright yellow inner arrow for extra emphasis
  cv2.arrowedLine(
    screenshot,
    start_point,
    end_point,
    color=(255, 255, 0),  # Yellow (BGR)
    thickness=6,
    tipLength=0.25
  )
  
  # Add direction text label with border box
  font_scale = 1.5
  font_thickness = 3
  text_size = cv2.getTextSize(label_text, cv2.FONT_HERSHEY_DUPLEX, font_scale, font_thickness)[0]
  label_x = center_x - text_size[0] // 2
  label_y = center_y + label_y_offset
  
  # Ensure text is within bounds
  label_x = max(12, min(label_x, width - text_size[0] - 12))
  label_y = max(text_size[1] + 15, min(label_y, height - 15))
  
  # Draw background box with border
  padding = 15
  # Outer black border (thick)
  cv2.rectangle(
    screenshot,
    (label_x - padding - 4, label_y - text_size[1] - padding - 4),
    (label_x + text_size[0] + padding + 4, label_y + padding + 4),
    (0, 0, 0),
    -1
  )
  # White background
  cv2.rectangle(
    screenshot,
    (label_x - padding, label_y - text_size[1] - padding),
    (label_x + text_size[0] + padding, label_y + padding),
    (255, 255, 255),
    -1
  )
  
  # Draw text (no outline)
  cv2.putText(
    screenshot,
    label_text,
    (label_x, label_y),
    cv2.FONT_HERSHEY_DUPLEX,
    font_scale,
    (0, 0, 0),
    thickness=font_thickness,
  )


def add_screenshot_label(screenshot: np.ndarray, label: str):
  """Add a text label to the right bottom of the screenshot.

  Args:
    screenshot: The screenshot as a numpy ndarray.
    label: The text label to add, just a single word.
  """
  height, width, _ = screenshot.shape
  screenshot[height - 30: height, width - 150: width, :] = (255, 255, 255)
  cv2.putText(
    screenshot,
    label,
    (width - 120, height - 5),
    cv2.FONT_HERSHEY_SIMPLEX,
    1,
    (0, 0, 0),
    thickness=2,
  )


def encode_image_for_html(image: np.ndarray) -> str:
  """Encode image in numpy ndarray to html string with correct color channels.

  Args:
    image: Image as a numpy ndarray.

  Returns:
    Encoded image to be used in html.
  """
  return base64.b64encode(
    cv2.imencode('.jpeg', cv2.cvtColor(image, cv2.COLOR_BGR2RGB))[1]
  ).decode('utf-8')


def clean_output(response: str) -> str:
  response = response.strip()
  if response.startswith("```json") and response.endswith("```"):
    response = response[7:-3].strip()
  return response


def _generate_screenshot_table(task_result: dict[str, Any], i: int) -> str:
  """Generate html string for the screenshot analysis table.

  Args:
    task_result: Task run result by Agent_base1.
    i: The index of the step.

  Returns:
    Html string for the screenshot analysis table.
  """
  html_str = (
    "<table style='width:100%;'><caption"
    " style='caption-side:top;text-align:left;'>Screenshot Analysis</caption>"
  )
  
  # Column for the raw screenshot
  if task_result['episode_data']['raw_screenshot'][i] is not None:
    encoded_raw_screenshot = encode_image_for_html(
      task_result['episode_data']['raw_screenshot'][i]
    )
    html_str += f"""
      <tr>
        <td style='text-align:center;'>
          Before Screenshot (raw):<br>
          <img src="data:image/png;base64,{encoded_raw_screenshot}" alt="Raw Screenshot" width="324" height="720">
        </td>
    """
  
  # Column for the screenshot before actions with marks
  if task_result['episode_data']['before_screenshot_with_som'][i] is not None:
    encoded_before_screenshot = encode_image_for_html(
      task_result['episode_data']['before_screenshot_with_som'][i]
    )
    html_str += f"""
        <td style='text-align:center;'>
          Before Screenshot with marks:<br>
          <img src="data:image/png;base64,{encoded_before_screenshot}" alt="Before Screenshot with Marks" width="324" height="720">
        </td>
    """
  
  # Column for the screenshot after actions with marks
  if task_result['episode_data']['after_screenshot_with_som'][i] is not None:
    encoded_after_screenshot = encode_image_for_html(
      task_result['episode_data']['after_screenshot_with_som'][i]
    )
    html_str += f"""
        <td style='text-align:center;'>
          After Screenshot with marks:<br>
          <img src="data:image/png;base64,{encoded_after_screenshot}" alt="After Screenshot with Marks" width="324" height="720">
        </td>
      </tr>
    """
  
  html_str += '</table>'
  return html_str


def validate_ui_element(
  ui_element: representation_utils.UIElement,
  screen_width_height_px: tuple[int, int],
) -> bool:
  """Used to filter out invalid UI element."""
  screen_width, screen_height = screen_width_height_px
  
  # Filters out invisible element.
  if not ui_element.is_visible:
    return False
  
  # Filters out element with invalid bounding box.
  if ui_element.bbox_pixels:
    x_min = ui_element.bbox_pixels.x_min
    x_max = ui_element.bbox_pixels.x_max
    y_min = ui_element.bbox_pixels.y_min
    y_max = ui_element.bbox_pixels.y_max
    
    if (
      x_min >= x_max
      or x_min >= screen_width
      or x_max <= 0
      or y_min >= y_max
      or y_min >= screen_height
      or y_max <= 0
    ):
      return False
  
  return True


def find_element_at_coordinate(
  x: int,
  y: int,
  ui_content_full_dict: list[models.UIElementDict],
  screen_size: Optional[tuple[int, int]] = None
) -> Optional[models.UIElementDict]:
  """
  Find the UI element that contains the given coordinate (x, y).
  
  Selection strategy (in priority order):
  1. Prefer clickable elements
  2. Prefer smaller elements (more precise)
  3. Prefer elements with higher z-index (appear on top)
  
  Args:
    x: X coordinate
    y: Y coordinate
    ui_content_full_dict: List of UI element dictionaries (from obs.ui_content_full_dict)
    screen_size: Optional screen size (width, height) for validation
    
  Returns:
    Dictionary of the best matching element, or None if no element contains the coordinate
  """
  if not ui_content_full_dict:
    return None
  
  candidates = []
  
  for element in ui_content_full_dict:
    # Check if element has bounding box information
    bbox = element.get('bbox_pixels')
    if not bbox:
      continue
    
    # Extract bbox coordinates
    # bbox can be a dict with x_min, y_min, x_max, y_max
    if isinstance(bbox, dict):
      x_min = bbox.get('x_min')
      y_min = bbox.get('y_min')
      x_max = bbox.get('x_max')
      y_max = bbox.get('y_max')
    else:
      # If bbox is an object with attributes
      x_min = getattr(bbox, 'x_min', None)
      y_min = getattr(bbox, 'y_min', None)
      x_max = getattr(bbox, 'x_max', None)
      y_max = getattr(bbox, 'y_max', None)
    
    if x_min is None or y_min is None or x_max is None or y_max is None:
      continue
    
    # Check if coordinate is within bbox
    if x_min <= x <= x_max and y_min <= y <= y_max:
      # Calculate element area for precision scoring
      area = (x_max - x_min) * (y_max - y_min)
      
      # Check if element is clickable
      is_clickable = element.get('is_clickable', False)
      
      candidates.append({
        'element': element,
        'area': area,
        'is_clickable': is_clickable,
        'index': element.get('index', -1)
      })
  
  if not candidates:
    return None
  
  # Sort candidates by priority:
  # 1. Clickable elements first
  # 2. Smaller area (more precise)
  # 3. Higher index (typically appears later, might be on top)
  candidates.sort(key=lambda c: (
    not c['is_clickable'],  # False (clickable) comes before True (not clickable)
    c['area'],  # Smaller area first
    -c['index']  # Higher index first (negative for descending)
  ))

  print_with_color('\n==== Element Candidates (for Coordinate Matching) ====', 'cyan')
  for idx, c in enumerate(candidates):
    print_with_color(
      f"  [{idx + 1}] Index: {c['index']} | Clickable: {c['is_clickable']} | Area: {c['area']} | Element: {c['element']}", 
      'cyan'
    )
  print_with_color('===============================================\n', 'cyan')
  
  return candidates[0]['element']


def is_triple_equal(
  str1: str,
  str2: str,
  str3: str
) -> bool:
  return str1 == str2 == str3


def store_image(image_data: numpy.ndarray, image_name: str, file_path):
  image = Image.fromarray(np.uint8(image_data))
  image_path = os.path.join(file_path, image_name)
  image.save(image_path)
  return image_path


def load_image_as_ndarray(image_path):
  """Load an image as a numpy array.
  
  Args:
    image_path: Path to the image file (absolute or relative).
    
  Returns:
    Numpy array of the image or None if failed
  """
  from pathlib import Path
  
  try:
    image_path_obj = Path(image_path)

    # Handle relative paths with multiple resolution strategies.
    # IMPORTANT: runtime log artifacts may be written relative to the process CWD
    # (e.g., when running from ./examples), while older code assumed project_root.
    # We keep storage location unchanged and only improve path resolution here.
    if not image_path_obj.is_absolute():
      candidates: list[Path] = []

      # 1) Prefer resolving relative to the experiment log root parent directory.
      # Example:
      #   FLAGS.log_folder_exp = "examples/log/autorpa_.../"
      #   image_path = "log/.../step_4_....png"
      # => parent("examples/log/...") == "examples", so "examples/log/..." exists.
      try:
        from absl import flags as _flags  # local import to avoid import cycles
        _FLAGS = _flags.FLAGS
        log_folder_exp = getattr(_FLAGS, "log_folder_exp", "") or ""
        if log_folder_exp:
          candidates.append(Path(log_folder_exp).parent / image_path_obj)
      except Exception:
        # Flags may not be initialized in some contexts; ignore and fall back.
        pass

      # 2) Backward-compatible fallback: resolve relative to project root (main.py dir).
      project_root = Path(__file__).parent.parent.parent
      candidates.append(project_root / image_path_obj)

      # 3) If the path happens to be relative to current working directory, try that too.
      candidates.append(Path.cwd() / image_path_obj)

      resolved = next((p for p in candidates if p.exists()), None)
      if resolved is None:
        # Keep old behavior of printing the fully resolved (project root) path first,
        # but also include the attempted roots to make debugging easier.
        print_with_color(
          "Image file does not exist (tried): "
          + ", ".join(str(p) for p in candidates),
          'red'
        )
        return None
      image_path_obj = resolved

    if not image_path_obj.exists():
      print_with_color(f"Image file does not exist: {image_path_obj}", 'red')
      return None

    image = Image.open(str(image_path_obj))
    return np.array(image)
  except Exception as e:
    print_with_color(f"Error loading image {image_path}: {e}", 'red')
    import traceback
    traceback.print_exc()
    return None


def save_image(image_array: np.ndarray, save_path: str):
  """Save a numpy array as an image file.
  
  Args:
    image_array: The image as a numpy array.
    save_path: The path where the image should be saved.
  """
  image = Image.fromarray(image_array)
  image.save(save_path)


def _generate_ui_element_description_from_element(
  ui_element: representation_utils.UIElement, index: int
) -> str:
  """Generate a description for a given UI element with important information.

  Args:
    ui_element: UI elements for the current screen.
    index: The numeric index for the UI element.

  Returns:
    The description for the UI element.
  """
  element_description = f'{{"index": {index}, '
  if ui_element.text:  # Text shown by the UI element
    value = ui_element.text.replace('"', "'")
    element_description += f'"text": "{value}", '
  if ui_element.content_description:  # For images, icons, etc.; describes meaning or function
    value = ui_element.content_description.replace('"', "'")
    element_description += (
      f'"content_description": "{value}", '
    )
  if ui_element.hint_text:
    value = ui_element.hint_text.replace('"', "'")
    element_description += f'"hint_text": "{value}", '
  if ui_element.tooltip:
    value = ui_element.tooltip.replace('"', "'")
    element_description += f'"tooltip": "{value}", '
  
  actions = []
  if ui_element.is_clickable: actions.append('click')  # Default clickable; noted in prompt
  if ui_element.is_long_clickable: actions.append('long_press')
  if ui_element.is_editable: actions.append('input_text')
  if ui_element.is_scrollable: actions.append('swipe')
  if actions:
    element_description += '"actions": ' + str(actions).replace("'", '"') + ', '
  
  if ui_element.is_selected:
    element_description += '"is_selected": True, '
  if ui_element.is_checkable:  # Add is_checked only when checkable
    element_description += (
      f'"is_checked": {"True" if ui_element.is_checked else "False"}, '
    )
  return element_description[:-2] + '}'  # Remove trailing comma and space


def _generate_ui_element_description_from_dict(
  elem_dict: dict[str, Any], index: Optional[int] = None
) -> str:
  """Generate a description for a UI element from dictionary format.

  Args:
    elem_dict: UI element dictionary.
    index: Optional index override. If not provided, uses elem_dict.get('index').

  Returns:
    The description for the UI element.
  """
  # Use provided index or get from dict
  elem_index = index if index is not None else elem_dict.get('index', -1)
  element_description = f'{{"index": {elem_index}, '
  
  # Text-like fields
  if elem_dict.get('text'):
    value = str(elem_dict['text']).replace('"', "'")
    element_description += f'"text": "{value}", '
  if elem_dict.get('content_description'):
    value = str(elem_dict['content_description']).replace('"', "'")
    element_description += f'"content_description": "{value}", '
  if elem_dict.get('hint_text'):
    value = str(elem_dict['hint_text']).replace('"', "'")
    element_description += f'"hint_text": "{value}", '
  if elem_dict.get('tooltip'):
    value = str(elem_dict['tooltip']).replace('"', "'")
    element_description += f'"tooltip": "{value}", '
  
  # Additional actions (from dict or derive from flags)
  actions = elem_dict.get('actions', [])
  if not actions:
    # Derive from flags if not present
    if elem_dict.get('is_long_clickable'):
      actions.append('long_press')
    if elem_dict.get('is_editable'):
      actions.append('input_text')
    if elem_dict.get('is_scrollable'):
      actions.append('swipe')
  if actions:
    element_description += '"actions": ' + str(actions).replace("'", '"') + ', '
  
  # State flags
  if elem_dict.get('is_selected'):
    element_description += '"is_selected": True, '
  if elem_dict.get('is_checkable'):  # Add is_checked only when checkable
    is_checked = elem_dict.get('is_checked', False)
    element_description += f'"is_checked": {"True" if is_checked else "False"}, '
  
  return element_description[:-2] + '}'  # Remove trailing comma and space


def _generate_ui_elements_description_str(
  ui_elements: Union[list[representation_utils.UIElement], list[dict[str, Any]]],
  screen_width_height_px: Optional[tuple[int, int]] = None,
  target_index: Optional[int] = None,
) -> str:
  """Generates a concise description of UI elements.

  Supports both UIElement objects and dictionary format (ui_content_full_dict).
  Iterates through the given list of UI elements, validates them,
  and generates a corresponding description for each valid element.
  The function skips certain elements from the Android virtual keyboard
  (e.g., Shift key, punctuation symbols) to avoid redundant information.

  Args:
      ui_elements: List of UI elements, either:
          - list[representation_utils.UIElement]: Raw UI element objects
          - list[dict]: UI element dictionaries (from ui_content_full_dict)
      screen_width_height_px: Screen width and height in pixels.
          Required when ui_elements is list[UIElement], optional for dict format.

  Returns:
      str: A string containing descriptions of valid UI elements,
           separated by new lines.
  """
  tree_info = ''
  
  # Check if input is list of dicts or list of UIElement objects
  if not ui_elements:
    return ''
  
  is_dict_format = isinstance(ui_elements[0], dict)
  
  if is_dict_format:
    # Handle dictionary format (ui_content_full_dict)
    for elem_dict in ui_elements:
      if not isinstance(elem_dict, dict):
        continue
      
      index = elem_dict.get('index', -1)
      if target_index is not None and index != target_index:
        continue

      # Skip specific elements from the Android Gboard virtual keyboard
      package_name = elem_dict.get('package_name', '')
      content_description = elem_dict.get('content_description', '')
      if package_name == 'com.google.android.inputmethod.latin' and content_description:
        con = str(content_description)
        # Ignore single letters (uppercase or lowercase), Shift key, Symbol keyboard, comma, and period
        if (len(con) == 1 and con.isalpha()) or con in ['Shift', 'Symbol keyboard', ',', '.']:
          continue
      
      # Generate description from dict
      tree_info += _generate_ui_element_description_from_dict(elem_dict, index) + '\n'
  else:
    # Handle UIElement objects format
    if screen_width_height_px is None:
      raise ValueError("screen_width_height_px is required when ui_elements is list[UIElement]")
    
    for index, ui_element in enumerate(ui_elements):
      if target_index is not None and index != target_index:
        continue
      # Validate the UI element
      if validate_ui_element(ui_element, screen_width_height_px):
        # Skip specific elements from the Android Gboard virtual keyboard
        if ui_element.package_name == 'com.google.android.inputmethod.latin' and ui_element.content_description:
          con = ui_element.content_description
          # Ignore single letters (uppercase or lowercase), Shift key, Symbol keyboard, comma, and period
          if (len(con) == 1 and con.isalpha()) or con in ['Shift', 'Symbol keyboard', ',', '.']:
            continue
        # Generate the UI element's description and append it to tree_info
        tree_info += _generate_ui_element_description_from_element(ui_element, index) + '\n'
  
  return tree_info


def project_ui_elements_to_full_dict(
  ui_elements: list[Any],
  screen_width_height_px: tuple[int, int],
) -> list[models.UIElementDict]:
  """
  Project raw UI element objects into a stable, JSON-serializable list[dict],
  including bbox info for engineering use (coordinate matching, etc.).
  """
  projected: list[dict] = []
  for index, ui_element in enumerate(ui_elements):
    if not validate_ui_element(ui_element, screen_width_height_px):
      continue
    # Skip specific elements from the Android Gboard virtual keyboard (keep parity with ui_content str)
    if getattr(ui_element, "package_name", None) == 'com.google.android.inputmethod.latin' and getattr(ui_element, "content_description", None):
      con = getattr(ui_element, "content_description", "")
      if (len(con) == 1 and con.isalpha()) or con in ['Shift', 'Symbol keyboard', ',', '.']:
        continue
    elem: dict[str, Any] = {"index": index}
    # Text-like fields
    for k in ["text", "content_description", "hint_text", "tooltip", "resource_id", "class_name", "package_name"]:
      v = getattr(ui_element, k, None)
      if v is not None and v != "":
        elem[k] = v
    # State flags (keep only when present/True to reduce size)
    for k in [
      "is_clickable", "is_long_clickable", "is_scrollable", "is_checkable", "is_checked",
      "is_enabled", "is_focusable", "is_focused", "is_password", "is_selected", "is_editable", "is_visible"
    ]:
      v = getattr(ui_element, k, None)
      if v is not None:
        elem[k] = bool(v)
    # actions: keep same semantics as old string format
    actions: list[str] = []
    if getattr(ui_element, "is_long_clickable", False):
      actions.append("long_press")
    if getattr(ui_element, "is_editable", False):
      actions.append("input_text")
    if getattr(ui_element, "is_scrollable", False):
      actions.append("swipe")
    if actions:
      elem["actions"] = actions
    # bbox
    bbox = getattr(ui_element, "bbox_pixels", None)
    if bbox is not None:
      elem["bbox_pixels"] = {
        "x_min": int(getattr(bbox, "x_min", 0)),
        "y_min": int(getattr(bbox, "y_min", 0)),
        "x_max": int(getattr(bbox, "x_max", 0)),
        "y_max": int(getattr(bbox, "y_max", 0)),
      }
    projected.append(elem)
  return projected


def write_to_file(
  file_path: str,
  file_name: str,
  content: Any
):
  os.makedirs(file_path, exist_ok=True)
  file_path = os.path.join(file_path, file_name)
  with open(file_path, 'w', encoding="utf-8") as file:
    file.write(str(content))
    file.flush()


def save_json(
  obj: Any,
  save_path: str,
  file_name: str
):
  """Save an object to JSON file.
  
  Args:
    obj: The object to save (should have model_dump() method for Pydantic models).
    save_path: Directory path to save the file.
    file_name: Name of the JSON file.
  """
  import os
  os.makedirs(save_path, exist_ok=True)
  file_path = os.path.join(save_path, file_name)
  
  def default_serializer(o):
    if hasattr(o, 'model_dump'):
      return o.model_dump()
    elif hasattr(o, 'dict'):
      return o.dict()
    raise TypeError(f'Object of type {o.__class__.__name__} is not JSON serializable')
  
  with open(file_path, 'w', encoding='utf-8') as f:
    json.dump(obj, f, indent=2, default=default_serializer, ensure_ascii=False)
    f.flush()


import pandas as pd
import os


# Function to add data and update the file
def record_cost_tokens(record_token: models.RecordToken):
  # Ensure the file_path exists
  os.makedirs(record_token.file_path, exist_ok=True)
  
  # File name and full path
  file_name = 'step_tokens.csv'
  full_file_path = os.path.join(record_token.file_path, file_name)
  
  # Initialize the table. If the file does not exist, create it
  if not os.path.exists(full_file_path):
    columns = ['Task Type', 'Task Num', 'Attempt', 'Stage', 'Step', 'Agent', 'Input', 'Output', 'Total', 'Cached', "LLM"]
    table = pd.DataFrame(columns=columns)
    table.to_csv(full_file_path, index=False)
  else:
    # If the file exists, load the table
    table = pd.read_csv(full_file_path)
  
  # usage=Usage(cache_creation_input_tokens=0, cache_read_input_tokens=0, input_tokens=3828, output_tokens=414)

  if isinstance(record_token.step_tokens, CompletionUsage):
    input_tokens = record_token.step_tokens.prompt_tokens
    output_tokens = record_token.step_tokens.completion_tokens
    total_tokens = record_token.step_tokens.total_tokens
    cached_tokens = record_token.step_tokens.prompt_tokens_details.cached_tokens if record_token.step_tokens.prompt_tokens_details else 0
  else:
    input_tokens = record_token.step_tokens.input_tokens
    output_tokens = record_token.step_tokens.output_tokens
    total_tokens = input_tokens + output_tokens
    cached_tokens = record_token.step_tokens.cache_creation_input_tokens + record_token.step_tokens.cache_read_input_tokens

  # Add a new row as a DataFrame
  new_row = pd.DataFrame([{
    'Task Type': record_token.task_type,
    'Task Num': record_token.task_num,
    'Attempt': str(FLAGS.cur_attempt_cnt),
    # 'Round': record_token.round,
    'Stage': record_token.stage,
    'Step': record_token.step,
    'Agent': record_token.agent,
    'Input': input_tokens,
    'Output': output_tokens,
    'Total': total_tokens,
    'Cached': cached_tokens,
    'LLM': record_token.llm,
  }])
  
  # Use pd.concat to combine the old table and the new row
  table = pd.concat([table, new_row], ignore_index=True)
  
  # Save the updated table to the file
  table.to_csv(full_file_path, index=False)


def match_actions_to_code(executed_actions: str, code: str) -> str:
  executed_lines = [line.strip() for line in executed_actions.strip().split("\n") if line.strip()]
  code_lines = code.strip().split("\n")  # Preserve original formatting, including indentation and blank lines
  
  matched_code = []
  action_index = 0
  
  def strip_params(action: str) -> str:
    """Remove parameters from a function call for loose matching."""
    return re.sub(r"\(.*?\)", "()", action)
  
  executed_lines_stripped = [strip_params(line) for line in executed_lines]
  
  for line in code_lines:
    if "```" in line:
      continue
    if line.strip().startswith('env_op') and action_index < len(executed_lines_stripped) and executed_lines_stripped[
      action_index] not in strip_params(line.strip()):
      break
    # Add lines to matched code
    matched_code.append(line)
    
    # If the stripped line matches the current action, move to the next action
    if (action_index < len(executed_lines_stripped) and
      executed_lines_stripped[action_index] in strip_params(line.strip())):
      action_index += 1
    # If all actions are matched, break
    if action_index == len(executed_lines_stripped):
      break
  
  # Verify all actions were matched
  # if action_index != len(executed_lines_stripped):
  #     return f"No match found. Unmatched actions: {executed_lines[action_index:]}"
  
  # Return matched code as a string, preserving original formatting
  return "\n".join(matched_code)


def extract_ui_value(soft_action: str, action_related_elements: str, index: int = None, verbose: bool = False) -> str:
  """
  Extract the value of the corresponding element from the kwargs of the soft-coded action output by the planner based on the key.
  
  soft_action: soft-coded action output by planner
  indexes: Extract indexes[0]. Currently, there will only be one index, so there is no need to consider multiple indexes for now.
  verbose: Whether to print detailed extraction info
  """
  if index is None:
    return soft_action
  
  # 1. Extract kwargs from the action string
  # Support both single braces: kwargs = {...} and double braces: kwargs = {{...}}
  # First, try to match the entire kwargs line (including double braces)
  kwargs_line_match = re.search(r"kwargs\s*=\s*([^\n#]+?)(?:\s*#.*)?\n", soft_action, re.DOTALL)
  if kwargs_line_match:
    kwargs_str = kwargs_line_match.group(1).strip()
    
    # Handle double braces: remove outer braces if present
    # e.g., {{...}} -> {...}
    if kwargs_str.startswith("{{") and kwargs_str.endswith("}}"):
      kwargs_str = kwargs_str[1:-1]  # Remove outer braces
    elif not kwargs_str.startswith("{"):
      # If it doesn't start with {, try to find the dict part
      brace_match = re.search(r"(\{.*\})", kwargs_str, re.DOTALL)
      if brace_match:
        kwargs_str = brace_match.group(1)
        if kwargs_str.startswith("{{") and kwargs_str.endswith("}}"):
          kwargs_str = kwargs_str[1:-1]
  else:
    kwargs_str = None
    
  if kwargs_str:
    kwargs_str = kwargs_str.replace("true", "True").replace("false", "False").replace("null", "None")
    
    # Try to parse as Python dict using ast.literal_eval
    try:
      kwargs = ast.literal_eval(kwargs_str)
    except (ValueError, SyntaxError, TypeError) as e:
      # If ast.literal_eval fails, try using json.loads as fallback
      # First, convert Python-style booleans and None to JSON format
      json_str = kwargs_str.replace("True", "true").replace("False", "false").replace("None", "null")
      try:
        kwargs = json.loads(json_str)
      except json.JSONDecodeError:
        # If both methods fail, raise the original error with more context
        raise ValueError(f"Failed to parse kwargs string: {kwargs_str}\nOriginal error: {e}")
  else:
    raise ValueError("Failed to extract kwargs from action string.")
  
  if verbose:
    print_with_color('🔍 Extracting UI values from related elements:', 'cyan')
    print(f'   {action_related_elements}')
  
  # 2. Parse the UI list
  action_related_elements = action_related_elements.strip()
  ui_elements = parse_str_to_jsonlist(action_related_elements)
  
  # 3. Retrieve the target UI element line
  target_element = next((item for item in ui_elements if item["index"] == index), {})
  
  # 4. Filter and update kwargs
  # Override kwargs with additional_actions from target_element
  # updated_kwargs = {
  #   key: target_element[key] if key in target_element else kwargs[key]
  #   for key in kwargs if key == "target_description" or key in target_element
  # }
  
  # Trust additional_actions from LLM-generated kwargs
  updated_kwargs = {
    key: (
      kwargs[key] if key == "actions"
      else target_element[key] if key in target_element
      else kwargs[key]
    )
    for key in kwargs
    if key == "target_description" or key in target_element or key == "actions"
  }
  
  # 5. Format updated_kwargs as a string
  updated_kwargs_str = json.dumps(updated_kwargs, indent=4, ensure_ascii=False)
  
  # 6. Replace the kwargs section in the action
  # Support both single braces: kwargs = {...} and double braces: kwargs = {{...}}
  # Match the entire kwargs line (including any braces and comments)
  updated_action = re.sub(r"kwargs\s*=\s*[^\n#]+(?:\s*#.*)?\n", f"kwargs = {updated_kwargs_str}\n", soft_action, flags=re.DOTALL)
  
  return updated_action


def parse_str_to_jsonlist(ui_content: str) -> list[dict]:
  # Replace boolean values for json convertion
  ui_content = re.sub(r'\bTrue\b', 'true', ui_content)
  ui_content = re.sub(r'\bFalse\b', 'false', ui_content)
  
  lines = ui_content.strip().split('\n')
  fixed_lines = []
  temp_line = ""
  
  for line in lines:
    temp_line += line.strip()  # Accumulate line content
    try:
      # Try parsing the JSON; if successful, add to fixed_lines
      json.loads(temp_line)
      fixed_lines.append(temp_line)
      temp_line = ""  # Clear the temporary accumulation
    except json.JSONDecodeError:
      temp_line += " "  # If it fails, continue accumulating
  
  # Parse the repaired JSON data
  ui_content = []
  for line in fixed_lines:
    try:
      ui_content.append(json.loads(line))
    except json.JSONDecodeError as e:
      print(f"Error decoding line: {line}")
      print(f"Error details: {e}")
  
  return ui_content


def print_with_color(message: Any, color: str) -> None:
  """Prints a message to the console with the specified color.

  Args:
      message: The message to print. It will be converted to a string if not already.
      color: The color to use for the message. Supported colors are 'red', 'green', 'yellow', 'blue', 'magenta', 'cyan', 'white', 'light_gray', 'dark_gray', 'light_red', and 'light_green'.
  """
  color_codes = {
    'red': '\033[91m',
    'green': '\033[92m',
    'yellow': '\033[93m',
    'blue': '\033[94m',
    'magenta': '\033[95m',
    'cyan': '\033[96m',
    'light_gray': '\033[37m',
    'dark_gray': '\033[90m',
    'light_red': '\033[91;1m',
    'light_green': '\033[92;1m',
  }
  reset_code = '\033[0m'
  color_code = color_codes.get(color, reset_code)
  print(f"{color_code}{message}{reset_code}")


def record_exp_result(file_path: str, exp_result_data: dict):
  df = pd.DataFrame([exp_result_data])
  write_header = not os.path.exists(file_path)
  df.to_csv(file_path, mode='a', index=False, header=write_header)


def extract_function_names(text):
  results = []
  for line in text.splitlines():
    line = line.strip()
    if not line:
      continue  # Skip empty lines
    if line.startswith("#") or line.startswith("```"):
      continue  # Skip comment and code-block marker lines
    
    match = re.search(r'([^\s(]+)\(', line)
    if match:
      results.append(match.group(1))
  return results[0]