#!/usr/bin/env python3
"""
Analysis Script for Evaluation Results
Generates comprehensive analysis reports from evaluation results
"""

import json
import os
import argparse
from datetime import datetime
from pathlib import Path
from typing import Dict, List, Any


def load_evaluation(eval_file: str) -> Dict:
    """Load evaluation results"""
    with open(eval_file, 'r') as f:
        return json.load(f)


def load_dev_data(dev_data_file: str) -> Dict[str, Dict]:
    """Load dev data for difficulty analysis"""
    with open(dev_data_file, 'r') as f:
        dev_data = json.load(f)
    return {str(item['question_id']): item for item in dev_data}


def analyze_by_database(results: Dict[str, Dict]) -> List[Dict]:
    """Analyze performance by database with detailed breakdown"""
    db_stats = {}
    
    for question_id, result in results.items():
        db_name = result.get("database")
        if not db_name:
            continue
            
        if db_name not in db_stats:
            db_stats[db_name] = {
                "database": db_name,
                "total": 0,
                "correct": 0,
                "mismatches": 0,
                "pred_errors": 0,
                "pred_timeouts": 0,
                "gt_errors": 0,
                "gt_timeouts": 0,
                "failed_questions": []
            }
        
        stats = db_stats[db_name]
        stats["total"] += 1
        
        if result["matches"]:
            stats["correct"] += 1
        else:
            # Track failure details
            failure_info = {
                "question_id": question_id,
                "status": result["status"],
                "predicted_sql": result["predicted_sql"],
                "ground_truth_sql": result["ground_truth_sql"]
            }
            
            if result["status"] == "mismatch":
                stats["mismatches"] += 1
                failure_info["predicted_results"] = result["predicted_results"]
                failure_info["ground_truth_results"] = result["ground_truth_results"]
            elif result["status"] == "pred_error":
                stats["pred_errors"] += 1
                failure_info["error"] = result["predicted_error"]
            elif result["status"] == "pred_timeout":
                stats["pred_timeouts"] += 1
                failure_info["error"] = result["predicted_error"]
            elif result["status"] == "gt_error":
                stats["gt_errors"] += 1
                failure_info["error"] = result["ground_truth_error"]
            elif result["status"] == "gt_timeout":
                stats["gt_timeouts"] += 1
                failure_info["error"] = result["ground_truth_error"]
            elif result["status"] in ["processing_error", "db_not_found", "both_error"]:
                # Legacy/fallback categories
                stats["pred_errors"] += 1
                failure_info["error"] = result.get("predicted_error", "Unknown error")
            
            stats["failed_questions"].append(failure_info)
    
    # Calculate percentages and sort
    db_list = []
    for db_name, stats in db_stats.items():
        stats["accuracy"] = stats["correct"] / stats["total"] if stats["total"] > 0 else 0
        stats["mismatch_rate"] = stats["mismatches"] / stats["total"] if stats["total"] > 0 else 0
        stats["pred_error_rate"] = stats["pred_errors"] / stats["total"] if stats["total"] > 0 else 0
        stats["pred_timeout_rate"] = stats["pred_timeouts"] / stats["total"] if stats["total"] > 0 else 0
        stats["gt_error_rate"] = stats["gt_errors"] / stats["total"] if stats["total"] > 0 else 0
        stats["gt_timeout_rate"] = stats["gt_timeouts"] / stats["total"] if stats["total"] > 0 else 0
        
        # Calculate data quality issues
        stats["data_quality_issues"] = stats["gt_errors"] + stats["gt_timeouts"]
        stats["prediction_issues"] = stats["pred_errors"] + stats["pred_timeouts"]
        
        # Performance classification
        if stats["accuracy"] >= 0.8:
            stats["performance"] = "excellent"
        elif stats["accuracy"] >= 0.7:
            stats["performance"] = "very_good"
        elif stats["accuracy"] >= 0.6:
            stats["performance"] = "good" 
        elif stats["accuracy"] >= 0.5:
            stats["performance"] = "fair"
        else:
            stats["performance"] = "challenging"
        
        db_list.append(stats)
    
    # Sort by accuracy descending
    return sorted(db_list, key=lambda x: x["accuracy"], reverse=True)


def analyze_by_difficulty(results: Dict[str, Dict], dev_data: Dict[str, Dict]) -> Dict:
    """Analyze performance by difficulty level"""
    difficulty_stats = {
        "simple": {"total": 0, "correct": 0, "sample_size": 0},
        "moderate": {"total": 0, "correct": 0, "sample_size": 0}, 
        "challenging": {"total": 0, "correct": 0, "sample_size": 0}
    }
    
    for question_id, result in results.items():
        if question_id not in dev_data:
            continue
            
        difficulty = dev_data[question_id].get("difficulty", "moderate").lower()
        if difficulty not in difficulty_stats:
            difficulty = "moderate"  # Default fallback
        
        difficulty_stats[difficulty]["total"] += 1
        difficulty_stats[difficulty]["sample_size"] += 1
        
        if result["matches"]:
            difficulty_stats[difficulty]["correct"] += 1
    
    # Calculate accuracies
    for difficulty, stats in difficulty_stats.items():
        if stats["total"] > 0:
            stats["accuracy"] = stats["correct"] / stats["total"]
            if stats["accuracy"] >= 0.7:
                stats["analysis"] = f"Strong performance on {difficulty} questions"
            elif stats["accuracy"] >= 0.6:
                stats["analysis"] = f"Reasonable performance on {difficulty} complexity"
            else:
                stats["analysis"] = f"{difficulty.capitalize()} questions remain challenging"
        else:
            stats["accuracy"] = 0
            stats["analysis"] = f"No {difficulty} questions in sample"
    
    return difficulty_stats


def identify_error_patterns(results: Dict[str, Dict]) -> List[Dict]:
    """Identify common error patterns from failed questions"""
    patterns = []
    
    # Case sensitivity pattern
    case_sensitivity_count = 0
    column_selection_count = 0
    join_issues_count = 0
    
    case_examples = []
    column_examples = []
    join_examples = []
    
    for question_id, result in results.items():
        if result["status"] == "mismatch" and result["predicted_results"] and result["ground_truth_results"]:
            pred_results = result["predicted_results"]
            gt_results = result["ground_truth_results"]
            
            # Simple heuristics for pattern detection
            pred_str = str(pred_results).lower()
            gt_str = str(gt_results).lower()
            
            # Case sensitivity: same content, different case
            if pred_str == gt_str and str(pred_results) != str(gt_results):
                case_sensitivity_count += 1
                if len(case_examples) < 3:
                    case_examples.append({
                        "question_id": question_id,
                        "predicted": str(pred_results)[:100],
                        "ground_truth": str(gt_results)[:100]
                    })
            
            # Column count difference suggests selection issues
            elif len(pred_results) > 0 and len(gt_results) > 0:
                pred_cols = len(pred_results[0]) if pred_results else 0
                gt_cols = len(gt_results[0]) if gt_results else 0
                
                if pred_cols != gt_cols:
                    column_selection_count += 1
                    if len(column_examples) < 3:
                        column_examples.append({
                            "question_id": question_id,
                            "predicted_columns": pred_cols,
                            "ground_truth_columns": gt_cols,
                            "database": result["database"]
                        })
    
    # Add patterns if significant
    if case_sensitivity_count > 0:
        patterns.append({
            "pattern": "Case sensitivity issues",
            "frequency": "high" if case_sensitivity_count > 5 else "medium",
            "count": case_sensitivity_count,
            "impact": "medium",
            "fix_complexity": "low",
            "examples": case_examples,
            "description": "Predicted results match ground truth when case-normalized"
        })
    
    if column_selection_count > 0:
        patterns.append({
            "pattern": "Column selection differences", 
            "frequency": "high" if column_selection_count > 10 else "medium",
            "count": column_selection_count,
            "impact": "medium",
            "fix_complexity": "medium", 
            "examples": column_examples,
            "description": "Different number of columns in result set"
        })
    
    return patterns


def generate_executive_summary(analysis_data: Dict) -> str:
    """Generate executive summary markdown"""
    metadata = analysis_data["metadata"]
    db_analysis = analysis_data["database_analysis"]
    difficulty_analysis = analysis_data["difficulty_analysis"]
    error_patterns = analysis_data["error_patterns"]
    
    # Get top performing databases
    top_dbs = db_analysis[:3]
    challenging_dbs = [db for db in db_analysis if db["performance"] == "challenging"]
    
    summary = f"""# Evaluation Results: Executive Summary

**Date:** {datetime.now().strftime('%B %d, %Y')}  
**Dataset:** BIRD Dev Set ({metadata['total_questions']} questions)  
**Overall Accuracy:** {metadata['accuracy']:.1%} ({metadata['error_counts']['match']} correct)

## 🎯 Key Results

### Performance Breakdown
- **Correct Answers:** {metadata['error_counts']['match']}
- **Mismatches:** {metadata['error_counts']['mismatch']} 
- **Prediction Errors:** {metadata['error_counts']['pred_error']}
- **Prediction Timeouts:** {metadata['error_counts'].get('pred_timeout', 0)}
- **Ground Truth Errors:** {metadata['error_counts']['gt_error']}
- **Ground Truth Timeouts:** {metadata['error_counts'].get('gt_timeout', 0)}

### Database Performance Ranking
"""
    
    for i, db in enumerate(top_dbs, 1):
        summary += f"{i}. **{db['database']}** - {db['accuracy']:.1%} ({db['performance']})\n"
    
    if challenging_dbs:
        summary += f"\n**Challenging Databases:** {', '.join(db['database'] for db in challenging_dbs)}\n"
    
    summary += f"""
### Difficulty Analysis
- **Simple Questions:** {difficulty_analysis['simple']['accuracy']:.1%} ({difficulty_analysis['simple']['correct']}/{difficulty_analysis['simple']['total']})
- **Moderate Questions:** {difficulty_analysis['moderate']['accuracy']:.1%} ({difficulty_analysis['moderate']['correct']}/{difficulty_analysis['moderate']['total']})
- **Challenging Questions:** {difficulty_analysis['challenging']['accuracy']:.1%} ({difficulty_analysis['challenging']['correct']}/{difficulty_analysis['challenging']['total']})

## 🔍 Error Analysis

### Common Failure Patterns
"""
    
    for pattern in error_patterns:
        summary += f"- **{pattern['pattern']}** ({pattern['count']} occurrences, {pattern['fix_complexity']} complexity fix)\n"
    
    summary += f"""
### Cache Performance
- **Ground Truth Cache Hit Rate:** {metadata['cache_stats']['ground_truth_cache']['hit_rate']:.1%}
- **Prediction Cache Hit Rate:** {metadata['cache_stats']['prediction_cache']['hit_rate']:.1%}

## 💡 Recommendations

### Immediate Actions
1. **Address Error Patterns:** Focus on high-frequency, low-complexity fixes
2. **Database-Specific Optimization:** Target challenging databases for improvement
3. **Ground Truth Validation:** {"Investigate ground truth errors" if metadata['error_counts']['gt_error'] > 0 else "Ground truth execution successful"}

---

*Generated from evaluation results on {metadata['timestamp']}*
"""
    
    return summary


def generate_comprehensive_analysis(analysis_data: Dict) -> Dict:
    """Generate comprehensive analysis report"""
    metadata = analysis_data["metadata"]
    
    return {
        "evaluation_metadata": {
            "report_title": "Comprehensive Cached Evaluation Analysis",
            "evaluation_date": datetime.now().strftime('%Y-%m-%d'),
            "report_timestamp": datetime.now().isoformat(),
            "total_questions": metadata["total_questions"],
            "overall_accuracy": metadata["accuracy"],
            "evaluation_method": "SQL execution with dual caching"
        },
        "performance_summary": {
            "accuracy": metadata["accuracy"],
            "total_questions": metadata["total_questions"],
            "correct_answers": metadata["error_counts"]["match"],
            "error_breakdown": metadata["error_counts"],
            "cache_performance": metadata["cache_stats"]
        },
        "database_analysis": analysis_data["database_analysis"],
        "difficulty_analysis": analysis_data["difficulty_analysis"], 
        "error_patterns": analysis_data["error_patterns"],
        "cache_statistics": metadata["cache_stats"],
        "recommendations": {
            "immediate_actions": [
                "Focus on high-frequency error patterns with low fix complexity",
                "Investigate databases with high prediction error rates",
                "Validate ground truth queries if any execution errors occurred"
            ],
            "performance_optimization": [
                "Target challenging databases for schema-specific improvements",
                "Analyze column selection patterns in failed queries",
                "Consider few-shot examples for complex query patterns"
            ]
        }
    }


def main():
    parser = argparse.ArgumentParser(description='Analyze Evaluation Results')
    parser.add_argument('--eval', required=True, help='Path to evaluation JSON file')
    parser.add_argument('--dev_data', required=True, help='Path to dev.json for difficulty analysis')
    parser.add_argument('--output_dir', required=True, help='Output directory for analysis files')
    parser.add_argument('--run_name', help='Run name for output files (default: extract from cached_eval filename)')
    
    args = parser.parse_args()
    
    # Load data
    print("Loading evaluation results...")
    evaluation = load_evaluation(args.eval)
    dev_data = load_dev_data(args.dev_data)
    
    # Extract run name
    if args.run_name:
        run_name = args.run_name
    else:
        run_name = Path(args.eval).stem.replace('_evaluation', '')
    
    print(f"Analyzing results for run: {run_name}")
    
    # Perform analysis
    print("Analyzing by database...")
    database_analysis = analyze_by_database(evaluation["results"])
    
    print("Analyzing by difficulty...")
    difficulty_analysis = analyze_by_difficulty(evaluation["results"], dev_data)
    
    print("Identifying error patterns...")
    error_patterns = identify_error_patterns(evaluation["results"])
    
    # Combine analysis data
    analysis_data = {
        "metadata": evaluation["metadata"],
        "database_analysis": database_analysis,
        "difficulty_analysis": difficulty_analysis,
        "error_patterns": error_patterns
    }
    
    # Generate output files
    os.makedirs(args.output_dir, exist_ok=True)
    
    # 1. Executive Summary
    print("Generating executive summary...")
    executive_summary = generate_executive_summary(analysis_data)
    exec_summary_file = os.path.join(args.output_dir, "executive_summary.md")
    with open(exec_summary_file, 'w') as f:
        f.write(executive_summary)
    
    # 2. Comprehensive Analysis Report
    print("Generating comprehensive analysis...")
    comprehensive_analysis = generate_comprehensive_analysis(analysis_data)
    comprehensive_file = os.path.join(args.output_dir, "comprehensive_analysis_report.json")
    with open(comprehensive_file, 'w') as f:
        json.dump(comprehensive_analysis, f, indent=2)
    
    # 3. Detailed Analysis (raw cached evaluation + analysis)
    print("Generating detailed analysis...")
    detailed_analysis = {
        **evaluation,
        "analysis": analysis_data
    }
    detailed_file = os.path.join(args.output_dir, "detailed_analysis.json")
    with open(detailed_file, 'w') as f:
        json.dump(detailed_analysis, f, indent=2)
    
    print(f"\n=== ANALYSIS COMPLETE ===")
    print(f"Files generated in {args.output_dir}:")
    print(f"  - executive_summary.md")
    print(f"  - comprehensive_analysis_report.json") 
    print(f"  - detailed_analysis.json")
    print(f"\nOverall accuracy: {evaluation['metadata']['accuracy']:.1%}")


if __name__ == "__main__":
    main()