"""
Input step info per step:
obs, reason for code, code, execution of code
"""
import ast
import json
import os
import re
from typing import Optional, Tuple

from absl import flags

from .prompts.builder_prompt import get_rpa_builder_prompt
from .utils import agent_utils, JSON_models
from .utils.JSON_API import OpenAIWrapper
from .utils.agent_utils import print_with_color

FLAGS = flags.FLAGS

# Maximum number of syntax error retries
MAX_SYNTAX_RETRY = 2


def validate_python_syntax(code: str, log_path: str = None) -> Tuple[bool, str]:
  """
  Validate Python code syntax using AST parsing.
  
  Args:
    code: The Python code string to validate
    log_path: Optional path to save syntax error details
  
  Returns:
    Tuple of (is_valid, error_message)
  """
  try:
    ast.parse(code)
    print_with_color('✅ RPA code syntax validation passed', 'green')
    return True, ""
  except SyntaxError as e:
    error_msg = (
      f"❌ Syntax Error in RPA Code:\n"
      f"  Line {e.lineno}: {e.msg}\n"
      f"  Text: {e.text}\n"
      f"  Offset: {' ' * (e.offset - 1) if e.offset else ''}^\n"
    )
    print_with_color(error_msg, 'red')
    
    if log_path:
      agent_utils.write_to_file(
        file_path=log_path,
        file_name='syntax_error.txt',
        content=f"{error_msg}\n\nFull Code:\n{code}"
      )
    
    return False, error_msg
  except Exception as e:
    error_msg = f"❌ Unexpected error during syntax validation: {e}"
    print_with_color(error_msg, 'red')
    return False, error_msg


def get_syntax_error_feedback(error_msg: str) -> str:
  """
  Generate feedback prompt for LLM when syntax error is detected.
  
  Args:
    error_msg: The error message from syntax validation
  
  Returns:
    Feedback string to append to the prompt
  """
  return (
    f"\n\n[CRITICAL - Syntax Error Detected]\n"
    f"The previously generated RPA code has a syntax error. Please fix it and regenerate:\n"
    f"{error_msg}\n"
    f"Common fixes:\n"
    f"- For matching multiple element types in XPath, use 'or' condition: //*[(name()='span' or name()='a') and text()='OK']\n"
    f"- Escape special characters in strings properly\n"
    f"- Use single quotes for strings containing double quotes\n"
    f"- Ensure all brackets and parentheses are matched\n"
    f"Please regenerate the RPA code with the syntax error fixed.\n"
  )


class RPA_Builder_Agent:
  def __init__(
    self,
    llm: OpenAIWrapper,
  ):
    self.rpa_builder_conslusion = ''  # Lives for the duration of each task type
    self.llm = llm
    self.reflection = None
    self.record_token = JSON_models.RecordToken()
  
  def _print_rpa_info(self, builder_output, rpa_info):
    """Print RPA builder output information."""
    print(f'Thought:\n{builder_output.thought}\n')
    if hasattr(builder_output, 'info_to_clarify'):
      print(f'info_to_clarify:\n{builder_output.info_to_clarify}\n')
    print(f'Task Type:\n{rpa_info.task_type}\n')
    print(f'Params:\n{rpa_info.parameters}\n')
    print(f'RPA Description:\n{rpa_info.rpa_description}\n')
    print(f'Code:\n{rpa_info.rpa_code}\n')
    print(f'Example Usage:\n{rpa_info.example_usage}\n')
    print(f'Conclusion:\n{rpa_info.conclusion}\n')
  
  def _generate_with_syntax_retry(self, base_prompt: str, images: list, log_task_path: str, 
                                   round_name: str, task_type: str, task_template: str,
                                   list_react_traj: list, pre_rpa_exec_traj, 
                                   fetched_info: dict) -> tuple:
    """
    Generate RPA code with automatic syntax error retry.
    
    Returns:
      tuple: (builder_output, rpa_info, total_token_cost)
    """
    syntax_error_feedback = ""
    total_tokens = 0
    
    for syntax_retry in range(MAX_SYNTAX_RETRY + 1):
      # Build prompt
      rpa_builder_prompt = base_prompt + syntax_error_feedback
      
      # Save prompt
      retry_suffix = f'_retry{syntax_retry}' if syntax_retry > 0 else ''
      agent_utils.write_to_file(file_path=log_task_path, 
                                file_name=f'rpa_builder_prompt_{round_name}{retry_suffix}.txt',
                                content=rpa_builder_prompt)
      
      # Call LLM
      builder_output, raw_response = self.llm.predict_mm(text_prompt=rpa_builder_prompt, 
                                                         images=images,
                                                         output_format=JSON_models.RPABuilderOutput)
      
      # Save output
      agent_utils.write_to_file(file_path=log_task_path, 
                                file_name=f'rpa_builder_output_{round_name}{retry_suffix}.txt',
                                content=builder_output)
      
      # Record tokens
      cost_tokens = raw_response.usage
      total_tokens += cost_tokens.total_tokens if hasattr(cost_tokens, 'total_tokens') else 0
      self.record_token.step = '-'
      self.record_token.agent = f'RPA Builder {round_name} (retry {syntax_retry})' if syntax_retry > 0 else f'RPA Builder {round_name}'
      self.record_token.step_tokens = cost_tokens
      agent_utils.record_cost_tokens(self.record_token)
      
      # Extract and print RPA info
      rpa_info = builder_output.output
      self._print_rpa_info(builder_output, rpa_info)
      
      # Validate syntax
      is_valid, error_msg = validate_python_syntax(rpa_info.rpa_code, log_task_path)
      
      if is_valid:
        return builder_output, rpa_info, total_tokens
      
      # Handle syntax error
      if syntax_retry < MAX_SYNTAX_RETRY:
        print_with_color(f'🔄 Syntax error detected, retrying... ({syntax_retry + 1}/{MAX_SYNTAX_RETRY})', 'yellow')
        syntax_error_feedback = get_syntax_error_feedback(error_msg)
        # Regenerate base prompt for next iteration
        base_prompt = get_rpa_builder_prompt(task_type=task_type, task_template=task_template,
                                             react_trajs=list_react_traj,
                                             pre_rpa_exec_traj=pre_rpa_exec_traj,
                                             fetched_info=fetched_info, use_tool=False,
                                             rpa_builder_conslusion=self.rpa_builder_conslusion)
      else:
        print_with_color(f'⚠️  Max syntax retries reached. Proceeding with code that has syntax errors.', 'yellow')
    
    return builder_output, rpa_info, total_tokens
  
  def generate_rpa_code(
    self,
    log_task_path: str,
    task_type: str,
    task_template: str,
    list_react_traj: Optional[list[JSON_models.ReActTraj]] = None,
    pre_rpa_exec_traj: Optional[JSON_models.RPAExecTraj] = None,
    encountered_task_goals: Optional[list[str]] = None,
  ) -> tuple[JSON_models.RPAInfo, int]:
    print('============================================')
    print("Current Agent: RPA_Builder_Agent\n")
    print(f'model: {self.llm.model_name}\n')
    os.makedirs(log_task_path, exist_ok=True)
    self.react_trajs = list_react_traj
    self.pre_rpa_exec_traj = pre_rpa_exec_traj
    
    cnt_fetch_info = 0
    MAX_FETCH_force = 0
    MAX_FETCH_optional = 3 if FLAGS.use_fetch_info else 0
    MAX_FETCH_total = MAX_FETCH_optional + MAX_FETCH_force
    fetched_info = {}
    fetched_screenshot = []
    if MAX_FETCH_total == 0:
      print_with_color(f'\nBuilder Round 0\n', 'blue')
      base_prompt = get_rpa_builder_prompt(task_type=task_type, task_template=task_template,
                                           react_trajs=list_react_traj,
                                           encountered_task_goals=encountered_task_goals,
                                           fetched_info=fetched_info,
                                           pre_rpa_exec_traj=pre_rpa_exec_traj,
                                           rpa_builder_conslusion=self.rpa_builder_conslusion)
      
      builder_output, rpa_info, _ = self._generate_with_syntax_retry(
        base_prompt=base_prompt,
        images=[],
        log_task_path=log_task_path,
        round_name='0',
        task_type=task_type,
        task_template=task_template,
        list_react_traj=list_react_traj,
        pre_rpa_exec_traj=pre_rpa_exec_traj,
        fetched_info=fetched_info
      )
      
      agent_utils.write_to_file(file_path=log_task_path, file_name='rpa_builder_output.txt', content=builder_output)
    
    for fetch_cnt in range(MAX_FETCH_total):
      print_with_color(f'\nBuilder Round {fetch_cnt}\n', 'blue')
      output_format = JSON_models.RPABuilderOutput_optional
      if fetch_cnt < MAX_FETCH_force:
        output_format = JSON_models.RPABuilderOutput_tool
      rpa_builder_prompt = get_rpa_builder_prompt(task_type=task_type, task_template=task_template,
                                                      react_trajs=list_react_traj,
                                                      pre_rpa_exec_traj=pre_rpa_exec_traj,
                                                      encountered_task_goals=encountered_task_goals,
                                                      fetched_info=fetched_info, use_tool=True,
                                                      rpa_builder_conslusion=self.rpa_builder_conslusion)
      agent_utils.write_to_file(file_path=log_task_path, file_name=f'rpa_builder_prompt_{fetch_cnt}.txt',
                                content=rpa_builder_prompt)
      builder_output, raw_response = self.llm.predict_mm(text_prompt=rpa_builder_prompt,
                                                        images=fetched_screenshot,
                                                        output_format=output_format)
      agent_utils.write_to_file(file_path=log_task_path, file_name=f'rpa_builder_output_{fetch_cnt}.txt',
                                content=builder_output)
      agent_utils.write_to_file(file_path=log_task_path, file_name=f'rpa_builder_raw_response_{fetch_cnt}.txt',
                                content=raw_response)
      cost_tokens = raw_response.usage
      self.record_token.step = '-'
      self.record_token.agent = f'RPA Builder {fetch_cnt}'
      self.record_token.step_tokens = cost_tokens
      agent_utils.record_cost_tokens(self.record_token)
      if isinstance(builder_output.output, JSON_models.FetchInfoTool):
        # count fetch times
        cnt_fetch_info += 1
        
        print(f'Thought:\n{builder_output.thought}\n')
        print(f'info_to_clarify:\n{builder_output.info_to_clarify}\n')
        print(f'output:\n{builder_output.output}\n')
        
        fetched_info = self.fetch_info(builder_output.output.traj_id, builder_output.output.step_n)  # Local screenshot + UI extraction
        if fetched_info['screenshot'] is not None:
          fetched_screenshot = [fetched_info['screenshot']]
          agent_utils.store_image(fetched_info['screenshot'], f'screenshot_{fetch_cnt}.png', log_task_path)
        if fetched_info['ui_content']:
          agent_utils.write_to_file(log_task_path, f'ui_content_{fetch_cnt}.txt', content=fetched_info['ui_content'])
      else:
        # RPA function generation (not fetch_info tool call)
        # First attempt already done, check syntax and retry if needed
        rpa_info = builder_output.output
        self._print_rpa_info(builder_output, rpa_info)
        
        # Validate syntax - if invalid, regenerate with retry logic
        is_valid, error_msg = validate_python_syntax(rpa_info.rpa_code, log_task_path)
        
        if not is_valid:
          # Use unified retry method
          print_with_color(f'🔄 Syntax error detected in initial generation, retrying...', 'yellow')
          base_prompt = get_rpa_builder_prompt(task_type=task_type, task_template=task_template,
                                               react_trajs=list_react_traj,
                                               pre_rpa_exec_traj=pre_rpa_exec_traj,
                                               encountered_task_goals=encountered_task_goals,
                                               fetched_info=fetched_info, use_tool=False,
                                               rpa_builder_conslusion=self.rpa_builder_conslusion)
          
          builder_output, rpa_info, _ = self._generate_with_syntax_retry(
            base_prompt=base_prompt,
            images=fetched_screenshot,
            log_task_path=log_task_path,
            round_name=f'{fetch_cnt}_syntax',
            task_type=task_type,
            task_template=task_template,
            list_react_traj=list_react_traj,
            pre_rpa_exec_traj=pre_rpa_exec_traj,
            fetched_info=fetched_info
          )
        
        break  # No more tool needed, exit loop
      
      # If we reach here, all attempts so far were tool calls
      if fetch_cnt == MAX_FETCH_total - 1:
        print_with_color(f'\nBuilder Round {fetch_cnt + 1} (Final, force generation)\n', 'blue')
        base_prompt = get_rpa_builder_prompt(task_type=task_type, task_template=task_template,
                                             react_trajs=list_react_traj,
                                             encountered_task_goals=encountered_task_goals,
                                             fetched_info=fetched_info,
                                             pre_rpa_exec_traj=pre_rpa_exec_traj,
                                             rpa_builder_conslusion=self.rpa_builder_conslusion)
        
        builder_output, rpa_info, _ = self._generate_with_syntax_retry(
          base_prompt=base_prompt,
          images=[],
          log_task_path=log_task_path,
          round_name=f'{fetch_cnt + 1}',
          task_type=task_type,
          task_template=task_template,
          list_react_traj=list_react_traj,
          pre_rpa_exec_traj=pre_rpa_exec_traj,
          fetched_info=fetched_info
        )
        
        agent_utils.write_to_file(file_path=log_task_path, file_name='rpa_builder_output.txt', content=builder_output)
    
    rpa_info = builder_output.output
    self.rpa_builder_conslusion = rpa_info.conclusion
    
    rpa_info.task_type=task_type
    
    return rpa_info, cnt_fetch_info
  
  def fetch_info(self, traj_id: str, step_n: int):
    if 'fix_react_traj' in traj_id:
      traj = self.pre_rpa_exec_traj.fix_react_traj
      env_step = traj[
        step_n - len(self.pre_rpa_exec_traj.env_op_traj) - 1].exec_step_info
      if len(traj) == step_n:
        ui_content = "Task has been finished"
      else:
        ui_content = env_step.after_ui_content
    else:  # Map traj_id name to actual trajectory object
      traj_lookup = {
        "pre_rpa_exec_traj": self.pre_rpa_exec_traj,
        "successful_react_traj": self.react_trajs[-1] if self.react_trajs else None,
        "failed_react_traj": self.react_trajs[0] if len(self.react_trajs) > 1 else None,
      }
      
      traj = traj_lookup.get(traj_id)
      if not traj:
        return {"traj_id": traj_id, "step_n": step_n, "ui_content": f"Invalid traj_id '{traj_id}'.", "screenshot": None}
      if step_n < 1 or step_n > len(traj.env_op_traj):
        return {"traj_id": traj_id, "step_n": step_n, "ui_content": f"Invalid step_n {step_n}.", "screenshot": None}
      if (not traj) and (step_n < 1 or step_n > len(traj.env_op_traj)):
        return {"traj_id": traj_id, "step_n": step_n, "ui_content": f"Invalid traj_id '{traj_id}' and step_n {step_n}.", "screenshot": None}
      
      env_step = traj.env_op_traj[step_n - 1]
      
      if len(traj.env_op_traj) == step_n:
        ui_content = "Task has been finished"
      else:
        ui_content = env_step.after_ui_content
    
    screenshot = env_step.after_screenshot_w_som_path if env_step.after_screenshot_w_som_path else env_step.after_screenshot_path
    
    # agent_utils.store_image(agent_utils.load_image_as_ndarray(screenshot), f'{traj_id}_{step_n}.png', '')
    
    result = {"traj_id": traj_id, "step_n": step_n, "ui_content": ui_content,
              "screenshot": agent_utils.load_image_as_ndarray(screenshot) if screenshot else None}
    
    return result
  
  def extract_json_call(self, output: str) -> dict:
    try:
      match = re.search(r'\{[\s\S]+\}', output)
      return json.loads(match.group(0)) if match else {}
    except Exception:
      return {}
