import ast
import os
import re
import sys
import json
import time
import tempfile
import traceback
import subprocess
from copy import deepcopy
from typing import Optional

import numpy as np
from absl import flags
from lxml import etree, html

from UIAgents.Agent_RPA.utils.JSON_models import ScreenObs
from .utils import agent_utils, JSON_models
from .utils.llm_client import get_llm_wrapper
from .utils.agent_utils import print_with_color

from browser_env import ScriptBrowserEnv, ActionTypes, action2str, create_stop_action
from browser_env.helper_functions import RenderHelper
from browser_env.actions import create_id_based_action, create_none_action, ActionParsingError
from evaluation_harness import evaluator_router

AWS_HOSTNAME = os.environ["AWS_HOSTNAME"]
URL_to_SITE = {
    f"http://{AWS_HOSTNAME}:7770": "SHOPPING",
    f"http://{AWS_HOSTNAME}:7780/admin": "SHOPPING_ADMIN",
    f"http://{AWS_HOSTNAME}:9999": "REDDIT", 
    f"http://{AWS_HOSTNAME}:8023": "GITLAB",
    f"http://{AWS_HOSTNAME}:3000": "MAP",
    f"http://{AWS_HOSTNAME}:8888": "WIKIPEDIA",
    f"http://{AWS_HOSTNAME}:4399": "HOMEPAGE"
}

FLAGS = flags.FLAGS


class ActionExecutionError(Exception):
  """Custom exception for errors during action execution."""
  
  def __init__(self, action_code, error_code, original_exception):
    self.action_code = action_code
    self.error_code = error_code  # self.kwargs
    self.original_exception = original_exception
    message = f"Error Action: {action_code}Related kwargs: {error_code}\nError: {original_exception}"
    super().__init__(message)


class ReachedMaxStepsError(Exception):
  """Raised when the number of executed actions exceeds the maximum allowed steps."""
  
  def __init__(self, action_code, error_code, message=None):
    self.error_code = error_code if error_code else action_code
    if message is None:
      message = "Maximum number of action steps reached."
    super().__init__(message)

def get_cur_site(url):
  for url_root, site in URL_to_SITE.items():
    if url_root in url:
      return site

EXTRA_TXT = "The information in this tab has been changed. This tab contains invalid data. Please resolve this before saving."
class EnvOperation:  # create in agent.reset()
  def __init__(self, raw_env: ScriptBrowserEnv, task_type: str):
    self.raw_env = raw_env
    self.task_type = task_type
    self.default_llm = get_llm_wrapper(FLAGS.default_llm, enable_logging=FLAGS.enable_llm_logging)
    if FLAGS.ask_mllm_llm != FLAGS.default_llm:
      self.ask_mllm_llm = get_llm_wrapper(FLAGS.ask_mllm_llm, enable_logging=FLAGS.enable_llm_logging)
    else:
      self.ask_mllm_llm = self.default_llm
    self.llm = self.default_llm  # For backward compatibility
    self.record_token = JSON_models.RecordToken()
    self.render_helper = None
  
  def reset(self, task_idx, save_path, to_init_task: bool = True, max_action_step: int = 20):
    self.task_idx = task_idx
    if to_init_task:
      # get the task (intent)
      self.config_file = f"config_files/{task_idx}.json"
      with open(self.config_file) as f:
        _c = json.load(f)
        task_id = _c["task_id"]
        sites = _c["sites"]
        task = _c["intent"]
        task_type = _c["intent_template"]

      ### Initialize the environment
      # automatically login
      if _c["storage_state"]:
        from browser_env.auto_login import get_site_comb_from_filepath
        cookie_file_name = os.path.basename(_c["storage_state"])
        comb = get_site_comb_from_filepath(cookie_file_name)
        temp_dir = tempfile.mkdtemp()
        # subprocess to renew the cookie
        subprocess.run(["python", "../webarena/browser_env/auto_login.py", "--auth_folder", temp_dir, "--site_list", *comb])
        _c["storage_state"] = f"{temp_dir}/{cookie_file_name}"
        assert os.path.exists(_c["storage_state"])
        # update the config file
        self.config_file = f"{temp_dir}/{os.path.basename(self.config_file)}"
        with open(self.config_file, "w") as f:
            json.dump(_c, f)

      print(f"[Config file]: {self.config_file}\n[Task_id]: {task_id}")
      print(f"[Sites]: {','.join(sites)}\n[Intent]: {task} ({task_type})")

      if self.render_helper is not None:
        self.render_helper.close()
      self.render_helper = RenderHelper(self.config_file, FLAGS.log_folder_exp, "id_accessibility_tree")
      self.evaluator = evaluator_router(self.config_file)
      self.answer_action = create_stop_action("")

      self.sites = sites
      self.task = task
      self.task_type = task_type

      # Reset the environment and get the initial observation
      obs, info = self.raw_env.reset(options={"config_file": self.config_file})
      obs, _, _, _, info = self.raw_env.step(create_none_action()) # ensure the page is refreshed
      
      self.state_info = {"observation": obs, "info": info}
      self.before_obs = None
      self.executed_actions = []  # Record the actually executed code.
      self.related_elements = []  # Store the elements involved in each executed code.
      self.executed_element_index = []  # Store the index of the elements involved in each executed code
      self.env_op_traj = []  # (temporary) store EnvExecStepInfo
      self.action_history = []  # store action_info_str, used in the prompt
      self.agent_done = False
      self.done = False
      self.reward = 0
    
    self.answer_return = None
    self.save_path = save_path
    os.makedirs(self.save_path, exist_ok=True)
    self.max_action_step = max_action_step
    self.cur_obs = self.get_obs(file_prefix=f'step_{len(self.executed_actions)}')
    
    self.record_token = JSON_models.RecordToken(file_path=FLAGS.log_folder_exp, task_type=self.task_type,
                                                task_num=f'Task {task_idx}', stage='Env OP')

  def execute_code(self, code: str, vars: dict, save_path: str, flag_exec_rpa: bool = False):
    print_with_color('============================================', 'cyan')
    print_with_color("Executing code...", 'cyan')
    
    self.save_path = save_path
    self.kwargs = None
    self.previous_kwargs = None
    self.previous_index = None
    self.before_obs = deepcopy(self.cur_obs)
    self.error_in_finding = None
    
    function_code = code # re.sub(r"^```(?:python)?\s*|\s*```$", "", code, flags=re.MULTILINE)
    
    # Convert lowercase true/false to Python boolean values (True/False)
    # Use word boundaries to avoid replacing true/false inside strings or variable names
    function_code = re.sub(r'\btrue\b', 'True', function_code)
    function_code = re.sub(r'\bfalse\b', 'False', function_code)
    
    local_vars = {'env_op': self, 'time': time, 're': re, 'json': json, "error_message": None}
    if vars.get("function_call"):
      function_code += f'\n\n{vars.get("function_call")}'  # function_call must go after function def to avoid errors
    else:
      local_vars.update(vars)
    print_with_color(function_code, 'cyan')
    print_with_color('--------------------------------', 'cyan')
    agent_utils.write_to_file(file_path=save_path, file_name='rpa_code.py', content=function_code)
    
    error_statement = None
    error_message = None
    executed_code = ''
    
    try:
      exec(function_code, local_vars, local_vars)  # Execute the entire code block
      print_with_color('The code was not interrupted.', 'green')
      executed_code = function_code
    except Exception as e:
      print_with_color('The code was interrupted.', 'red')
      error_message, error_statement, executed_code = self._handle_execution_error(e, function_code, vars)
    finally:
      # update exec.done and exec_feedback according to answer_return
      if self.answer_return is None:
        exec_feedback = error_message
      else:
        print('Agent indicates task is done.')
        self.done = True
        if self.answer_return == 'N/A':
          exec_feedback = "The stop action with N/A (indicates the task unfeasible) is executed. "
        else:
          exec_feedback = "The stop action with complete is executed. "
      
      if self.done:
        exec_feedback = f"{exec_feedback} The task is done."
      
      if self.error_in_finding:
        if error_message:
          error_message = f'{error_message}\n{self.error_in_finding}'
        else:
          error_message = self.error_in_finding
      
      # Consider merging exec_feedback and error_message
      exec_result = JSON_models.ExecResult(
        executed_code=executed_code,
        error_statement=str(error_statement),
        error_message=error_message,
        exec_feedback=exec_feedback,
        answer_return=self.answer_return,
        agent_done=self.agent_done,
        done=self.done,
      )
      
      print_with_color('\n****self.executed_element_index:', 'cyan')
      print_with_color(self.executed_element_index, 'cyan')
      print_with_color('\n****self.executed_actions:', 'cyan')
      print_with_color(self.executed_actions, 'cyan')
      print_with_color('\n****self.related_elements:', 'cyan')
      print_with_color(self.related_elements, 'cyan')
      
      traj = self.env_op_traj
      if len(traj) == 0:
        traj = [JSON_models.EnvExecStepInfo()]
    return traj, exec_result
  
  def _handle_execution_error(self, e, function_code, vars):
    if isinstance(e, (ActionExecutionError, ReachedMaxStepsError)):
      executed_code, error_line = extract_code_before_error(function_code, e.error_code)
      error_statement = e.error_code
    else:
      tb = traceback.extract_tb(e.__traceback__)
      last_trace = tb[-1]
      error_line = last_trace.lineno
      code_lines = function_code.splitlines()
      if 1 <= error_line <= len(code_lines):
        error_statement = code_lines[error_line - 1]
      else:
        error_statement = "Unable to retrieve the error code"
        # error_statement = "(Unable to retrieve the error code, index out of range)"
      executed_code = code_lines[:error_line]
      if error_line < len(code_lines):
        executed_code.append(vars.get("function_call"))
    
    error_message = f"An error occurred! Error line: {error_line}, Error code: {error_statement}, " \
                    f"Error type: {e.__class__.__name__}, Error message: {e}"
    print_with_color(f"❌ {error_message}", 'cyan')
    
    executed_code = '\n'.join(executed_code) if isinstance(executed_code, list) else executed_code
    return error_message, error_statement, executed_code
  
  def execute_action(self, action_dict, action_code):
    print_with_color(action_dict, 'cyan')
    
    self.before_obs = deepcopy(self.cur_obs)
    target_index = action_dict.get('index', None)
    action_feedback = ''
    
    try:
      assert not self.done, "Task has been stopped."
      action_feedback = 'Task has been stopped.'
      if action_dict['action_type'] == 'ask_mllm':
        action_feedback = "Action has been performed.\n"
      else:
        ### Parse the action
        try:
          action = create_id_based_action(action_dict['script'])
        except ActionParsingError as e:
          raise ValueError(f"Invalid action: {action_code}")
        ### Get the action result
        action_feedback = get_action_description(action, self.state_info["info"]["observation_metadata"])
        # TODO: fix timeout error
        # if "Please check" in action_info:
        #    raise ValueError(f"Error: {action_info}")
        
        ### Execute the action
        obs, _, self.done, _, info = self.raw_env.step(action)
        obs, _, self.done, _, info = self.raw_env.step(create_none_action()) # ensure that the action is complete and the page is refreshed
        ### Get the observation
        self.state_info = {"observation": obs, "info": info}

        print(f'action_script: {action_dict["script"]}')
        print(f'action_feedback: {action_feedback}')
    except Exception as e:
      raise ActionExecutionError(action_code, self.kwargs, e)  # Caught by execute_code()
    finally:
      time.sleep(0.1)  # Prevent tap from becoming double-tap when execution is too fast
      self.executed_actions.append(action_code)
      self.executed_element_index.append(target_index)
      
      self.cur_obs = self.get_obs(file_prefix=f'step_{len(self.executed_actions)}')

      # Ensure action, index, element for each page interaction are stored together; no attribute missing
      before_ui_content = self.before_obs.ui_content
      self.related_elements.append(self.before_obs.ui_elements[target_index] if target_index is not None else '')
      
      after_ui_content = self.cur_obs.ui_content
      is_screen_changed = True if self.before_obs.ui_elements != self.cur_obs.ui_elements else False
      
      # update self.env_op_traj
      step_info = JSON_models.EnvExecStepInfo(
        before_ui_content=before_ui_content,
        before_screenshot_path=self.before_obs.screenshot_path,
        executed_action=self.executed_actions[-1],
        related_target=self.executed_element_index[-1],
        related_elements=self.related_elements[-1],
        action_feedback=action_feedback,
        after_ui_content=after_ui_content,
        after_screenshot_path=self.cur_obs.screenshot_path,
        is_screen_changed=is_screen_changed,
      )
      self.env_op_traj.append(step_info)
      
      # update self.action_history
      action_info_str = f'Step-{len(self.action_history) + 1}:\nExecuted action: {step_info.executed_action}\n'
      if step_info.related_elements:
        action_info_str += f' Related elements: {step_info.related_elements} '
      action_info_str += step_info.action_feedback
      if not step_info.is_screen_changed:
        action_info_str += '\nNo screen changes.'
      self.action_history.append(action_info_str)
      
      # handle the done and reward
      if action_dict["action_type"] == "stop":
        self.done = True
        self.answer_action = action
      elif len(self.executed_actions) >= self.max_action_step:
        self.done = True
        print_with_color(f'Reached max action steps: {len(self.executed_actions)}/{self.max_action_step}', 'red')
        raise ReachedMaxStepsError(action_code, self.kwargs)  # Caught by execute_code()
      if self.done:
        self.reward = bool(self.evaluator(
          trajectory=[self.answer_action],
          config_file=self.config_file,
          page=self.raw_env.page,
          client=self.raw_env.get_page_client(self.raw_env.page)))
        with open(self.config_file, "r") as f:
          configs = json.load(f)
          eval_types = configs["eval"]["eval_types"]
        feedback = f"Success: {self.reward}"
        # The feedback for url_match. Since some tasks require matching URL instead of giving answer, this needs to be clarified to the agent.
        if eval_types == ["url_match"] and not self.reward:
          feedback += "; You are not stopping at the required web page."
        self.action_history[-1] += f" This epoch is done. {feedback}"

  ## -----start: action space
  # Here are the admissible actions:
  ### Page Operation Actions:
  # clicks on an element with a specific id on the webpage. If the element has popup options, input the option to select one
  def click(self, id: int, option=""):
    if option:
      action_dict = {"action_type": "select", "index": id, "script": f"select [{id}] [{option}]"}
      return self.execute_action(action_dict, f"env_op.click({id}, \"{option}\")")
    else:
      action_dict = {"action_type": "click", "index": id, "script": f"click [{id}]"}
      return self.execute_action(action_dict, f"env_op.click({id})")
  
  # type the content into the field with id. By default, the existing content in the field will be cleared unless clear_existing is set to False, and the "Enter" key is pressed after typing unless press_enter_after is set to False
  def type(self, id: int, content: str, clear_existing=True, press_enter_after=True):
    # if bool parameter is passed as string, convert it to actual boolean value
    if isinstance(clear_existing, str):
    # remove possible quotes, then check
      clear_existing = clear_existing.strip('"').strip("'").lower() in ('true', '1')
    if isinstance(press_enter_after, str):
      press_enter_after = press_enter_after.strip('"').strip("'").lower() in ('true', '1')
  
    action_dict = {"action_type": "type", "index": id, "script": f"type [{id}] [{content}] [{int(clear_existing)}] [{int(press_enter_after)}]"}
    return self.execute_action(action_dict, f"env_op.type({id},\"{content}\",\"{clear_existing}\",\"{press_enter_after}\")")
  
  # hover over an element with id
  def hover(self, id: int):
    action_dict = {"action_type": "hover", "index": id, "script": f"hover [{id}]"}
    return self.execute_action(action_dict, f"env_op.hover({id})")
  
  # simulates the pressing of a key combination on the keyboard (e.g., Ctrl+v)
  def press(self, key_comb: str):
    action_dict = {"action_type": "press", "index": None, "script": f"press [{key_comb}]"}
    return self.execute_action(action_dict, f"env_op.press(\"{key_comb}\")")
  
  # scroll the page up or down.
  def scroll(self, direction: str="down"):
    action_dict = {"action_type": "scroll", "index": None, "script": f"scroll [{direction}]"}
    return self.execute_action(action_dict, f"env_op.scroll(\"{direction}\")")
  
  ### Tab Management Actions:
  # open a new, empty browser tab
  def new_tab(self):
    action_dict = {"action_type": "new_tab", "index": None, "script": "new_tab"}
    return self.execute_action(action_dict, "env_op.new_tab()")
      
  # switch the browser's focus to a specific tab using its index
  def tab_focus(self, tab_index: int):
    action_dict = {"action_type": "tab_focus", "index": None, "script": f"tab_focus [{tab_index}]"}
    return self.execute_action(action_dict, f"env_op.tab_focus({tab_index})")

  # close the currently active tab.
  def close_tab(self):
    action_dict = {"action_type": "close_tab", "index": None, "script": "close_tab"}
    return self.execute_action(action_dict, "env_op.close_tab()")
  
  ### URL Navigation Actions:
  # navigate to a specific URL
  def goto(self, url: str):
    action_dict = {"action_type": "goto", "index": None, "script": f"goto [{url}]"}
    return self.execute_action(action_dict, f"env_op.goto(\"{url}\")")
  
  # navigate to the previously viewed page
  def go_back(self):
    action_dict = {"action_type": "go_back", "index": None, "script": "go_back"}
    return self.execute_action(action_dict, "env_op.go_back()")
  
  # navigate to the next page (if a previous 'go_back' action was performed).
  def go_forward(self):
    action_dict = {"action_type": "go_forward", "index": None, "script": "go_forward"}
    return self.execute_action(action_dict, "env_op.go_forward()")

  ### Completion Action:
  # Issue this action when you believe the task is complete. If the objective is to find a text-based answer, provide the answer in the bracket. If you believe the task is impossible to complete, provide the answer="N/A" in the bracket.
  def stop(self, answer: str=""):
    self.answer_return = answer
    self.agent_done = True
    action_dict = {"action_type": "stop", "index": None, "script": f"stop [{answer}]"}
    return self.execute_action(action_dict, f"env_op.stop(\"{answer}\")")
  
  ## -----end: action space
  
  def ask_mllm(self, question):
    system_prompt = ("Carefully examine the page information and answer the question.\n"
          "[Output Format]\n"
          "1. thought: output your brief thought (under 100 words).\n"
          "2. answer: output the answer to the question. Do not wrap with Markdown tags.\n"
          "In your answer, just directly return the required information and do not include any other words. Please respond using English typography conventions only.\n"
          "[Example Output 1 - Simple Answer]\n"
          "thought: I need to find the number of comments that have received more downvotes than upvotes for the user who made the latest post on the current page.\n"
          "answer: 10\n"
          "[Example Output 2 - JSON Answer]\n"
          "thought: I need to find synonyms from the options list.\n"
          'answer: {"matches": ["brave", "courageous"]}\n')
    user_prompt = (f"\nExecution History:\n" + '\n'.join(self.action_history) +
              f"\n{self.cur_obs.ui_content}\n"
              f"Question: {question}\n")
    agent_utils.write_to_file(file_path=self.save_path,
                              file_name=f'step-{len(self.executed_actions)}_ask_mllm_prompt.txt', 
                              content=f"[system]\n{system_prompt}\n\n[user]\n{user_prompt}")
    output, raw_response = self.ask_mllm_llm.predict_mm(
      user_prompt=user_prompt, 
      images=[self.cur_obs.screenshot],
      system_prompt=system_prompt,
      output_format=JSON_models.AskMLLMOutput
    )
    agent_utils.write_to_file(file_path=self.save_path,
                              file_name=f'step-{len(self.executed_actions)}_ask_mllm_output.txt', 
                              content=output.model_dump_json())
    print_with_color(f'ask_mllm response: thought={output.thought}, answer={output.answer}', 'cyan')

    # Use structured output directly
    output_ans = output.answer
    
    cost_tokens = raw_response.usage
    self.record_token.step = str(len(self.executed_actions))
    self.record_token.agent = 'ask mllm'
    self.record_token.step_tokens = cost_tokens
    self.record_token.llm = FLAGS.ask_mllm_llm
    agent_utils.record_cost_tokens(self.record_token)
    
    action_dict = {"action_type": "ask_mllm", "index": None, "script": None}
    self.execute_action(action_dict, f"env_op.ask_mllm(question={question})")
    return output_ans
  
  def get_ui_content(self):
    return self.cur_obs.ui_content_dict
  
  def find_element(self, **kwargs) -> int:
    """
    Find an element in the UI list based on the given filtering criteria.

    Parameters:
        **kwargs: Filtering criteria, such as keyword, text, content_description, is_clickable, etc.

    Returns:
        int: The index of the first element that matches the criteria.

    Exceptions:
        ValueError: If no element matching the criteria is found.
    """
    print_with_color('\nfind_element(**kwargs)', 'cyan')
    print_with_color(f'kwargs: {kwargs}', 'cyan')
    self.kwargs = kwargs
    
    candidate_elements = []
    exclude_keys = ["target_description", "meta"]  # Exclude semantic description and metadata keys
    for element in self.cur_obs.ui_content_dict.values():
      # Check if all key-value pairs match the element, excluding special keys
      if all(
        element.get(key) is not None and element.get(key) == value
        for key, value in kwargs.items()
        if key not in exclude_keys
      ):
        candidate_elements.append(element)
    
    if len(candidate_elements) != 1:
      return self.match_element_with_mllm(kwargs, candidate_elements)
    
    # Return the index of the single matched element
    index = candidate_elements[0]["index"]
    print_with_color(f'matched index: {index}', 'cyan')
    return index  # when cnt == 1
  
  def match_element_with_mllm(self, target_element_info: dict, candidate_elements: list) -> int:
    """
      Use GPT-4o to select the most appropriate UI element from candidate elements
      based on the provided target element description.

      Returns:
        int: The index of the best-matching element, or -1 if no confident match is found.
    """
    candidate_elements_str = f"Candidate elements (filtered by basic attribute matching):\n{candidate_elements}\n\n" if candidate_elements else ''
    ui_content, screenshot_som = self.cur_obs.ui_content, self.cur_obs.screenshot_with_som_resized
    
    # Determine if this is a semantic matching case:
    # 1. If candidate_elements is empty, exact match failed - allow semantic matching
    # 2. If target_description contains "matches" or "contains", explicitly requesting semantic match
    is_semantic_match = False
    if len(candidate_elements) == 0:
      # Exact match failed, allow semantic matching with lower threshold
      is_semantic_match = True
    else:
      # Check if explicitly requesting semantic match via target_description
      target_desc = target_element_info.get('target_description', '')
      if target_desc and ('matches' in target_desc.lower() or 'contains' in target_desc.lower()):
        is_semantic_match = True
    
    if is_semantic_match:
      prompt = (
        "You are a UI automation assistant. Your task is to match the user-provided description with the most appropriate UI element. "
        "This is a semantic matching case where exact text match is not required—find the element that best matches the semantic description.\n\n"
        "[Output Format]\n"
        "Select the best-matching element and return its index with a confidence score (1–10).\n"
        "For semantic matching, results with a score of 8 or higher will be accepted. If no suitable match exists, return -1.\n\n"
        "Example Output:\n"
        "{'thought': '...', 'target_index': 0, 'confidence_score': 8}\n"
        "Only return with a brief thought, the target index, and the confidence score.\n\n"
        "[Input]\n"
        f"Target element info:\n{target_element_info}\n\n"
        f"UI content:\n{ui_content}\n\n"
        f"{candidate_elements_str}"
      )
    else:
      prompt = (
        "You are a UI automation assistant. Your task is to match the user-provided description with the most appropriate UI element. "
        "Text-type attributes must match the element content exactly.\n\n"
        "[Output Format]\n"
        "Select the best-matching element and return its index with a confidence score (1–10).\n"
        "Only results with a score of 10 will be accepted. If no suitable match exists, return -1.\n\n"
        "Example Output:\n"
        "{'thought': '...', 'target_index': 0, 'confidence_score': 6}\n"
        "Only return with a brief thought, the target index, and the confidence score.\n\n"
        "[Input]\n"
        f"Target element info:\n{target_element_info}\n\n"
        f"UI content:\n{ui_content}\n\n"
        f"{candidate_elements_str}"
      )
    agent_utils.write_to_file(self.save_path, f'step-{len(self.executed_actions)}_match_element_with_mllm_prompt.txt', prompt)
    llm = get_llm_wrapper(model_name=FLAGS.grounder_llm, enable_logging=FLAGS.enable_llm_logging)
    output, raw_response = llm.predict_mm(prompt, [screenshot_som], output_format=JSON_models.MllmMatchTarget)

    agent_utils.write_to_file(self.save_path, f'step-{len(self.executed_actions)}_match_element_with_mllm_output.txt', output)
    
    cost_tokens = raw_response.usage
    self.record_token.step = str(len(self.executed_actions))
    self.record_token.agent = 'match_element_with_mllm'
    self.record_token.step_tokens = cost_tokens
    self.record_token.llm = FLAGS.grounder_llm
    agent_utils.record_cost_tokens(self.record_token)
    
    index = output.target_index
    score = output.confidence_score
    
    print_with_color(f'mllm_match response: index={index}, confidence score={score}', 'cyan')
    
    # Apply different thresholds: semantic matching allows score >= 8, exact matching requires score == 10
    if is_semantic_match:
      if score < 8: index = -1
    else:
      if score < 10: index = -1
    
    print('actual index:', index)
    self.previous_index = index
    
    return index
  
  def get_obs(self, log_task_path: Optional[str] = None, file_prefix: str = '', save: bool = True) -> ScreenObs:
    """
    get_obs() is called at the start of the env and after each action.
    """
    time.sleep(2)  # Wait for the screen to stabilize
    if log_task_path is None:
      log_task_path = self.save_path
    
    self.raw_obs = self.state_info["observation"]["text"].replace(EXTRA_TXT, "")
    self.cur_url = self.state_info["info"]["page"].url
    self.cur_site = get_cur_site(self.cur_url)
    self.ui_content = f"OBSERVATION:\n{self.raw_obs}\nCurrent URL of Simulated {self.cur_site}: {self.cur_url}\n"

    obs_nodes_info = self.state_info["info"]["observation_metadata"]["text"]["obs_nodes_info"]
    self.ui_elements = {
      int(k): v['text'].replace(EXTRA_TXT, "")
      for k, v in obs_nodes_info.items()
    }
    # Get attributes of each element
    self.ui_content_dict = {
      k: agent_utils.parse_str_to_dict(v)
      for k, v in self.ui_elements.items()
    }
    
    # Add metadata to the UI content dictionary
    self.ui_content_dict["meta"] = {"url": self.cur_url}
    
    tab_match = re.search(r"^(Tab .*)$", self.raw_obs, re.MULTILINE)
    if tab_match:
      self.ui_content_dict["meta"]["tab"] = tab_match.group(1).strip()

    scroll_match = re.search(r"^(Scroll Bar: .*)$", self.raw_obs, re.MULTILINE)
    if scroll_match:
      self.ui_content_dict["meta"]["scroll_bar"] = scroll_match.group(1).strip()

    screenshot = self.state_info["observation"]["image"]

    obs = ScreenObs(
      ui_content=self.ui_content,
      ui_elements=self.ui_elements,
      ui_content_dict=self.ui_content_dict,
      screenshot=screenshot,
      screenshot_path = ''
    )
    
    if save:
      agent_utils.write_to_file(file_path=log_task_path, file_name=file_prefix + '_ui_content.txt', content=self.ui_content)
      agent_utils.write_to_file(file_path=log_task_path, file_name=file_prefix + '_ui_content_dict.txt',
                                content=self.ui_content_dict)
      obs.screenshot_path = agent_utils.store_image(screenshot, file_prefix + '_screenshot.png',
                              file_path=log_task_path)
    return obs

def get_action_description(action, observation_metadata, action_set_tag="id_accessibility_tree"):
    """Generate the text version of the predicted actions to store in action history for prompt use.
    May contain hint information to recover from the failures"""

    text_meta_data = observation_metadata["text"]["obs_nodes_info"]
    if action["action_type"] in [ActionTypes.CLICK, ActionTypes.HOVER, ActionTypes.TYPE, ActionTypes.SELECT_OPTION]:
        action_name = str(action["action_type"]).split(".")[1].lower()
        if action["element_id"] in text_meta_data:
            node_content = text_meta_data[action["element_id"]]["text"]
            node_content = " ".join(node_content.split()[1:])
            action_str = action2str(action, action_set_tag, node_content)
            if node_content.startswith("checkbox"):
                action_str += " Notice: You cannot click a checkbox element, and you should click the StaticText above it instead."
        else:
            action_str = f"Attempt to perfom \"{action_name}\" on element \"[{action['element_id']}]\" but no matching element found. Please check the observation more carefully."
    else:
        action_str = "Action has been performed."

    return action_str


# **Extract code up to the point where the error occurred**
def extract_code_before_error(code_str: str, error_obj) -> tuple[str, int]:
  """
  Extracts code from the beginning of code_str up to and including the part that matches error_obj.
  - If error_obj is a dict: finds the assignment block that matches the dict.
  - If error_obj is a string and a function call: matches function name via AST.
  - Fallback: searches raw text for partial match.

  Returns (executed_code, error_line_num)
  """
  lines = code_str.splitlines()
  
  # ---------- Handle string-form error (e.g., "env_op.swipe(...)") ----------
  if isinstance(error_obj, str):
    try:
      tree = ast.parse(code_str)
      error_expr = ast.parse(error_obj, mode='eval')
      if isinstance(error_expr.body, ast.Call):
        # Extract function name from error string
        error_func = error_expr.body.func
        if isinstance(error_func, ast.Name):
          error_func_name = error_func.id
        elif isinstance(error_func, ast.Attribute):
          error_func_name = error_func.attr
        else:
          error_func_name = None
        
        if error_func_name:
          for node in ast.walk(tree):
            if isinstance(node, ast.Call):
              func = node.func
              node_func_name = func.id if isinstance(func, ast.Name) else (
                func.attr if isinstance(func, ast.Attribute) else None
              )
              if node_func_name == error_func_name and hasattr(node, 'lineno'):
                line_num = node.lineno
                return '\n'.join(lines[:line_num]), line_num
    except Exception:
      pass
    
    # Fallback: raw string search
    for i, line in enumerate(lines):
      if error_obj.strip() in line:
        return '\n'.join(lines[:i + 1]), i + 1
    
    return "Error string not found in code.", -1
  
  # ---------- Handle dict-form error (e.g., kwargs = {...}) ----------
  try:
    tree = ast.parse(code_str)
  except SyntaxError:
    return "Invalid code syntax.", -1
  
  best_node = None
  best_score = 0
  
  def dict_match_score(dict_node: ast.Dict) -> int:
    score = 0
    for k_ast, v_ast in zip(dict_node.keys, dict_node.values):
      try:
        key = ast.literal_eval(k_ast)
        val = ast.literal_eval(v_ast)
        if key in error_obj and error_obj[key] == val:
          score += 1
        elif key in error_obj and isinstance(error_obj[key], list) and isinstance(val, list):
          if set(error_obj[key]).issubset(set(val)):
            score += 1
      except Exception:
        continue
    return score
  
  for node in ast.walk(tree):
    if isinstance(node, ast.Assign) and isinstance(node.value, ast.Dict):
      score = dict_match_score(node.value)
      if score > best_score:
        best_score = score
        best_node = node
  
  if best_node:
    start = best_node.lineno
    end = getattr(best_node, 'end_lineno', start)
    return '\n'.join(lines[:end]), start
  
  return "Relevant code block not found.", -1
