'''
GUI Agent System - Caption and Action Agents (Static Dataset Version)
'''

import os
import json
from typing import List, Dict, Any, Optional, Tuple
from PIL import Image
import logging

class CaptionAgent:
    """Agent responsible for generating text descriptions of GUI screenshots"""
    
    def __init__(self, prompts_file: str = "prompts.json"):
        """Initialize the caption agent with prompts"""
        with open(prompts_file, 'r', encoding='utf-8') as f:
            prompts = json.load(f)
        self.caption_prompt = prompts["caption_agent"]["prompt"]

    def generate_messages(self, screenshot_path: str) -> List[Dict[str, Any]]:
        """Generate messages for captioning a single screenshot"""
        try:
            # Verify image exists
            if not os.path.exists(screenshot_path):
                raise FileNotFoundError(f"Screenshot not found: {screenshot_path}")
            
            messages = [
                {"role": "user", "content": [
                    {"type": "text", "text": self.caption_prompt},
                    {"type": "image", "image": screenshot_path},
                ]},
            ]
            
            return messages
            
        except Exception as e:
            logging.error(f"Error generating messages for screenshot {screenshot_path}: {e}")
            raise


class ActionAgent:
    """Agent responsible for describing the effects of actions on GUI state"""
    
    def __init__(self, prompts_file: str = "prompts.json"):
        """Initialize the action agent with prompts"""
        with open(prompts_file, 'r', encoding='utf-8') as f:
            prompts = json.load(f)
        self.action_prompt = prompts["action_agent"]["prompt"]

    def generate_messages(self, 
                         before_screenshot: str, 
                         after_screenshot: str, 
                         action: str, 
                         instruction: str) -> List[Dict[str, Any]]:
        """Generate messages for describing action effect"""
        try:
            # Verify both images exist
            if not os.path.exists(before_screenshot):
                raise FileNotFoundError(f"Before screenshot not found: {before_screenshot}")
            if not os.path.exists(after_screenshot):
                raise FileNotFoundError(f"After screenshot not found: {after_screenshot}")
            
            messages = [
                {"role": "user", "content": [
                    {"type": "text", "text": self.action_prompt.format(
                        instruction=instruction, 
                        action=action
                    )},
                    {"type": "image", "image": before_screenshot},
                    {"type": "image", "image": after_screenshot},
                ]},
            ]
            
            return messages
            
        except Exception as e:
            logging.error(f"Error generating messages for action effect: {e}")
            raise


class DecisionAgentStatic:
    """Agent responsible for making decisions based on static dataset"""
    
    def __init__(self, prompts_file: str = "prompts.json"):
        """Initialize the decision agent with prompts"""
        with open(prompts_file, 'r', encoding='utf-8') as f:
            prompts = json.load(f)
        self.decision_prompt = prompts["decision_agent"]["prompt"]
        
        # Initialize state history (will be set from dataset)
        self.screenshot_history: List[str] = []
        self.caption_history: List[str] = []
        self.action_history: List[str] = []
        self.action_summary_history: List[str] = []
        
        # Current step pointer
        self.current_step: int = 0
        
        # Predicted actions list
        self.predicted_action_list: List[str] = []
        
        # Initialize other agents
        self.caption_agent = CaptionAgent(prompts_file)
        self.action_agent = ActionAgent(prompts_file)

    def load_dataset_episode(self, episode_data: Dict[str, Any], screenshots_dir: str):
        """Load episode data from static dataset"""
        try:
            # Extract data from episode
            self.episode_id = episode_data.get('episode_id', 'unknown')
            self.goal = episode_data.get('goal', '')
            self.step_instructions = episode_data.get('step_instructions', [])
            self.ground_truth_actions = episode_data.get('actions', [])
            self.screenshot_ids = episode_data.get('screenshots', [])
            self.heights = episode_data.get('heights', [])
            self.widths = episode_data.get('widths', [])
            
            # Build screenshot paths
            self.screenshot_history = []
            for screenshot_id in self.screenshot_ids:
                screenshot_path = os.path.join(screenshots_dir, f"{screenshot_id}.jpg")
                self.screenshot_history.append(screenshot_path)
            
            # Convert ground truth actions to string format
            self.action_history = []
            for action_data in self.ground_truth_actions:
                action_str = self._convert_action_to_string(action_data)
                self.action_history.append(action_str)
            
            # Reset state
            self.current_step = 0
            self.caption_history.clear()
            self.action_summary_history.clear()
            self.predicted_action_list.clear()
            
            logging.info(f"Loaded episode {self.episode_id} with {len(self.screenshot_history)} screenshots and {len(self.action_history)} actions")
            
        except Exception as e:
            logging.error(f"Error loading dataset episode: {e}")
            raise

    def _convert_action_to_string(self, action_data: Dict[str, Any]) -> str:
        """Convert action data to string format"""
        action_type = action_data.get('action_type', '')
        
        if action_type == 'click':
            x = action_data.get('x', 0)
            y = action_data.get('y', 0)
            return f"click(start_box='({x},{y})')"
        
        elif action_type == 'navigate_back':
            return "press_back()"
        
        elif action_type == 'type':
            content = action_data.get('content', '')
            return f"type(content='{content}')"
        
        elif action_type == 'scroll':
            direction = action_data.get('direction', 'down')
            return f"scroll(start_box='(0.5,0.8)', direction='{direction}')"
        
        elif action_type == 'long_press':
            x = action_data.get('x', 0)
            y = action_data.get('y', 0)
            return f"long_press(start_box='({x},{y})')"
        
        elif action_type == 'open_app':
            app_name = action_data.get('app_name', '')
            return f"open_app(app_name='{app_name}')"
        
        else:
            return f"{action_type}()"

    def get_current_screenshot(self) -> str:
        """Get current screenshot based on step pointer"""
        if self.current_step < len(self.screenshot_history):
            return self.screenshot_history[self.current_step]
        else:
            raise IndexError(f"Step {self.current_step} exceeds screenshot history length {len(self.screenshot_history)}")

    def get_previous_screenshot(self) -> str:
        """Get previous screenshot based on step pointer"""
        if self.current_step > 0 and self.current_step < len(self.screenshot_history):
            return self.screenshot_history[self.current_step - 1]
        else:
            raise IndexError(f"Cannot get previous screenshot for step {self.current_step}")

    def get_previous_action(self) -> str:
        """Get previous action based on step pointer"""
        if self.current_step > 0 and self.current_step - 1 < len(self.action_history):
            return self.action_history[self.current_step - 1]
        else:
            raise IndexError(f"Cannot get previous action for step {self.current_step}")

    def generate_messages(self, instruction: str, current_screenshot: str) -> List[Dict[str, Any]]:
        """Generate messages for decision making"""
        try:
            # Format action summaries
            action_summaries = ""
            for i, summary in enumerate(self.action_summary_history):
                action_summaries += f"Action {i+1}: {summary}\n"
            
            messages = [
                {"role": "user", "content": [
                    {"type": "text", "text": self.decision_prompt.format(
                        action_summaries=action_summaries,
                        instruction=instruction
                    )},
                    {"type": "image", "image": current_screenshot},
                ]},
            ]
            
            return messages
            
        except Exception as e:
            logging.error(f"Error generating messages for decision: {e}")
            raise

    def step(self, vllm_inference) -> str:
        """Execute one step of the decision process"""
        try:
            # Check if we've reached the end
            if self.current_step >= len(self.screenshot_history):
                raise IndexError(f"Already at the end of episode (step {self.current_step})")
            
            # 1. Get current screenshot
            current_screenshot = self.get_current_screenshot()
            
            # 2. Generate caption for current screenshot
            caption_messages = self.caption_agent.generate_messages(current_screenshot)
            caption_result = vllm_inference.inference_batch(
                dataset=[{"messages": caption_messages}],
                batch_size=1,
                max_tokens=256,
                temperature=0.3,
                top_p=0.9
            )
            current_caption = caption_result[0].strip() if caption_result else "Failed to generate caption"
            self.caption_history.append(current_caption)
            
            # 3. Generate action summary if not initial state
            if self.current_step > 0:
                previous_screenshot = self.get_previous_screenshot()
                previous_action = self.get_previous_action()
                
                action_messages = self.action_agent.generate_messages(
                    before_screenshot=previous_screenshot,
                    after_screenshot=current_screenshot,
                    action=previous_action,
                    instruction=self.goal
                )
                
                action_summary_result = vllm_inference.inference_batch(
                    dataset=[{"messages": action_messages}],
                    batch_size=1,
                    max_tokens=256,
                    temperature=0.3,
                    top_p=0.9
                )
                
                action_summary = action_summary_result[0].strip() if action_summary_result else "Failed to generate action summary"
                self.action_summary_history.append(action_summary)
            
            # 4. Generate decision
            decision_messages = self.generate_messages(self.goal, current_screenshot)
            decision_result = vllm_inference.inference_batch(
                dataset=[{"messages": decision_messages}],
                batch_size=1,
                max_tokens=1024,
                temperature=0.7,
                top_p=0.9
            )
            
            predicted_action = decision_result[0].strip() if decision_result else "wait()"
            self.predicted_action_list.append(predicted_action)
            
            # 5. Move to next step
            self.current_step += 1
            
            return predicted_action
            
        except Exception as e:
            logging.error(f"Error in decision step: {e}")
            raise

    def run_episode(self, vllm_inference) -> List[str]:
        """Run the complete episode and return all predicted actions"""
        try:
            predicted_actions = []
            
            # Run until we've processed all screenshots
            while self.current_step < len(self.screenshot_history):
                action = self.step(vllm_inference)
                predicted_actions.append(action)
            
            return predicted_actions
            
        except Exception as e:
            logging.error(f"Error running episode: {e}")
            raise

    def get_current_state(self) -> Dict[str, Any]:
        """Get current state information"""
        return {
            'episode_id': self.episode_id,
            'goal': self.goal,
            'current_step': self.current_step,
            'total_steps': len(self.screenshot_history),
            'captions_generated': len(self.caption_history),
            'action_summaries_generated': len(self.action_summary_history),
            'actions_predicted': len(self.predicted_action_list)
        }

    def get_history(self) -> Dict[str, List[str]]:
        """Get the complete history"""
        return {
            'screenshots': self.screenshot_history.copy(),
            'captions': self.caption_history.copy(),
            'ground_truth_actions': self.action_history.copy(),
            'predicted_actions': self.predicted_action_list.copy(),
            'action_summaries': self.action_summary_history.copy()
        }

    def reset(self):
        """Reset the agent's state (but keep loaded episode data)"""
        self.current_step = 0
        self.caption_history.clear()
        self.action_summary_history.clear()
        self.predicted_action_list.clear() 