#!/usr/bin/env python3
"""
Simplified Evaluation Error Extraction Tool

Helps evolution agents quickly extract error cases from evaluation files.
Focuses on status field filtering with JSON output.
"""

import json
import random
from pathlib import Path
from typing import Dict, List, Optional, Any, Union
import argparse
import sys


def find_most_recent_iteration(base_path: Path = Path.cwd()) -> Optional[str]:
    """
    Find the most recent iteration directory in the current research directory.
    
    Returns:
        Iteration name (e.g., "iteration_031") or None
    """
    iteration_dirs = []
    
    # Look for iteration directories in current directory (we're in a research dir)
    for path in base_path.glob("iteration_*"):
        if path.is_dir():
            # Extract iteration number
            try:
                iter_num = int(path.name.split('_')[1])
                iteration_dirs.append((iter_num, path.name))
            except (IndexError, ValueError):
                continue
    
    if not iteration_dirs:
        return None
    
    # Return the highest numbered iteration
    iteration_dirs.sort(reverse=True)
    return iteration_dirs[0][1]


def find_evaluation_files(agent: str,
                         iteration: Optional[str] = None,
                         databases: Optional[List[str]] = None,
                         base_path: Path = Path.cwd()) -> List[tuple[Path, str]]:
    """
    Find evaluation files for the specified agent and filters.
    Assumes we're running from within a research directory.
    
    Returns:
        List of (path, database_name) tuples
    """
    # Determine iteration
    if not iteration:
        iteration = find_most_recent_iteration(base_path)
        if not iteration:
            print("Error: No iteration directories found", file=sys.stderr)
            return []
    
    # Build search pattern
    agent_pattern = f"agent_{agent}*" if not agent.startswith("agent_") else agent
    pattern = f"{iteration}/{agent_pattern}/*/results/evaluation.json"
    
    eval_files = []
    for eval_file in base_path.glob(pattern):
        # Extract database name from path
        db_name = eval_file.parent.parent.name
        
        # Filter by database if specified
        if databases and db_name not in databases:
            continue
            
        eval_files.append((eval_file, db_name))
    
    return eval_files


def load_evaluation(file_path: Path) -> Dict[str, Any]:
    """Load evaluation JSON file."""
    try:
        with open(file_path, 'r') as f:
            return json.load(f)
    except Exception as e:
        print(f"Warning: Failed to load {file_path}: {e}", file=sys.stderr)
        return {}


def extract_errors_from_file(evaluation: Dict[str, Any],
                            database: str,
                            status_filter: List[str]) -> List[Dict[str, Any]]:
    """
    Extract errors from a single evaluation file.
    
    Args:
        evaluation: Loaded evaluation data
        database: Database name for context
        status_filter: List of statuses to include
    
    Returns:
        List of error entries
    """
    errors = []
    results = evaluation.get('results', {})
    
    # Handle both dict and list formats
    if isinstance(results, dict):
        for q_id, result in results.items():
            status = result.get('status', 'unknown')
            
            # Check if this is a match/mismatch case
            if 'matches' in result and 'status' not in result:
                status = 'match' if result['matches'] else 'mismatch'
            
            # Apply status filter
            if status not in status_filter:
                continue
            
            # Extract error information
            error_entry = {
                'database': database,
                'question_id': q_id,
                'status': status,
                'question': result.get('question', result.get('question_text', '')),
                'evidence': result.get('evidence', ''),
                'predicted_sql': result.get('predicted_sql', result.get('predicted', '')),
                'ground_truth_sql': result.get('ground_truth_sql', result.get('ground_truth', ''))
            }
            
            # Add predicted error message if present (for pred_error status)
            if 'predicted_error' in result and result['predicted_error']:
                error_entry['predicted_error'] = result['predicted_error']
            
            # Add results if available (first 3 rows)
            pred_results = result.get('predicted_results')
            gt_results = result.get('ground_truth_results')
            if pred_results is not None:
                error_entry['predicted_results'] = pred_results[:3] if pred_results else []
            if gt_results is not None:
                error_entry['ground_truth_results'] = gt_results[:3] if gt_results else []
            
            errors.append(error_entry)
    
    return errors


def main():
    """Main entry point."""
    parser = argparse.ArgumentParser(
        description='Extract evaluation errors for evolution agent analysis'
    )
    parser.add_argument('--agent', required=True,
                       help='Agent name to analyze (required)')
    parser.add_argument('--iteration',
                       help='Iteration to analyze (defaults to most recent)')
    parser.add_argument('--database',
                       help='Comma-separated list of databases (defaults to all)')
    parser.add_argument('--status',
                       help='Status to filter: match, mismatch, pred_error, pred_timeout, or comma-separated list '
                            '(defaults to mismatch,pred_error,pred_timeout)')
    parser.add_argument('--max-samples', type=int, default=30,
                       help='Maximum number of samples to return (default: 30)')
    parser.add_argument('--output-json', required=True,
                       help='Output JSON file path (required)')
    
    args = parser.parse_args()
    
    # Parse databases
    databases = None
    if args.database:
        databases = [db.strip() for db in args.database.split(',')]
    
    # Parse status filter
    if args.status:
        if ',' in args.status:
            status_filter = [s.strip() for s in args.status.split(',')]
        else:
            status_filter = [args.status.strip()]
    else:
        # Default: all error statuses
        status_filter = ['mismatch', 'pred_error', 'pred_timeout']
    
    # Find evaluation files
    eval_files = find_evaluation_files(
        agent=args.agent,
        iteration=args.iteration,
        databases=databases
    )
    
    if not eval_files:
        print(f"Error: No evaluation files found for agent '{args.agent}'", file=sys.stderr)
        if args.iteration:
            print(f"  Iteration: {args.iteration}", file=sys.stderr)
        if databases:
            print(f"  Databases: {databases}", file=sys.stderr)
        return 1
    
    # Collect all errors
    all_errors = []
    status_counts = {}
    total_questions = 0
    databases_analyzed = set()
    
    for eval_file, db_name in eval_files:
        databases_analyzed.add(db_name)
        evaluation = load_evaluation(eval_file)
        
        if not evaluation:
            continue
        
        # Count total questions
        results = evaluation.get('results', {})
        if isinstance(results, dict):
            total_questions += len(results)
        
        # Extract errors
        errors = extract_errors_from_file(evaluation, db_name, status_filter)
        all_errors.extend(errors)
        
        # Count statuses
        for error in errors:
            status = error['status']
            status_counts[status] = status_counts.get(status, 0) + 1
    
    # Random sample if we have too many
    sampled_errors = all_errors
    if len(all_errors) > args.max_samples:
        sampled_errors = random.sample(all_errors, args.max_samples)
    
    # Build output
    output = {
        'agent': args.agent,
        'iteration': args.iteration or find_most_recent_iteration(),
        'filters': {
            'databases': list(databases_analyzed),
            'status': status_filter
        },
        'summary': {
            'databases_analyzed': len(databases_analyzed),
            'total_questions': total_questions,
            'total_errors_found': len(all_errors),
            'errors_returned': len(sampled_errors),
            'breakdown': status_counts
        },
        'errors': sampled_errors
    }
    
    # Write output
    try:
        # Create parent directories if they don't exist
        output_path = Path(args.output_json)
        output_path.parent.mkdir(parents=True, exist_ok=True)
        
        with open(output_path, 'w') as f:
            json.dump(output, f, indent=2)
        print(f"✅ Extracted {len(sampled_errors)} error samples from {len(eval_files)} files")
        print(f"   Written to: {args.output_json}")
    except Exception as e:
        print(f"Error writing output: {e}", file=sys.stderr)
        return 1
    
    return 0


if __name__ == "__main__":
    sys.exit(main())