#!/usr/bin/env python3
"""
Script to generate summary reports from evaluation results.
This script processes all evaluated JSONL files and creates a comprehensive YAML report.
"""

import json
import yaml
import os
import sys
import argparse
from pathlib import Path
from collections import defaultdict
import statistics
from datetime import datetime


def load_jsonl(file_path):
    """Load JSONL file and return list of records."""
    records = []
    try:
        with open(file_path, 'r', encoding='utf-8') as f:
            for line in f:
                line = line.strip()
                if line:
                    records.append(json.loads(line))
    except Exception as e:
        print(f"Error loading {file_path}: {e}")
    return records


def calculate_metrics(records):
    """Calculate metrics from evaluation records."""
    if not records:
        return {'thinker_metrics': {}, 'prediction_metrics': {}}
    
    # Separate thinker metrics from prediction metrics
    thinker_metrics = {}
    prediction_metrics = {}
    
    # Extract thinker metrics (per-thinker evaluations)
    # Filter out unwanted metrics
    excluded_metrics = {'predictions_evaluation', 'judge_ref_score', 'judge_inner_voice_adherence'}
    thinker_metric_keys = set()
    for record in records:
        if 'evaluations' in record:
            for key in record['evaluations'].keys():
                if key not in excluded_metrics:
                    thinker_metric_keys.add(key)
    
    # Calculate thinker metrics
    for metric in thinker_metric_keys:
        values = []
        null_count = 0
        for record in records:
            if 'evaluations' in record and metric in record['evaluations']:
                metric_data = record['evaluations'][metric]
                
                # Handle per-thinker metrics like {"0": [0.0], "1": [0.0]}
                if isinstance(metric_data, dict):
                    for thinker_id, scores in metric_data.items():
                        if isinstance(scores, list) and scores:
                            # Take the last score (most recent evaluation)
                            last_score = scores[-1]
                            if last_score is not None and isinstance(last_score, (int, float)):
                                values.append(float(last_score))
                            elif last_score is None:
                                null_count += 1
        
        if values or null_count > 0:
            thinker_metrics[metric] = {
                'count': len(values),
                'null_count': null_count,
                'mean': round(statistics.mean(values), 4) if values else None,
                'std': round(statistics.stdev(values), 4) if len(values) > 1 else (0.0 if values else None)
            }
    
    # Calculate prediction metrics
    prediction_metrics_data = {}
    for record in records:
        if 'evaluations' in record and 'predictions_evaluation' in record['evaluations']:
            pred_data = record['evaluations']['predictions_evaluation']
            
            if isinstance(pred_data, dict):
                for metric_name, scores in pred_data.items():
                    if isinstance(scores, list) and scores:
                        if metric_name not in prediction_metrics_data:
                            prediction_metrics_data[metric_name] = {'values': [], 'null_count': 0}
                        
                        for score in scores:
                            if score is not None and isinstance(score, (int, float)):
                                prediction_metrics_data[metric_name]['values'].append(float(score))
                            elif score is None:
                                prediction_metrics_data[metric_name]['null_count'] += 1
    
    # Calculate statistics for prediction metrics
    prediction_metrics_summary = {}
    for metric_name, data in prediction_metrics_data.items():
        values = data['values']
        null_count = data['null_count']
        
        if values or null_count > 0:
            prediction_metrics_summary[metric_name] = {
                'count': len(values),
                'null_count': null_count,
                'mean': round(statistics.mean(values), 4) if values else None,
                'std': round(statistics.stdev(values), 4) if len(values) > 1 else (0.0 if values else None)
            }
    
    return {
        'thinker_metrics': thinker_metrics,
        'prediction_metrics': prediction_metrics_summary
    }


def generate_summary_report(output_dir):
    """Generate summary report from all evaluated files."""
    output_path = Path(output_dir)
    report_file = output_path / "report.yaml"
    
    print(f"Creating summary report: {report_file}")
    
    # Find all evaluated JSONL files
    evaluated_files = list(output_path.glob("*_evaluated.jsonl"))
    
    if not evaluated_files:
        print("No evaluated files found!")
        return
    
    # Initialize summary data
    summary = {
        'evaluation_summary': {
            'timestamp': datetime.utcnow().strftime("%Y-%m-%dT%H:%M:%SZ"),
            'total_datasets': len(evaluated_files),
            'total_records': 0,
            'description': {
                'thinker_metrics': 'Metrics evaluated on individual thinker traces from group_traces (per-thinker scores)',
                'prediction_metrics': 'Metrics evaluated on the final predictions (aggregated from group_traces)'
            },
            'datasets_by_num_paths': {}
        }
    }
    
    total_records = 0
    datasets_by_num_paths = defaultdict(list)
    
    # Process each evaluated file
    for evaluated_file in evaluated_files:
        filename = evaluated_file.name
        print(f"Processing: {filename}")
        
        # Extract dataset info from filename
        # Format: run_eval_20250918_234323_aime24_evaluated.jsonl
        parts = filename.replace('_evaluated.jsonl', '').split('_')
        if len(parts) >= 4:
            run_eval_folder = '_'.join(parts[:-1])  # run_eval_20250918_234323_aime24
            dataset_name = parts[-1]  # aime24
        else:
            run_eval_folder = "unknown"
            dataset_name = filename.replace('_evaluated.jsonl', '')
        
        # Create unique key for each run_eval + dataset combination
        unique_key = f"{run_eval_folder}_{dataset_name}"
        
        # Load records
        records = load_jsonl(evaluated_file)
        if not records:
            print(f"Warning: No records found in {filename}")
            continue
        
        total_records += len(records)
        
        # Calculate metrics
        metrics_summary = calculate_metrics(records)
        
        # Find corresponding source file
        source_file = find_source_file(run_eval_folder, dataset_name)
        
        # Load eval settings
        eval_settings = load_eval_settings(run_eval_folder)
        num_paths = eval_settings.get('num_paths', 'unknown')
        max_path_tokens = eval_settings.get('max_path_tokens', 0)
        
        # Create dataset entry
        dataset_entry = {
            'run_eval_folder': run_eval_folder,
            'dataset_name': dataset_name,
            'source_file': str(source_file) if source_file else "unknown",
            'evaluated_file': str(evaluated_file),
            'record_count': len(records),
            'model_path': eval_settings.get('model_path'),
            'num_paths': num_paths,
            'shift': eval_settings.get('shift'),
            'max_path_tokens': max_path_tokens,
            # Metrics evaluated on individual thinker traces from group_traces
            'thinker_metrics': metrics_summary['thinker_metrics'],
            # Metrics evaluated on the final predictions (aggregated from group_traces)
            'prediction_metrics': metrics_summary['prediction_metrics']
        }
        
        # Group by num_paths
        datasets_by_num_paths[num_paths].append((max_path_tokens, unique_key, dataset_entry))
    
    # Sort datasets within each num_paths group by max_path_tokens
    for num_paths in datasets_by_num_paths:
        # Sort by max_path_tokens (ascending), then by unique_key for consistency
        datasets_by_num_paths[num_paths].sort(key=lambda x: (x[0] if isinstance(x[0], (int, float)) else 0, x[1]))
        
        # Convert to final format
        summary['evaluation_summary']['datasets_by_num_paths'][str(num_paths)] = {
            'num_paths': num_paths,
            'datasets': {key: entry for _, key, entry in datasets_by_num_paths[num_paths]}
        }
    
    # Update total records
    summary['evaluation_summary']['total_records'] = total_records
    
    # Write YAML report
    try:
        with open(report_file, 'w', encoding='utf-8') as f:
            yaml.dump(summary, f, default_flow_style=False, sort_keys=False, indent=2)
        print(f"Summary report saved to: {report_file}")
    except Exception as e:
        print(f"Error writing report: {e}")


def find_source_file(run_eval_folder, dataset_name):
    """Find the source group_think_eval.jsonl file."""
    # Look in the experiments directory
    experiments_dir = Path("/Users/fengtingliao/external/group_think_work/group_think_data/experiments")
    
    # Try to find the run_eval folder
    run_eval_path = experiments_dir / run_eval_folder
    if run_eval_path.exists():
        # Look for eval_outputs subdirectory
        eval_outputs_dirs = list(run_eval_path.glob("eval_outputs_*"))
        if eval_outputs_dirs:
            eval_outputs_dir = eval_outputs_dirs[0]  # Take the first one
            source_file = eval_outputs_dir / f"{dataset_name}_group_think_eval.jsonl"
            if source_file.exists():
                return source_file
    
    return None


def get_eval_outputs_dir(run_eval_folder: str) -> Path | None:
    experiments_dir = Path("/Users/fengtingliao/external/group_think_work/group_think_data/experiments")
    run_eval_path = experiments_dir / run_eval_folder
    if run_eval_path.exists():
        eval_outputs_dirs = list(run_eval_path.glob("eval_outputs_*"))
        if eval_outputs_dirs:
            return eval_outputs_dirs[0]
    return None


def load_eval_settings(run_eval_folder: str) -> dict:
    settings = {}
    eval_outputs_dir = get_eval_outputs_dir(run_eval_folder)
    if not eval_outputs_dir:
        return settings
    settings_file = eval_outputs_dir / "eval_setting.yaml"
    if not settings_file.exists():
        return settings
    try:
        with open(settings_file, "r", encoding="utf-8") as f:
            data = yaml.safe_load(f) or {}
            # Extract required fields with defaults
            settings = {
                "model_path": data.get("model_path"),
                "num_paths": data.get("num_paths"),
                "shift": data.get("shift"),
                "max_path_tokens": data.get("max_path_tokens"),
            }
    except Exception:
        pass
    return settings


def main():
    parser = argparse.ArgumentParser(description="Generate summary report from evaluation results")
    parser.add_argument("output_dir", help="Directory containing evaluated JSONL files")
    
    args = parser.parse_args()
    
    if not os.path.exists(args.output_dir):
        print(f"Error: Output directory {args.output_dir} does not exist")
        sys.exit(1)
    
    generate_summary_report(args.output_dir)


if __name__ == "__main__":
    main()
