#!/usr/bin/env python3
"""
Script to compute average scores across all benchmark_summary.json files
in subdirectories of the current directory.
"""

import json
import os
import glob
import re
from typing import Dict, List, Any
from collections import defaultdict
import statistics

def find_benchmark_summaries(base_dir: str) -> List[str]:
    """
    Find all benchmark_summary.json files in subdirectories of base_dir.
    
    Args:
        base_dir: Base directory to search in
        
    Returns:
        List of paths to benchmark_summary.json files
    """
    pattern = os.path.join(base_dir, "*", "benchmark_summary.json")
    return glob.glob(pattern)

def find_detailed_reports(base_dir: str) -> List[str]:
    """
    Find all detailed_report.txt files in subdirectories of base_dir.
    
    Args:
        base_dir: Base directory to search in
        
    Returns:
        List of paths to detailed_report.txt files
    """
    pattern = os.path.join(base_dir, "*", "detailed_report.txt")
    return glob.glob(pattern)

def find_query_result_files(base_dir: str) -> List[str]:
    """
    Find all query_*_result.json files in subdirectories of base_dir.
    
    Args:
        base_dir: Base directory to search in
        
    Returns:
        List of paths to query result JSON files
    """
    pattern = os.path.join(base_dir, "*", "query_*_result.json")
    return glob.glob(pattern)

def load_summary_data(file_path: str) -> Dict[str, Any]:
    """
    Load and parse a benchmark_summary.json file.
    
    Args:
        file_path: Path to the JSON file
        
    Returns:
        Dictionary containing the parsed JSON data
    """
    try:
        with open(file_path, 'r', encoding='utf-8') as f:
            return json.load(f)
    except Exception as e:
        print(f"Error loading {file_path}: {e}")
        return {}

def load_query_result_data(file_path: str) -> Dict[str, Any]:
    """
    Load and parse a query result JSON file.
    
    Args:
        file_path: Path to the JSON file
        
    Returns:
        Dictionary containing the parsed JSON data
    """
    try:
        with open(file_path, 'r', encoding='utf-8') as f:
            return json.load(f)
    except Exception as e:
        print(f"Error loading {file_path}: {e}")
        return {}

def extract_quality_dimensions_from_queries(query_files: List[str]) -> Dict[str, List[float]]:
    """
    Extract quality dimension scores from individual query result files.
    
    Args:
        query_files: List of paths to query result JSON files
        
    Returns:
        Dictionary mapping quality dimension names to lists of scores
    """
    quality_dimensions = defaultdict(list)
    
    print(f"Processing {len(query_files)} query files...")
    
    for i, file_path in enumerate(query_files):
        query_data = load_query_result_data(file_path)
        if not query_data:
            continue
        
        # Extract quality scores if they exist (they're nested in detailed_evaluation)
        quality_scores = {}
        if 'detailed_evaluation' in query_data and 'quality_scores' in query_data['detailed_evaluation']:
            quality_scores = query_data['detailed_evaluation']['quality_scores']
        
        # Known quality dimension fields
        quality_fields = [
            'personalization_fidelity', 'factuality', 'citation_quality', 
            'fluency', 'structure', 'temporal_task_accuracy', 'temporal_accuracy',
            'task_accuracy', 'overall_score'
        ]
        
        for field in quality_fields:
            if field in quality_scores and isinstance(quality_scores[field], (int, float)):
                quality_dimensions[field].append(float(quality_scores[field]))
    
    print(f"Extracted quality dimensions from {len(query_files)} queries")
    return dict(quality_dimensions)

def parse_quality_dimensions(file_path: str) -> Dict[str, Dict[str, float]]:
    """
    Parse quality dimensions from detailed_report.txt file.
    
    Args:
        file_path: Path to the detailed report text file
        
    Returns:
        Dictionary containing quality dimension metrics
    """
    try:
        with open(file_path, 'r', encoding='utf-8') as f:
            content = f.read()
        
        # Find the QUALITY DIMENSIONS ANALYSIS section
        quality_section_match = re.search(
            r'QUALITY DIMENSIONS ANALYSIS\n-+\n(.*?)(?=\n\n|\Z)', 
            content, 
            re.DOTALL
        )
        
        if not quality_section_match:
            return {}
        
        quality_text = quality_section_match.group(1)
        quality_metrics = {}
        
        # Parse each quality dimension line
        # Format: "dimension_name: mean ± std_dev"
        for line in quality_text.strip().split('\n'):
            if ':' in line:
                parts = line.split(':')
                if len(parts) == 2:
                    dimension_name = parts[0].strip()
                    value_part = parts[1].strip()
                    
                    # Extract mean and std_dev using regex
                    match = re.match(r'(\d+\.\d+)\s*±\s*(\d+\.\d+)', value_part)
                    if match:
                        mean_val = float(match.group(1))
                        std_val = float(match.group(2))
                        quality_metrics[dimension_name] = {
                            'mean': mean_val,
                            'std_dev': std_val
                        }
        
        return quality_metrics
        
    except Exception as e:
        print(f"Error parsing quality dimensions from {file_path}: {e}")
        return {}

def compute_average_scores(summaries: List[Dict[str, Any]], quality_dimensions_from_queries: Dict[str, List[float]], 
                          quality_reports: List[Dict[str, Dict[str, float]]] = None, 
                          domain_names: List[str] = None) -> Dict[str, Any]:
    """
    Compute average scores across all benchmark summaries and quality dimensions.
    
    Args:
        summaries: List of benchmark summary dictionaries
        quality_dimensions_from_queries: Quality dimension scores extracted from individual queries
        quality_reports: Optional list of quality dimension dictionaries from detailed reports
        domain_names: Optional list of domain names corresponding to summaries
        
    Returns:
        Dictionary containing averaged scores and statistics, including per-domain breakdown
    """
    if not summaries:
        return {}
    
    # Collect all score metrics
    score_metrics = defaultdict(list)
    total_queries_list = []
    
    # Collect context retrieval detailed metrics
    context_retrieval_metrics = defaultdict(list)
    
    # Collect intent evaluation detailed metrics
    intent_metrics = {
        'per_field_precision': defaultdict(list),
        'per_field_recall': defaultdict(list), 
        'per_field_f1': defaultdict(list)
    }
    
    # Collect per-domain results
    per_domain_results = {}
    
    # Collect quality dimensions
    quality_dimensions = defaultdict(list)
    
    # Collect score distributions
    distribution_totals = defaultdict(int)
    
    for i, summary in enumerate(summaries):
        if not summary:  # Skip empty summaries
            continue
        
        # Get domain name for this summary
        domain_name = domain_names[i] if domain_names and i < len(domain_names) else f"domain_{i+1}"
        
        # Initialize per-domain results
        if domain_name not in per_domain_results:
            per_domain_results[domain_name] = {
                'total_queries': 0,
                'average_scores': {},
                'context_retrieval_metrics': {},
                'intent_detailed_metrics': {},
                'score_distribution': {},
                'quality_dimensions': {}
            }
            
        # Collect total queries
        if 'total_queries' in summary:
            total_queries_list.append(summary['total_queries'])
            per_domain_results[domain_name]['total_queries'] = summary['total_queries']
        
        # Collect average scores
        if 'average_scores' in summary:
            per_domain_results[domain_name]['average_scores'] = summary['average_scores'].copy()
            for metric, value in summary['average_scores'].items():
                if isinstance(value, (int, float)):
                    score_metrics[metric].append(value)
        
        # Collect context retrieval precision, recall, F1 from detailed results
        context_retrieval_summary = defaultdict(list)
        if 'detailed_results' in summary:
            for result in summary['detailed_results']:
                if 'detailed_evaluation' in result and 'context_retrieval' in result['detailed_evaluation']:
                    cr = result['detailed_evaluation']['context_retrieval']
                    if 'precision' in cr:
                        context_retrieval_metrics['precision'].append(cr['precision'])
                        context_retrieval_summary['precision'].append(cr['precision'])
                    if 'recall' in cr:
                        context_retrieval_metrics['recall'].append(cr['recall'])
                        context_retrieval_summary['recall'].append(cr['recall'])
                    if 'f1_score' in cr:
                        context_retrieval_metrics['f1_score'].append(cr['f1_score'])
                        context_retrieval_summary['f1_score'].append(cr['f1_score'])
        
        # Store per-domain context retrieval metrics
        for metric, values in context_retrieval_summary.items():
            if values:
                per_domain_results[domain_name]['context_retrieval_metrics'][metric] = {
                    'mean': statistics.mean(values),
                    'count': len(values)
                }
        
        # Collect intent detailed metrics from summary
        if 'intent_detailed_metrics' in summary:
            intent_detailed = summary['intent_detailed_metrics']
            per_domain_results[domain_name]['intent_detailed_metrics'] = intent_detailed.copy()
            
            # Collect for cross-domain aggregation
            for metric_type in ['per_field_precision', 'per_field_recall', 'per_field_f1']:
                if metric_type in intent_detailed:
                    for field_name, value in intent_detailed[metric_type].items():
                        if isinstance(value, (int, float)):
                            intent_metrics[metric_type][field_name].append(value)
        
        # Collect score distributions
        if 'score_distribution' in summary:
            per_domain_results[domain_name]['score_distribution'] = summary['score_distribution'].copy()
            for category, count in summary['score_distribution'].items():
                distribution_totals[category] += count
    
    # Use quality dimensions from individual queries (more accurate)
    for dimension_name, values in quality_dimensions_from_queries.items():
        quality_dimensions[dimension_name].extend(values)
    
    # Process quality reports (as backup/additional source)
    if quality_reports:
        for quality_report in quality_reports:
            for dimension_name, metrics in quality_report.items():
                if 'mean' in metrics:
                    quality_dimensions[dimension_name].append(metrics['mean'])
    
    # Compute averages
    averaged_scores = {}
    for metric, values in score_metrics.items():
        if values:
            averaged_scores[metric] = {
                'mean': statistics.mean(values),
                'median': statistics.median(values),
                'std_dev': statistics.stdev(values) if len(values) > 1 else 0.0,
                'min': min(values),
                'max': max(values),
                'count': len(values)
            }
    
    # Compute context retrieval detailed metrics averages
    context_retrieval_detailed = {}
    for metric, values in context_retrieval_metrics.items():
        if values:
            context_retrieval_detailed[metric] = {
                'mean': statistics.mean(values),
                'median': statistics.median(values),
                'std_dev': statistics.stdev(values) if len(values) > 1 else 0.0,
                'min': min(values),
                'max': max(values),
                'count': len(values)
            }
    
    # Compute quality dimensions averages
    quality_dimensions_averages = {}
    for dimension, values in quality_dimensions.items():
        if values:
            quality_dimensions_averages[dimension] = {
                'mean': statistics.mean(values),
                'median': statistics.median(values),
                'std_dev': statistics.stdev(values) if len(values) > 1 else 0.0,
                'min': min(values),
                'max': max(values),
                'count': len(values)
            }
    
    # Compute intent evaluation detailed metrics averages
    intent_detailed_averages = {}
    for metric_type, fields_data in intent_metrics.items():
        intent_detailed_averages[metric_type] = {}
        for field_name, values in fields_data.items():
            if values:
                intent_detailed_averages[metric_type][field_name] = {
                    'mean': statistics.mean(values),
                    'median': statistics.median(values),
                    'std_dev': statistics.stdev(values) if len(values) > 1 else 0.0,
                    'min': min(values),
                    'max': max(values),
                    'count': len(values)
                }
    
    # Process per-domain quality dimensions from queries
    # Group quality dimensions by domain based on file paths or query indices
    domain_quality_mapping = {}
    if domain_names:
        queries_per_domain = len(quality_dimensions_from_queries.get('overall_score', [])) // len(domain_names) if quality_dimensions_from_queries.get('overall_score') else 0
        
        for dimension_name, all_values in quality_dimensions_from_queries.items():
            for domain_idx, domain_name in enumerate(domain_names):
                if domain_name not in domain_quality_mapping:
                    domain_quality_mapping[domain_name] = {}
                
                # Extract values for this domain (assuming equal distribution)
                start_idx = domain_idx * queries_per_domain
                end_idx = (domain_idx + 1) * queries_per_domain
                domain_values = all_values[start_idx:end_idx] if end_idx <= len(all_values) else all_values[start_idx:]
                
                if domain_values:
                    domain_quality_mapping[domain_name][dimension_name] = {
                        'mean': statistics.mean(domain_values),
                        'median': statistics.median(domain_values),
                        'std_dev': statistics.stdev(domain_values) if len(domain_values) > 1 else 0.0,
                        'min': min(domain_values),
                        'max': max(domain_values),
                        'count': len(domain_values)
                    }
    
    # Add quality dimensions to per-domain results
    for domain_name, quality_dims in domain_quality_mapping.items():
        if domain_name in per_domain_results:
            per_domain_results[domain_name]['quality_dimensions'] = quality_dims
    
    # Create summary statistics
    results = {
        'number_of_domains': len([s for s in summaries if s]),
        'total_queries_across_domains': sum(total_queries_list),
        'queries_per_domain': {
            'mean': statistics.mean(total_queries_list) if total_queries_list else 0,
            'median': statistics.median(total_queries_list) if total_queries_list else 0,
            'values': total_queries_list
        },
        'averaged_scores': averaged_scores,
        'context_retrieval_detailed_metrics': context_retrieval_detailed,
        'intent_detailed_averages': intent_detailed_averages,
        'quality_dimensions_averages': quality_dimensions_averages,
        'aggregated_score_distribution': dict(distribution_totals),
        'overall_distribution_percentage': {
            category: (count / sum(distribution_totals.values()) * 100) 
            for category, count in distribution_totals.items()
        } if sum(distribution_totals.values()) > 0 else {},
        'per_domain_results': per_domain_results
    }
    
    return results

def format_results(results: Dict[str, Any]) -> str:
    """
    Format the results for display.
    
    Args:
        results: Dictionary containing computed results
        
    Returns:
        Formatted string representation
    """
    if not results:
        return "No results to display."
    
    output = []
    output.append("=" * 80)
    output.append("BENCHMARK SUMMARY ANALYSIS")
    output.append("=" * 80)
    output.append("")
    
    # Basic statistics
    output.append(f"Number of domains analyzed: {results['number_of_domains']}")
    output.append(f"Total queries across all domains: {results['total_queries_across_domains']}")
    output.append(f"Average queries per domain: {results['queries_per_domain']['mean']:.1f}")
    output.append("")
    
    # Score metrics
    output.append("AVERAGE SCORES ACROSS ALL DOMAINS:")
    output.append("-" * 50)
    
    if 'averaged_scores' in results:
        for metric, stats in results['averaged_scores'].items():
            output.append(f"\n{metric.replace('_', ' ').title()}:")
            output.append(f"  Mean: {stats['mean']:.4f}")
            output.append(f"  Median: {stats['median']:.4f}")
            output.append(f"  Std Dev: {stats['std_dev']:.4f}")
            output.append(f"  Range: [{stats['min']:.4f}, {stats['max']:.4f}]")
            output.append(f"  Domains: {stats['count']}")
    
    # Intent evaluation detailed metrics
    output.append("\n" + "=" * 50)
    output.append("INTENT EVALUATION DETAILED METRICS:")
    output.append("-" * 50)
    
    if 'intent_detailed_averages' in results:
        for metric_type, fields in results['intent_detailed_averages'].items():
            output.append(f"\n{metric_type.replace('_', ' ').title()}:")
            for field_name, stats in fields.items():
                output.append(f"  {field_name.replace('_', ' ').title()}:")
                output.append(f"    Mean: {stats['mean']:.4f}")
                output.append(f"    Std Dev: {stats['std_dev']:.4f}")
                output.append(f"    Range: [{stats['min']:.4f}, {stats['max']:.4f}]")
                output.append(f"    Domains: {stats['count']}")
    
    # Context retrieval detailed metrics
    output.append("\n" + "=" * 50)
    output.append("CONTEXT RETRIEVAL DETAILED METRICS:")
    output.append("-" * 50)
    
    if 'context_retrieval_detailed_metrics' in results:
        for metric, stats in results['context_retrieval_detailed_metrics'].items():
            output.append(f"\n{metric.replace('_', ' ').title()}:")
            output.append(f"  Mean: {stats['mean']:.4f}")
            output.append(f"  Median: {stats['median']:.4f}")
            output.append(f"  Std Dev: {stats['std_dev']:.4f}")
            output.append(f"  Range: [{stats['min']:.4f}, {stats['max']:.4f}]")
            output.append(f"  Total Queries: {stats['count']}")
    
    # Quality dimensions averages
    output.append("\n" + "=" * 50)
    output.append("QUALITY DIMENSIONS AVERAGES:")
    output.append("-" * 50)
    
    if 'quality_dimensions_averages' in results:
        for dimension, stats in results['quality_dimensions_averages'].items():
            output.append(f"\n{dimension.replace('_', ' ').title()}:")
            output.append(f"  Mean: {stats['mean']:.4f}")
            output.append(f"  Median: {stats['median']:.4f}")
            output.append(f"  Std Dev: {stats['std_dev']:.4f}")
            output.append(f"  Range: [{stats['min']:.4f}, {stats['max']:.4f}]")
            output.append(f"  Domains: {stats['count']}")
    
    # Score distribution
    output.append("\n" + "=" * 50)
    output.append("AGGREGATED SCORE DISTRIBUTION:")
    output.append("-" * 50)
    
    if 'aggregated_score_distribution' in results:
        for category, count in results['aggregated_score_distribution'].items():
            percentage = results['overall_distribution_percentage'].get(category, 0)
            output.append(f"{category.capitalize()}: {count} queries ({percentage:.1f}%)")
    
    # Per-domain results
    output.append("\n" + "=" * 80)
    output.append("PER-DOMAIN DETAILED RESULTS")
    output.append("=" * 80)
    
    if 'per_domain_results' in results:
        for domain_name, domain_data in results['per_domain_results'].items():
            output.append(f"\n{domain_name.upper()} DOMAIN:")
            output.append("-" * 40)
            output.append(f"Total queries: {domain_data.get('total_queries', 0)}")
            
            # Domain average scores
            if 'average_scores' in domain_data and domain_data['average_scores']:
                output.append("\nAverage Scores:")
                for metric, value in domain_data['average_scores'].items():
                    if isinstance(value, (int, float)):
                        output.append(f"  {metric.replace('_', ' ').title()}: {value:.4f}")
            
            # Domain intent detailed metrics
            if 'intent_detailed_metrics' in domain_data and domain_data['intent_detailed_metrics']:
                output.append("\nIntent Evaluation Metrics:")
                for metric_type, fields in domain_data['intent_detailed_metrics'].items():
                    if isinstance(fields, dict) and fields:
                        output.append(f"  {metric_type.replace('_', ' ').title()}:")
                        for field_name, value in fields.items():
                            if isinstance(value, (int, float)):
                                output.append(f"    {field_name.replace('_', ' ').title()}: {value:.4f}")
                    elif isinstance(fields, (int, float)):
                        output.append(f"  {metric_type.replace('_', ' ').title()}: {fields:.4f}")
            
            # Domain context retrieval metrics
            if 'context_retrieval_metrics' in domain_data and domain_data['context_retrieval_metrics']:
                output.append("\nContext Retrieval Metrics:")
                for metric, stats in domain_data['context_retrieval_metrics'].items():
                    output.append(f"  {metric.capitalize()}: {stats['mean']:.4f} (n={stats['count']})")
            
            # Domain quality dimensions
            if 'quality_dimensions' in domain_data and domain_data['quality_dimensions']:
                output.append("\nQuality Dimensions:")
                for dimension, stats in domain_data['quality_dimensions'].items():
                    output.append(f"  {dimension.replace('_', ' ').title()}:")
                    output.append(f"    Mean: {stats['mean']:.4f}")
                    output.append(f"    Range: [{stats['min']:.4f}, {stats['max']:.4f}]")
                    output.append(f"    Queries: {stats['count']}")
            
            # Domain score distribution
            if 'score_distribution' in domain_data and domain_data['score_distribution']:
                output.append("\nScore Distribution:")
                total_domain_queries = sum(domain_data['score_distribution'].values())
                for category, count in domain_data['score_distribution'].items():
                    percentage = (count / total_domain_queries * 100) if total_domain_queries > 0 else 0
                    output.append(f"  {category.capitalize()}: {count} queries ({percentage:.1f}%)")
    
    return "\n".join(output)

def main():
    """Main function to execute the analysis."""
    current_dir = os.getcwd()
    print(f"Searching for benchmark_summary.json files in: {current_dir}")
    
    # Find all benchmark summary files
    summary_files = find_benchmark_summaries(current_dir)
    
    # Find all detailed report files
    report_files = find_detailed_reports(current_dir)
    
    if not summary_files:
        print("No benchmark_summary.json files found in subdirectories.")
        return
    
    print(f"Found {len(summary_files)} benchmark summary files:")
    for file_path in summary_files:
        domain_name = os.path.basename(os.path.dirname(file_path))
        print(f"  - {domain_name}: {file_path}")
    
    print(f"\nFound {len(report_files)} detailed report files:")
    for file_path in report_files:
        domain_name = os.path.basename(os.path.dirname(file_path))
        print(f"  - {domain_name}: {file_path}")
    
    print("\nLoading and analyzing data...")
    
    # Load all summary data and collect domain names
    summaries = []
    domain_names = []
    for file_path in summary_files:
        summary_data = load_summary_data(file_path)
        domain_name = os.path.basename(os.path.dirname(file_path))
        
        if summary_data:
            summaries.append(summary_data)
            domain_names.append(domain_name)
            print(f"  + Loaded {domain_name} ({summary_data.get('total_queries', 0)} queries)")
        else:
            print(f"  ✗ Failed to load {domain_name}")
    
    # Find all query result files
    query_files = find_query_result_files(current_dir)
    
    # Load all quality dimension data from individual query results
    print(f"\nFound {len(query_files)} query result files")
    quality_dimensions_from_queries = extract_quality_dimensions_from_queries(query_files)
    
    print(f"Extracted quality dimensions: {list(quality_dimensions_from_queries.keys())}")
    for dim_name, values in quality_dimensions_from_queries.items():
        print(f"  - {dim_name}: {len(values)} values")
    
    # Load all quality dimension data from detailed reports (as backup)
    quality_reports = []
    for file_path in report_files:
        quality_data = parse_quality_dimensions(file_path)
        if quality_data:
            quality_reports.append(quality_data)
            domain_name = os.path.basename(os.path.dirname(file_path))
            print(f"  + Loaded quality dimensions for {domain_name}")
        else:
            domain_name = os.path.basename(os.path.dirname(file_path))
            print(f"  ✗ Failed to load quality dimensions for {domain_name}")
    
    # Compute average scores
    results = compute_average_scores(summaries, quality_dimensions_from_queries, quality_reports, domain_names)
    
    # Display results
    print("\n" + format_results(results))
    
    # Save results to file
    output_file = "aggregated_benchmark_analysis.json"
    with open(output_file, 'w', encoding='utf-8') as f:
        json.dump(results, f, indent=2, ensure_ascii=False)
    
    print(f"\nDetailed results saved to: {output_file}")

if __name__ == "__main__":
    main()