#!/usr/bin/env python3
"""
Android GUI Control Task Evaluation Script
Optimized version of evaluate.ipynb for comprehensive evaluation
Compatible with test.py output format
"""

import json
import math
import numpy as np
from collections import defaultdict
from typing import Dict, List, Tuple, Any
from tqdm import tqdm
import argparse
import os
from datetime import datetime
import re

from utils import calculate_f1, R1ActionParser

class AndroidControlEvaluator:
    """Comprehensive evaluator for Android GUI control tasks"""
    
    def __init__(self):
        """Initialize the evaluator with action mappings and parser"""
        self.action_parser = R1ActionParser()
        
        # Action type mapping from ground truth to predicted format
        self.action_mapping = {
            "click": "click", 
            "open_app": "open_app", 
            "scroll": "scroll", 
            "long_press": "long_press", 
            "navigate_back": "press_back", 
            "input_text": "type", 
            "wait": "wait",
            "navigate_home": "navigate_home",
            "left_click": "click"
        }
        
        # Actions that don't require coordinate/text matching
        self.direct_actions = ["navigate_back", "wait", "navigate_home"]
        
        # Evaluation results storage
        self.results = {
            'overall': {'opp_acc': [], 'ele_acc': []},
            'by_action_type': defaultdict(lambda: {'opp_acc': [], 'ele_acc': []}),
            'episode_level': [],
            'detailed_results': []
        }
    
    def parse_mobile_use_action(self, pred_action: str) -> List[Dict[str, Any]]:
        """Parse mobile_use format action from LLM response"""
        try:
            # Extract tool_call content
            tool_call_match = re.search(r'<tool_call>\s*(\{.*?\})\s*</tool_call>', pred_action, re.DOTALL)
            if not tool_call_match:
                return []
            
            tool_call_str = tool_call_match.group(1)
            tool_call_data = json.loads(tool_call_str)
            
            if tool_call_data.get('name') != 'mobile_use':
                return []
            
            args = tool_call_data.get('arguments', {})
            action_type = args.get('action', '')
            
            # Map action types
            if action_type == 'system_button':
                button = args.get('button', '')
                if button == 'Back':
                    action_type = 'press_back'
                elif button == 'Home':
                    action_type = 'navigate_home'
                else:
                    return []  # Other system buttons not supported
            elif action_type == 'swipe':
                # Convert swipe to scroll with direction
                direction = self._get_swipe_direction(args.get('coordinate', []), args.get('coordinate2', []))
                if direction:
                    return [{
                        'action_type': 'scroll',
                        'action_inputs': {
                            'direction': direction,
                            'coordinate': args.get('coordinate', [])
                        }
                    }]
                return []
            elif action_type == 'open':
                action_type = 'open_app'
            
            elif action_type == 'left_click':
                action_type = 'click'
            
            elif action_type == 'terminate':
                action_type = 'wait'
            
            # Build action dict
            action_dict = {'action_type': action_type, 'action_inputs': {}}
            
            # Add relevant inputs based on action type
            if action_type in ['click', 'long_press']:
                coord = args.get('coordinate', [])
                if len(coord) == 2:
                    action_dict['action_inputs']['start_box'] = coord
            elif action_type == 'type':
                action_dict['action_inputs']['content'] = args.get('text', '')
            elif action_type == 'open_app':
                action_dict['action_inputs']['app_name'] = args.get('text', '')
            elif action_type == 'scroll':
                action_dict['action_inputs']['direction'] = args.get('text', '')
                coord = args.get('coordinate', [])
                if len(coord) == 2:
                    action_dict['action_inputs']['start_box'] = coord
            
            return [action_dict]
            
        except Exception as e:
            return []
    
    def _get_swipe_direction(self, start_coord: List[int], end_coord: List[int]) -> str:
        """Determine swipe direction from start and end coordinates"""
        if len(start_coord) != 2 or len(end_coord) != 2:
            return ''
        
        dx = end_coord[0] - start_coord[0]
        dy = end_coord[1] - start_coord[1]
        
        # Determine primary direction
        if abs(dx) > abs(dy):
            return 'right' if dx > 0 else 'left'
        else:
            return 'up' if dy > 0 else 'down'
    

    
    def evaluate_action(self, gt_action: Dict[str, Any], pred_action: str, 
                       error_margin: float) -> Tuple[float, float]:
        """
        Evaluate a single action prediction against ground truth
        
        Args:
            gt_action: Ground truth action dictionary
            pred_action: Predicted action string (may contain <think> and <action> tags)
            error_margin: Error margin for coordinate-based actions
            
        Returns:
            Tuple of (opportunity accuracy, element accuracy)
        """
        try:
            gt_action_type = gt_action['action_type']
            
            if not pred_action:
                return 0.0, 0.0
            
            # Try new mobile_use format first, then fallback to R1ActionParser
            res_action_pred = self.parse_mobile_use_action(pred_action)
            
            if not res_action_pred:
                # Fallback to original R1ActionParser
                res_action_pred, answer = self.action_parser.parse_llm_response(pred_action)
                if not res_action_pred or not res_action_pred[0]:
                    return 0.0, 0.0
                res_action_type = res_action_pred[0]['action_type']
            else:
                res_action_type = res_action_pred[0]['action_type']
            
            # Check action type match
            if self.action_mapping[gt_action_type] != res_action_type:
                return 0.0, 0.0
            
            # Direct actions (no coordinate/text matching needed)
            if gt_action_type in self.direct_actions:
                return 1.0, 1.0
            
            # Coordinate-based actions
            if gt_action_type in ["click", "long_press"]:
                return self._evaluate_coordinate_action(gt_action, res_action_pred[0], error_margin)
            
            # Text-based actions
            elif gt_action_type == "input_text":
                return self._evaluate_text_action(gt_action, res_action_pred[0], 'text', 'content')
            
            elif gt_action_type == "open_app":
                return self._evaluate_text_action(gt_action, res_action_pred[0], 'app_name', 'app_name')
            
            # Direction-based actions
            elif gt_action_type == "scroll":
                return self._evaluate_direction_action(gt_action, res_action_pred[0])
            
            else:
                return 1.0, 0.0
                
        except Exception as e:
            print(f"Error evaluating action: {e}")
            return 0.0, 0.0
    
    def _evaluate_coordinate_action(self, gt_action: Dict[str, Any], 
                                   pred_action: Dict[str, Any], 
                                   error_margin: float) -> Tuple[float, float]:
        """Evaluate coordinate-based actions (click, long_press)"""
        try:
            gt_x, gt_y = gt_action['x'], gt_action['y']
            res_x, res_y = pred_action['action_inputs']['start_box']
            
            distance = math.sqrt((res_x - gt_x)**2 + (res_y - gt_y)**2)
            threshold = 0.14 * error_margin
            individual_threshold = 0.035 * error_margin
            
            # Original distance-based evaluation
            if distance <= threshold:
                return 1.0, 1.0
            
            # Additional condition: if either x or y is close enough individually
            if abs(res_x - gt_x) <= individual_threshold or abs(res_y - gt_y) <= individual_threshold:
                return 1.0, 1.0
            
            return 1.0, 0.0
        except:
            return 1.0, 0.0
    
    def _evaluate_text_action(self, gt_action: Dict[str, Any], 
                             pred_action: Dict[str, Any], 
                             gt_key: str, pred_key: str) -> Tuple[float, float]:
        """Evaluate text-based actions (input_text, open_app)"""
        try:
            gt_text = gt_action[gt_key].lower()
            res_text = pred_action['action_inputs'][pred_key].lower()
            
            # Check for substring match
            if gt_text in res_text or res_text in gt_text:
                return 1.0, 1.0
            
            # Check F1 score
            f1_score = calculate_f1(res_text, gt_text)
            if f1_score >= 0.5:
                return 1.0, 1.0
            else:
                return 1.0, 0.0
        except:
            return 1.0, 0.0
    
    def _evaluate_direction_action(self, gt_action: Dict[str, Any], 
                                  pred_action: Dict[str, Any]) -> Tuple[float, float]:
        """Evaluate direction-based actions (scroll)"""
        try:
            gt_direction = gt_action['direction'].lower()
            res_direction = pred_action['action_inputs']['direction'].lower()
            
            if gt_direction == res_direction:
                return 1.0, 1.0
            else:
                return 1.0, 0.0
        except:
            return 1.0, 0.0
    
    def evaluate_episode(self, gt_episode: Dict[str, Any], 
                        pred_episode: Dict[str, Any]) -> Dict[str, Any]:
        """Evaluate a single episode using test.py format"""
        episode_results = {
            'episode_id': gt_episode.get('episode_id', 'unknown'),
            'goal': gt_episode.get('goal', ''),
            'actions': [],
            'episode_opp_acc': 0.0,
            'episode_ele_acc': 0.0,
            'total_actions': 0,
            'alignment_issue': False
        }
        
        episode_opp_acc = []
        episode_ele_acc = []
        
        # Get ground truth and predicted actions
        gt_actions = gt_episode.get('actions', [])
        pred_actions = pred_episode.get('predicted_actions', [])
        
        # Only evaluate the first len(gt_actions) predicted actions to ensure alignment
        # This means we ignore the last predicted action if there are more predictions than ground truth
        num_actions_to_evaluate = min(len(gt_actions), len(pred_actions))
        
        # Check for alignment issue (but we'll handle it by truncating)
        if len(gt_actions) != len(pred_actions):
            episode_results['alignment_issue'] = True
            print(f"⚠️  Episode {episode_results['episode_id']}: Truncating predictions - "
                  f"{len(gt_actions)} GT actions vs {len(pred_actions)} predicted actions, "
                  f"evaluating first {num_actions_to_evaluate} actions")
        
        # Evaluate each action (only up to the number of ground truth actions)
        for action_idx in range(num_actions_to_evaluate):
            gt_action = gt_actions[action_idx]
            pred_action = pred_actions[action_idx]
            
            # Calculate error margin
            if action_idx < len(gt_episode.get('widths', [])) and action_idx < len(gt_episode.get('heights', [])):
                error_margin = (gt_episode['widths'][action_idx] + gt_episode['heights'][action_idx]) / 2
            else:
                error_margin = 1000  # Default error margin
            
            # Evaluate action
            opp_acc, ele_acc = self.evaluate_action(gt_action, pred_action, error_margin)
            
            episode_opp_acc.append(opp_acc)
            episode_ele_acc.append(ele_acc)
            
            # Parse the predicted action to get the parsed action
            parsed_action = self.parse_mobile_use_action(pred_action)
            if not parsed_action:
                # Fallback to original R1ActionParser
                parsed_action, _ = self.action_parser.parse_llm_response(pred_action)
            
            # Store detailed action result
            action_result = {
                'action_idx': action_idx,
                'gt_action': gt_action,
                'pred_action': pred_action,
                'parsed_action': parsed_action,
                'opp_acc': opp_acc,
                'ele_acc': ele_acc,
                'action_type': gt_action['action_type']
            }
            
            # Add screenshot information if available
            if action_idx < len(gt_episode.get('screenshots', [])):
                screenshot_id = gt_episode['screenshots'][action_idx]
                action_result['screenshot_id'] = screenshot_id
                action_result['screenshot_path'] = f"{screenshot_id}.jpg"
            episode_results['actions'].append(action_result)
            
            # Update overall results
            self.results['overall']['opp_acc'].append(opp_acc)
            self.results['overall']['ele_acc'].append(ele_acc)
            self.results['by_action_type'][gt_action['action_type']]['opp_acc'].append(opp_acc)
            self.results['by_action_type'][gt_action['action_type']]['ele_acc'].append(ele_acc)
        
        # Calculate episode-level metrics
        if episode_opp_acc:
            episode_results['episode_opp_acc'] = np.mean(episode_opp_acc)
            episode_results['episode_ele_acc'] = np.mean(episode_ele_acc)
            episode_results['total_actions'] = len(episode_opp_acc)
        
        self.results['episode_level'].append(episode_results)
        return episode_results
    
    def evaluate_test_results(self, gt_data: List[Dict[str, Any]], 
                             test_results: List[Dict[str, Any]]) -> Dict[str, Any]:
        """Evaluate results from test.py format"""
        print(f"Evaluating {len(test_results)} episodes from test.py results...")
        
        # Create mapping from episode_id to test result
        test_results_map = {result['episode_id']: result for result in test_results}
        
        for episode_idx in tqdm(range(len(gt_data)), desc="Evaluating episodes"):
            gt_episode = gt_data[episode_idx]
            episode_id = str(gt_episode.get('episode_id', 'unknown'))
            
            if episode_id in test_results_map:
                test_episode = test_results_map[episode_id]
                self.evaluate_episode(gt_episode, test_episode)
            else:
                print(f"⚠️  Episode {episode_id} not found in test results")
        
        return self._calculate_final_metrics()
    
    def evaluate_dataset(self, gt_data: List[Dict[str, Any]], 
                        pred_data: List[List[Dict[str, Any]]]) -> Dict[str, Any]:
        """Evaluate entire dataset (legacy format)"""
        print(f"Evaluating {len(gt_data)} episodes...")
        
        for episode_idx in tqdm(range(len(gt_data)), desc="Evaluating episodes"):
            if episode_idx < len(pred_data):
                self.evaluate_episode(gt_data[episode_idx], pred_data[episode_idx])
        
        return self._calculate_final_metrics()
    
    def _calculate_final_metrics(self) -> Dict[str, Any]:
        """Calculate final evaluation metrics"""
        metrics = {
            'overall': {},
            'by_action_type': {},
            'episode_level': {},
            'alignment_statistics': {}
        }
        
        # Overall metrics
        overall_opp_acc = np.array(self.results['overall']['opp_acc'])
        overall_ele_acc = np.array(self.results['overall']['ele_acc'])
        
        metrics['overall'] = {
            'opportunity_accuracy': {
                'mean': float(np.mean(overall_opp_acc)),
                'std': float(np.std(overall_opp_acc)),
                'total': int(np.sum(overall_opp_acc)),
                'count': len(overall_opp_acc)
            },
            'element_accuracy': {
                'mean': float(np.mean(overall_ele_acc)),
                'std': float(np.std(overall_ele_acc)),
                'total': int(np.sum(overall_ele_acc)),
                'count': len(overall_ele_acc)
            }
        }
        
        # Action type specific metrics
        for action_type, acc_data in self.results['by_action_type'].items():
            opp_acc = np.array(acc_data['opp_acc'])
            ele_acc = np.array(acc_data['ele_acc'])
            
            metrics['by_action_type'][action_type] = {
                'opportunity_accuracy': {
                    'mean': float(np.mean(opp_acc)),
                    'std': float(np.std(opp_acc)),
                    'total': int(np.sum(opp_acc)),
                    'count': len(opp_acc)
                },
                'element_accuracy': {
                    'mean': float(np.mean(ele_acc)),
                    'std': float(np.std(ele_acc)),
                    'total': int(np.sum(ele_acc)),
                    'count': len(ele_acc)
                }
            }
        
        # Episode level metrics
        episode_opp_acc = [ep['episode_opp_acc'] for ep in self.results['episode_level']]
        episode_ele_acc = [ep['episode_ele_acc'] for ep in self.results['episode_level']]
        
        metrics['episode_level'] = {
            'opportunity_accuracy': {
                'mean': float(np.mean(episode_opp_acc)),
                'std': float(np.std(episode_opp_acc)),
                'median': float(np.median(episode_opp_acc))
            },
            'element_accuracy': {
                'mean': float(np.mean(episode_ele_acc)),
                'std': float(np.std(episode_ele_acc)),
                'median': float(np.median(episode_ele_acc))
            }
        }
        
        # Alignment statistics
        # Since we truncate predictions to match ground truth, all episodes should be aligned
        aligned_episodes = len(self.results['episode_level'])  # All episodes are now aligned
        total_episodes = len(self.results['episode_level'])
        
        metrics['alignment_statistics'] = {
            'aligned_episodes': aligned_episodes,
            'total_episodes': total_episodes,
            'alignment_rate': 1.0,  # Always 1.0 since we truncate
            'truncated_predictions': sum(1 for ep in self.results['episode_level'] if ep.get('alignment_issue', False))
        }
        
        return metrics
    
    def print_results(self, metrics: Dict[str, Any]):
        """Print evaluation results in a formatted way"""
        print("\n" + "="*60)
        print("ANDROID GUI CONTROL TASK EVALUATION RESULTS")
        print("="*60)
        
        # Overall results
        print(f"\nOVERALL RESULTS:")
        print(f"Opportunity Accuracy: {metrics['overall']['opportunity_accuracy']['mean']:.4f} "
              f"({metrics['overall']['opportunity_accuracy']['total']}/{metrics['overall']['opportunity_accuracy']['count']})")
        print(f"Element Accuracy: {metrics['overall']['element_accuracy']['mean']:.4f} "
              f"({metrics['overall']['element_accuracy']['total']}/{metrics['overall']['element_accuracy']['count']})")
        
        # Action type specific results
        print(f"\nRESULTS BY ACTION TYPE:")
        for action_type, action_metrics in metrics['by_action_type'].items():
            opp_acc = action_metrics['opportunity_accuracy']
            ele_acc = action_metrics['element_accuracy']
            print(f"{action_type:12} - Opp: {opp_acc['mean']:.4f} ({opp_acc['total']:3d}/{opp_acc['count']:3d}) | "
                  f"Ele: {ele_acc['mean']:.4f} ({ele_acc['total']:3d}/{ele_acc['count']:3d})")
        
        # Episode level results
        print(f"\nEPISODE LEVEL RESULTS:")
        print(f"Opportunity Accuracy: {metrics['episode_level']['opportunity_accuracy']['mean']:.4f} ± "
              f"{metrics['episode_level']['opportunity_accuracy']['std']:.4f}")
        print(f"Element Accuracy: {metrics['episode_level']['element_accuracy']['mean']:.4f} ± "
              f"{metrics['episode_level']['element_accuracy']['std']:.4f}")
        
        # Alignment statistics
        print(f"\nALIGNMENT STATISTICS:")
        print(f"Aligned Episodes: {metrics['alignment_statistics']['aligned_episodes']}/{metrics['alignment_statistics']['total_episodes']}")
        print(f"Alignment Rate: {metrics['alignment_statistics']['alignment_rate']:.4f}")
        print(f"Episodes with Truncated Predictions: {metrics['alignment_statistics']['truncated_predictions']}")
        print(f"Note: Predictions are truncated to match ground truth action count for fair evaluation")
        
        print("="*60)
    
    def save_results(self, metrics: Dict[str, Any], output_path: str):
        """Save evaluation results to JSON file"""
        results = {
            'metrics': metrics,
            'detailed_results': self.results['episode_level'],
            'evaluation_timestamp': datetime.now().isoformat(),
            'total_episodes': len(self.results['episode_level']),
            'total_actions': len(self.results['overall']['opp_acc'])
        }
        
        os.makedirs(os.path.dirname(output_path), exist_ok=True)
        with open(output_path, 'w', encoding='utf-8') as f:
            json.dump(results, f, indent=2, ensure_ascii=False)
        
        print(f"\nResults saved to: {output_path}")

def main():
    """Main evaluation function"""
    parser = argparse.ArgumentParser(description='Evaluate Android GUI Control Task Results')
    parser.add_argument('--gt_path', type=str, required=True,
                       help='Path to ground truth JSON file')
    parser.add_argument('--pred_path', type=str, required=True,
                       help='Path to prediction results JSON file (test.py format)')
    parser.add_argument('--output_path', type=str, default='eval_results/evaluation_results.json',
                       help='Path to save evaluation results')
    
    args = parser.parse_args()
    
    # Load data
    print(f"Loading ground truth data from: {args.gt_path}")
    with open(args.gt_path, 'r') as f:
        gt_data = json.load(f)
    
    print(f"Loading prediction data from: {args.pred_path}")
    with open(args.pred_path, 'r') as f:
        pred_data = json.load(f)
    
    # Initialize evaluator
    evaluator = AndroidControlEvaluator()
    
    # Run evaluation (assuming test.py format)
    metrics = evaluator.evaluate_test_results(gt_data, pred_data)
    
    # Print and save results
    evaluator.print_results(metrics)
    evaluator.save_results(metrics, args.output_path)

if __name__ == "__main__":
    main() 