from autorpa.utils.models import EnvExecStepInfo
import json
import os
import shutil
import time
from pathlib import Path
from typing import Any, Optional

from . import models
from .agent_utils import print_with_color


class ReactTrajBank:
  def __init__(self, save_path="./", file_name="react_trajs_bank.json", load_local_bank=True,
               screenshot_dir="screenshots", backup_dir="backups", auto_backup=True, start_timestamp=None):
    """
    Initialize the TrajBank with configuration-based separation.
    Each agent configuration gets its own isolated bank.
    
    :param save_path: Base directory (e.g., "./data"), ignored if file_name is absolute
    :param file_name: Filename or absolute path (if absolute, config separation is disabled)
    :param load_local_bank: Whether to load existing trajectories.
    :param screenshot_dir: Ignored - always "screenshots" under config dir
    :param backup_dir: Ignored - always "backups" under config dir
    :param auto_backup: Whether to auto-backup on save.
    :param start_timestamp: Program start timestamp for consistent backup naming.
    """
    from absl import flags
    FLAGS = flags.FLAGS
    
    self.react_trajs_dict = {}
    
    # Generate agent configuration and config ID
    self.agent_config = self._generate_agent_config(FLAGS)
    self.config_id = self._generate_config_id(FLAGS)
    self.agent_type = FLAGS.gui_agent_type
    
    # Handle both absolute path and relative path for file_name
    file_name_path = Path(file_name)
    if file_name_path.is_absolute():
      # file_name is absolute path - use legacy single-bank mode
      self.file_path = file_name_path
      self.save_path = self.file_path.parent
      self.screenshot_dir = self.save_path / screenshot_dir
      self.backup_dir = self.save_path / backup_dir
      print_with_color("⚠️  Using legacy mode (absolute path provided)", 'yellow')
    else:
      # Configuration-based separation mode
      # Structure: data/react_trajs_banks/{agent_type}/{config_id}/
      base_path = Path(save_path) / 'react_trajs_banks' / self.agent_type / self.config_id
      base_path.mkdir(parents=True, exist_ok=True)
      
      self.save_path = base_path
      self.file_path = base_path / 'react_trajs_bank.json'
      self.screenshot_dir = base_path / 'screenshots'
      self.backup_dir = base_path / 'backups'
    
    # Backup settings
    self.auto_backup = auto_backup
    if start_timestamp is None:
      import datetime
      start_timestamp = datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
    self.start_timestamp = start_timestamp
    self.backup_file_path = self.backup_dir / f"react_trajs_bank_{self.start_timestamp}.json"
    
    # Create directories
    self.save_path.mkdir(parents=True, exist_ok=True)
    self.screenshot_dir.mkdir(parents=True, exist_ok=True)
    self.backup_dir.mkdir(parents=True, exist_ok=True)
    
    # Runtime cache for translated trajectories (soft-coded actions)
    # Key: (task_type, task_goal) -> translated list[ReActTraj]
    # This avoids redundant ActionTranslator calls within a single program run
    self._translation_cache = {}
    
    # Print initialization info
    print_with_color("=" * 80, 'cyan')
    print_with_color("📁 ReactTrajBank Initialized", 'cyan')
    print_with_color("=" * 80, 'cyan')
    print_with_color(f"  Agent Type: {self.agent_type}", 'cyan')
    print_with_color(f"  Config ID: {self.config_id}", 'cyan')
    print_with_color(f"  Agent Config: {self.agent_config}", 'cyan')
    print_with_color(f"  Bank Path: {self.file_path}", 'cyan')
    print_with_color(f"  Screenshot Dir: {self.screenshot_dir}", 'cyan')
    print_with_color(f"  Backup Dir: {self.backup_dir}", 'cyan')
    print_with_color("=" * 80, 'cyan')
    
    # load local react_trajs_bank.json
    if load_local_bank:
      if self.file_path.exists():
        self.react_trajs_dict = self.load_trajs(str(self.file_path))
        print_with_color(f"📥 Loaded react_trajs_bank from {self.file_path}.", 'green')
      else:
        print_with_color(f"📝 {self.file_path} not found, starting fresh.", 'yellow')
    else:
      print_with_color("\n📦 NOT LOADING LOCAL REACT LIBRARY\n", 'yellow')
  
  def _generate_agent_config(self, FLAGS) -> dict:
    """Generate complete agent configuration that affects trajectory quality."""
    config = {'gui_agent_type': FLAGS.gui_agent_type}
    
    if FLAGS.gui_agent_type == 'react_star':
      config.update({
        'planner_llm': FLAGS.planner_llm,
        'summarizer_llm': FLAGS.summarizer_llm,
        'action_space': FLAGS.react_star_action_space,
        'ui_info': FLAGS.react_star_ui_info,
        'reflection_rounds': FLAGS.reflection_rounds,
      })
    elif FLAGS.gui_agent_type == 'droidrun':
      config.update({
        'model': getattr(FLAGS, 'droidrun_model', 'default'),
        'action_space': getattr(FLAGS, 'droidrun_action_space', 'default'),
      })
    elif FLAGS.gui_agent_type == 'askui':
      config.update({
        'model': getattr(FLAGS, 'askui_model', 'default'),
      })
    
    return config
  
  def _generate_config_id(self, FLAGS) -> str:
    """Generate unique configuration identifier based on all parameters that affect trajectory."""
    if FLAGS.gui_agent_type == 'react_star':
      planner = self._simplify_model_name(FLAGS.planner_llm)
      summarizer = self._simplify_model_name(FLAGS.summarizer_llm)
      action_space = FLAGS.react_star_action_space
      ui_info = self._simplify_model_name(FLAGS.react_star_ui_info)  # Also simplify ui_info
      # reflection = FLAGS.reflection_rounds
      
      # return f"{planner}_{summarizer}_{action_space}_{ui_info}_ref{reflection}"
      return f"{planner}_{summarizer}_{action_space}_{ui_info}"

    
    elif FLAGS.gui_agent_type == 'droidrun':
      model = getattr(FLAGS, 'droidrun_model', 'default')
      action_space = getattr(FLAGS, 'droidrun_action_space', 'default')
      return f"{self._simplify_model_name(model)}_{action_space}"
    
    elif FLAGS.gui_agent_type == 'askui':
      model = getattr(FLAGS, 'askui_model', 'default')
      return self._simplify_model_name(model)
    
    else:
      return "default"
  
  def _simplify_model_name(self, model: str) -> str:
    """Simplify model name for use in paths: gpt-5-low -> gpt5low"""
    return model.replace('-', '').replace('.', '').replace('_', '')
  
  def clear(self):
    self.react_trajs_dict = {}
  
  def _count_total_trajectories(self) -> int:
    """Count total number of trajectories across all tasks."""
    total = 0
    for task_type_dict in self.react_trajs_dict.values():
      for traj_list in task_type_dict.values():
        total += len(traj_list)
    return total
  
  def _sanitize_filename(self, filename: str) -> str:
    """Sanitize string for use as filename."""
    invalid_chars = '<>:"/\\|?*. '
    for char in invalid_chars:
      filename = filename.replace(char, '_')
    return filename[:50]
  
  def load_trajs(self, file_path: str) -> dict[str, dict[str, list[models.ReActTraj]]]:
    """
    Load trajectories from JSON and parse into nested model objects.

    Returns:
        A dictionary: task_type -> task -> list[ReActTraj]
    """
    print_with_color(f"🔄 Loading react_trajs_bank from: {file_path}", 'cyan')
    # Empty / missing bank should not crash execution (common on first run).
    try:
      with open(file_path, "r", encoding="utf-8") as f:
        content = f.read().strip()
        if not content:
          print_with_color("   ⚠️  react_trajs_bank.json is empty; initializing empty bank.", 'yellow')
          return {}
        data = json.loads(content)
    except FileNotFoundError:
      print_with_color("   ⚠️  react_trajs_bank.json not found; initializing empty bank.", 'yellow')
      return {}
    except json.JSONDecodeError as e:
      print_with_color(f"   ⚠️  react_trajs_bank.json is not valid JSON ({e}); initializing empty bank.", 'yellow')
      return {}
    
    parsed_data: dict[str, dict[str, list[models.ReActTraj]]] = {}
    total_traj_cnt = 0
    # Expect bank root to be: {task_type: {task_goal: [ReActTraj...]}}
    if not isinstance(data, dict):
      print_with_color("   ⚠️  react_trajs_bank.json root is not a dict; initializing empty bank.", 'yellow')
      return {}
    for task_type, task_dict in data.items():
      parsed_data[task_type] = {}
      for task, traj_list in task_dict.items():
        parsed_trajs: list[models.ReActTraj] = []
        for traj in traj_list:
          parsed_steps: list[models.ReActStepInfo] = []
          for step in traj.get("traj", []):
            # Parse nested exec_step_info into EnvExecStepInfo object (V2)
            if "exec_step_info" in step and isinstance(step["exec_step_info"], dict):
              step["exec_step_info"] = models.EnvExecStepInfo(**step["exec_step_info"])
            parsed_steps.append(models.ReActStepInfo(**step))
          
          traj["traj"] = parsed_steps
          parsed_traj = models.ReActTraj(**traj)
          parsed_trajs.append(parsed_traj)
          total_traj_cnt += 1
        
        parsed_data[task_type][task] = parsed_trajs
        # Debug: trajectory load status per task
        print_with_color(
          f"   📥 Loaded {len(parsed_trajs)} trajs for {task_type}/{task}; "
          f"sample_type={type(parsed_trajs[0]).__name__ if parsed_trajs else 'N/A'}",
          'cyan'
        )
    
    print_with_color(f"✅ Finished loading react_trajs_bank, total {total_traj_cnt} trajectories.", 'green')
    return parsed_data
  
  def save(self, dict_file=None):
    """
    dict_file:
      None: save the whole react_trajs_dict to file
      not None: save the target react_trajs_dict to file

    result:
      backup the old file before overwriting
      overwrite the old file
      backup the new file (overwrite the old backup file created before)
      backup the screenshots
    """
    dict_file = dict_file if dict_file else self.react_trajs_dict
    
    # Backup old file before overwriting (if exists)
    if self.auto_backup and self.file_path.exists():
      try:
        shutil.copy2(self.file_path, self.backup_file_path)
        print_with_color(f"📦 Backed up old version to: {self.backup_file_path.name}", 'cyan')
      except Exception as e:
        print_with_color(f"⚠️  Failed to backup old file: {e}", 'yellow')
    
    def default_serializer(obj):
      if isinstance(obj, models.ReActTraj):
        return obj.model_dump()
      raise TypeError(f'Object of type {obj.__class__.__name__} is not JSON serializable')
    
    sorted_dict = dict(sorted(dict_file.items(), key=lambda x: x[0]))
    with open(self.file_path, 'w') as json_file:
      json.dump(sorted_dict, json_file, indent=2, default=default_serializer)
    
    print_with_color(f"💾 Saved react_trajs_bank to: {self.file_path}", 'green')
    
    # Create backup after saving (ensures we always have a backup)
    if self.auto_backup:
      self._create_post_save_backup()
  
  def _create_post_save_backup(self):
    """Create backup after saving (includes JSON and screenshots)."""
    try:
      # Copy the just-saved JSON file to back up
      shutil.copy2(self.file_path, self.backup_file_path)
      print_with_color(f"📦 Backup created: {self.backup_file_path.name}", 'cyan')
      
      # Backup screenshots
      self._backup_screenshots()
    except Exception as e:
      print_with_color(f"⚠️  Failed to create backup: {e}", 'yellow')
  
  def _backup_screenshots(self):
    """Backup screenshots directory."""
    if not self.screenshot_dir.exists():
      return
    
    screenshots_backup_dir = self.backup_dir / f"screenshots_{self.start_timestamp}"
    try:
      if screenshots_backup_dir.exists():
        shutil.rmtree(screenshots_backup_dir)
      shutil.copytree(self.screenshot_dir, screenshots_backup_dir)
      file_count = sum(1 for _ in screenshots_backup_dir.rglob('*.png'))
      print_with_color(f"   📸 Screenshots backup updated: {file_count} files", 'cyan')
    except Exception as e:
      print_with_color(f"⚠️  Failed to backup screenshots: {e}", 'yellow')
  
  def _centralize_screenshot(self, screenshot_path: str, task_type: str, task_goal: str,
                             instance_id: int, step_n: int, prefix: str = "before") -> Optional[str]:
    """Copy screenshot to centralized directory and return new relative path."""
    if not screenshot_path:
      return None
    
    screenshot_path_obj = Path(screenshot_path)
    if not screenshot_path_obj.is_absolute():
      screenshot_path_obj = Path.cwd() / screenshot_path
    
    if not screenshot_path_obj.exists():
      print_with_color(f"⚠️  Screenshot not found: {screenshot_path}", 'yellow')
      return None
    
    # Simplified directory structure: screenshots/{task_type}/task{instance_id}/
    # Since instance_id uniquely identifies a task, we don't need task_goal in the path
    instance_screenshot_dir = self.screenshot_dir / task_type / f"task{instance_id}"
    instance_screenshot_dir.mkdir(parents=True, exist_ok=True)
    
    # Simplified filename: instance_id is already in the path, so we don't need it in filename
    new_filename = f"step{step_n}_{prefix}.png"
    new_path = instance_screenshot_dir / new_filename
    
    try:
      shutil.copy2(screenshot_path_obj, new_path)
      # Return relative path from project root
      # Format: data/react_trajs_banks/{agent_type}/{config_id}/screenshots/{task_type}/task{instance_id}/{filename}
      relative_path = os.path.join(
        "data", "react_trajs_banks", self.agent_type, self.config_id,
        "screenshots", task_type, f"task{instance_id}", new_filename
      )
      print_with_color(f"  📸 Centralized screenshot: {screenshot_path} → {relative_path}", 'cyan')
      return relative_path
    except Exception as e:
      print_with_color(f"⚠️  Failed to copy screenshot {screenshot_path}: {e}", 'yellow')
      return screenshot_path
  
  def _centralize_trajectory_screenshots(self, task_type: str, task_goal: str,
                                         instance_id: int, traj: models.ReActTraj) -> None:
    """Centralize all screenshots in a trajectory (modifies traj in-place)."""
    print_with_color(f"📦 Centralizing screenshots for {task_type}/{task_goal} instance {instance_id}...", 'cyan')
    for step_info in traj.traj:
      exec_info = step_info.exec_step_info
      if not exec_info:
        continue
      
      # Centralize original screenshots (before)
      if exec_info.before_screenshot_path:
        new_path = self._centralize_screenshot(
          exec_info.before_screenshot_path, task_type, task_goal, instance_id, step_info.step_n, "before"
        )
        if new_path:
          exec_info.before_screenshot_path = new_path
      
      # Centralize SOM-marked screenshots (before)
      if exec_info.before_screenshot_w_som_path:
        new_path = self._centralize_screenshot(
          exec_info.before_screenshot_w_som_path, task_type, task_goal, instance_id, step_info.step_n, "before_som"
        )
        if new_path:
          exec_info.before_screenshot_w_som_path = new_path
      
      # Centralize original screenshots (after)
      if exec_info.after_screenshot_path:
        new_path = self._centralize_screenshot(
          exec_info.after_screenshot_path, task_type, task_goal, instance_id, step_info.step_n, "after"
        )
        if new_path:
          exec_info.after_screenshot_path = new_path
      
      # Centralize SOM-marked screenshots (after)
      if exec_info.after_screenshot_w_som_path:
        new_path = self._centralize_screenshot(
          exec_info.after_screenshot_w_som_path, task_type, task_goal, instance_id, step_info.step_n, "after_som"
        )
        if new_path:
          exec_info.after_screenshot_w_som_path = new_path
  
  def _delete_instance_screenshots(self, task_type: str, task_goal: str, instance_id: int) -> None:
    """
    Remove the centralized screenshot dir for the given task_type / instance_id.
    
    Dir layout: screenshots/{task_type}/task{instance_id}/
    Deletes the whole instance dir for simplicity.
    """
    instance_screenshot_dir = self.screenshot_dir / task_type / f"task{instance_id}"
    
    if not instance_screenshot_dir.exists():
      return
    
    try:
      # Count files before deletion for logging
      file_count = sum(1 for _ in instance_screenshot_dir.rglob('*.png'))
      shutil.rmtree(instance_screenshot_dir)
      print_with_color(f"🗑️  Deleted instance screenshot directory: {instance_screenshot_dir} ({file_count} files)", 'yellow')
    except Exception as e:
      print_with_color(f"⚠️  Failed to delete instance screenshot directory {instance_screenshot_dir}: {e}", 'yellow')
  
  def _validate_trajectory_screenshots(self, task_type: str, task_goal: str, 
                                       traj: models.ReActTraj, debug: bool = False) -> dict:
    """
    Validate screenshot availability in a trajectory.
    
    Returns:
      Dict: {'total': int, 'missing': int, 'missing_paths': list[str]}
    """
    total = 0
    missing = 0
    missing_paths = []
    
    if debug:
      print_with_color(f"\n🔍 Debug: Validating screenshots for {task_type}/{task_goal}", 'yellow')
      print_with_color(f"   save_path: {self.save_path}", 'yellow')
    
    for step_info in traj.traj:
      exec_info = step_info.exec_step_info
      if not exec_info:
        continue
      
      # Check before screenshot
      if exec_info.before_screenshot_w_som_path:
        total += 1
        screenshot_path = exec_info.before_screenshot_w_som_path
        
        # Convert to absolute path for validation
        path_obj = Path(screenshot_path)
        if not path_obj.is_absolute():
          if screenshot_path.startswith('data/'):
            # Get project root (go up from save_path to find 'data')
            # save_path could be data/ or data/subfolder/
            project_root = self.save_path
            while project_root.name != 'data' and project_root.parent != project_root:
              project_root = project_root.parent
            if project_root.name == 'data':
              project_root = project_root.parent
            full_path = project_root / screenshot_path
          else:
            full_path = self.save_path / screenshot_path
        else:
          full_path = path_obj
        
        if not full_path.exists():
          missing += 1
          missing_paths.append(str(exec_info.before_screenshot_w_som_path))
          if debug:
            print_with_color(f"   ❌ Missing: {screenshot_path}", 'red')
            print_with_color(f"      Resolved to: {full_path}", 'red')
        elif debug:
          print_with_color(f"   ✓ Found: {screenshot_path}", 'green')
      
      # Check after screenshot
      if exec_info.after_screenshot_w_som_path:
        total += 1
        screenshot_path = exec_info.after_screenshot_w_som_path
        
        path_obj = Path(screenshot_path)
        if not path_obj.is_absolute():
          if screenshot_path.startswith('data/'):
            project_root = self.save_path
            while project_root.name != 'data' and project_root.parent != project_root:
              project_root = project_root.parent
            if project_root.name == 'data':
              project_root = project_root.parent
            full_path = project_root / screenshot_path
          else:
            full_path = self.save_path / screenshot_path
        else:
          full_path = path_obj
        
        if not full_path.exists():
          missing += 1
          missing_paths.append(str(exec_info.after_screenshot_w_som_path))
    
    return {'total': total, 'missing': missing, 'missing_paths': missing_paths}
  
  def add_react_traj(self, task_type: str, task_goal: str, instance_id: int, list_react_traj: list[models.ReActTraj],
                     centralize_screenshots: bool = True, force_update: bool = False):
    """
    Add trajectories with optional quality-based replacement and screenshot centralization.
    
    This method now supports both:
    1. Single-round trajectories (backward compatible)
    2. Multi-round trajectories (all attempts from one session)
    
    Args:
      task_type: Task type name
      task_goal: Task goal
      instance_id: Instance ID
      list_react_traj: List of ReActTraj objects (can be multiple rounds from same session)
      centralize_screenshots: Whether to centralize screenshots
      force_update: If True, skip quality comparison and always replace existing trajectories for the same instance_id.
                    When force_update=True, the method will reload the full bank from disk to ensure no data loss.
    """
    # When force_update=True, reload full data so other instance_id trajectories are not lost
    if force_update and self.file_path.exists():
      print_with_color(
        f"   🔄 Force update mode: reloading full bank from disk to ensure data integrity",
        'cyan'
      )
      # Reload full data and replace in-memory (ensure integrity)
      self.react_trajs_dict = self.load_trajs(str(self.file_path))
    
    if task_type not in self.react_trajs_dict:
      self.react_trajs_dict[task_type] = {}
    
    if task_goal not in self.react_trajs_dict[task_type]:
      self.react_trajs_dict[task_type][task_goal] = []
    
    # Process new trajectories: write metadata only; centralize screenshots
    # only after quality policy decides whether to adopt this session
    new_trajs_processed = []
    for idx, traj in enumerate(list_react_traj):
      # Add metadata
      traj.instance_id = instance_id
      traj.timestamp = self.start_timestamp
      traj.num_steps = len(traj.traj)
      traj.agent_config = self.agent_config  # Store agent configuration
      
      # If round is not set, set it based on position in list
      if traj.round is None:
        traj.round = idx
      
      new_trajs_processed.append(traj)
    
    # Get existing trajectories for this task_goal
    existing_trajs = self.react_trajs_dict[task_type][task_goal]
    # Debug: verify structure read from react_trajs_bank / previously saved
    print_with_color(
      f"   ✅ Existing trajectories len={len(existing_trajs)}; "
      f"type={type(existing_trajs).__name__}; "
      f"elem_type={type(existing_trajs[0]).__name__ if existing_trajs else 'N/A'}",
      'green'
    )
    
    # Check if this instance_id already exists in existing trajectories
    has_existing_instance = False
    if existing_trajs:
      for traj in existing_trajs:
        if getattr(traj, 'instance_id', None) == instance_id:
          has_existing_instance = True
          break
    
    # Merge with quality-based strategy
    use_new_session = True  # By default (no old data or new instance_id) use new session
    if existing_trajs and has_existing_instance:
      if force_update:
        # Force update: drop all old trajectories for this instance_id, use new ones
        print_with_color(
          f"   🔄 Force update: replacing existing instance {instance_id} trajectories (quality comparison skipped)",
          'yellow'
        )
        # Drop all old trajectories for this instance_id
        filtered_trajs = [
          traj for traj in existing_trajs 
          if getattr(traj, 'instance_id', None) != instance_id
        ]
        filtered_trajs.extend(new_trajs_processed)
        self.react_trajs_dict[task_type][task_goal] = filtered_trajs
        merged_trajs = filtered_trajs
        use_new_session = True
      else:
        # Quality comparison only when existing trajs have this instance_id
        merged_trajs, use_new_session = self._quality_based_merge(
          existing_trajs, new_trajs_processed, instance_id
        )
        self.react_trajs_dict[task_type][task_goal] = merged_trajs
    else:
      # New instance_id or no existing trajectories; append directly
      if not has_existing_instance and existing_trajs:
        print_with_color(
          f"   ✅ Adding new instance_id {instance_id} directly (no quality comparison needed)",
          'green'
        )
      self.react_trajs_dict[task_type][task_goal].extend(new_trajs_processed)
      merged_trajs = self.react_trajs_dict[task_type][task_goal]
    
    # Handle screenshots based on quality policy
    # Case 1: new session better (or no old session) → clear old screenshots for this instance, centralize best traj
    # Case 2: old session better → keep existing trajs/screenshots, do not centralize new traj
    if centralize_screenshots and new_trajs_processed:
      if use_new_session:
        # Before centralizing, clear screenshot dir for this instance_id if present
        self._delete_instance_screenshots(task_type, task_goal, instance_id)
        # Centralize screenshots for all rounds of this instance
        for traj in merged_trajs:
          if getattr(traj, "instance_id", instance_id) == instance_id:
            self._centralize_trajectory_screenshots(task_type, task_goal, instance_id, traj)
    
    # Print summary
    rounds_info = f"{len(new_trajs_processed)} rounds" if len(new_trajs_processed) > 1 else "1 round"
    final_success = new_trajs_processed[-1].final_success_bool if new_trajs_processed else False
    print_with_color(
      f"✅ Added trajectory: {task_type}/{task_goal} (instance {instance_id}, {rounds_info}, "
      f"{'success' if final_success else 'failed'})",
      'green'
    )
  
  def _quality_based_merge(self, existing_trajs: list[models.ReActTraj],
                           new_trajs: list[models.ReActTraj], new_instance_id: int) -> tuple[list[models.ReActTraj], bool]:
    """
    Merge trajectories using quality-based strategy.
    
    Strategy:
    1. Group existing trajectories by instance_id and session (timestamp)
    2. For each instance, keep ONLY ONE best session (all its rounds)
    3. Compare new session with best existing session:
       - If new session is better: replace entire existing session
       - Otherwise: keep existing session, discard new one
    
    Key principle: One task type → One task goal → One instance → ONE best session (with all its rounds)
    
    Returns:
      merged_trajs: All trajectories after quality-based merge
      use_new_session: Whether the new session was chosen (True = new session is best)
    """
    # Group existing trajs by instance_id and session_timestamp
    sessions_by_instance = {}
    for traj in existing_trajs:
      inst_id = getattr(traj, 'instance_id', 0)
      session_ts = getattr(traj, 'timestamp', 'unknown')
      
      if inst_id not in sessions_by_instance:
        sessions_by_instance[inst_id] = {}
      
      if session_ts not in sessions_by_instance[inst_id]:
        sessions_by_instance[inst_id][session_ts] = []
      
      sessions_by_instance[inst_id][session_ts].append(traj)
    
    # Get new session info
    new_session_ts = new_trajs[0].timestamp if new_trajs else 'unknown'
    
    # Check if we have existing session for this instance
    existing_sessions = sessions_by_instance.get(new_instance_id, {})
    
    if not existing_sessions:
      # No existing session, add new session directly
      print_with_color(f"   ✅ Adding new session (no existing data)", 'green')
      merged = self._flatten_sessions(sessions_by_instance, new_instance_id, new_trajs)
      return merged, True
    
    # Compare new session with existing sessions
    # Find the best existing session for this instance
    best_existing_session = None
    best_existing_key = None
    best_existing_quality = None
    
    for session_key, session_trajs in existing_sessions.items():
      quality = self._evaluate_session_quality(session_trajs)
      if best_existing_quality is None or quality > best_existing_quality:
        best_existing_quality = quality
        best_existing_session = session_trajs
        best_existing_key = session_key
    
    # Evaluate new session quality
    new_quality = self._evaluate_session_quality(new_trajs)
    
    # Compare and decide: keep ONLY the best session for this instance
    if new_quality > best_existing_quality:
      # Replace with new session (discard all existing sessions for this instance)
      print_with_color(
        f"   🔄 Replacing with new session: better quality "
        f"(new: {new_quality:.2f} vs best existing: {best_existing_quality:.2f})",
        'cyan'
      )
      # Clear all existing sessions and keep only new one
      existing_sessions = {new_session_ts: new_trajs}
      use_new_session = True
    else:
      # Keep only the best existing session (discard all other sessions including new one)
      print_with_color(
        f"   ✓ Keeping best existing session: already better "
        f"(existing: {best_existing_quality:.2f} vs new: {new_quality:.2f})",
        'green'
      )
      # Keep only the best existing session
      existing_sessions = {best_existing_key: best_existing_session}
      use_new_session = False
    
    # Update sessions_by_instance with the single best session
    sessions_by_instance[new_instance_id] = existing_sessions
    
    # Flatten back to list
    merged = self._flatten_sessions(sessions_by_instance)
    return merged, use_new_session
  
  def _evaluate_session_quality(self, session_trajs: list[models.ReActTraj]) -> float:
    """
    Evaluate quality of a session (multiple rounds).
    
    Quality score calculation:
    - Final attempt must succeed: base score from final_success_score
    - Fewer steps is better: bonus for efficiency
    - Fewer rounds is better: bonus for quick success
    
    Returns:
      Quality score (higher is better). Returns 0 if final attempt failed.
    """
    if not session_trajs:
      return 0.0
    
    # Get final attempt (last trajectory)
    final_attempt = session_trajs[-1]
    
    # If final attempt failed, quality is 0 (should not be stored)
    if not final_attempt.final_success_bool:
      return 0.0
    
    # Base score from success score
    quality = final_attempt.final_success_score * 100
    
    # Bonus for fewer steps (normalize: assume 20 steps is baseline, fewer is better)
    steps = final_attempt.num_steps or len(final_attempt.traj)
    steps_bonus = max(0, 20 - steps)  # Up to +20 points
    
    # Bonus for fewer rounds (normalize: 1 round = +10, 2 rounds = +5, 3+ rounds = 0)
    rounds = len(session_trajs)
    rounds_bonus = max(0, 15 - rounds * 5)  # Up to +10 points
    
    quality += steps_bonus + rounds_bonus
    
    return quality
  
  def _flatten_sessions(self, sessions_by_instance: dict, 
                        target_instance: int = None, 
                        new_trajs: list[models.ReActTraj] = None) -> list[models.ReActTraj]:
    """Flatten sessions dict back to list of trajectories."""
    result = []
    
    for inst_id, sessions in sessions_by_instance.items():
      if target_instance is not None and inst_id != target_instance:
        # Keep other instances as-is
        for session_trajs in sessions.values():
          result.extend(session_trajs)
      elif target_instance is not None and inst_id == target_instance and new_trajs is not None:
        # For target instance, only add if explicitly provided
        result.extend(new_trajs)
      else:
        # Normal case: flatten all sessions
        for session_trajs in sessions.values():
          result.extend(session_trajs)
    
    return result
  
  def get_react_traj(self, task_type: str, task_goal: str, instance_id: int = None, fallback_to_other_instances: bool = False,
                     validate_screenshots: bool = False, return_all_rounds: bool = False) -> list[models.ReActTraj] | None:
    """
    Get trajectory with fallback mechanism.
    
    Args:
      task_type: Task type
      task_goal: Task goal
      instance_id: Instance ID
      fallback_to_other_instances: Whether to fallback to other instances
      validate_screenshots: Whether to check screenshot availability (default: False,
                           actual loading errors will be caught during ActionTranslator)
      return_all_rounds: If True, return all rounds of the session; if False (default),
                        return only the final successful round
    
    Returns:
      List of ReActTraj or None if not found.
      - If return_all_rounds=False (default): returns [final_round_traj]
      - If return_all_rounds=True: returns [round_0, round_1, ..., final_round]
    
    New behavior with multi-round support:
    - Trajectories are now stored as sessions (multiple rounds with same timestamp)
    - By default, returns only the final (successful) round
    - Set return_all_rounds=True to get all rounds of a session
    
    Priority:
    1. Try to find trajectory with matching instance_id
    2. If not found and fallback_to_other_instances=True:
       Use successful trajectory from another instance of same task_goal
    """
    if task_type not in self.react_trajs_dict:
      return None
    
    if task_goal not in self.react_trajs_dict[task_type]:
      return None
    
    if instance_id is None:
      raise ValueError("instance_id is required")
    
    # Get all trajectories for this task_goal
    all_trajs = self.react_trajs_dict[task_type][task_goal]
    
    # Group trajectories by instance_id and timestamp (session)
    sessions_by_instance = {}
    for traj in all_trajs:
      inst_id = getattr(traj, 'instance_id', 0)
      session_ts = getattr(traj, 'timestamp', 'unknown')
      
      if inst_id not in sessions_by_instance:
        sessions_by_instance[inst_id] = {}
      
      if session_ts not in sessions_by_instance[inst_id]:
        sessions_by_instance[inst_id][session_ts] = []
      
      sessions_by_instance[inst_id][session_ts].append(traj)
    
    # Sort each session's trajectories by round number
    for inst_id in sessions_by_instance:
      for session_ts in sessions_by_instance[inst_id]:
        sessions_by_instance[inst_id][session_ts].sort(
          key=lambda t: getattr(t, 'round', 0)
        )
    
    # Try to find session for matching instance_id
    target_session = None
    if instance_id in sessions_by_instance:
      # Get the latest session (by timestamp) for this instance
      sessions = sessions_by_instance[instance_id]
      if sessions:
        latest_ts = max(sessions.keys())
        target_session = sessions[latest_ts]
    
    # Fallback: use other instances' successful sessions
    if not target_session and fallback_to_other_instances:
      for inst_id, sessions in sessions_by_instance.items():
        if inst_id == instance_id:
          continue
        
        for session_ts, session_trajs in sessions.items():
          # Check if session's final round is successful
          if session_trajs and session_trajs[-1].final_success_bool:
            if target_session is None:
              target_session = session_trajs
              fallback_instance = inst_id
            else:
              # Compare quality: prefer fewer steps in final round
              current_steps = len(target_session[-1].traj)
              candidate_steps = len(session_trajs[-1].traj)
              if candidate_steps < current_steps:
                target_session = session_trajs
                fallback_instance = inst_id
      
      if target_session:
        print_with_color(
          f"🔄 No trajectory found for instance {instance_id}, "
          f"using instance {fallback_instance} trajectory "
          f"({len(target_session)} rounds, final: {len(target_session[-1].traj)} steps) as fallback",
          'cyan'
        )
    
    if not target_session:
      return None
    
    # Decide what to return based on return_all_rounds flag
    if return_all_rounds:
      result_trajs = target_session  # Return all rounds
    else:
      result_trajs = [target_session[-1]]  # Return only final (successful) round
    
    # Validate screenshots if requested
    if validate_screenshots:
      for traj in result_trajs:
        validation = self._validate_trajectory_screenshots(task_type, task_goal, traj, debug=False)
        
        if validation['missing'] > 0:
          traj_instance_id = getattr(traj, 'instance_id', instance_id)
          traj_round = getattr(traj, 'round', 0)
          print_with_color(
            f"⚠️  Warning: {validation['missing']}/{validation['total']} "
            f"screenshots missing in {task_type}/{task_goal} instance {traj_instance_id} round {traj_round}",
            'yellow'
          )
    
    return result_trajs
  
  def get_last_traj(self, task_type: str, return_all_rounds: bool = False):
    """
    Get the last trajectory for a task type (most recently added task_goal).
    
    Args:
      task_type: Task type name
      return_all_rounds: If True, return all rounds; if False (default), return only final round
    
    Returns:
      List of ReActTraj or None if not found
    """
    if task_type not in self.react_trajs_dict:
      return None
    
    if not self.react_trajs_dict[task_type]:
      return None
    
    # Get the last task_goal
    last_key, all_trajs = next(reversed(self.react_trajs_dict[task_type].items()))
    
    if not all_trajs:
      return None
    
    # Group by timestamp (session) and get the latest session
    sessions = {}
    for traj in all_trajs:
      session_ts = getattr(traj, 'timestamp', 'unknown')
      if session_ts not in sessions:
        sessions[session_ts] = []
      sessions[session_ts].append(traj)
    
    # Get latest session
    latest_ts = max(sessions.keys())
    latest_session = sessions[latest_ts]
    
    # Sort by round
    latest_session.sort(key=lambda t: getattr(t, 'round', 0))
    
    # Return based on flag
    if return_all_rounds:
      return latest_session
    else:
      return [latest_session[-1]]  # Return only final round
  
  # =========================================================================
  # Runtime Translation Cache (for soft-coded actions)
  # =========================================================================
  
  def cache_translated_traj(
      self, 
      task_type: str, 
      task_goal: str, 
      instance_id: int,
      translated_trajs: list[models.ReActTraj]
  ):
    """
    Cache translated trajectories (with soft-coded actions) in memory.
    
    This avoids redundant ActionTranslator LLM calls within a single program run.
    The cache is cleared when the program restarts.
    
    Args:
        task_type: Task type (e.g., "CameraTakePhoto")
        task_goal: Task goal (e.g., "Take one photo.")
        instance_id: Instance ID
        translated_trajs: List of trajectories with soft-coded actions
    """
    cache_key = (task_type, task_goal, instance_id)
    self._translation_cache[cache_key] = translated_trajs
    print_with_color(
        f"💾 Cached translated trajectory: {task_type}/{task_goal} (instance {instance_id})",
        'cyan'
    )
  
  def get_cached_translation(
      self, 
      task_type: str, 
      task_goal: str,
      instance_id: int
  ) -> list[models.ReActTraj] | None:
    """
    Get cached translated trajectories if available.
    
    Returns:
        Cached translated trajectories or None if not in cache
    """
    cache_key = (task_type, task_goal, instance_id)
    if cache_key in self._translation_cache:
      print_with_color(
          f"✅ Using cached translation: {task_type}/{task_goal} (instance {instance_id})",
          'green'
      )
      return self._translation_cache[cache_key]
    return None
  
  def clear_translation_cache(self, task_type: str, task_goal: str, instance_id: int = None):
    """Clear the runtime translation cache."""
    if instance_id is not None:
      cache_key = (task_type, task_goal, instance_id)
      self._translation_cache.pop(cache_key, None)
    else:
      self._translation_cache = {}
    print_with_color("🗑️  Translation cache cleared", 'yellow')
  
  def save_temp(self, list_react_traj: Optional[list[models.ReActTraj]],
                save_path=None, file_name='list_react_traj.txt'):
    """
    Save a list of ReAct-style trajectories into a human-readable .txt file.
    """
    successful_trajs_str = ''
    failed_trajs_str = ''
    successful_count = 0
    failed_count = 0
    
    for react_traj in list_react_traj:
      react_traj_str = f"Task: {react_traj.task}\n"
      react_traj_str += '\nAction History:\n' + '\n'.join(react_traj.action_history) + '\n'
      
      if getattr(react_traj, "executed_rpa_code", None):
        react_traj_str += (
          f'\nRPA Execution:\n'
          f'Executed rpa code:\n{react_traj.executed_rpa_code}\n'
          f'Error message:\n{react_traj.rpa_error_message}\n'
        )
      
      for idx, step in enumerate(react_traj.traj, start=1):
        react_traj_str += (
          f"Step {idx}:\n"
          f"  Observation: {step.obs_description}\n"
          f"  Reason: {step.action_reason}\n"
          f"  Hard Coded Action: {step.hard_coded_action}\n"
          f"  Soft Coded Action: {step.soft_coded_action}\n"
          f"  Related Elements: {getattr(step.exec_step_info, 'related_elements', '')}\n"
          f"  Execution Summary: {step.execution_summary}\n"
          f"  UI Content Full Dict: {getattr(step.exec_step_info, 'before_ui_content_full_dict', None)}\n"
        )
      
      react_traj_str += f"Final Success: {react_traj.final_success_bool} (Score: {react_traj.final_success_score})\n\n"
      
      if react_traj.final_success_bool:
        successful_count += 1
        successful_trajs_str += f"Successful Trajectory {successful_count}:\n{react_traj_str}"
      else:
        failed_count += 1
        failed_trajs_str += f"Failed Trajectory {failed_count}:\n{react_traj_str}"
    
    react_trajs_str = (
      f"Successful Trajectories:\n{successful_trajs_str}\n\n"
      f"Failed Trajectories:\n{failed_trajs_str}"
    )
    
    save_path = save_path if save_path else self.save_path
    os.makedirs(save_path, exist_ok=True)
    file_path = os.path.join(save_path, file_name)
    with open(file_path, 'w', encoding='utf-8') as f:
      f.write(react_trajs_str)
      f.flush()


class RPAExecTrajBank:
  def __init__(self, save_path="", file_name="rpa_exec_trajs_bank.json"):
    """
    Initialize the TrajBank with a directory to store/load trajectories for one task type.
    :param save_path: Directory path to save and load traj files.
    """
    self.rpa_version = "v_0"
    self.rpa_exec_dicts = {}
    self.file_path = os.path.join(save_path, file_name)
    self.save_path = save_path
  
  def clear(self):
    self.rpa_exec_dicts = {}
  
  def add_rpa_exec_traj(self, rpa_code: str, rpa_exec_traj: models.RPAExecTraj):
    if rpa_code not in [
      rpa_exec_dict['rpa_code']
      for rpa_exec_dict in self.rpa_exec_dicts.values()
      if isinstance(rpa_exec_dict, dict) and 'rpa_code' in rpa_exec_dict
    ]:
      self.rpa_version = f"v_{len(self.rpa_exec_dicts)}"
      self.rpa_exec_dicts[self.rpa_version] = {"rpa_code": rpa_code, "trajs": [], "success": []}
    self.rpa_exec_dicts[self.rpa_version]["trajs"].append(rpa_exec_traj)
    self.rpa_exec_dicts[self.rpa_version]["success"].append(rpa_exec_traj.success)
  
  # get for rpa builder
  def get_relevant_trajs(self, task: str) -> list[models.RPAExecTraj] | None:
    for rpa_exec_dict in reversed[Any](list[Any](self.rpa_exec_dicts.values())):
      for traj in rpa_exec_dict["trajs"]:
        if task == traj.task:
          return [traj]
    return None
  
  def save_temp(self, rpa_exec_traj: models.RPAExecTraj,
                save_path: str = None, file_name='rpa_exec_traj.txt'):
    traj_str = (
      "RPA Execution Trajectory:\n"
      f"Task: {rpa_exec_traj.task}\n"
      f"Function Call: {rpa_exec_traj.function_call}\n"
      f"RPA Code: \n{rpa_exec_traj.rpa_code}\n"
      f"Success: {rpa_exec_traj.success}\n"
      "\nTrajectory Steps:\n"
    )
    for idx, step in enumerate[Any](rpa_exec_traj.traj, 1):
      traj_str += (
        f"Step {idx}:\n"
        f"  Executed Action: {step.executed_action}\n"
        f"  Related Elements: {step.related_elements}\n"
        f"  Related Target: {step.related_target}\n"
        f"  Action Feedback: {step.action_feedback}\n"
        f"  Is Screen Changed: {step.is_screen_changed}\n"
        "-------------------------\n"
      )
    exec_result = rpa_exec_traj.exec_result
    traj_str += (
      "\nExecution Result:\n"
      f"  - Executed Code: {exec_result.executed_code if exec_result else ''}\n\n"
      f"  - Error Statement: {exec_result.error_statement if exec_result else ''}\n"
      f"  - Error Message: {exec_result.error_message if exec_result else ''}\n"
      f"  - Answer Return: {exec_result.answer_return if exec_result else ''}\n"
      f"  - Exec Feedback: {exec_result.exec_feedback if exec_result else ''}\n"
      f"  - Done: {exec_result.done if exec_result else ''}\n"
    )
    action_history_str = '\n\n'.join(rpa_exec_traj.action_history)
    traj_str += (
      f"\nAction History:\n{action_history_str}\n"
      f"\nExec Feedback:\n{rpa_exec_traj.exec_result.exec_feedback}"
    )
    save_path = save_path if save_path else self.save_path
    os.makedirs(save_path, exist_ok=True)
    file_path = os.path.join(save_path, file_name)
    with open(file_path, 'w', encoding='utf-8') as f:
      f.write(traj_str)
      f.flush()
