#!/usr/bin/env python3
import json
import argparse
from pathlib import Path
from glob import glob

def extract_metrics_from_paths(base_path_str):
    """Extract metrics from all extracted_eval_scores.json files and create a table."""
    
    # Find all extracted_eval_scores.json files
    # base_path = Path(base_path_str)
    extracted_files = list(glob(f"{base_path_str}/**/extracted_eval_scores.json", recursive=True))
    print(extracted_files)
    extracted_results = []
    
    for file_path in extracted_files:
        file_path = Path(file_path)
        try:
            # Load the extracted scores
            with open(file_path, 'r', encoding='utf-8') as f:
                scores = json.load(f)
            
            # Extract the metrics we need
            judge_gold_score = scores.get('judge_score_max_avg', None)
            collaborativeness = scores.get('collaborative_max_avg', None)
            repetition = scores.get('repetition_max_avg', None)
            
            # Extract dataset name and number of thinkers from JSON
            dataset_name = scores.get("dataset_name")
            
            # If dataset_name is null/None, extract it from the file path
            if dataset_name is None:
                path_str = str(file_path)
                if 'brainteaser' in path_str.lower():
                    dataset_name = "brainteaser"
                elif 'mmlu' in path_str.lower():
                    dataset_name = "mmlu_pro"
                elif 'explore' in path_str.lower():
                    dataset_name = "explore_tom"
                else:
                    dataset_name = "unknown"
            
            num_thinkers_value = scores.get("num_thinkers")
            num_thinkers = str(num_thinkers_value) if num_thinkers_value is not None else "unknown"
            
            result_entry = {
                'number_of_thinkers': num_thinkers,
                'dataset_name': dataset_name,
                'judge_gold_score': judge_gold_score,
                'collaborativeness': collaborativeness,
                'repetition': repetition,
                'file_path': str(file_path)
            }
            
            # Check for duplicates before adding
            if result_entry not in extracted_results:
                extracted_results.append(result_entry)
            else:
                print(f"Skipping duplicate entry: {file_path}")
            
        except (json.JSONDecodeError, FileNotFoundError, KeyError) as e:
            print(f"Error processing {file_path}: {e}")
            continue
    
    return extracted_results

def create_table(table_results):
    """Create and display the results table."""
    if not table_results:
        print("No results found.")
        return
    
    # Display the table
    print("\nExtracted Metrics Table:")
    print("=" * 120)
    
    # Header
    print(f"{'Thinkers':<10} {'Dataset':<30} {'Judge Gold Score':<18} {'Collaborativeness':<18} {'Repetition':<15} {'File Path':<30}")
    print("-" * 120)
    
    # Sort results by number of thinkers (convert to int for proper sorting)
    def sort_key(result):
        try:
            return int(result['number_of_thinkers'])
        except (ValueError, TypeError):
            return float('inf')  # Put non-numeric values at the end
    
    sorted_results = sorted(table_results, key=sort_key)
    
    # Data rows
    for result in sorted_results:
        judge_score = f"{result['judge_gold_score']:.4f}" if result['judge_gold_score'] is not None else "N/A"
        collab = f"{result['collaborativeness']:.4f}" if result['collaborativeness'] is not None else "N/A"
        rep = f"{result['repetition']:.4f}" if result['repetition'] is not None else "N/A"
        dataset_name = result['dataset_name'] or "unknown"
        
        print(f"{result['number_of_thinkers']:<10} {dataset_name:<30} {judge_score:<18} {collab:<18} {rep:<15} {result['file_path']:<30}")
    
    # Save to CSV next to this script
    output_file = str(Path(__file__).parent / "metrics_table.csv")
    with open(output_file, 'w', encoding='utf-8') as f:
        f.write("number_of_thinkers,dataset_name,judge_gold_score,collaborativeness,repetition,file_path\n")
        for result in sorted_results:
            judge_score = result['judge_gold_score'] if result['judge_gold_score'] is not None else ""
            collab = result['collaborativeness'] if result['collaborativeness'] is not None else ""
            rep = result['repetition'] if result['repetition'] is not None else ""
            dataset_name = result['dataset_name'] or "unknown"
            f.write(f"{result['number_of_thinkers']},{dataset_name},{judge_score},{collab},{rep},{result['file_path']}\n")
    
    print(f"\nTable saved to: {output_file}")
    
    # Summary statistics
    unique_datasets = len(set(r['dataset_name'] for r in table_results if r['dataset_name']))
    unique_thinkers = len(set(r['number_of_thinkers'] for r in table_results))
    print("\nSummary:")
    print(f"Total files processed: {len(table_results)}")
    print(f"Unique datasets: {unique_datasets}")
    print(f"Unique thinker counts: {unique_thinkers}")

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Extract metrics from extracted_eval_scores.json files")
    parser.add_argument("--base_path", "-b", 
                        default="/Users/fengtingliao/external/group_think_work/group_think_data/experiments",
                        help="Base path to search for extracted_eval_scores.json files")
    args = parser.parse_args()
    
    results = extract_metrics_from_paths(args.base_path)
    create_table(results)

# python3 extract_metrics_table.py -b "../../group_think_data/experiments/run_eval_20250923_17*/"
# python3 extract_metrics_table.py -b "../../group_think_data/experiments/250924_gt_sim_results/final_res_gt_traces/"
# python3 extract_metrics_table.py -b "../../group_think_data/experiments/run_eval_20250924_103345/"