"""
Compute inter-annotator agreement between two judge outputs.

Usage:
    python inter_annotator_agreement.py <csv_file_1> <csv_file_2>
    
Examples:
    python inter_annotator_agreement.py results_judge1.csv results_judge2.csv
    python inter_annotator_agreement.py results_gpt4o.csv results_gpt4o_mini.csv --output agreement_stats.csv
    python inter_annotator_agreement.py results1.csv results2.csv --output metrics.csv --save_merged True

This script computes various inter-annotator agreement metrics between two CSV files
containing judge scores for the same questions.

Metrics computed:
- Pearson correlation (linear relationship)
- Spearman correlation (monotonic relationship)
- Intraclass Correlation Coefficient (ICC)
- Mean Absolute Error (MAE)
- Root Mean Squared Error (RMSE)
- Cohen's Kappa (for categorized scores)
- Percentage agreement within thresholds (±5, ±10, ±20 points)
"""

import pandas as pd
import numpy as np
from scipy import stats
from sklearn.metrics import cohen_kappa_score


def compute_icc(ratings1, ratings2):
    """
    Compute Intraclass Correlation Coefficient (ICC) for absolute agreement.
    
    Args:
        ratings1: First set of ratings
        ratings2: Second set of ratings
    
    Returns:
        float: ICC value
    """
    n = len(ratings1)
    
    # Combine ratings into a matrix
    ratings = np.column_stack([ratings1, ratings2])
    
    # Calculate mean per subject (row) and per rater (column)
    subject_means = np.mean(ratings, axis=1)
    rater_means = np.mean(ratings, axis=0)
    grand_mean = np.mean(ratings)
    
    # Calculate sum of squares
    ss_total = np.sum((ratings - grand_mean) ** 2)
    ss_rows = 2 * np.sum((subject_means - grand_mean) ** 2)
    ss_error = ss_total - ss_rows
    
    # Calculate mean squares
    ms_rows = ss_rows / (n - 1)
    ms_error = ss_error / n
    
    # Calculate ICC
    if ms_rows + ms_error == 0:
        return 0.0
    icc = (ms_rows - ms_error) / (ms_rows + ms_error)
    
    return icc


def categorize_scores(scores, bins=[0, 30, 70, 100]):
    """
    Categorize continuous scores into bins for Cohen's Kappa.
    
    Default bins:
        0-30: Low
        30-70: Medium  
        70-100: High
    """
    return pd.cut(scores, bins=bins, labels=['low', 'medium', 'high'], include_lowest=True)


def compute_agreement_metrics(df1, df2, judge_columns):
    """
    Compute inter-annotator agreement metrics for judge columns.
    
    Args:
        df1: First DataFrame with judge scores
        df2: Second DataFrame with judge scores
        judge_columns: List of column names containing judge scores
    
    Returns:
        dict: Dictionary of metrics for each judge column
    """
    results = {}
    
    # Merge the two dataframes on question, answer, and question_id to ensure
    # we're comparing the exact same question-answer pairs
    merge_cols = ['question', 'question_id']
    
    # Also merge on 'answer' if it exists in both dataframes
    if 'answer' in df1.columns and 'answer' in df2.columns:
        merge_cols.append('answer')
    
    merged = pd.merge(
        df1,
        df2,
        on=merge_cols,
        suffixes=('_judge1', '_judge2'),
        how='inner'
    )
    
    print(f"\nMatched {len(merged)} question-answer pairs between the two files")
    print(f"Judge 1 had {len(df1)} rows, Judge 2 had {len(df2)} rows")
    print(f"Merged on columns: {', '.join(merge_cols)}")
    
    # Sanity check: if files are identical, we should have same number of rows
    if len(df1) == len(df2) and len(merged) != len(df1):
        print(f"\nWARNING: Input files have the same number of rows ({len(df1)}),")
        print(f"         but merge resulted in {len(merged)} rows.")
        print(f"         This suggests duplicate question-answer pairs exist.")
        if 'answer' not in merge_cols:
            print(f"         Consider ensuring 'answer' column is present for 1-to-1 matching.")
    
    if len(merged) == 0:
        print("ERROR: No matching rows found between the two files!")
        return results
    
    for col in judge_columns:
        col1 = f"{col}_judge1"
        col2 = f"{col}_judge2"
        
        if col1 not in merged.columns or col2 not in merged.columns:
            print(f"Warning: Column '{col}' not found in both files, skipping")
            continue
        
        # Convert to numeric, coercing errors to NaN
        merged[col1] = pd.to_numeric(merged[col1], errors='coerce')
        merged[col2] = pd.to_numeric(merged[col2], errors='coerce')
        
        # Get the scores, dropping NaN values
        scores_df = merged[[col1, col2]].dropna()
        
        if len(scores_df) == 0:
            print(f"Warning: No valid scores for column '{col}', skipping")
            continue
        
        # Explicitly convert to float64 numpy arrays
        scores1 = scores_df[col1].astype(np.float64).values
        scores2 = scores_df[col2].astype(np.float64).values
        
        # Compute various agreement metrics
        metrics = {}
        
        # 1. Pearson correlation (linear relationship)
        pearson_r, pearson_p = stats.pearsonr(scores1, scores2)
        metrics['pearson_r'] = pearson_r
        metrics['pearson_p'] = pearson_p
        
        # 2. Spearman correlation (monotonic relationship)
        spearman_r, spearman_p = stats.spearmanr(scores1, scores2)
        metrics['spearman_r'] = spearman_r
        metrics['spearman_p'] = spearman_p
        
        # 3. Intraclass Correlation Coefficient
        metrics['icc'] = compute_icc(scores1, scores2)
        
        # 4. Mean Absolute Error
        mae = np.mean(np.abs(scores1 - scores2))
        metrics['mae'] = mae
        
        # 5. Root Mean Squared Error
        rmse = np.sqrt(np.mean((scores1 - scores2) ** 2))
        metrics['rmse'] = rmse
        
        # 6. Mean difference (bias)
        mean_diff = np.mean(scores1 - scores2)
        metrics['mean_difference'] = mean_diff
        
        # 7. Standard deviation of differences
        std_diff = np.std(scores1 - scores2)
        metrics['std_difference'] = std_diff
        
        # 8. Percentage agreement within thresholds
        within_5 = np.mean(np.abs(scores1 - scores2) <= 5) * 100
        within_10 = np.mean(np.abs(scores1 - scores2) <= 10) * 100
        within_20 = np.mean(np.abs(scores1 - scores2) <= 20) * 100
        metrics['agreement_within_5'] = within_5
        metrics['agreement_within_10'] = within_10
        metrics['agreement_within_20'] = within_20
        
        # 9. Cohen's Kappa (categorized scores)
        try:
            cat1 = categorize_scores(scores1)
            cat2 = categorize_scores(scores2)
            kappa = cohen_kappa_score(cat1, cat2)
            metrics['cohen_kappa'] = kappa
        except Exception as e:
            print(f"Warning: Could not compute Cohen's Kappa for '{col}': {e}")
            metrics['cohen_kappa'] = None
        
        # 10. Sample size
        metrics['n_samples'] = len(scores1)
        
        results[col] = metrics
    
    return results, merged


def print_metrics(results):
    """Print the agreement metrics in a readable format."""
    for col, metrics in results.items():
        print(f"\n{'='*70}")
        print(f"Agreement Metrics for: {col}")
        print(f"{'='*70}")
        print(f"  Sample Size:              {metrics['n_samples']}")
        print(f"\n  Correlation Metrics:")
        print(f"    Pearson r:              {metrics['pearson_r']:.4f} (p={metrics['pearson_p']:.4e})")
        print(f"    Spearman ρ:             {metrics['spearman_r']:.4f} (p={metrics['spearman_p']:.4e})")
        print(f"    ICC:                    {metrics['icc']:.4f}")
        print(f"\n  Error Metrics:")
        print(f"    Mean Absolute Error:    {metrics['mae']:.4f}")
        print(f"    Root Mean Squared Error:{metrics['rmse']:.4f}")
        print(f"    Mean Difference (bias): {metrics['mean_difference']:.4f}")
        print(f"    Std of Differences:     {metrics['std_difference']:.4f}")
        print(f"\n  Agreement within threshold:")
        print(f"    Within ±5 points:       {metrics['agreement_within_5']:.2f}%")
        print(f"    Within ±10 points:      {metrics['agreement_within_10']:.2f}%")
        print(f"    Within ±20 points:      {metrics['agreement_within_20']:.2f}%")
        if metrics.get('cohen_kappa') is not None:
            print(f"\n  Cohen's Kappa (categorized):{metrics['cohen_kappa']:.4f}")
            # Interpretation
            kappa = metrics['cohen_kappa']
            if kappa < 0:
                interpretation = "Poor (less than chance)"
            elif kappa < 0.20:
                interpretation = "Slight"
            elif kappa < 0.40:
                interpretation = "Fair"
            elif kappa < 0.60:
                interpretation = "Moderate"
            elif kappa < 0.80:
                interpretation = "Substantial"
            else:
                interpretation = "Almost perfect"
            print(f"    Interpretation:         {interpretation}")
        
        # Sanity check: warn if agreement is suspiciously low
        if metrics['mae'] < 0.001 and metrics['pearson_r'] > 0.999:
            print(f"\n  ✓ PERFECT AGREEMENT (likely same file or identical judges)")
        elif metrics['mae'] > 5 and metrics['pearson_r'] < 0.7:
            print(f"\n  ⚠ LOW AGREEMENT - verify files contain matching question-answer pairs")


def save_metrics_to_csv(results, output_file):
    """Save metrics to a CSV file."""
    rows = []
    for col, metrics in results.items():
        row = {'judge_column': col}
        row.update(metrics)
        rows.append(row)
    
    df = pd.DataFrame(rows)
    df.to_csv(output_file, index=False)
    print(f"\nMetrics saved to: {output_file}")


def analyze_agreement(file1, file2, output=None, save_merged=False):
    """
    Analyze inter-annotator agreement between two CSV files.
    
    Args:
        file1: Path to first CSV file
        file2: Path to second CSV file
        output: Optional path to save metrics CSV
        save_merged: If True, save the merged dataframe
    """
    try:
        # Load CSV files
        print(f"Loading {file1}...")
        df1 = pd.read_csv(file1)
        print(f"Loading {file2}...")
        df2 = pd.read_csv(file2)
        
        # Check required columns
        required_cols = ['question', 'question_id']
        for col in required_cols:
            if col not in df1.columns or col not in df2.columns:
                print(f"ERROR: Both CSV files must contain '{col}' column")
                return None
        
        # Check if 'answer' column exists (recommended for accurate matching)
        if 'answer' not in df1.columns or 'answer' not in df2.columns:
            print("WARNING: 'answer' column not found in both files.")
            print("         Matching only on question and question_id may cause issues")
            print("         if the same question appears multiple times with different answers.")
        
        # Find judge columns (numeric columns that aren't 'question' or 'question_id')
        judge_cols = []
        for col in df1.columns:
            if col in ['question', 'answer', 'question_id']:
                continue
            # Check if column exists in both dataframes and has numeric values
            if col in df2.columns:
                try:
                    # Try to convert to numeric
                    test1 = pd.to_numeric(df1[col], errors='coerce')
                    test2 = pd.to_numeric(df2[col], errors='coerce')
                    # Check if we have at least some valid numeric values
                    if test1.notna().sum() > 0 and test2.notna().sum() > 0:
                        judge_cols.append(col)
                except:
                    pass
        
        if not judge_cols:
            print("ERROR: No common numeric judge columns found")
            return None
        
        print(f"\nFound judge columns: {', '.join(judge_cols)}")
        
        # Compute agreement metrics
        results, merged = compute_agreement_metrics(df1, df2, judge_cols)
        
        if not results:
            return None
        
        # Print results
        print_metrics(results)
        
        # Save to CSV if requested
        if output:
            save_metrics_to_csv(results, output)
        
        # Save merged dataframe if requested
        if save_merged:
            merged_output = output.replace('.csv', '_merged.csv') if output else 'merged_judges.csv'
            merged.to_csv(merged_output, index=False)
            print(f"Merged dataframe saved to: {merged_output}")
        
        return results
        
    except Exception as e:
        print(f"ERROR: {e}")
        import traceback
        traceback.print_exc()
        return None


if __name__ == "__main__":
    import fire
    fire.Fire(analyze_agreement)

