#!/usr/bin/env python3
"""
Batch processing agent for Android GUI Control Task Evaluation
Supports processing multiple episodes in parallel using vLLM batch inference
"""

import json
import os
import logging
from typing import Dict, List, Any, Tuple
from dataclasses import dataclass
from agents.agents_static import DecisionAgentStatic

@dataclass
class EpisodeState:
    """Represents the state of a single episode during batch processing"""
    episode_id: str
    goal: str
    screenshot_history: List[str]
    action_history: List[str]
    current_step: int
    caption_history: List[str]
    action_summary_history: List[str]
    predicted_action_list: List[str]
    step_instructions: List[str]
    original_actions: List[Dict[str, Any]]
    heights: List[int]
    widths: List[int]
    screenshot_ids: List[int]

class BatchDecisionAgent:
    """Agent responsible for making decisions on multiple episodes in parallel"""
    
    def __init__(self, prompts_file: str = "prompts.json"):
        """Initialize the batch decision agent with prompts"""
        with open(prompts_file, 'r', encoding='utf-8') as f:
            prompts = json.load(f)
        # self.decision_prompt = prompts["decision_agent"]["prompt"]
        self.decision_prompt = prompts["mobile_assistant"]["prompt"]
        
        # Initialize other agents
        self.caption_agent = DecisionAgentStatic(prompts_file).caption_agent
        self.action_agent = DecisionAgentStatic(prompts_file).action_agent
        
        # Batch processing state
        self.episode_states: Dict[str, EpisodeState] = {}
        
    def load_episodes_batch(self, episodes_data: List[Dict[str, Any]], screenshots_dir: str):
        """Load multiple episodes for batch processing"""
        try:
            for episode_data in episodes_data:
                episode_id = str(episode_data.get('episode_id', 'unknown'))
                
                # Extract data from episode
                goal = episode_data.get('goal', '')
                step_instructions = episode_data.get('step_instructions', [])
                ground_truth_actions = episode_data.get('actions', [])
                screenshot_ids = episode_data.get('screenshots', [])
                heights = episode_data.get('heights', [])
                widths = episode_data.get('widths', [])
                
                # Build screenshot paths
                screenshot_history = []
                for screenshot_id in screenshot_ids:
                    if screenshot_id.endswith('.jpg'):
                        screenshot_path = os.path.join(screenshots_dir, screenshot_id)
                    elif screenshot_id.endswith('.png'):
                        screenshot_path = os.path.join(screenshots_dir, screenshot_id)
                    else:
                        screenshot_path = os.path.join(screenshots_dir, f"{screenshot_id}.jpg")
                    screenshot_history.append(screenshot_path)
                
                # Convert ground truth actions to string format
                action_history = []
                for action_data in ground_truth_actions:
                    action_str = self._convert_action_to_string(action_data)
                    action_history.append(action_str)
                
                # Create episode state
                episode_state = EpisodeState(
                    episode_id=episode_id,
                    goal=goal,
                    screenshot_history=screenshot_history,
                    action_history=action_history,
                    current_step=0,
                    caption_history=[],
                    action_summary_history=[],
                    predicted_action_list=[],
                    step_instructions=step_instructions,
                    original_actions=ground_truth_actions,
                    heights=heights,
                    widths=widths,
                    screenshot_ids=screenshot_ids
                )
                
                self.episode_states[episode_id] = episode_state
            
            logging.info(f"Loaded {len(self.episode_states)} episodes for batch processing")
            
        except Exception as e:
            logging.error(f"Error loading episodes batch: {e}")
            raise
    
    def _convert_action_to_string(self, action_data: Dict[str, Any]) -> str:
        """Convert action data to string format"""
        action_type = action_data.get('action_type', '')
        
        if action_type == 'click':
            x = action_data.get('x', 0)
            y = action_data.get('y', 0)
            return f"click(start_box='({x},{y})')"
        
        elif action_type == 'navigate_back':
            return "press_back()"
        
        elif action_type == 'type':
            content = action_data.get('content', '')
            return f"type(content='{content}')"
        
        elif action_type == 'scroll':
            direction = action_data.get('direction', 'down')
            return f"scroll(start_box='(0.5,0.8)', direction='{direction}')"
        
        elif action_type == 'long_press':
            x = action_data.get('x', 0)
            y = action_data.get('y', 0)
            return f"long_press(start_box='({x},{y})')"
        
        elif action_type == 'open_app':
            app_name = action_data.get('app_name', '')
            return f"open_app(app_name='{app_name}')"
        
        else:
            return f"{action_type}()"
    
    def _get_active_episodes(self) -> List[str]:
        """Get list of episode IDs that haven't finished processing"""
        active_episodes = []
        for episode_id, state in self.episode_states.items():
            if state.current_step < len(state.screenshot_history):
                active_episodes.append(episode_id)
        return active_episodes
    
    def _batch_generate_captions(self, vllm_inference, active_episodes: List[str]) -> Dict[str, str]:
        """Generate captions for multiple episodes in batch"""
        if not active_episodes:
            return {}
        
        # Prepare batch data
        batch_data = []
        episode_to_batch_idx = {}
        
        for i, episode_id in enumerate(active_episodes):
            state = self.episode_states[episode_id]
            current_screenshot = state.screenshot_history[state.current_step]
            
            caption_messages = self.caption_agent.generate_messages(current_screenshot)
            batch_data.append({"messages": caption_messages})
            episode_to_batch_idx[episode_id] = i
        
        # Batch inference
        caption_results = vllm_inference.inference_batch(
            dataset=batch_data,
            batch_size=len(batch_data),
            max_tokens=256,
            temperature=0.3,
            top_p=0.9
        )
        
        # Map results back to episodes
        captions = {}
        for episode_id, batch_idx in episode_to_batch_idx.items():
            if batch_idx < len(caption_results):
                captions[episode_id] = caption_results[batch_idx].strip()
            else:
                captions[episode_id] = "Failed to generate caption"
        
        return captions
    
    def _batch_generate_action_summaries(self, vllm_inference, active_episodes: List[str]) -> Dict[str, str]:
        """Generate action summaries for multiple episodes in batch"""
        if not active_episodes:
            return {}
        
        # Filter episodes that need action summaries (not first step)
        episodes_needing_summaries = []
        for episode_id in active_episodes:
            state = self.episode_states[episode_id]
            if state.current_step > 0:
                episodes_needing_summaries.append(episode_id)
        
        if not episodes_needing_summaries:
            return {}
        
        # Prepare batch data
        batch_data = []
        episode_to_batch_idx = {}
        
        for i, episode_id in enumerate(episodes_needing_summaries):
            state = self.episode_states[episode_id]
            previous_screenshot = state.screenshot_history[state.current_step - 1]
            current_screenshot = state.screenshot_history[state.current_step]
            previous_action = state.action_history[state.current_step - 1]
            
            action_messages = self.action_agent.generate_messages(
                before_screenshot=previous_screenshot,
                after_screenshot=current_screenshot,
                action=previous_action,
                instruction=state.goal
            )
            
            batch_data.append({"messages": action_messages})
            episode_to_batch_idx[episode_id] = i
        
        # Batch inference
        summary_results = vllm_inference.inference_batch(
            dataset=batch_data,
            batch_size=len(batch_data),
            max_tokens=256,
            temperature=0.3,
            top_p=0.9
        )
        
        # Map results back to episodes
        summaries = {}
        for episode_id, batch_idx in episode_to_batch_idx.items():
            if batch_idx < len(summary_results):
                summaries[episode_id] = summary_results[batch_idx].strip()
            else:
                summaries[episode_id] = "Failed to generate action summary"
        
        return summaries
    
    def _batch_generate_decisions(self, vllm_inference, active_episodes: List[str]) -> Dict[str, str]:
        """Generate decisions for multiple episodes in batch"""
        if not active_episodes:
            return {}
        
        # Prepare batch data
        batch_data = []
        episode_to_batch_idx = {}
        
        for i, episode_id in enumerate(active_episodes):
            state = self.episode_states[episode_id]
            current_screenshot = state.screenshot_history[state.current_step]
            
            # Format action summaries and captions
            action_summaries = ""
            for j, summary in enumerate(state.action_summary_history):
                action_summaries += f"Action {j+1}: {summary}\n"
            
            caption_summaries = ""
            for j, caption in enumerate(state.caption_history):
                caption_summaries += f"Caption {j+1}: {caption}\n"
            
            # Get current step width and height
            current_width = state.widths[state.current_step] if state.current_step < len(state.widths) else 1092
            current_height = state.heights[state.current_step] if state.current_step < len(state.heights) else 2408
            
            decision_messages = [
                {"role": "user", "content": [
                    {"type": "text", "text": self.decision_prompt.format(
                        action_summaries=action_summaries,
                        caption_summaries=caption_summaries,
                        instruction=state.goal,
                        width=current_width,
                        height=current_height
                    )},
                    {"type": "image", "image": current_screenshot},
                ]},
            ]
            
            batch_data.append({"messages": decision_messages})
            episode_to_batch_idx[episode_id] = i
        
        # Batch inference
        decision_results = vllm_inference.inference_batch(
            dataset=batch_data,
            batch_size=len(batch_data),
            max_tokens=1024,
            temperature=0.7,
            top_p=0.9
        )
        
        # Map results back to episodes
        decisions = {}
        for episode_id, batch_idx in episode_to_batch_idx.items():
            if batch_idx < len(decision_results):
                decisions[episode_id] = decision_results[batch_idx].strip()
            else:
                decisions[episode_id] = "wait()"
        
        return decisions
    
    def run_batch_step(self, vllm_inference) -> Dict[str, str]:
        """Execute one step for all active episodes in parallel"""
        try:
            active_episodes = self._get_active_episodes()
            
            if not active_episodes:
                return {}  # All episodes completed
            
            # 1. Generate captions in batch
            captions = self._batch_generate_captions(vllm_inference, active_episodes)
            
            # 2. Generate action summaries in batch
            action_summaries = self._batch_generate_action_summaries(vllm_inference, active_episodes)
            
            # 3. Generate decisions in batch
            decisions = self._batch_generate_decisions(vllm_inference, active_episodes)
            
            # 4. Update episode states
            results = {}
            for episode_id in active_episodes:
                state = self.episode_states[episode_id]
                
                # Update caption
                if episode_id in captions:
                    state.caption_history.append(captions[episode_id])
                
                # Update action summary
                if episode_id in action_summaries:
                    state.action_summary_history.append(action_summaries[episode_id])
                
                # Update predicted action
                if episode_id in decisions:
                    predicted_action = decisions[episode_id]
                    state.predicted_action_list.append(predicted_action)
                    results[episode_id] = predicted_action
                
                # Move to next step
                state.current_step += 1
            
            return results
            
        except Exception as e:
            logging.error(f"Error in batch step: {e}")
            raise
    
    def run_batch_episodes(self, vllm_inference) -> Dict[str, List[str]]:
        """Run all episodes to completion using batch processing"""
        try:
            all_results = {}
            
            # Run until all episodes are completed
            while True:
                step_results = self.run_batch_step(vllm_inference)
                
                if not step_results:  # All episodes completed
                    break
                
                # Store results
                for episode_id, action in step_results.items():
                    if episode_id not in all_results:
                        all_results[episode_id] = []
                    all_results[episode_id].append(action)
            
            return all_results
            
        except Exception as e:
            logging.error(f"Error running batch episodes: {e}")
            raise
    
    def get_episode_history(self, episode_id: str) -> Dict[str, List[str]]:
        """Get the complete history for a specific episode"""
        if episode_id not in self.episode_states:
            raise KeyError(f"Episode {episode_id} not found")
        
        state = self.episode_states[episode_id]
        return {
            'screenshots': state.screenshot_history.copy(),
            'captions': state.caption_history.copy(),
            'ground_truth_actions': state.action_history.copy(),
            'predicted_actions': state.predicted_action_list.copy(),
            'action_summaries': state.action_summary_history.copy()
        }
    
    def get_all_episode_histories(self) -> Dict[str, Dict[str, List[str]]]:
        """Get the complete history for all episodes"""
        histories = {}
        for episode_id in self.episode_states.keys():
            histories[episode_id] = self.get_episode_history(episode_id)
        return histories
    
    def get_batch_statistics(self) -> Dict[str, Any]:
        """Get statistics about the batch processing"""
        total_episodes = len(self.episode_states)
        active_episodes = len(self._get_active_episodes())
        completed_episodes = total_episodes - active_episodes
        
        total_actions = sum(len(state.action_history) for state in self.episode_states.values())
        total_predictions = sum(len(state.predicted_action_list) for state in self.episode_states.values())
        
        return {
            'total_episodes': total_episodes,
            'active_episodes': active_episodes,
            'completed_episodes': completed_episodes,
            'total_ground_truth_actions': total_actions,
            'total_predicted_actions': total_predictions,
            'average_actions_per_episode': total_actions / total_episodes if total_episodes > 0 else 0
        } 