import ast
import json
import os
import re
import time
import traceback
from copy import deepcopy
from typing import Optional

import numpy as np
from absl import flags
from lxml import etree, html
from selenium.webdriver.common.keys import Keys

from UIAgents.Agent_RPA.utils.JSON_models import ScreenObs
from computergym.miniwob.base_env import MiniWoBEnvironment
from computergym.miniwob.miniwob_interface.action import (
  MiniWoBType,
  MiniWoBElementClickXpath,
  MiniWoBElementClickOption,
  MiniWoBMoveXpath,
)
from .utils import agent_utils, JSON_models
from .utils.JSON_API import get_llm_wrapper
from .utils.agent_utils import print_with_color

FLAGS = flags.FLAGS


def normalize_html_for_comparison(html_content: str) -> str:
  """
  Normalize HTML/XML by removing dynamic attributes that don't affect visual appearance.
  Works for both web (MiniWoB) and Android (XML hierarchy).
  """
  if not html_content:
    return ""
  
  # Remove common dynamic attributes
  dynamic_attrs = [
    r'\s+data-[a-zA-Z0-9_-]+="[^"]*"',  # data-* (web frameworks like MiniWoB)
    r'\s+bounds="[^"]*"',                # bounds (Android coordinates)
    r'\s+checksum="[^"]*"',              # checksum (Android)
    r'\s+timestamp="[^"]*"',             # timestamp attributes
  ]
  
  for pattern in dynamic_attrs:
    html_content = re.sub(pattern, '', html_content)
  
  # Normalize whitespace: collapse multiple spaces/newlines into single space
  html_content = re.sub(r'\s+', ' ', html_content)
  
  return html_content.strip()


class ActionExecutionError(Exception):
  """Custom exception for errors during action execution."""
  
  def __init__(self, action_code, error_code, original_exception):
    self.action_code = action_code
    self.error_code = error_code  # self.kwargs
    self.original_exception = original_exception
    message = f"Error Action: {action_code}Related kwargs: {error_code}\nError: {original_exception}"
    super().__init__(message)


class ReachedMaxStepsError(Exception):
  """Raised when the number of executed actions exceeds the maximum allowed steps."""
  
  def __init__(self, action_code, error_code, message=None):
    self.error_code = error_code if error_code else action_code
    if message is None:
      message = "Maximum number of action steps reached."
    super().__init__(message)


from selenium.webdriver.common.by import By


class EnvOperation:  # create in agent.reset()
  def __init__(self, raw_env: MiniWoBEnvironment, task_type: str):
    self.raw_env = raw_env
    self.task_type = task_type
    self.llm = get_llm_wrapper(FLAGS.default_llm)
    self.record_token = JSON_models.RecordToken()
  
  def reset(self, task_idx, save_path, to_init_task: bool = True, max_action_step: int = 20):
    self.task_idx = task_idx
    if to_init_task:
      raw_obs = self.raw_env.reset(seeds=[task_idx], record_screenshots=True)
      self.raw_obs = raw_obs[0]
      self.task = self.raw_obs.utterance
      
      self.before_obs = None
      self.executed_actions = []  # Record the actually executed code.
      self.executed_element_xpath = []  # Store the xpath of the elements involved in each executed code
      self.related_elements = []  # Store the elements involved in each executed code.
      self.env_op_traj = []  # (temporary) store EnvExecStepInfo
      self.action_history = []  # store action_info_str, used in the prompt
      self.agent_done = False
      self.done = False
      self.reward = 0
    
    self.answer_return = None
    self.save_path = save_path
    os.makedirs(self.save_path, exist_ok=True)
    self.max_action_step = max_action_step
    self.cur_obs = self.get_obs(file_prefix=f'step_{len(self.executed_actions)}')
    
    self.record_token = JSON_models.RecordToken(file_path=FLAGS.log_folder_exp, task_type=self.task_type,
                                                task_num=f'Task {task_idx}', stage='Env OP')
  
  def execute_code(self, code: str, vars: dict, save_path: str, flag_exec_rpa: bool = False):
    print_with_color('============================================', 'cyan')
    print_with_color("Executing code...", 'cyan')
    
    self.save_path = save_path
    self.kwargs = None
    self.previous_kwargs = None
    self.previous_index = None
    self.before_obs = deepcopy(self.cur_obs)
    
    function_code = code
    
    local_vars = {'env_op': self, 'time': time, 're': re, 'json': json, 'error_message': None}
    if vars.get("function_call"):
      function_code += f'\n\n{vars.get("function_call")}'  # function_call must go after function def to avoid errors
    else:
      local_vars.update(vars)
    print_with_color(function_code, 'cyan')
    print_with_color('--------------------------------', 'cyan')
    agent_utils.write_to_file(file_path=save_path, file_name='rpa_code.py', content=function_code)
    
    error_statement = None
    error_message = None
    executed_code = ''
    
    try:
      exec(function_code, local_vars, local_vars)  # Execute the entire code block
      print_with_color('The code was not interrupted.', 'green')
      executed_code = function_code
    except Exception as e:
      print_with_color('The code was interrupted.', 'red')
      error_message, error_statement, executed_code = self._handle_execution_error(e, function_code, vars)
    finally:
      # update exec.done and exec_feedback according to answer_return
      if self.answer_return is None:
        exec_feedback = error_message
      else:
        print('Agent indicates task is done.')
        self.done = True
        if self.answer_return == 'N/A':
          exec_feedback = "The stop action with N/A (indicates the task unfeasible) is executed. "
        else:
          exec_feedback = "The stop action with complete is executed. "
      
      if self.done:
        exec_feedback = f"{exec_feedback} The task is done."
        # Since "tic-tac-toe" doesn't display the opponent's three-in-a-row, additional feedback is provided to remind the agent of the reason for its failure.
        if self.reward <= 0 and self.task_type == "tic-tac-toe":
          exec_feedback += " You failed to prevent your opponent from forming a line."
      
      # Consider merging exec_feedback and error_message
      exec_result = JSON_models.ExecResult(
        executed_code=executed_code,
        error_statement=str(error_statement),
        error_message=error_message,
        exec_feedback=exec_feedback,
        answer_return=self.answer_return,
        agent_done=self.agent_done,
        done=self.done,
      )
      
      print_with_color('\n****self.executed_element_xpath:', 'cyan')
      print_with_color(self.executed_element_xpath, 'cyan')
      print_with_color('\n****self.executed_actions:', 'cyan')
      print_with_color(self.executed_actions, 'cyan')
      print_with_color('\n****self.related_elements:', 'cyan')
      print_with_color(self.related_elements, 'cyan')
      
      traj = self.env_op_traj
      if len(traj) == 0:
        traj = [JSON_models.EnvExecStepInfo()]
    return traj, exec_result
  
  def _handle_execution_error(self, e, function_code, vars):
    if isinstance(e, (ActionExecutionError, ReachedMaxStepsError)):
      executed_code, error_line = extract_code_before_error(function_code, e.error_code)
      error_statement = e.error_code
    else:
      tb = traceback.extract_tb(e.__traceback__)
      last_trace = tb[-1]
      error_line = last_trace.lineno
      code_lines = function_code.splitlines()
      if 1 <= error_line <= len(code_lines):
        error_statement = code_lines[error_line - 1]
      else:
        error_statement = "Unable to retrieve the error code"
        # error_statement = "(Unable to retrieve the error code, index out of range)"
      executed_code = code_lines[:error_line]
      if error_line < len(code_lines):
        executed_code.append(vars.get("function_call"))
    
    error_message = f"An error occurred! Error line: {error_line}, Error code: {error_statement}, " \
                    f"Error type: {e.__class__.__name__}, Error message: {e}"
    print_with_color(f"❌ {error_message}", 'cyan')
    
    executed_code = '\n'.join(executed_code) if isinstance(executed_code, list) else executed_code
    return error_message, error_statement, executed_code
  
  def execute_action(self, action_script, action_code):
    print_with_color(action_script, 'cyan')
    
    self.before_obs = deepcopy(self.cur_obs)
    action_feedback = ''
    
    try:
      assert not self.done, "Task has been stopped."
      action_feedback = 'Task has been stopped.'
      if 'ask_mllm' in action_code:
        action_feedback = "Action has been performed.\n"
      elif action_script is not None:
        raw_obs, reward, done, info = self.raw_env.step([action_script])
        self.raw_obs, self.reward, self.done = raw_obs[0], reward[0], all(done)
        if self.reward > 0:
          self.agent_done = True  # MiniWoB reward > 0 means task done
        if 'failed' in str(info['run_info']).lower():
          action_feedback = "The action failed to execute.\n" + f"{info['run_info']}"
        else:
          action_feedback = "Action has been performed.\n" + f"{info['run_info']}"
        print('action_script')
        print(action_script)
        print(action_feedback)
    except Exception as e:
      raise ActionExecutionError(action_code, self.kwargs, e)  # Caught by execute_code()
    finally:
      time.sleep(0.1)  # Prevent tap from becoming double-tap when execution is too fast
      # Ensure action, xpath, element for each page interaction are stored together; no attribute missing
      before_html = self.before_obs.ui_content
      self.executed_actions.append(action_code)
      if hasattr(action_script, 'xpath') and action_script.xpath is not None:
        tartget_xpath = action_script.xpath
        self.executed_element_xpath.append(tartget_xpath)
        # Get HTML line(s) for target element
        tree = html.fromstring(before_html)
        query_elems = tree.xpath(tartget_xpath)
        if query_elems:
          query_elem = query_elems[0]
          elem_line = etree.tostring(query_elem, pretty_print=False, encoding="unicode")
        else:
          elem_line = f"No matched element: {tartget_xpath}"
        self.related_elements.append(elem_line)
      else:
        self.executed_element_xpath.append(None)
        self.related_elements.append("")
      
      self.cur_obs = self.get_obs(file_prefix=f'step_{len(self.executed_actions)}')
      after_html = self.cur_obs.ui_content
      # Normalize HTML before comparison to ignore dynamic attributes
      is_screen_changed = (normalize_html_for_comparison(before_html) != 
                          normalize_html_for_comparison(after_html))
      
      # update self.env_op_traj
      step_info = JSON_models.EnvExecStepInfo(
        before_ui_content=before_html,
        before_screenshot_path=self.before_obs.screenshot_path,
        executed_action=self.executed_actions[-1],
        related_target=self.executed_element_xpath[-1],
        related_elements=self.related_elements[-1],
        action_feedback=action_feedback,
        after_ui_content=after_html,
        after_screenshot_path=self.cur_obs.screenshot_path,
        is_screen_changed=is_screen_changed,
      )
      self.env_op_traj.append(step_info)
      
      # update self.action_history
      action_info_str = f'Step-{len(self.action_history) + 1}:\nExecuted action: {step_info.executed_action}\n'
      if step_info.related_elements:
        action_info_str += f' Related elements: {step_info.related_elements} '
      action_info_str += step_info.action_feedback
      if not step_info.is_screen_changed:
        action_info_str += '\nNo screen changes.'
      self.action_history.append(action_info_str)
      
      if len(self.executed_actions) >= self.max_action_step:
        self.done = True
        print_with_color(f'Reached max action steps: {len(self.executed_actions)}/{self.max_action_step}', 'red')
        raise ReachedMaxStepsError(action_code, self.kwargs)  # Caught by execute_code()
  
  ## -----start: action space
  # Execution relies on `self.execute_action(action_dict)`
  def click_xpath(self, xpath: str):
    action_script = MiniWoBElementClickXpath(xpath)
    action_code = f"env_op.click_xpath(\"{xpath}\")"
    self.execute_action(action_script, action_code)
  
  def click_option(self, xpath: str):
    action_script = MiniWoBElementClickOption(xpath)
    action_code = f"env_op.click_option(\"{xpath}\")"
    self.execute_action(action_script, action_code)
  
  def type(self, xpath: str, characters: str, clear_existing=True, press_enter_after=True):
    action_script = MiniWoBType(xpath=xpath, text=characters, press_key=False, clear_existing=clear_existing,
                                press_enter_after=press_enter_after)
    action_code = f"env_op.type(\"{xpath}\", \"{characters}\", clear_existing={clear_existing}, press_enter_after={press_enter_after})"
    self.execute_action(action_script, action_code)
  
  def press_key(self, key: str):
    if key == 'enter':
      miniwob_key = '\n'
    elif key == 'space':
      miniwob_key = ' '
    elif key == 'arrow_left':
      miniwob_key = Keys.LEFT
    elif key == 'arrow_right':
      miniwob_key = Keys.RIGHT
    elif key == 'arrow_up':
      miniwob_key = Keys.UP
    elif key == 'arrow_down':
      miniwob_key = Keys.DOWN
    elif key == 'backspace':
      miniwob_key = Keys.BACKSPACE
    else:
      raise ValueError("Unknowned key pressed.")
    action_script = MiniWoBType(xpath=None, text=miniwob_key, press_key=True)
    action_code = f"env_op.press_key(\"{key}\")"
    self.execute_action(action_script, action_code)
  
  def move_mouse_on(self, xpath: str):
    action_script = MiniWoBMoveXpath(xpath)
    action_code = f"env_op.move_mouse_on(\"{xpath}\")"
    self.execute_action(action_script, action_code)
  
  # Does NOT rely on `self.execute_action(action_dict)`
  def stop(self, goal_status: str):
    action_code = f"env_op.stop(goal_status=\"{goal_status}\")"
    self.answer_return = goal_status
    self.agent_done = True
    # If task is already done (e.g., MiniWoB auto-detected completion), skip execute_action to avoid duplicate stop error
    if not self.done:
      self.execute_action(None, action_code)
    else:
      print_with_color('Task already completed, skipping stop action execution.', 'yellow')
  
  ## -----end: action space
  
  def ask_mllm(self, question):
    prompt = ("Carefully examine the page information and answer the question.\n"
              "1. thought: You must output your brief thought within <thought>...</thought>.\n"
              "2. answer: You must output the answer to the question within <answer>...</answer>.\n"
              "In your answer, just directly return the required information and do not include any other words.\n"
              "IMPORTANT: If the answer is a JSON object, you MUST use double quotes (\") for both keys and string values, NOT single quotes (').\n"
              "[Example Output 1 - Simple Answer]\n"
              "<thought>I need to find the number of comments that have received more downvotes than upvotes for the user who made the latest post on the current page.</thought>\n"
              "<answer>10</answer>\n"
              "[Example Output 2 - JSON Answer]\n"
              "<thought>I need to find synonyms from the options list.</thought>\n"
              '<answer>{"matches": ["brave", "courageous"]}</answer>\n'
              f"\nExecution History:\n" + '\n'.join(self.action_history) +
              f"html:\n{self.cur_obs.ui_content}\n"
              f"Question: {question}\n")
    agent_utils.write_to_file(file_path=self.save_path,
                              file_name=f'step-{len(self.executed_actions)}_ask_mllm_prompt.txt', content=prompt)
    output, raw_response = self.llm.predict_mm(prompt, [self.cur_obs.screenshot], output_format=JSON_models.StringOutput)
    agent_utils.write_to_file(file_path=self.save_path,
                              file_name=f'step-{len(self.executed_actions)}_ask_mllm_output.txt', content=output.str)
    print_with_color(f'ask_mllm response: {output.str}', 'cyan')
    
    # Extract the answer from the output
    groups = re.search(r"<thought>(.*?)</thought>.*?<answer>(.*?)</answer>", output.str, re.DOTALL)
    if groups:
      output_str = groups.group(2).strip()
    else:
      output_str = output.str.strip()
    
    # Try to convert Python dict format (single quotes) to JSON format (double quotes)
    # This handles cases where LLM returns {'key': 'value'} instead of {"key": "value"}
    if output_str.startswith('{') and output_str.endswith('}'):
      try:
        # Try to parse as Python literal first
        parsed = ast.literal_eval(output_str)
        # If successful, convert to proper JSON string
        output_str = json.dumps(parsed, ensure_ascii=False)
        print_with_color(f'Converted Python dict to JSON: {output_str}', 'yellow')
      except (ValueError, SyntaxError):
        # If it fails, keep the original string (might already be valid JSON)
        pass
    
    cost_tokens = raw_response.usage
    self.record_token.step = str(len(self.executed_actions))
    self.record_token.agent = 'ask mllm'
    self.record_token.step_tokens = cost_tokens
    agent_utils.record_cost_tokens(self.record_token)
    
    self.execute_action(action_script=f'{{"action_type": "ask_mllm", "question": "{question}"}}',
                        action_code=f"env_op.ask_mllm(question=\"{question}\")")
    return output_str
  
  def get_ui_content(self):
    return self.cur_obs.ui_content
  
  def process_ob_html(self, raw_obs):
    """
    Not using selenium parsing; ui_elements unused for now, left as-is
    Args:
      raw_obs:

    Returns:

    """
    # If the step does not change the html, there is no html in the observation
    if not raw_obs:
      return None, None
    ui_content = raw_obs.html_body
    if raw_obs.html_extra != '':
      ui_content += raw_obs.html_extra
    
    tree = html.fromstring(f"<html><body>{raw_obs.html_body}</body></html>")
    elements = []
    for elem in tree.iter():
      element_info = {
        "index": len(elements),
        "xpath": tree.getroottree().getpath(elem),
        "tag": elem.tag,
        "text": elem.text.strip() if elem.text else None
      }
      element_info.update(dict(elem.attrib))
      elements.append(element_info)
    return ui_content, elements
  
  def check_xpath(self, xpath: str) -> bool:
    """Helper function to check if an XPath finds any elements."""
    try:
      driver = self.raw_env.instances[0].driver
      elements = driver.find_elements(By.XPATH, xpath)
      return bool(elements)
    except Exception:
      return False
  
  def is_element_visible_and_interactable(self, element) -> bool:
    """
    Check if a WebElement is visible and interactable.
    Returns False if element is hidden by display:none, visibility:hidden, aria-hidden=true, or parent is hidden.
    """
    try:
      # Method 1: Selenium's is_displayed() - checks CSS display, visibility, opacity, size, etc.
      if not element.is_displayed():
        return False
      
      # Method 2: Check aria-hidden attribute
      aria_hidden = element.get_attribute('aria-hidden')
      if aria_hidden == 'true':
        return False
      
      # Method 3: Additional style checks for edge cases
      style = element.get_attribute('style')
      if style:
        if 'display: none' in style or 'display:none' in style:
          return False
        if 'visibility: hidden' in style or 'visibility:hidden' in style:
          return False
      
      return True
    except Exception as e:
      # If we can't determine visibility, assume it's not interactable
      print(f'Warning: Could not check element visibility: {e}')
      return False
  
  def _check_element_in_scope(self, element_xpath: str, original_xpath: str) -> bool:
    """
    Check if an element is within the expected scope defined by the original xpath.
    
    Args:
      element_xpath: The actual xpath of the found element
      original_xpath: The original search xpath that may contain scope restrictions
    
    Returns:
      True if element is within scope or no scope restriction exists, False otherwise
    """
    # Check if original xpath has scope restriction (contains parent_path]//child)
    if ']//' not in original_xpath:
      return True  # No scope restriction
    
    driver = self.raw_env.instances[0].driver
    scope_prefix = original_xpath.split(']//')[0] + ']'
    print(f'🔒 Checking if element is within scope: {scope_prefix}')
    
    try:
      scope_elements = driver.find_elements(By.XPATH, scope_prefix)
      if scope_elements:
        scope_element = scope_elements[0]
        scope_element_xpath = self.dom_to_xpath(scope_element)
        
        if not element_xpath.startswith(scope_element_xpath):
          print_with_color(f'⚠️ Element is OUTSIDE the expected scope.', 'yellow')
          print_with_color(f'   Expected scope: {scope_element_xpath}', 'yellow')
          print_with_color(f'   Found element: {element_xpath}', 'yellow')
          return False
        else:
          print(f'✅ Element is within scope.')
          return True
      else:
        print_with_color(f'⚠️ Could not verify scope (scope element not found). Accepting element.', 'yellow')
        return True
    except Exception as e:
      print_with_color(f'⚠️ Error checking scope: {e}. Accepting element.', 'yellow')
      return True
  
  def dom_to_xpath(self, element) -> str:
    """
    Convert a Selenium WebElement to its XPath using JavaScript.
    """
    driver = self.raw_env.instances[0].driver
    try:
      # JavaScript to get XPath of element
      return driver.execute_script("""
      function absoluteXPath(element) {
        if (element === document.body)
          return '/html/body';
        if (!element || element.nodeType !== 1)
          return '';
        var ix= 0;
        var siblings= element.parentNode ? element.parentNode.childNodes : [];
        for (var i= 0; i < siblings.length; i++) {
          var sibling= siblings[i];
          if (sibling === element)
            return absoluteXPath(element.parentNode) + '/' + element.tagName.toLowerCase() + '[' + (ix+1) + ']';
          if (sibling.nodeType === 1 && sibling.tagName === element.tagName)
            ix++;
        }
        return '';
      }
      return absoluteXPath(arguments[0]);
      """, element)
    except Exception:
      return None
  
  def find_element(self, xpath: str, target_description: str = "") -> Optional[str]:
    """
    Locate an element using XPath. If the xpath matches 0 or multiple elements, use MLLM to find the best match.
    
    Args:
      xpath: XPath expression to locate the element
      target_description: Optional description of the target element for MLLM matching
    
    Returns:
      str: The xpath of the best-matching element, or None if no match found
    """
    print(f'\nfinding element with xpath: {xpath}')
    print(f'target_description: {target_description}')
    
    driver = self.raw_env.instances[0].driver
    
    # Filter out elements in the query area (MiniWoB specific)
    query_elements = driver.find_elements(By.XPATH, '//*[@id="query"]')
    query_element = query_elements[0] if query_elements else None
    query_xpath = None
    if query_element:
      query_xpath = self.dom_to_xpath(query_element)
    
    def _not_in_query(el):
      if not query_xpath:
        return True
      el_xpath = self.dom_to_xpath(el)
      return not (el_xpath and el_xpath.startswith(query_xpath))
    
    # Try to find elements matching the xpath
    try:
      matched_elements = driver.find_elements(By.XPATH, xpath)
      # Filter out elements in query area
      matched_elements = [el for el in matched_elements if _not_in_query(el)]
      
      print(f'Found {len(matched_elements)} matching elements (excluding query area)')
      
      # Filter out non-visible/non-interactable elements
      visible_elements = [el for el in matched_elements if self.is_element_visible_and_interactable(el)]
      print(f'After visibility filter: {len(visible_elements)} visible/interactable elements')
      
      matched_elements = visible_elements
    except Exception as e:
      matched_elements = []
      print(f'Error finding element: {e}')
      
    # If exactly one element found, verify it's within scope (if applicable) before returning
    if len(matched_elements) == 1:
      result_xpath = self.dom_to_xpath(matched_elements[0])
      if result_xpath and self.check_xpath(result_xpath):
        print(f'Unique element found: {result_xpath}')
        # Check scope if the original xpath had scope restriction
        if self._check_element_in_scope(result_xpath, xpath):
          return result_xpath
        else:
          # Element is outside scope, treat as not found
          matched_elements = []
    
    # Use MLLM to select from candidates or search globally
    candidate_xpaths = [self.dom_to_xpath(el) for el in matched_elements if self.dom_to_xpath(el)]
    
    target_info = {
      'xpath': xpath,
      'target_description': target_description
    }
    return self.match_element_with_mllm(target_element_info=target_info, candidate_elements=candidate_xpaths)
  
  def match_element_with_mllm(self, target_element_info: dict, candidate_elements: Optional[list] = None) -> str:
    """
      Use MLLM to select the most appropriate UI element from candidate elements
      based on the provided target element description.

      Args:
        target_element_info: dict with 'xpath' and 'target_description'
        candidate_elements: Optional list of candidate xpath strings

      Returns:
        str: The xpath of the best-matching element, or None if no confident match is found.
    """
    retry_guideline = ''
    if str(target_element_info) == str(self.previous_kwargs):
      retry_guideline = f'Previous search returned xpath = {self.previous_index}. Element not found or xpath invalid. Please recheck and search again carefully. Output a differrent xpath.'
    self.previous_kwargs = target_element_info
    
    # agent_utils.write_to_file(self.save_path, 'candidate_elements.txt', candidate_elements)
    candidate_elements_str = f"Candidate elements (possible matches):\n{candidate_elements}\n\n" if candidate_elements else ''
    
    prompt = (
      "You are a UI automation assistant. Your task is to match the user-provided description with the most appropriate UI element that is VISIBLE and INTERACTABLE.\n"
      "Text-type attributes must match the element content exactly (after normalize-space).\n\n"
      # f"{retry_guideline}"
      "[Critical Requirements]\n"
      "1. The element MUST be VISIBLE (NOT hidden by display:none, visibility:hidden, or aria-hidden=true)\n"
      "2. The returned xpath MUST point to an INTERACTABLE element (clickable, typeable, etc.)\n"
      "3. If target content is in a non-interactable element (e.g., <span>, <div>), find the nearest interactable ancestor (<a>, <button>, or elements with @onclick, @role='button')\n\n"
      "[Output Format]\n"
      "Select the best-matching element and return its xpath with a confidence score (1–10).\n"
      "Use the image to decide when elements are similar — trust your judgment.\n"
      "Only results with a score of 8 or higher will be accepted. If no suitable match exists, return xpath: None.\n\n"
      "When writing XPath expressions, always use //*[name()='tag' and ...] instead of //tag[...] to avoid namespace issues.\n"
      "Use ancestor/descendant relationships when needed (e.g., //*[name()='a' and .//span[text()='Click']]).\n\n"
      "Example Output 1:\n"
      "{\n"
      '  "target": "//*[@data-wob_ref=\'5\']",\n'
      '  "confidence_score": 10\n'
      "}\n\n"
      "Example Output 2:\n"
      "{\n"
      '  "target": "//*[name()=\'text\' and text()=\'Save\']",\n'
      '  "confidence_score": 6\n'
      "}\n\n"
      "Example Output 3:\n"
      "{\n"
      '  "target": "//*[name()=\'button\' and text()=\'ok\']",\n'
      '  "confidence_score": 10\n'
      "}\n\n"
      "Example Output 4 - not found:\n"
      "{\n"
      '  "target": None,\n'
      '  "confidence_score": 10\n'
      "}\n\n"
      "Only return a JSON object with fields: 'target' and 'confidence_score'.\n\n"
      "[Input]\n"
      f"Target element info:\n{target_element_info}\n\n"
      "The input xpath may not work directly (e.g., matches 0 or multiple elements, or points to non-interactable element). "
      "Analyze the HTML structure carefully and find the correct interactable element.\n"
      f"html:\n{self.cur_obs.ui_content}\n"
      f"{candidate_elements_str}"
    )
    agent_utils.write_to_file(self.save_path, f'step-{len(self.executed_actions)}_match_element_with_mllm_prompt.txt', prompt)
    llm = get_llm_wrapper(model_name=FLAGS.grounding_llm)
    output, raw_response = llm.predict_mm(prompt, [self.cur_obs.screenshot], output_format=JSON_models.MllmMatchTarget)
    agent_utils.write_to_file(self.save_path, f'step-{len(self.executed_actions)}_match_element_with_mllm_output.txt', output)

    cost_tokens = raw_response.usage
    self.record_token.step = str(len(self.executed_actions))
    self.record_token.agent = 'match_element_with_mllm'
    self.record_token.step_tokens = cost_tokens
    self.record_token.llm = FLAGS.grounding_llm
    agent_utils.record_cost_tokens(self.record_token)
    
    xpath = output.target if output.target and output.confidence_score >= 8 else None
    score = output.confidence_score
    
    print_with_color(f'mllm_match response: xpath={xpath}, confidence score={score}', 'cyan')
    
    # Verify that the returned xpath points to a visible and interactable element
    # Also verify scope if the original xpath had scope restrictions
    if xpath:
      driver = self.raw_env.instances[0].driver
      try:
        elements = driver.find_elements(By.XPATH, xpath)
        if elements:
          element = elements[0]
          if not self.is_element_visible_and_interactable(element):
            print_with_color(f'⚠️ MLLM returned element is NOT visible/interactable. Rejecting xpath.', 'yellow')
            xpath = None
          else:
            # Check scope restriction
            original_xpath = target_element_info.get('xpath', '')
            element_xpath = self.dom_to_xpath(element)
            if self._check_element_in_scope(element_xpath, original_xpath):
              print_with_color(f'✅ MLLM returned element is visible/interactable and within scope.', 'green')
            else:
              xpath = None
        else:
          print_with_color(f'⚠️ MLLM returned xpath matches no elements. Rejecting.', 'yellow')
          xpath = None
      except Exception as e:
        print_with_color(f'⚠️ Error validating MLLM xpath: {e}', 'yellow')
        xpath = None
    
    # if score < 10: xpath = None
    print('actual xpath:', xpath)
    self.previous_index = xpath
    
    return xpath
  
  def get_obs(self, log_task_path: str = None, file_prefix: str = '', save: bool = True) -> ScreenObs:
    """
    get_obs() is called at the start of the env and after each action.
    """
    time.sleep(2)  # Wait for the screen to stabilize
    if log_task_path is None:
      log_task_path = self.save_path
    
    self.ui_content, self.ui_elements = self.process_ob_html(self.raw_obs)
    # If the step does not change the html, there is no html in the observation
    if self.ui_content is None:
      self.ui_content = self.before_obs.ui_content
      self.ui_elements = self.before_obs.ui_elements
      screenshot = np.array(self.before_obs.screenshot)
    else:
      screenshot = np.array(self.raw_obs.screenshot)
    
    obs = ScreenObs(
      ui_content=self.ui_content,
      ui_elements=self.ui_elements,
      screenshot=screenshot,
      screenshot_path = ''
    )
    
    if save:
      agent_utils.write_to_file(file_path=log_task_path, file_name=file_prefix + '_html.txt', content=self.ui_content)
      agent_utils.write_to_file(file_path=log_task_path, file_name=file_prefix + '_dom_elements.txt',
                                content=self.ui_elements)
      obs.screenshot_path = agent_utils.store_image(screenshot, file_prefix + '_screenshot.png', file_path=log_task_path)
    return obs


# **Extract code up to the point where the error occurred**
def extract_code_before_error(code_str: str, error_obj) -> tuple[str, int]:
  """
  Extracts code from the beginning of code_str up to and including the part that matches error_obj.
  - If error_obj is a dict: finds the assignment block that matches the dict.
  - If error_obj is a string and a function call: matches function name via AST.
  - Fallback: searches raw text for partial match.

  Returns (executed_code, error_line_num)
  """
  lines = code_str.splitlines()
  
  # ---------- Handle string-form error (e.g., "env_op.swipe(...)") ----------
  if isinstance(error_obj, str):
    try:
      tree = ast.parse(code_str)
      error_expr = ast.parse(error_obj, mode='eval')
      if isinstance(error_expr.body, ast.Call):
        # Extract function name from error string
        error_func = error_expr.body.func
        if isinstance(error_func, ast.Name):
          error_func_name = error_func.id
        elif isinstance(error_func, ast.Attribute):
          error_func_name = error_func.attr
        else:
          error_func_name = None
        
        if error_func_name:
          for node in ast.walk(tree):
            if isinstance(node, ast.Call):
              func = node.func
              node_func_name = func.id if isinstance(func, ast.Name) else (
                func.attr if isinstance(func, ast.Attribute) else None
              )
              if node_func_name == error_func_name and hasattr(node, 'lineno'):
                line_num = node.lineno
                return '\n'.join(lines[:line_num]), line_num
    except Exception:
      pass
    
    # Fallback: raw string search
    for i, line in enumerate(lines):
      if error_obj.strip() in line:
        return '\n'.join(lines[:i + 1]), i + 1
    
    return "Error string not found in code.", -1
  
  # ---------- Handle dict-form error (e.g., kwargs = {...}) ----------
  try:
    tree = ast.parse(code_str)
  except SyntaxError:
    return "Invalid code syntax.", -1
  
  best_node = None
  best_score = 0
  
  def dict_match_score(dict_node: ast.Dict) -> int:
    score = 0
    for k_ast, v_ast in zip(dict_node.keys, dict_node.values):
      try:
        key = ast.literal_eval(k_ast)
        val = ast.literal_eval(v_ast)
        if key in error_obj and error_obj[key] == val:
          score += 1
        elif key in error_obj and isinstance(error_obj[key], list) and isinstance(val, list):
          if set(error_obj[key]).issubset(set(val)):
            score += 1
      except Exception:
        continue
    return score
  
  for node in ast.walk(tree):
    if isinstance(node, ast.Assign) and isinstance(node.value, ast.Dict):
      score = dict_match_score(node.value)
      if score > best_score:
        best_score = score
        best_node = node
  
  if best_node:
    start = best_node.lineno
    end = getattr(best_node, 'end_lineno', start)
    return '\n'.join(lines[:end]), start
  
  return "Relevant code block not found.", -1
