#!/usr/bin/env python3
"""
Android GUI Control Task Evaluation Script
Converted from Jupyter notebook to standalone Python script
Processes entire dataset using batch processing and saves results to JSON
"""

import json
import os
import sys
import time
import argparse
from datetime import datetime

def main():
    # Parse command line arguments
    parser = argparse.ArgumentParser(description='Android GUI Control Task Evaluation Script')
    parser.add_argument('--dataset', type=str, default='aitw', 
                       choices=['android_control', 'mind2web', 'gui_odyssey', 'aitw'],
                       help='Dataset to evaluate (default: aitw)')
    args = parser.parse_args()
    
    # Configuration
    BATCH_SIZE = 32  # Number of episodes to process in parallel
    
    # Set up environment
    print("Setting up environment...")
    os.environ["CUDA_VISIBLE_DEVICES"] = "0,1,2,3"
    
    # Import required modules
    try:
        from utils import VLLMInference
        from agents.agents_batch import BatchDecisionAgent
    except ImportError as e:
        print(f"Error importing required modules: {e}")
        print("Please ensure all dependencies are installed and modules are available.")
        sys.exit(1)
    
    # Load data
    print("Loading evaluation data...")
    
    # Set data paths based on dataset argument
    if args.dataset == 'android_control':
        data_path = ''
        imgs_dir = ""
        max_samples = 100
    elif args.dataset == 'mind2web':
        data_path = 'data/mind2web_website.json'
        imgs_dir = ""
        max_samples = 300
    elif args.dataset == 'gui_odyssey':
        data_path = 'data/gui_odyssey.json'
        imgs_dir = ''
        max_samples = 100
    elif args.dataset == 'aitw':
        data_path = 'data/aitw.json'
        imgs_dir = ''
        max_samples = 300
    
    try:
        with open(data_path, 'r') as f:
            data = json.load(f)
            if max_samples:
                data = data[:max_samples]
        print(f"Successfully loaded {args.dataset} data with {len(data)} episodes")
    except FileNotFoundError:
        print(f"Error: Data file not found at {data_path}")
        sys.exit(1)
    except json.JSONDecodeError:
        print(f"Error: Invalid JSON format in {data_path}")
        sys.exit(1)
    
    # Initialize model
    print("Initializing VLLM model...")
    model_path = ""
    
    try:
        vllm_inference = VLLMInference(
            model_path=model_path,
            tensor_parallel_size=4,
            gpu_memory_utilization=0.8
        )
        print("Model initialized successfully")
    except Exception as e:
        print(f"Error initializing model: {e}")
        sys.exit(1)
    
    # Initialize batch agent
    print("Initializing batch decision agent...")
    try:
        batch_agent = BatchDecisionAgent()
        print("Batch agent initialized successfully")
    except Exception as e:
        print(f"Error initializing batch agent: {e}")
        sys.exit(1)
    
    # Process episodes in batches
    print(f"\n=== Processing {len(data)} Episodes in Batches of {BATCH_SIZE} ===")
    all_results = []
    
    # Start overall timing
    overall_start_time = time.time()
    batch_times = []
    
    for batch_start in range(0, len(data), BATCH_SIZE):
        batch_end = min(batch_start + BATCH_SIZE, len(data))
        batch_data = data[batch_start:batch_end]
        
        print(f"\nProcessing batch {batch_start//BATCH_SIZE + 1}/{(len(data) + BATCH_SIZE - 1)//BATCH_SIZE}")
        print(f"Episodes {batch_start + 1}-{batch_end} (IDs: {[ep.get('episode_id', 'unknown') for ep in batch_data]})")
        
        # Start batch timing
        batch_start_time = time.time()
        
        try:
            # Load batch of episodes
            batch_agent.load_episodes_batch(batch_data, imgs_dir)
            
            # Run batch episodes
            batch_results = batch_agent.run_batch_episodes(vllm_inference)
            
            # Get all episode histories
            episode_histories = batch_agent.get_all_episode_histories()
            
            # Create result entries
            for episode_data in batch_data:
                episode_id = str(episode_data.get('episode_id', 'unknown'))
                
                if episode_id in episode_histories:
                    history = episode_histories[episode_id]
                    
                    result = {
                        'episode_id': episode_id,
                        'dataset': args.dataset,
                        'goal': episode_data.get('goal', ''),
                        'screenshots': history['screenshots'],
                        'captions': history['captions'],
                        'ground_truth_actions': history['ground_truth_actions'],
                        'predicted_actions': history['predicted_actions'],
                        'action_summaries': history['action_summaries'],
                        'step_instructions': episode_data.get('step_instructions', []),
                        'original_actions': episode_data.get('actions', []),
                        'heights': episode_data.get('heights', []),
                        'widths': episode_data.get('widths', []),
                        'screenshot_ids': episode_data.get('screenshots', [])
                    }
                    
                    all_results.append(result)
                    print(f"✅ Episode {episode_id} completed successfully")
                else:
                    print(f"❌ Episode {episode_id} not found in results")
                    # Add error entry
                    error_result = {
                        'episode_id': episode_id,
                        'dataset': args.dataset,
                        'goal': episode_data.get('goal', ''),
                        'error': 'Episode not found in batch results',
                        'screenshots': episode_data.get('screenshots', []),
                        'ground_truth_actions': [],
                        'predicted_actions': [],
                        'captions': [],
                        'action_summaries': []
                    }
                    all_results.append(error_result)
            
            # Print batch statistics
            batch_stats = batch_agent.get_batch_statistics()
            print(f"Batch completed: {batch_stats['completed_episodes']}/{batch_stats['total_episodes']} episodes")
            print(f"Total actions: {batch_stats['total_ground_truth_actions']} GT, {batch_stats['total_predicted_actions']} Pred")
            
            # Calculate and report batch timing
            batch_end_time = time.time()
            batch_duration = batch_end_time - batch_start_time
            batch_times.append(batch_duration)
            
            print(f"⏱️  Batch execution time: {batch_duration:.2f} seconds ({batch_duration/60:.2f} minutes)")
            print(f"⏱️  Average time per episode: {batch_duration/len(batch_data):.2f} seconds")
            
        except Exception as e:
            print(f"❌ Error processing batch: {e}")
            
            # Calculate and report batch timing even for failed batches
            batch_end_time = time.time()
            batch_duration = batch_end_time - batch_start_time
            batch_times.append(batch_duration)
            
            print(f"⏱️  Failed batch execution time: {batch_duration:.2f} seconds ({batch_duration/60:.2f} minutes)")
            
            # Add error entries for all episodes in this batch
            for episode_data in batch_data:
                episode_id = str(episode_data.get('episode_id', 'unknown'))
                error_result = {
                    'episode_id': episode_id,
                    'dataset': args.dataset,
                    'goal': episode_data.get('goal', ''),
                    'error': str(e),
                    'screenshots': episode_data.get('screenshots', []),
                    'ground_truth_actions': [],
                    'predicted_actions': [],
                    'captions': [],
                    'action_summaries': []
                }
                all_results.append(error_result)
            continue
    
    # Calculate overall timing statistics
    overall_end_time = time.time()
    total_duration = overall_end_time - overall_start_time
    
    print(f"\n=== Timing Summary ===")
    print(f"⏱️  Total execution time: {total_duration:.2f} seconds ({total_duration/60:.2f} minutes)")
    print(f"⏱️  Total execution time: {total_duration/3600:.2f} hours")
    
    if batch_times:
        avg_batch_time = sum(batch_times) / len(batch_times)
        min_batch_time = min(batch_times)
        max_batch_time = max(batch_times)
        print(f"⏱️  Average batch time: {avg_batch_time:.2f} seconds ({avg_batch_time/60:.2f} minutes)")
        print(f"⏱️  Fastest batch: {min_batch_time:.2f} seconds ({min_batch_time/60:.2f} minutes)")
        print(f"⏱️  Slowest batch: {max_batch_time:.2f} seconds ({max_batch_time/60:.2f} minutes)")
        print(f"⏱️  Total batches processed: {len(batch_times)}")
    
    if all_results:
        avg_episode_time = total_duration / len(all_results)
        print(f"⏱️  Average time per episode: {avg_episode_time:.2f} seconds")
    
    # Save results to JSON file
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    output_filename = f"eval_results/batch_evaluation_results_{args.dataset}_{timestamp}.json"
    
    print(f"\n=== Saving Results ===")
    print(f"Saving {len(all_results)} episode results to {output_filename}")
    
    try:
        with open(output_filename, 'w', encoding='utf-8') as f:
            json.dump(all_results, f, indent=2, ensure_ascii=False)
        print(f"✅ Results saved successfully to {output_filename}")
    except Exception as e:
        print(f"❌ Error saving results: {e}")
        sys.exit(1)
    
    # Print summary statistics
    print(f"\n=== Summary Statistics ===")
    successful_episodes = [r for r in all_results if 'error' not in r]
    failed_episodes = [r for r in all_results if 'error' in r]
    
    print(f"Total episodes processed: {len(all_results)}")
    print(f"Successful episodes: {len(successful_episodes)}")
    print(f"Failed episodes: {len(failed_episodes)}")
    print(f"Batch size used: {BATCH_SIZE}")
    
    if successful_episodes:
        total_gt_actions = sum(len(r['ground_truth_actions']) for r in successful_episodes)
        total_pred_actions = sum(len(r['predicted_actions']) for r in successful_episodes)
        print(f"Total ground truth actions: {total_gt_actions}")
        print(f"Total predicted actions: {total_pred_actions}")
        print(f"Average actions per episode: {total_gt_actions/len(successful_episodes):.2f}")
        
        # Calculate alignment statistics
        aligned_episodes = 0
        for episode in successful_episodes:
            if len(episode['ground_truth_actions']) == len(episode['predicted_actions']):
                aligned_episodes += 1
        
        print(f"Episodes with aligned actions: {aligned_episodes}/{len(successful_episodes)}")
    
    print(f"\n=== Evaluation Complete ===")
    print(f"Results saved to: {output_filename}")

if __name__ == "__main__":
    main() 
    
    
## python run.py --dataset gui_odyssey