import json
import os
from typing import Optional

from UIAgents.Agent_RPA.utils import JSON_models
from .agent_utils import print_with_color


class ReactTrajBank:
  def __init__(self, save_path="./", file_name="react_trajs_bank.json", load_local_bank=True):
    """
    Initialize the TrajBank with a directory to store/load trajectories.
    :param save_path: Directory path to save and load traj files.
    """
    self.react_trajs_dict = {}
    self.save_path = save_path
    self.file_path = os.path.join(save_path, file_name)
    
    # load local react_trajs_bank.json
    if load_local_bank:
      if os.path.exists(self.file_path):
        self.react_trajs_dict = self.load_trajs(self.file_path)
        print_with_color(f"react_trajs_dict load from {self.file_path}.", 'green')
      else:
        print_with_color(f"{self.file_path} is not found.", 'red')
    else:
      print_with_color("\nNOT USE_REACT_LIBRARY\n", 'yellow')
  
  def clear(self):
    self.react_trajs_dict = {}
  
  def load_trajs(self, file_path: str) -> dict[str, dict[str, list[JSON_models.ReActTraj]]]:
    """
    Load trajectories from JSON and parse into nested model objects.

    Returns:
        A dictionary: task_type -> task -> list[ReActTraj]
    """
    with open(file_path, "r", encoding="utf-8") as f:
      data = json.load(f)
    
    parsed_data = {}
    for task_type, task_dict in data.items():
      parsed_data[task_type] = {}
      for task, traj_list in task_dict.items():
        parsed_trajs = []
        for traj in traj_list:
          parsed_steps = []
          for step in traj.get("traj", []):
            # Parse nested env_op_traj list into EnvExecStepInfo objects
            env_op_traj = step.get("env_op_traj", [])
            step["env_op_traj"] = [JSON_models.EnvExecStepInfo(**s) for s in env_op_traj]
            parsed_steps.append(JSON_models.ReActStepInfo(**step))
          
          traj["traj"] = parsed_steps
          parsed_trajs.append(JSON_models.ReActTraj(**traj))
        parsed_data[task_type][task] = parsed_trajs
    
    return parsed_data
  
  def save(self, dict_file=None):
    dict_file = dict_file if dict_file else self.react_trajs_dict
    
    def default_serializer(obj):
      if isinstance(obj, JSON_models.ReActTraj):
        return obj.model_dump()
      raise TypeError(f'Object of type {obj.__class__.__name__} is not JSON serializable')
    
    sorted_dict = dict(sorted(dict_file.items(), key=lambda x: x[0]))
    with open(self.file_path, 'w') as json_file:
      json.dump(sorted_dict, json_file, indent=2, default=default_serializer)
  
  def add_react_traj(self, task_type: str, task: str, list_react_traj: list[JSON_models.ReActTraj]):
    if task_type not in self.react_trajs_dict:
      self.react_trajs_dict[task_type] = {}
    self.react_trajs_dict[task_type][task] = list_react_traj
  
  def get_react_traj(self, task_type: str, task: str) -> Optional[list[JSON_models.ReActTraj]]:
    if task_type in self.react_trajs_dict and task in self.react_trajs_dict[task_type]:
      return self.react_trajs_dict[task_type][task]
    else:
      return None
  
  def get_last_traj(self, task_type: str):
    last_value = None
    if task_type in self.react_trajs_dict:
      last_key, last_value = next(reversed(self.react_trajs_dict[task_type].items()))
    return last_value
  
  def save_temp(self, list_react_traj: Optional[list[JSON_models.ReActTraj]],
                save_path=None, file_name='list_react_traj.txt'):
    """
    Save a list of ReAct-style trajectories into a human-readable .txt file.
    """
    successful_trajs_str = ''
    failed_trajs_str = ''
    successful_count = 0
    failed_count = 0
    
    for react_traj in list_react_traj:
      react_traj_str = f"Task: {react_traj.task}\n"
      react_traj_str += '\nAction History:\n' + '\n'.join(react_traj.action_history) + '\n'
      
      if getattr(react_traj, "executed_rpa_code", None):
        react_traj_str += (
          f'\nRPA Execution:\n'
          f'Executed rpa code:\n{react_traj.executed_rpa_code}\n'
          f'Error message:\n{react_traj.rpa_error_message}\n'
        )
      
      for idx, step in enumerate(react_traj.traj, start=1):
        react_traj_str += (
          f"Step {idx}:\n"
          f"  Observation: {step.obs_description}\n"
          f"  Reason: {step.action_reason}\n"
          f"  Action: {step.action}\n"
          f"  Soft Coded Action: {step.soft_coded_action}\n"
          f"  Related Elements: {step.related_elements}\n"
          f"  Execution Summary: {step.execution_summary}\n"
          f"  UI Content: {step.ui_content}\n"
        )
        
      # EnvOp Trajectory
      # Output ordering is off; should follow react step; low impact, fix later
      if react_traj.env_op_traj:
        react_traj_str += "  EnvOp Trajectory:\n"
        for i, env_step in enumerate(react_traj.env_op_traj):
          env_line = f"    [{i}] {env_step.action_feedback} | Target: {env_step.related_target}"
          if not env_step.is_screen_changed:
            env_line += " | No screen change"
          react_traj_str += env_line + "\n"
        
        react_traj_str += "-------------------------\n"
      
      react_traj_str += f"Success: {react_traj.success}\n\n"
      
      if react_traj.success:
        successful_count += 1
        successful_trajs_str += f"Successful Trajectory {successful_count}:\n{react_traj_str}"
      else:
        failed_count += 1
        failed_trajs_str += f"Failed Trajectory {failed_count}:\n{react_traj_str}"
    
    react_trajs_str = (
      f"Successful Trajectories:\n{successful_trajs_str}\n\n"
      f"Failed Trajectories:\n{failed_trajs_str}"
    )
    
    save_path = save_path if save_path else self.save_path
    os.makedirs(save_path, exist_ok=True)
    file_path = os.path.join(save_path, file_name)
    with open(file_path, 'w', encoding='utf-8') as f:
      f.write(react_trajs_str)
      f.flush()


class RPAExecTrajBank:
  def __init__(self, save_path="", file_name="rpa_exec_trajs_bank.json"):
    """
    Initialize the TrajBank with a directory to store/load trajectories for one task type.
    :param save_path: Directory path to save and load traj files.
    """
    self.rpa_version = "v_0"
    self.rpa_exec_dicts = {}
    self.file_path = os.path.join(save_path, file_name)
    self.save_path = save_path
  
  def clear(self):
    self.rpa_exec_dicts = {}
  
  def add_rpa_exec_traj(self, rpa_code: str, rpa_exec_traj: JSON_models.RPAExecTraj):
    if rpa_code not in [
      rpa_exec_dict['rpa_code']
      for rpa_exec_dict in self.rpa_exec_dicts.values()
      if isinstance(rpa_exec_dict, dict) and 'rpa_code' in rpa_exec_dict
    ]:
      self.rpa_version = f"v_{len(self.rpa_exec_dicts)}"
      self.rpa_exec_dicts[self.rpa_version] = {"rpa_code": rpa_code, "trajs": [], "success": []}
    self.rpa_exec_dicts[self.rpa_version]["trajs"].append(rpa_exec_traj)
    self.rpa_exec_dicts[self.rpa_version]["success"].append(rpa_exec_traj.success)
  
  # get for rpa builder
  def get_relevant_trajs(self, task: str) -> Optional[list[JSON_models.RPAExecTraj]]:
    for rpa_exec_dict in reversed(list(self.rpa_exec_dicts.values())):
      for traj in rpa_exec_dict["trajs"]:
        if task == traj.task:
          return [traj]
    return None
  
  def save_temp(self, rpa_exec_traj: JSON_models.RPAExecTraj,
                save_path: str = None, file_name='rpa_exec_traj.txt'):
    traj_str = (
      "RPA Execution Trajectory:\n"
      f"Task: {rpa_exec_traj.task}\n"
      f"Function Call: {rpa_exec_traj.function_call}\n"
      f"RPA Code: \n{rpa_exec_traj.rpa_code}\n"
      f"Success: {rpa_exec_traj.success}\n"
      "\nTrajectory Steps:\n"
    )
    for idx, step in enumerate(rpa_exec_traj.env_op_traj, 1):
      traj_str += (
        f"Step {idx}:\n"
        f"  Executed Action: {step.executed_action}\n"
        f"  Related Elements: {step.related_elements}\n"
        f"  Related Target: {step.related_target}\n"
        f"  Action Feedback: {step.action_feedback}\n"
        f"  Is Screen Changed: {step.is_screen_changed}\n"
        "-------------------------\n"
      )
    exec_result = rpa_exec_traj.exec_result
    traj_str += (
      "\nExecution Result:\n"
      f"  - Executed Code: {exec_result.executed_code if exec_result else ''}\n\n"
      f"  - Error Statement: {exec_result.error_statement if exec_result else ''}\n"
      f"  - Error Message: {exec_result.error_message if exec_result else ''}\n"
      f"  - Answer Return: {exec_result.answer_return if exec_result else ''}\n"
      f"  - Exec Feedback: {exec_result.exec_feedback if exec_result else ''}\n"
      f"  - Done: {exec_result.done if exec_result else ''}\n"
    )
    action_history_str = '\n\n'.join(rpa_exec_traj.action_history)
    traj_str += (
      f"\nAction History:\n{action_history_str}\n"
      # f"\nEnv Feedback:\n{rpa_exec_traj.env_feedback}"
    )
    save_path = save_path if save_path else self.save_path
    os.makedirs(save_path, exist_ok=True)
    file_path = os.path.join(save_path, file_name)
    with open(file_path, 'w', encoding='utf-8') as f:
      f.write(traj_str)
      f.flush()
