#!/usr/bin/env python3
"""
Script to collect and aggregate results from evaluation runs.
"""

import argparse
import json
import os
import pandas as pd


def eval_to_score(eval_name):
    if "healthbench" in eval_name:
        return ["overall_score", "axis:accuracy", "axis:completeness", "theme:complex_responses", "theme:global_health", "theme:health_data_tasks"]
    if "hle" in eval_name:
        return ["score", "calibration_error"]
    if "browsecomp" in eval_name:
        return ["score"]


def find_and_load_results(output_dir, models, evals, tag, seeds=1):
    """Find and load result files matching the exact format: {eval_name}_{model_name}_{tag}.json"""
    results = []

    for model in models:
        for eval_name in evals:
            for seed in range(seeds):
                filename = f"{eval_name}_{model}_{tag}_{seed}.json"
                file_path = os.path.join(output_dir, model, filename)

                if os.path.exists(file_path):
                    try:
                        with open(file_path, 'r') as f:
                            data = json.load(f)

                        # Get the main score (prefer f1_score, fall back to score)
                        scores = eval_to_score(eval_name)
                        score = {k: data.get(k, None) for k in scores}
                        if "calibration_error" in score:
                            score["calibration_error"] = score["calibration_error"] / 100

                        results.append({
                            "eval_name": eval_name,
                            "model_name": model,
                            "score": score,
                            "file_path": file_path,
                            "all_metrics": data
                        })

                    except Exception as e:
                        print(f"Error loading {file_path}: {e}")
                        continue
                else:
                    print(f"Warning: No result file found for {filename}")

    return results


def aggregate_results(results, include_count=True):
    """Aggregate results into a summary table."""
    if not results:
        print("No results to aggregate.")
        return None

    # Create a DataFrame for easy manipulation
    df_data = []
    for result in results:
        for k, v in result["score"].items():
            df_data.append({
                "eval_name": result["eval_name"],
                "model_name": result["model_name"],
                "metric": k,
                "score": v
            })

    df = pd.DataFrame(df_data)

    # Pivot to create a table with models as rows and evals as columns
    pivot_df = df.pivot_table(
        index="model_name",
        columns=["eval_name", "metric"],
        values="score",
        sort=False,
        aggfunc={
            'score': ['mean', 'std', 'count']
        }
    )

    # Reorder columns to group statistics for each metric together
    # First, get the unique combinations of eval_name and metric
    eval_metric_pairs = df[['eval_name', 'metric']].drop_duplicates().values.tolist()

    # Create new column order
    new_columns = []
    for eval_name, metric in eval_metric_pairs:
        new_columns.extend([
            ('mean', eval_name, metric),
            ('std', eval_name, metric),
            ('count', eval_name, metric)
        ])

    # Reorder the columns
    pivot_df = pivot_df.reindex(columns=new_columns)

    # Convert to mean ± std format
    final_df = pd.DataFrame(index=pivot_df.index)
    for eval_name, metric in eval_metric_pairs:
        mean_col = ('mean', eval_name, metric)
        std_col = ('std', eval_name, metric)
        count_col = ('count', eval_name, metric)
        if mean_col in pivot_df.columns and std_col in pivot_df.columns:
            # Check if std values are not NaN before including them
            mean_values = pivot_df[mean_col]
            std_values = pivot_df[std_col]

            result_series = (mean_values * 100).map('{:.2f}'.format)

            # Only include ±std if std is not NaN, and add count as (n)
            count_values = pivot_df[count_col]
            for idx in result_series.index:
                if not pd.isna(std_values.loc[idx]):
                    result_series.loc[idx] += '±' + '{:.2f}'.format(std_values.loc[idx] * 100)
                if not pd.isna(count_values.loc[idx]) and include_count:
                    result_series.loc[idx] += ' ({})'.format(int(count_values.loc[idx]))

            final_df[f"{eval_name}-{metric}"] = result_series
        elif mean_col in pivot_df.columns:
            result_series = (pivot_df[mean_col] * 100).map('{:.2f}'.format)
            count_values = pivot_df[count_col] if count_col in pivot_df.columns else None
            if count_values is not None:
                for idx in result_series.index:
                    if not pd.isna(count_values.loc[idx]) and include_count:
                        result_series.loc[idx] += ' ({})'.format(int(count_values.loc[idx]))
            final_df[f"{eval_name}-{metric}"] = result_series

    return final_df


def main():
    parser = argparse.ArgumentParser(
        description="Collect and aggregate evaluation results from JSON files."
    )
    parser.add_argument(
        "--models",
        type=str,
        required=True,
        help="Comma-separated list of model names to collect results for"
    )
    parser.add_argument(
        "--evals",
        type=str,
        default="hle_text,browsecomp,healthbench_hard",
        help="Comma-separated list of evaluation names to collect results for"
    )
    parser.add_argument(
        "--tag",
        type=str,
        default="v1_300",
        help="Tag to filter results by"
    )
    parser.add_argument(
        "--output-dir",
        type=str,
        default="simple-evals/outputs",
        help="Directory containing result files (default: simple-evals/outputs)"
    )
    parser.add_argument(
        "--detailed",
        action="store_true",
        help="Show detailed metrics for each result"
    )
    parser.add_argument(
        "--include_count",
        action="store_true",
        help="Include number of seeds"
    )
    parser.add_argument(
        "--output-csv",
        type=str,
        help="Save aggregated results to CSV file"
    )
    parser.add_argument(
        "--seeds",
        type=int,
        default=3,
        help="Number of seeds to collect results for"
    )

    args = parser.parse_args()

    # Parse comma-separated arguments
    models = [m.strip() for m in args.models.split(",")]
    evals = [e.strip() for e in args.evals.split(",")]

    print(f"Collecting results for:")
    print(f"  Models: {models}")
    print(f"  Evaluations: {evals}")
    print(f"  Tag: {args.tag}")
    print(f"  Output directory: {args.output_dir}")
    print(f"  Seeds: {args.seeds}")

    # Find and load results
    results = find_and_load_results(args.output_dir, models, evals, args.tag, args.seeds)
    print(f"\nSuccessfully loaded {len(results)} results")

    if not results:
        print("No valid results could be loaded.")
        return

    # Show detailed results if requested
    if args.detailed:
        print("\nDetailed results:")
        for result in results:
            print(f"\n{result['eval_name']}_{result['model_name']}:")
            print(f"  Score: {result['score']}")
            print(f"  File: {result['file_path']}")
            if 'all_metrics' in result:
                metrics = result['all_metrics']
                for key, value in metrics.items():
                    if key not in ['score', 'f1_score']:
                        print(f"  {key}: {value}")

    # Aggregate and display results
    aggregated = aggregate_results(results, include_count=args.include_count)
    if aggregated is not None:
        print("\nAggregated Results:")
        print(aggregated.to_string(float_format='%.4f'))

        # Save to CSV if requested
        if args.output_csv:
            aggregated.to_csv(args.output_csv)
            print(f"\nResults saved to {args.output_csv}")

    print(aggregated.to_csv())


if __name__ == "__main__":
    main()
