'''
Reflection Agent - 用于分析评测结果并反思错误原因
适用于 batch_evaluation_results 格式
'''

import json
import re
import random
import base64
import math
from typing import List, Dict, Any

class ReflectionAgent:
    """Agent responsible for analyzing evaluation results and reflecting on errors"""
    
    def __init__(self, prompts_file: str = "prompts.json"):
        """Initialize the reflection agent with prompts"""
        self.prompts_file = prompts_file
        self.reflection_prompt = self._load_reflection_prompt()
        
    def _load_reflection_prompt(self) -> str:
        """Load reflection prompt from prompts file or use default"""
        try:
            with open(self.prompts_file, 'r', encoding='utf-8') as f:
                prompts = json.load(f)
                return prompts.get("reflection_agent", {}).get("prompt", self._get_default_reflection_prompt())
        except Exception:
            return self._get_default_reflection_prompt()
    
    def _get_default_reflection_prompt(self) -> str:
        """Get default reflection prompt - fallback if prompts.json is not available"""
        return "You are an AI reflection expert. Please analyze the provided information and identify error causes."

    def _image_to_base64(self, image_path: str) -> str:
        """Convert local image file to base64 data URL"""
        try:
            with open(image_path, 'rb') as image_file:
                image_data = image_file.read()
                base64_data = base64.b64encode(image_data).decode('utf-8')
                # Determine image type from file extension
                if image_path.lower().endswith('.jpg') or image_path.lower().endswith('.jpeg'):
                    mime_type = 'image/jpeg'
                elif image_path.lower().endswith('.png'):
                    mime_type = 'image/png'
                else:
                    mime_type = 'image/jpeg'  # default to jpeg
                return f"data:{mime_type};base64,{base64_data}"
        except Exception as e:
            print(f"Warning: Could not load image {image_path}: {e}")
            return ""

    def load_evaluation_results(self, results_file: str) -> Dict[str, Any]:
        """Load evaluation results from JSON file"""
        try:
            with open(results_file, 'r', encoding='utf-8') as f:
                results = json.load(f)
            return results
        except Exception as e:
            raise Exception(f"Error loading evaluation results from {results_file}: {e}")

    def reconstruct_reasoning_prompt(self, 
                                  goal: str, 
                                  action_summaries: List[str], 
                                  caption_summaries: List[str] = None) -> str:
        """Reconstruct the reasoning prompt based on action summaries and caption summaries"""
        # Format action summaries
        action_summaries_text = ""
        for j, summary in enumerate(action_summaries):
            action_summaries_text += f"Action {j+1}: {summary}\n"
        
        # Format caption summaries if available
        caption_summaries_text = ""
        if caption_summaries:
            for j, caption in enumerate(caption_summaries):
                caption_summaries_text += f"Caption {j+1}: {caption}\n"
        
        # Construct prompt
        prompt = f"Past Caption Summaries: \n{caption_summaries_text}"
        prompt += f"Past Action Summaries: \n{action_summaries_text}"
        prompt += f"Global Instruction: \n{goal}"
        
        return prompt

    def generate_reflection_prompt(self, 
                                 goal: str, 
                                 caption_summaries: List[str],
                                 action_summaries: List[str],
                                 predicted_action: str, 
                                 ground_truth_action: Dict, 
                                 screenshot_path: str = "") -> List[Dict[str, Any]]:
        """Generate reflection prompt for error analysis in OpenAI chat message format"""
        # Format caption summaries
        caption_summaries_text = ""
        for j, caption in enumerate(caption_summaries):
            caption_summaries_text += f"Caption {j}. {caption}\n"
        
        # Format action summaries
        action_summaries_text = ""
        for j, summary in enumerate(action_summaries):
            action_summaries_text += f"Action {j}. {summary}\n"
        
        text_content = self.reflection_prompt.format(
            goal=goal,
            caption_summaries=caption_summaries_text,
            action_summaries=action_summaries_text,
            predicted_action=predicted_action,
            ground_truth_action=json.dumps(ground_truth_action, ensure_ascii=False, indent=2)
        )
        
        # Create OpenAI chat message format
        content = [{"type": "text", "text": text_content}]
        
        # Add screenshot if provided
        if screenshot_path:
            content.append({"type": "image_url", "image_url": {"url": screenshot_path}})
        
        return [{"role": "user", "content": content}]

    def debug_sample_action(self, results_file: str) -> Dict[str, Any]:
        """Debug function: randomly sample an episode and action, generate both prompts"""
        try:
            # Load evaluation results - the file contains a list of episodes directly
            episodes = self.load_evaluation_results(results_file)
            
            if not episodes:
                return {"error": "No episodes found in results"}
            
            # Randomly sample episode
            episode = random.choice(episodes)
            
            # Get action data from the episode
            action_summaries = episode.get("action_summaries", [])
            caption_summaries = episode.get("captions", [])
            predicted_actions = episode.get("predicted_actions", [])
            ground_truth_actions = episode.get("ground_truth_actions", [])
            original_actions = episode.get("original_actions", [])
            screenshots = episode.get("screenshots", [])
            
            if not action_summaries:
                return {"error": "No action_summaries found in episode"}
            
            # Randomly sample an action index
            action_idx = random.randint(0, len(action_summaries) - 1)
            
            # Extract data for prompts
            goal = episode.get("goal", "")
            screenshot = screenshots[action_idx] if action_idx < len(screenshots) else ""
            
            # Get previous action summaries for reasoning prompt
            prev_action_summaries = action_summaries[:action_idx]
            prev_caption_summaries = caption_summaries[:action_idx]
            
            # Get current action data
            current_predicted_action = predicted_actions[action_idx] if action_idx < len(predicted_actions) else ""
            current_ground_truth = original_actions[action_idx] if action_idx < len(original_actions) else {}
            
            # Extract reasoning and action from predicted action
            reasoning = self._extract_reasoning(current_predicted_action)
            extracted_action = self._extract_action(current_predicted_action)
            
            # Generate reasoning prompt using actual action_summaries
            reasoning_prompt = self.reconstruct_reasoning_prompt(goal, prev_action_summaries, prev_caption_summaries)
            
            # Generate reflection prompt
            # Convert local image file to base64 data URL for OpenAI API
            screenshot_path = self._image_to_base64(screenshot) if screenshot else ""
            reflection_prompt = self.generate_reflection_prompt(
                goal, prev_caption_summaries, prev_action_summaries, extracted_action, current_ground_truth, screenshot_path=screenshot_path
            )
            
            return {
                "episode_id": episode.get("episode_id"),
                "action_idx": action_idx,
                "goal": goal,
                "action_summaries_available": len(action_summaries),
                "prev_action_summaries": prev_action_summaries,
                "current_action_summary": action_summaries[action_idx] if action_idx < len(action_summaries) else "",
                "reasoning": reasoning,
                "extracted_action": extracted_action,
                "ground_truth_action": current_ground_truth,
                "reasoning_prompt": reasoning_prompt,
                "reflection_prompt": reflection_prompt,
                "full_predicted_action": current_predicted_action,
                "screenshot_path": screenshot,
                "screenshot_base64_available": bool(screenshot_path)
            }
            
        except Exception as e:
            return {"error": f"Error in debug_sample_action: {e}"}

    def _extract_reasoning(self, pred_action: str) -> str:
        """Extract reasoning from predicted action string"""
        # 查找<think>标签中的内容
        think_match = re.search(r'<think>(.*?)</think>', pred_action, re.DOTALL)
        if think_match:
            return think_match.group(1).strip()
        return ""

    def _extract_action(self, pred_action: str) -> str:
        """Extract action from predicted action string"""
        # 查找<action>标签中的内容
        action_match = re.search(r'<action>(.*?)</action>', pred_action, re.DOTALL)
        if action_match:
            return action_match.group(1).strip()
        return pred_action

    def is_action_match(self, gt_action: Dict[str, Any], pred_action: str, error_margin: float = 1000) -> tuple[bool, bool]:
        """
        Check if predicted action matches ground-truth action
        
        Args:
            gt_action: Ground truth action dictionary
            pred_action: Predicted action string (may contain <think> and <action> tags)
            error_margin: Error margin for coordinate-based actions
            
        Returns:
            Tuple of (action_type_match, complete_match)
            - action_type_match: True if action types match
            - complete_match: True if both action type and content match
        """
        try:
            # Action type mapping from ground truth to predicted format
            action_mapping = {
                "click": "click", 
                "open_app": "open_app", 
                "scroll": "scroll", 
                "long_press": "long_press", 
                "navigate_back": "press_back", 
                "input_text": "type", 
                "wait": "wait"
            }
            
            # Direct actions that don't require coordinate/text matching
            direct_actions = ["navigate_back", "wait"]
            
            gt_action_type = gt_action['action_type']
            
            if not pred_action:
                return False, False
            
            # Parse predicted action
            # action_match = re.search(r'<action>(.*?)</action>', pred_action, re.DOTALL)
            # if not action_match:
            #     print("no action match")
            #     return False, False
                
            action_text = pred_action
            
            # Parse action function and parameters
            match = re.match(r"(\w+)\((.*)\)", action_text)
            if not match:
                print("no match")
                return False, False
                
            pred_func = match.group(1)
            params_str = match.group(2)
            
            # Check action type match
            action_type_match = action_mapping[gt_action_type] == pred_func
            if not action_type_match:
                return False, False
            
            # Direct actions (no coordinate/text matching needed)
            if gt_action_type in direct_actions:
                return True, True
            
            # Parse parameters
            params = {}
            # Handle both quoted and unquoted parameters
            param_pairs = re.findall(r"(\w+)=([^,)]+)", params_str)
            for key, value in param_pairs:
                # Remove quotes if present
                value = value.strip("'\"")
                if key in ["start_box", "end_box"]:
                    coords = value.strip("()").split(",")
                    if len(coords) == 2:
                        x = float(coords[0].strip())
                        y = float(coords[1].strip())
                        value = [x, y]
                params[key] = value
            
            # Check content match for different action types
            content_match = False
            
            # Coordinate-based actions
            if gt_action_type in ["click", "long_press"]:
                gt_x, gt_y = gt_action['x'], gt_action['y']
                if 'start_box' in params:
                    pred_x, pred_y = params['start_box']
                    distance = math.sqrt((pred_x - gt_x)**2 + (pred_y - gt_y)**2)
                    threshold = 0.14 * error_margin
                    content_match = distance <= threshold
            
            # Text-based actions
            elif gt_action_type == "input_text":
                gt_text = gt_action['text'].lower()
                pred_text = params.get('content', '').lower()
                content_match = gt_text in pred_text or pred_text in gt_text
            
            elif gt_action_type == "open_app":
                gt_app = gt_action['app_name'].lower()
                pred_app = params.get('app_name', '').lower()
                content_match = gt_app in pred_app or pred_app in gt_app
            
            # Direction-based actions
            elif gt_action_type == "scroll":
                gt_direction = gt_action['direction'].lower()
                pred_direction = params.get('direction', '').lower()
                content_match = gt_direction == pred_direction
            
            return action_type_match, content_match
                
        except Exception:
            return False, False                             