#!/usr/bin/env python3
# -*- coding: utf-8 -*-

"""
Summarize model capabilities from mix benchmark results
Create a CSV file with models as rows and capabilities as columns
"""

import os
import pandas as pd
import numpy as np

def load_model_parameters():
    """
    Load model parameters from different benchmark files
    """
    base_path = "yourpath/result_improved_mixed_benchmark/05_model_parameters"
    
    # Load different parameter files
    psi_df = pd.read_csv(os.path.join(base_path, "theta_hat_psi.csv"))
    ceval_df = pd.read_csv(os.path.join(base_path, "theta_hat_zeta_CEVAL.csv"))
    csqa_df = pd.read_csv(os.path.join(base_path, "theta_hat_zeta_CSQA.csv"))
    mmlu_df = pd.read_csv(os.path.join(base_path, "theta_hat_zeta_MMLU.csv"))
    
    return psi_df, ceval_df, csqa_df, mmlu_df

def create_capability_summary(psi_df, ceval_df, csqa_df, mmlu_df, output_dir):
    """
    Create a summary CSV with models as rows and capabilities as columns
    """
    # Get all unique models
    all_models = set()
    for df in [psi_df, ceval_df, csqa_df, mmlu_df]:
        all_models.update(df['model'].values)
    
    all_models = sorted(list(all_models))
    
    # Create summary data
    summary_data = []
    for model in all_models:
        row = {'model': model}
        
        # Add psi (overall ability)
        if model in psi_df['model'].values:
            psi_value = psi_df[psi_df['model'] == model]['theta_hat'].values[0]
            row['Overall Ability (psi)'] = psi_value
        else:
            row['Overall Ability (psi)'] = np.nan
        
        # Add CEVAL zeta
        if model in ceval_df['model'].values:
            ceval_value = ceval_df[ceval_df['model'] == model]['theta_hat'].values[0]
            row['CEVAL Ability (zeta)'] = ceval_value
        else:
            row['CEVAL Ability (zeta)'] = np.nan
        
        # Add CSQA zeta
        if model in csqa_df['model'].values:
            csqa_value = csqa_df[csqa_df['model'] == model]['theta_hat'].values[0]
            row['CSQA Ability (zeta)'] = csqa_value
        else:
            row['CSQA Ability (zeta)'] = np.nan
        
        # Add MMLU zeta
        if model in mmlu_df['model'].values:
            mmlu_value = mmlu_df[mmlu_df['model'] == model]['theta_hat'].values[0]
            row['MMLU Ability (zeta)'] = mmlu_value
        else:
            row['MMLU Ability (zeta)'] = np.nan
        
        summary_data.append(row)
    
    # Create and save the summary dataframe
    summary_df = pd.DataFrame(summary_data)
    
    # Sort by overall ability
    summary_df = summary_df.sort_values('Overall Ability (psi)', ascending=False)
    
    # Save to CSV
    output_file = os.path.join(output_dir, "model_capabilities_summary.csv")
    summary_df.to_csv(output_file, index=False)
    print(f"Saved model capabilities summary to: {output_file}")
    
    return summary_df

def create_detailed_capability_report(summary_df, output_dir):
    """
    Create a detailed report with analysis
    """
    report_content = """# Model Capabilities Summary Report

## Overview

This report summarizes the capabilities of different models based on the improved mixed benchmark IRT analysis. Each model's ability is measured across different benchmarks (CEVAL, CSQA, MMLU) as well as an overall ability score.

## Capability Definitions

- **Overall Ability (psi)**: General model capability across all benchmarks
- **CEVAL Ability (zeta)**: Model capability specifically on CEVAL benchmark
- **CSQA Ability (zeta)**: Model capability specifically on CSQA benchmark
- **MMLU Ability (zeta)**: Model capability specifically on MMLU benchmark

## Model Rankings by Overall Ability

"""
    
    # Add model rankings
    for i, (idx, row) in enumerate(summary_df.iterrows(), 1):
        report_content += f"{i}. **{row['model']}**: Overall Ability = {row['Overall Ability (psi)']:.4f}\n"
    
    report_content += "\n## Detailed Capability Scores\n\n"
    
    # Add detailed table
    report_content += "| Model | Overall Ability (psi) | CEVAL Ability (zeta) | CSQA Ability (zeta) | MMLU Ability (zeta) |\n"
    report_content += "|-------|----------------------|---------------------|--------------------|--------------------|\n"
    
    for _, row in summary_df.iterrows():
        report_content += f"| {row['model']} | {row['Overall Ability (psi)']:.4f} | {row['CEVAL Ability (zeta)']:.4f} | {row['CSQA Ability (zeta)']:.4f} | {row['MMLU Ability (zeta)']:.4f} |\n"
    
    report_content += "\n## Key Insights\n\n"
    
    # Find top models in each category
    top_overall = summary_df.iloc[0]['model']
    top_ceval = summary_df.loc[summary_df['CEVAL Ability (zeta)'].idxmax()]['model']
    top_csqa = summary_df.loc[summary_df['CSQA Ability (zeta)'].idxmax()]['model']
    top_mmlu = summary_df.loc[summary_df['MMLU Ability (zeta)'].idxmax()]['model']
    
    report_content += f"- **Top Overall Model**: {top_overall}\n"
    report_content += f"- **Top CEVAL Model**: {top_ceval}\n"
    report_content += f"- **Top CSQA Model**: {top_csqa}\n"
    report_content += f"- **Top MMLU Model**: {top_mmlu}\n"
    
    # Save report
    report_file = os.path.join(output_dir, "model_capabilities_report.md")
    with open(report_file, 'w') as f:
        f.write(report_content)
    print(f"Saved model capabilities report to: {report_file}")

def main():
    # Define output directory
    output_dir = "yourpath/comparison_results"
    os.makedirs(output_dir, exist_ok=True)
    
    print("Loading model parameters...")
    
    # Load model parameters
    psi_df, ceval_df, csqa_df, mmlu_df = load_model_parameters()
    
    # Create capability summary
    print("Creating model capabilities summary...")
    summary_df = create_capability_summary(psi_df, ceval_df, csqa_df, mmlu_df, output_dir)
    
    # Display summary
    print("\nModel Capabilities Summary:")
    print(summary_df.to_string(index=False))
    
    # Create detailed report
    print("Creating detailed capability report...")
    create_detailed_capability_report(summary_df, output_dir)
    
    print(f"\nAll model capabilities results saved to: {output_dir}")

if __name__ == "__main__":
    main()