"""
Input step info per step:
obs, reason for code, code, execution of code
"""
import json
import os
import re
from typing import Optional

from absl import flags

from .prompts.builder_prompt import get_rpa_builder_prompt
from .utils import agent_utils, JSON_models
from .utils.llm_client import OpenAIWrapper
from .utils.agent_utils import print_with_color
from .utils.code_validation import validate_python_syntax

FLAGS = flags.FLAGS

# Maximum number of syntax error retries
MAX_SYNTAX_RETRY = 2


def get_syntax_error_feedback(error_msg: str) -> str:
  """
  Generate feedback prompt for LLM when syntax error is detected.
  
  Args:
    error_msg: The error message from syntax validation
  
  Returns:
    Feedback string to append to the prompt
  """
  return (
    f"\n\n[CRITICAL - Syntax Error Detected]\n"
    f"The previously generated RPA code has a syntax error. Please fix it and regenerate:\n"
    f"{error_msg}\n"
    f"Common fixes:\n"
    f"- Escape special characters in strings properly\n"
    f"- Use single quotes for strings containing double quotes\n"
    f"- Ensure all brackets and parentheses are matched\n"
    f"- Check that all function calls have matching parentheses\n"
    f"- Verify proper indentation\n"
    f"Please regenerate the RPA code with the syntax error fixed.\n"
  )


class RPA_Builder_Agent:
  def __init__(
    self,
    llm: OpenAIWrapper,
  ):
    self.rpa_builder_conslusion = ''  # Lives for the duration of each task type
    self.llm = llm
    self.reflection = None
    self.record_token = JSON_models.RecordToken()
  
  def generate_rpa_code(
    self,
    log_task_path: str,
    task_type: str,
    task_template: str,
    list_react_traj: Optional[list[JSON_models.ReActTraj]] = None,
    pre_rpa_exec_traj: Optional[JSON_models.RPAExecTraj] = None,
    encountered_task_goals: Optional[list[str]] = None,
  ) -> tuple[JSON_models.RPAInfo, int]:
    print('============================================')
    print("Current Agent: RPA_Builder_Agent\n")
    print(f'model: {self.llm.model_name}\n')
    os.makedirs(log_task_path, exist_ok=True)
    self.react_trajs = list_react_traj
    self.pre_rpa_exec_traj = pre_rpa_exec_traj
    
    cnt_fetch_info = 0
    MAX_FETCH_force = 0
    MAX_FETCH_optional = 3 if FLAGS.use_fetch_info else 0
    MAX_FETCH_total = MAX_FETCH_optional + MAX_FETCH_force
    fetched_info = {}
    fetched_screenshot = []
    if MAX_FETCH_total == 0:
      print_with_color(f'\nBuilder Round 0\n', 'blue')
      syntax_error_feedback = ""
      
      for syntax_retry in range(MAX_SYNTAX_RETRY + 1):
        rpa_builder_prompt = get_rpa_builder_prompt(task_type=task_type, task_template=task_template,
                                                    react_trajs=list_react_traj,
                                                    fetched_info=fetched_info,
                                                    pre_rpa_exec_traj=pre_rpa_exec_traj,
                                                    rpa_builder_conslusion=self.rpa_builder_conslusion,
                                                    encountered_task_goals=encountered_task_goals)
        
        # Add syntax error feedback if retrying
        if syntax_retry > 0:
          rpa_builder_prompt += syntax_error_feedback
        
        retry_suffix = f'_retry{syntax_retry}' if syntax_retry > 0 else ''
        agent_utils.write_to_file(file_path=log_task_path, file_name=f'rpa_builder_prompt_0{retry_suffix}.txt',
                                  content=rpa_builder_prompt)
        builder_output, raw_response = self.llm.predict_mm(user_prompt=rpa_builder_prompt, images=[],
                                                           output_format=JSON_models.RPABuilderOutput)
        agent_utils.write_to_file(file_path=log_task_path, file_name=f'rpa_builder_output_0{retry_suffix}.txt',
                                  content=builder_output)
        
        rpa_info = builder_output.output
        print(f'Thought:\n{builder_output.thought}\n')
        if hasattr(builder_output, 'info_to_clarify'):
          print(f'info_to_clarify:\n{builder_output.info_to_clarify}\n')
        print(f'Task Type:\n{rpa_info.task_type}\n')
        print(f'Params:\n{rpa_info.parameters}\n')
        print(f'RPA Description:\n{rpa_info.rpa_description}\n')
        print(f'Code:\n{rpa_info.rpa_code}\n')
        print(f'Example Usage:\n{rpa_info.example_usage}\n')
        print(f'Conclusion:\n{rpa_info.conclusion}\n')
        
        cost_tokens = raw_response.usage
        self.record_token.step = '-'
        self.record_token.agent = f'RPA Builder 0 (retry {syntax_retry})' if syntax_retry > 0 else f'RPA Builder 0'
        self.record_token.step_tokens = cost_tokens
        self.record_token.llm = FLAGS.builder_llm
        agent_utils.record_cost_tokens(self.record_token)
        
        # Validate syntax
        is_valid, error_msg = validate_python_syntax(rpa_info.rpa_code, log_task_path)
        
        if is_valid:
          agent_utils.write_to_file(file_path=log_task_path, file_name='rpa_builder_output.txt', content=builder_output)
          break
        
        # Handle syntax error
        if syntax_retry < MAX_SYNTAX_RETRY:
          print_with_color(f'🔄 Syntax error detected, retrying... ({syntax_retry + 1}/{MAX_SYNTAX_RETRY})', 'yellow')
          syntax_error_feedback = get_syntax_error_feedback(error_msg)
        else:
          print_with_color(f'⚠️  Max syntax retries reached. Proceeding with code that has syntax errors.', 'yellow')
          agent_utils.write_to_file(file_path=log_task_path, file_name='rpa_builder_output.txt', content=builder_output)
          break
    
    for fetch_cnt in range(MAX_FETCH_total):
      print_with_color(f'\nBuilder Round {fetch_cnt}\n', 'blue')
      output_format = JSON_models.RPABuilderOutput_optional
      if fetch_cnt < MAX_FETCH_force:
        output_format = JSON_models.RPABuilderOutput_tool
      rpa_builder_prompt = get_rpa_builder_prompt(task_type=task_type, task_template=task_template,
                                                      react_trajs=list_react_traj,
                                                      pre_rpa_exec_traj=pre_rpa_exec_traj,
                                                      fetched_info=fetched_info, use_tool=True,
                                                      rpa_builder_conslusion=self.rpa_builder_conslusion,
                                                      encountered_task_goals=encountered_task_goals)
      agent_utils.write_to_file(file_path=log_task_path, file_name=f'rpa_builder_prompt_{fetch_cnt}.txt',
                                content=rpa_builder_prompt)
      builder_output, raw_response = self.llm.predict_mm(user_prompt=rpa_builder_prompt,
                                                        images=fetched_screenshot,
                                                        output_format=output_format)
      agent_utils.write_to_file(file_path=log_task_path, file_name=f'rpa_builder_output_{fetch_cnt}.txt',
                                content=builder_output)
      agent_utils.write_to_file(file_path=log_task_path, file_name=f'rpa_builder_raw_response_{fetch_cnt}.txt',
                                content=raw_response)
      cost_tokens = raw_response.usage
      self.record_token.step = '-'
      self.record_token.agent = f'RPA Builder {fetch_cnt}'
      self.record_token.step_tokens = cost_tokens
      self.record_token.llm = FLAGS.builder_llm
      agent_utils.record_cost_tokens(self.record_token)
      if isinstance(builder_output.output, JSON_models.FetchInfoTool):
        # count fetch times
        cnt_fetch_info += 1
        
        print(f'Thought:\n{builder_output.thought}\n')
        print(f'info_to_clarify:\n{builder_output.info_to_clarify}\n')
        print(f'output:\n{builder_output.output}\n')
        
        fetched_info = self.fetch_info(builder_output.output.traj_id, builder_output.output.step_n)  # Local screenshot + UI extraction
        if fetched_info['screenshot'] is not None:
          fetched_screenshot = [fetched_info['screenshot']]
          agent_utils.store_image(fetched_info['screenshot'], f'screenshot_{fetch_cnt}.png', log_task_path)
        if fetched_info['ui_content']:
          agent_utils.write_to_file(log_task_path, f'ui_content_{fetch_cnt}.txt', content=fetched_info['ui_content'])
      else:
        rpa_info = builder_output.output
        print(f'Thought:\n{builder_output.thought}\n')
        print(f'info_to_clarify:\n{builder_output.info_to_clarify}\n')
        print(f'Task Type:\n{rpa_info.task_type}\n')
        print(f'Params:\n{rpa_info.parameters}\n')
        print(f'RPA Description:\n{rpa_info.rpa_description}\n')
        print(f'Code:\n{rpa_info.rpa_code}\n')
        print(f'Example Usage:\n{rpa_info.example_usage}\n')
        print(f'Conclusion:\n{rpa_info.conclusion}\n')
        
        # Validate syntax - if invalid, regenerate with retry logic
        is_valid, error_msg = validate_python_syntax(rpa_info.rpa_code, log_task_path)
        
        if not is_valid:
          # Use retry logic for syntax error
          print_with_color(f'🔄 Syntax error detected in initial generation, retrying...', 'yellow')
          syntax_error_feedback = get_syntax_error_feedback(error_msg)
          
          for syntax_retry in range(MAX_SYNTAX_RETRY):
            rpa_builder_prompt = get_rpa_builder_prompt(task_type=task_type, task_template=task_template,
                                                        react_trajs=list_react_traj,
                                                        pre_rpa_exec_traj=pre_rpa_exec_traj,
                                                        fetched_info=fetched_info, use_tool=False,
                                                        rpa_builder_conslusion=self.rpa_builder_conslusion,
                                                        encountered_task_goals=encountered_task_goals)
            rpa_builder_prompt += syntax_error_feedback
            
            retry_suffix = f'_syntax_retry{syntax_retry + 1}'
            agent_utils.write_to_file(file_path=log_task_path, file_name=f'rpa_builder_prompt_{fetch_cnt}{retry_suffix}.txt',
                                      content=rpa_builder_prompt)
            builder_output, raw_response = self.llm.predict_mm(user_prompt=rpa_builder_prompt,
                                                              images=fetched_screenshot,
                                                              output_format=JSON_models.RPABuilderOutput)
            agent_utils.write_to_file(file_path=log_task_path, file_name=f'rpa_builder_output_{fetch_cnt}{retry_suffix}.txt',
                                      content=builder_output)
            
            cost_tokens = raw_response.usage
            self.record_token.step = '-'
            self.record_token.agent = f'RPA Builder {fetch_cnt} (syntax retry {syntax_retry + 1})'
            self.record_token.step_tokens = cost_tokens
            self.record_token.llm = FLAGS.builder_llm
            agent_utils.record_cost_tokens(self.record_token)
            
            rpa_info = builder_output.output
            print(f'Thought:\n{builder_output.thought}\n')
            if hasattr(builder_output, 'info_to_clarify'):
              print(f'info_to_clarify:\n{builder_output.info_to_clarify}\n')
            print(f'Task Type:\n{rpa_info.task_type}\n')
            print(f'Params:\n{rpa_info.parameters}\n')
            print(f'RPA Description:\n{rpa_info.rpa_description}\n')
            print(f'Code:\n{rpa_info.rpa_code}\n')
            print(f'Example Usage:\n{rpa_info.example_usage}\n')
            print(f'Conclusion:\n{rpa_info.conclusion}\n')
            
            is_valid, error_msg = validate_python_syntax(rpa_info.rpa_code, log_task_path)
            
            if is_valid:
              break
            
            if syntax_retry < MAX_SYNTAX_RETRY - 1:
              print_with_color(f'🔄 Syntax error still present, retrying... ({syntax_retry + 2}/{MAX_SYNTAX_RETRY})', 'yellow')
              syntax_error_feedback = get_syntax_error_feedback(error_msg)
            else:
              print_with_color(f'⚠️  Max syntax retries reached. Proceeding with code that has syntax errors.', 'yellow')
        
        break  # No more tool needed, exit loop
      
      # If we reach here, all attempts so far were tool calls
      if fetch_cnt == MAX_FETCH_total - 1:
        print_with_color(f'\nBuilder Round{fetch_cnt + 1}\n', 'blue')
        syntax_error_feedback = ""
        
        for syntax_retry in range(MAX_SYNTAX_RETRY + 1):
          rpa_builder_prompt = get_rpa_builder_prompt(task_type=task_type, task_template=task_template,
                                                          react_trajs=list_react_traj,
                                                          fetched_info=fetched_info,
                                                          pre_rpa_exec_traj=pre_rpa_exec_traj,
                                                          rpa_builder_conslusion=self.rpa_builder_conslusion,
                                                          encountered_task_goals=encountered_task_goals)
          
          # Add syntax error feedback if retrying
          if syntax_retry > 0:
            rpa_builder_prompt += syntax_error_feedback
          
          retry_suffix = f'_retry{syntax_retry}' if syntax_retry > 0 else ''
          agent_utils.write_to_file(file_path=log_task_path, file_name=f'rpa_builder_prompt_{fetch_cnt + 1}{retry_suffix}.txt',
                                    content=rpa_builder_prompt)
          builder_output, raw_response = self.llm.predict_mm(user_prompt=rpa_builder_prompt, images=[],
                                                             output_format=JSON_models.RPABuilderOutput)
          agent_utils.write_to_file(file_path=log_task_path, file_name=f'rpa_builder_output_{fetch_cnt + 1}{retry_suffix}.txt',
                                    content=builder_output)
          
          rpa_info = builder_output.output
          print(f'Thought:\n{builder_output.thought}\n')
          if hasattr(builder_output, 'info_to_clarify'):
            print(f'info_to_clarify:\n{builder_output.info_to_clarify}\n')
          print(f'Task Type:\n{rpa_info.task_type}\n')
          print(f'Params:\n{rpa_info.parameters}\n')
          print(f'RPA Description:\n{rpa_info.rpa_description}\n')
          print(f'Code:\n{rpa_info.rpa_code}\n')
          print(f'Example Usage:\n{rpa_info.example_usage}\n')
          print(f'Conclusion:\n{rpa_info.conclusion}\n')
          
          cost_tokens = raw_response.usage
          self.record_token.step = '-'
          self.record_token.agent = f'RPA Builder {fetch_cnt + 1} (retry {syntax_retry})' if syntax_retry > 0 else f'RPA Builder {fetch_cnt + 1}'
          self.record_token.step_tokens = cost_tokens
          self.record_token.llm = FLAGS.builder_llm
          agent_utils.record_cost_tokens(self.record_token)
          
          # Validate syntax
          is_valid, error_msg = validate_python_syntax(rpa_info.rpa_code, log_task_path)
          
          if is_valid:
            agent_utils.write_to_file(file_path=log_task_path, file_name='rpa_builder_output.txt', content=builder_output)
            break
          
          # Handle syntax error
          if syntax_retry < MAX_SYNTAX_RETRY:
            print_with_color(f'🔄 Syntax error detected, retrying... ({syntax_retry + 1}/{MAX_SYNTAX_RETRY})', 'yellow')
            syntax_error_feedback = get_syntax_error_feedback(error_msg)
          else:
            print_with_color(f'⚠️  Max syntax retries reached. Proceeding with code that has syntax errors.', 'yellow')
            agent_utils.write_to_file(file_path=log_task_path, file_name='rpa_builder_output.txt', content=builder_output)
            break
    
    rpa_info = builder_output.output
    self.rpa_builder_conslusion = rpa_info.conclusion
    
    rpa_info.task_type=task_type
    
    return rpa_info, cnt_fetch_info
  
  def fetch_info(self, traj_id: str, step_n: int):
    if 'fix_react_traj' in traj_id:
      # fix_react_traj is list[ReActTraj], usually the last one
      if not self.pre_rpa_exec_traj.fix_react_traj or len(self.pre_rpa_exec_traj.fix_react_traj) == 0:
        ui_content = "Fix react trajectory not available"
        return {
          'ui_content': ui_content,
          'traj_id': traj_id,
          'step_n': step_n
        }
      # Fix trajectory is usually the last one
      target_traj = self.pre_rpa_exec_traj.fix_react_traj[-1]
      # Find the step by matching step_n
      target_step_info = None
      for step_info in target_traj.traj:
        if step_info.step_n == step_n:
          target_step_info = step_info
          break
      if not target_step_info:
        ui_content = f"Step {step_n} not found in fix_react_traj"
        return {
          'ui_content': ui_content,
          'traj_id': traj_id,
          'step_n': step_n
        }
      env_step = target_step_info.exec_step_info
      if len(target_traj.traj) == step_n:
        ui_content = "Task has been finished"
      else:
        ui_content = env_step.after_ui_content
    else:  # Map traj_id name to actual trajectory object
      traj_lookup = {
        "pre_rpa_exec_traj": self.pre_rpa_exec_traj,
        "successful_react_traj": self.react_trajs[-1] if self.react_trajs else None,
        "failed_react_traj": self.react_trajs[0] if len(self.react_trajs) > 1 else None,
      }
      
      traj = traj_lookup.get(traj_id)
      if not traj:
        return {"traj_id": traj_id, "step_n": step_n, "ui_content": f"Invalid traj_id '{traj_id}'.", "screenshot": None}
      if step_n < 1 or step_n > len(traj.env_op_traj):
        return {"traj_id": traj_id, "step_n": step_n, "ui_content": f"Invalid step_n {step_n}.", "screenshot": None}
      if (not traj) and (step_n < 1 or step_n > len(traj.env_op_traj)):
        return {"traj_id": traj_id, "step_n": step_n, "ui_content": f"Invalid traj_id '{traj_id}' and step_n {step_n}.", "screenshot": None}
      
      env_step = traj.env_op_traj[step_n - 1]
      
      if len(traj.env_op_traj) == step_n:
        ui_content = "Task has been finished"
      else:
        ui_content = env_step.after_ui_content
    
    screenshot = env_step.after_screenshot_w_som_path if env_step.after_screenshot_w_som_path else env_step.after_screenshot_path
    
    # agent_utils.store_image(agent_utils.load_image_as_ndarray(screenshot), f'{traj_id}_{step_n}.png', '')
    
    result = {"traj_id": traj_id, "step_n": step_n, "ui_content": ui_content,
              "screenshot": agent_utils.load_image_as_ndarray(screenshot) if screenshot else None}
    
    return result
  
  def extract_json_call(self, output: str) -> dict:
    try:
      match = re.search(r'\{[\s\S]+\}', output)
      return json.loads(match.group(0)) if match else {}
    except Exception:
      return {}
