#!/usr/bin/env python3
# -*- coding: utf-8 -*-

"""
Calculate and display correlations between different ability measures
"""

import os
import pandas as pd
import numpy as np

def load_capability_data():
    """
    Load model capabilities data
    """
    file_path = "/Users/bytedance/Desktop/QileZhang/llm/IRT/eval/IRT/mix_benchmark/comparison_results/model_capabilities_summary.csv"
    if os.path.exists(file_path):
        return pd.read_csv(file_path)
    else:
        print(f"File not found: {file_path}")
        return None

def calculate_correlations(cap_df):
    """
    Calculate correlations between different abilities
    """
    # Select ability columns
    ability_cols = ['Overall Ability (psi)', 'CEVAL Ability (zeta)', 
                    'CSQA Ability (zeta)', 'MMLU Ability (zeta)']
    
    # Calculate correlation matrix
    corr_matrix = cap_df[ability_cols].corr()
    
    print("Correlation Matrix:")
    print(corr_matrix)
    print("\n")
    
    # Display specific correlations
    print("Specific Correlations:")
    print(f"Overall Ability vs CEVAL Ability: {corr_matrix.loc['Overall Ability (psi)', 'CEVAL Ability (zeta)']:.4f}")
    print(f"Overall Ability vs CSQA Ability: {corr_matrix.loc['Overall Ability (psi)', 'CSQA Ability (zeta)']:.4f}")
    print(f"Overall Ability vs MMLU Ability: {corr_matrix.loc['Overall Ability (psi)', 'MMLU Ability (zeta)']:.4f}")
    print(f"CEVAL Ability vs CSQA Ability: {corr_matrix.loc['CEVAL Ability (zeta)', 'CSQA Ability (zeta)']:.4f}")
    print(f"CEVAL Ability vs MMLU Ability: {corr_matrix.loc['CEVAL Ability (zeta)', 'MMLU Ability (zeta)']:.4f}")
    print(f"CSQA Ability vs MMLU Ability: {corr_matrix.loc['CSQA Ability (zeta)', 'MMLU Ability (zeta)']:.4f}")
    
    return corr_matrix

def create_detailed_correlation_report(corr_matrix, cap_df, output_dir):
    """
    Create a detailed report with correlation analysis
    """
    report_content = """# Detailed Correlation Analysis Report

## Correlation Matrix

"""
    
    # Add correlation matrix to report
    report_content += "|  | Overall Ability (psi) | CEVAL Ability (zeta) | CSQA Ability (zeta) | MMLU Ability (zeta) |\n"
    report_content += "|---|---|---|---|---|\n"
    
    for index in corr_matrix.index:
        row = corr_matrix.loc[index]
        report_content += f"| {index} | {row['Overall Ability (psi)']:.4f} | {row['CEVAL Ability (zeta)']:.4f} | {row['CSQA Ability (zeta)']:.4f} | {row['MMLU Ability (zeta)']:.4f} |\n"
    
    report_content += "\n## Interpretation of Correlations\n\n"
    
    # Interpret the correlations
    overall_ceval = corr_matrix.loc['Overall Ability (psi)', 'CEVAL Ability (zeta)']
    overall_csqa = corr_matrix.loc['Overall Ability (psi)', 'CSQA Ability (zeta)']
    overall_mmlu = corr_matrix.loc['Overall Ability (psi)', 'MMLU Ability (zeta)']
    
    report_content += f"1. **Overall Ability and CEVAL Ability**: Correlation = {overall_ceval:.4f}\n"
    if overall_ceval > 0.7:
        report_content += "   - Strong positive correlation. Models that perform well overall also tend to perform well on CEVAL.\n"
    elif overall_ceval > 0.3:
        report_content += "   - Moderate positive correlation. There's a noticeable relationship between overall performance and CEVAL performance.\n"
    else:
        report_content += "   - Weak correlation. Overall performance and CEVAL performance are not strongly related.\n"
    
    report_content += f"2. **Overall Ability and CSQA Ability**: Correlation = {overall_csqa:.4f}\n"
    if overall_csqa > 0.7:
        report_content += "   - Strong positive correlation. Models that perform well overall also tend to perform well on CSQA.\n"
    elif overall_csqa > 0.3:
        report_content += "   - Moderate positive correlation. There's a noticeable relationship between overall performance and CSQA performance.\n"
    else:
        report_content += "   - Weak correlation. Overall performance and CSQA performance are not strongly related.\n"
    
    report_content += f"3. **Overall Ability and MMLU Ability**: Correlation = {overall_mmlu:.4f}\n"
    if overall_mmlu > 0.7:
        report_content += "   - Strong positive correlation. Models that perform well overall also tend to perform well on MMLU.\n"
    elif overall_mmlu > 0.3:
        report_content += "   - Moderate positive correlation. There's a noticeable relationship between overall performance and MMLU performance.\n"
    else:
        report_content += "   - Weak correlation. Overall performance and MMLU performance are not strongly related.\n"
    
    # Add some statistics about the data
    report_content += "\n## Data Statistics\n\n"
    report_content += f"Number of models analyzed: {len(cap_df)}\n"
    report_content += f"Overall Ability - Mean: {cap_df['Overall Ability (psi)'].mean():.4f}, Std: {cap_df['Overall Ability (psi)'].std():.4f}\n"
    report_content += f"CEVAL Ability - Mean: {cap_df['CEVAL Ability (zeta)'].mean():.4f}, Std: {cap_df['CEVAL Ability (zeta)'].std():.4f}\n"
    report_content += f"CSQA Ability - Mean: {cap_df['CSQA Ability (zeta)'].mean():.4f}, Std: {cap_df['CSQA Ability (zeta)'].std():.4f}\n"
    report_content += f"MMLU Ability - Mean: {cap_df['MMLU Ability (zeta)'].mean():.4f}, Std: {cap_df['MMLU Ability (zeta)'].std():.4f}\n"
    
    # Save report
    report_file = os.path.join(output_dir, "correlation_analysis_report.md")
    with open(report_file, 'w') as f:
        f.write(report_content)
    print(f"Saved correlation analysis report to: {report_file}")

def main():
    # Define output directory
    output_dir = "/Users/bytedance/Desktop/QileZhang/llm/IRT/eval/IRT/mix_benchmark/comparison_results"
    os.makedirs(output_dir, exist_ok=True)
    
    print("Loading model capabilities data...")
    
    # Load capability data
    cap_df = load_capability_data()
    if cap_df is None:
        return
    
    # Calculate correlations
    print("Calculating correlations between abilities...")
    corr_matrix = calculate_correlations(cap_df)
    
    # Create detailed report
    print("Creating detailed correlation analysis report...")
    create_detailed_correlation_report(corr_matrix, cap_df, output_dir)
    
    print(f"\nCorrelation analysis completed. Results saved to: {output_dir}")

if __name__ == "__main__":
    main()