#!/usr/bin/env python3
# -*- coding: utf-8 -*-

"""
Visualize model capabilities from mix benchmark results
Create charts showing model abilities across different benchmarks
"""

import os
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

# Set style for plots
plt.rcParams['font.family'] = ['Arial', 'DejaVu Sans', 'Liberation Sans']
plt.rcParams['axes.unicode_minus'] = False
sns.set_style("whitegrid")

def load_capability_data():
    """
    Load model capabilities data
    """
    file_path = "yourpath/comparison_results/model_capabilities_summary.csv"
    if os.path.exists(file_path):
        return pd.read_csv(file_path)
    else:
        print(f"File not found: {file_path}")
        return None

def create_overall_ability_chart(cap_df, output_dir):
    """
    Create a chart showing overall abilities of top models
    """
    # Get top 15 models by overall ability
    top_models = cap_df.head(15)
    
    plt.figure(figsize=(14, 10))
    
    # Create bar chart
    bars = plt.bar(range(len(top_models)), top_models['Overall Ability (psi)'], 
                   color='red', alpha=0.7, edgecolor='black', linewidth=1)
    
    # Add value labels on bars
    for i, (bar, value) in enumerate(zip(bars, top_models['Overall Ability (psi)'])):
        plt.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.01, 
                f'{value:.2f}', ha='center', va='bottom', fontsize=12)
    
    # Formatting
    plt.xlabel("Models", fontsize=16)
    plt.ylabel("Overall Ability (psi)", fontsize=16)
    plt.title("Top 15 Models by Overall Ability", fontsize=20, pad=20)
    plt.xticks(range(len(top_models)), [model.split('_')[0] for model in top_models['model']], 
               rotation=45, ha='right', fontsize=12)
    plt.yticks(fontsize=14)
    plt.grid(True, alpha=0.3)
    
    # Add horizontal line at y=0
    plt.axhline(y=0, color='black', linestyle='-', linewidth=0.5)
    
    # Save plot
    plt.tight_layout()
    output_path = os.path.join(output_dir, "top_models_overall_ability.png")
    plt.savefig(output_path, dpi=300, bbox_inches="tight")
    plt.close()
    print(f"Saved overall ability chart to: {output_path}")

def create_benchmark_comparison_chart(cap_df, output_dir):
    """
    Create a chart comparing models across different benchmarks
    """
    # Get top 10 models by overall ability
    top_models = cap_df.head(10)
    
    plt.figure(figsize=(16, 10))
    
    # Define benchmark columns
    benchmarks = ['Overall Ability (psi)', 'CEVAL Ability (zeta)', 
                  'CSQA Ability (zeta)', 'MMLU Ability (zeta)']
    benchmark_labels = ['Overall', 'CEVAL', 'CSQA', 'MMLU']
    
    # Define colors for each benchmark
    colors = ['red', 'blue', 'green', 'purple']
    
    # Plot each benchmark
    x_pos = np.arange(len(top_models))
    width = 0.2
    
    for i, (benchmark, label, color) in enumerate(zip(benchmarks, benchmark_labels, colors)):
        values = top_models[benchmark].values
        plt.bar(x_pos + i*width, values, width, label=label, color=color, alpha=0.8)
    
    # Formatting
    plt.xlabel("Models", fontsize=16)
    plt.ylabel("Ability Score", fontsize=16)
    plt.title("Model Abilities Across Different Benchmarks (Top 10 Models)", fontsize=20, pad=20)
    plt.xticks(x_pos + width*1.5, [model.split('_')[0] for model in top_models['model']], 
               rotation=45, ha='right', fontsize=12)
    plt.yticks(fontsize=14)
    plt.legend(loc='upper right', fontsize=14)
    plt.grid(True, alpha=0.3)
    
    # Add horizontal line at y=0
    plt.axhline(y=0, color='black', linestyle='-', linewidth=0.5)
    
    # Save plot
    plt.tight_layout()
    output_path = os.path.join(output_dir, "benchmark_comparison.png")
    plt.savefig(output_path, dpi=300, bbox_inches="tight")
    plt.close()
    print(f"Saved benchmark comparison chart to: {output_path}")

def create_ability_correlation_heatmap(cap_df, output_dir):
    """
    Create a heatmap showing correlations between different abilities
    """
    # Select ability columns
    ability_cols = ['Overall Ability (psi)', 'CEVAL Ability (zeta)', 
                    'CSQA Ability (zeta)', 'MMLU Ability (zeta)']
    
    # Calculate correlation matrix
    corr_matrix = cap_df[ability_cols].corr()
    
    plt.figure(figsize=(10, 8))
    
    # Create heatmap
    sns.heatmap(corr_matrix, annot=True, cmap='coolwarm', center=0,
                square=True, linewidths=0.5, cbar_kws={"shrink": .8})
    
    # Formatting
    plt.title("Correlation Between Different Ability Measures", fontsize=18, pad=20)
    plt.xticks(rotation=45, ha='right', fontsize=12)
    plt.yticks(rotation=0, fontsize=12)
    
    # Save plot
    plt.tight_layout()
    output_path = os.path.join(output_dir, "ability_correlation_heatmap.png")
    plt.savefig(output_path, dpi=300, bbox_inches="tight")
    plt.close()
    print(f"Saved ability correlation heatmap to: {output_path}")

def create_detailed_model_analysis(cap_df, output_dir):
    """
    Create detailed analysis of model performance
    """
    # Create a comprehensive analysis
    analysis_content = """# Detailed Model Capability Analysis

## Top Performers by Overall Ability

"""
    
    # Top 10 models by overall ability
    top_10 = cap_df.head(10)
    analysis_content += "| Rank | Model | Overall Ability | CEVAL Ability | CSQA Ability | MMLU Ability |\n"
    analysis_content += "|------|-------|----------------|---------------|--------------|--------------|\n"
    
    for i, (_, row) in enumerate(top_10.iterrows(), 1):
        analysis_content += f"| {i} | {row['model']} | {row['Overall Ability (psi)']:.4f} | {row['CEVAL Ability (zeta)']:.4f} | {row['CSQA Ability (zeta)']:.4f} | {row['MMLU Ability (zeta)']:.4f} |\n"
    
    analysis_content += "\n## Top Performers by Individual Benchmarks\n\n"
    
    # Top models by individual benchmarks
    benchmarks = ['CEVAL Ability (zeta)', 'CSQA Ability (zeta)', 'MMLU Ability (zeta)']
    benchmark_names = ['CEVAL', 'CSQA', 'MMLU']
    
    for benchmark, name in zip(benchmarks, benchmark_names):
        top_by_benchmark = cap_df.nlargest(5, benchmark)
        analysis_content += f"### Top 5 Models on {name}\n\n"
        analysis_content += "| Rank | Model | Ability Score |\n"
        analysis_content += "|------|-------|---------------|\n"
        for i, (_, row) in enumerate(top_by_benchmark.iterrows(), 1):
            analysis_content += f"| {i} | {row['model']} | {row[benchmark]:.4f} |\n"
        analysis_content += "\n"
    
    analysis_content += "## Performance Insights\n\n"
    
    # Find models with most balanced performance
    cap_df['ability_range'] = cap_df[benchmarks].max(axis=1) - cap_df[benchmarks].min(axis=1)
    balanced_models = cap_df.nsmallest(5, 'ability_range')
    
    analysis_content += "### Most Balanced Performers (Lowest Ability Range)\n\n"
    analysis_content += "| Model | Ability Range | Overall Ability |\n"
    analysis_content += "|-------|---------------|----------------|\n"
    for _, row in balanced_models.iterrows():
        analysis_content += f"| {row['model']} | {row['ability_range']:.4f} | {row['Overall Ability (psi)']:.4f} |\n"
    
    analysis_content += "\n### Highest Variance Performers (Highest Ability Range)\n\n"
    varied_models = cap_df.nlargest(5, 'ability_range')
    analysis_content += "| Model | Ability Range | Overall Ability |\n"
    analysis_content += "|-------|---------------|----------------|\n"
    for _, row in varied_models.iterrows():
        analysis_content += f"| {row['model']} | {row['ability_range']:.4f} | {row['Overall Ability (psi)']:.4f} |\n"
    
    # Save analysis
    analysis_file = os.path.join(output_dir, "detailed_model_analysis.md")
    with open(analysis_file, 'w') as f:
        f.write(analysis_content)
    print(f"Saved detailed model analysis to: {analysis_file}")

def main():
    # Define output directory
    output_dir = "yourpath/comparison_results"
    os.makedirs(output_dir, exist_ok=True)
    
    print("Loading model capabilities data...")
    
    # Load capability data
    cap_df = load_capability_data()
    if cap_df is None:
        return
    
    # Create overall ability chart
    print("Creating overall ability chart...")
    create_overall_ability_chart(cap_df, output_dir)
    
    # Create benchmark comparison chart
    print("Creating benchmark comparison chart...")
    create_benchmark_comparison_chart(cap_df, output_dir)
    
    # Create ability correlation heatmap
    print("Creating ability correlation heatmap...")
    create_ability_correlation_heatmap(cap_df, output_dir)
    
    # Create detailed analysis
    print("Creating detailed model analysis...")
    create_detailed_model_analysis(cap_df, output_dir)
    
    print(f"\nAll model capabilities visualizations saved to: {output_dir}")

if __name__ == "__main__":
    main()