"""
Batch compute inter-annotator agreement for all rejudged CSV files vs originals.

Usage:
    python batch_iaa.py --eval_result_dir ./eval_result --output iaa_summary.csv

This script finds all files with "_rejudged.csv" suffix, compares them to their
original versions, and computes average IAA metrics across all pairs.
"""

import os
import glob
import pandas as pd
import numpy as np
from pathlib import Path
import logging

from inter_annotator_agreement import compute_agreement_metrics


def find_rejudged_pairs(eval_result_dir):
    """
    Find all rejudged CSV files and their corresponding originals.
    
    Args:
        eval_result_dir: Root directory to search for CSV files
    
    Returns:
        List of tuples: [(original_path, rejudged_path), ...]
    """
    pairs = []
    
    # Find all rejudged files
    rejudged_pattern = os.path.join(eval_result_dir, "**/*_rejudged.csv")
    rejudged_files = glob.glob(rejudged_pattern, recursive=True)
    
    logging.info(f"Found {len(rejudged_files)} rejudged files")
    
    for rejudged_path in rejudged_files:
        # Construct original file path
        original_path = rejudged_path.replace("_rejudged.csv", ".csv")
        
        # Check if original exists
        if os.path.exists(original_path):
            pairs.append((original_path, rejudged_path))
            logging.info(f"  Pair found: {os.path.basename(original_path)} <-> {os.path.basename(rejudged_path)}")
        else:
            logging.warning(f"  Original not found for: {rejudged_path}")
    
    return pairs


def compute_iaa_for_pair(original_path, rejudged_path):
    """
    Compute IAA metrics for a single pair of files.
    
    Args:
        original_path: Path to original CSV
        rejudged_path: Path to rejudged CSV
    
    Returns:
        dict: IAA metrics or None if failed
    """
    try:
        logging.info(f"\nProcessing pair:")
        logging.info(f"  Original: {original_path}")
        logging.info(f"  Rejudged: {rejudged_path}")
        
        # Load CSV files
        df1 = pd.read_csv(original_path)
        df2 = pd.read_csv(rejudged_path)
        
        # Find judge columns
        judge_cols = []
        for col in df1.columns:
            if col in ['question', 'answer', 'question_id']:
                continue
            if col in df2.columns:
                try:
                    test1 = pd.to_numeric(df1[col], errors='coerce')
                    test2 = pd.to_numeric(df2[col], errors='coerce')
                    if test1.notna().sum() > 0 and test2.notna().sum() > 0:
                        judge_cols.append(col)
                except:
                    pass
        
        if not judge_cols:
            logging.warning("  No common judge columns found, skipping")
            return None
        
        logging.info(f"  Judge columns: {', '.join(judge_cols)}")
        
        # Compute metrics
        results, merged = compute_agreement_metrics(df1, df2, judge_cols)
        
        if not results:
            logging.warning("  Failed to compute metrics")
            return None
        
        # Add file information to results
        for col, metrics in results.items():
            metrics['original_file'] = os.path.basename(original_path)
            metrics['rejudged_file'] = os.path.basename(rejudged_path)
            metrics['file_pair'] = os.path.basename(original_path).replace('.csv', '')
        
        return results
        
    except Exception as e:
        logging.error(f"  Error processing pair: {e}")
        import traceback
        traceback.print_exc()
        return None


def aggregate_metrics(all_results):
    """
    Aggregate metrics across all file pairs.
    
    Args:
        all_results: List of dicts, each containing metrics for different judge columns
    
    Returns:
        pd.DataFrame with individual and average metrics
    """
    all_rows = []
    
    # Collect all individual results
    for results in all_results:
        for judge_col, metrics in results.items():
            row = {'judge_column': judge_col}
            row.update(metrics)
            all_rows.append(row)
    
    if not all_rows:
        return None
    
    df = pd.DataFrame(all_rows)
    
    # Compute averages for each judge column
    numeric_cols = [
        'pearson_r', 'spearman_r', 'icc', 'mae', 'rmse',
        'mean_difference', 'std_difference',
        'agreement_within_5', 'agreement_within_10', 'agreement_within_20',
        'cohen_kappa', 'n_samples'
    ]
    
    # Filter to only columns that exist
    numeric_cols = [col for col in numeric_cols if col in df.columns]
    
    # Compute average for each judge column across all file pairs
    summary_rows = []
    for judge_col in df['judge_column'].unique():
        subset = df[df['judge_column'] == judge_col]
        
        avg_row = {
            'judge_column': judge_col,
            'file_pair': 'AVERAGE',
            'n_file_pairs': len(subset)
        }
        
        for col in numeric_cols:
            if col in subset.columns:
                avg_row[f'{col}_mean'] = subset[col].mean()
                avg_row[f'{col}_std'] = subset[col].std()
                avg_row[f'{col}_min'] = subset[col].min()
                avg_row[f'{col}_max'] = subset[col].max()
        
        summary_rows.append(avg_row)
    
    summary_df = pd.DataFrame(summary_rows)
    
    return df, summary_df


def print_summary(summary_df):
    """Print a readable summary of the results."""
    if summary_df is None or len(summary_df) == 0:
        print("No results to summarize")
        return
    
    print("\n" + "="*80)
    print("INTER-ANNOTATOR AGREEMENT SUMMARY")
    print("="*80)
    
    for _, row in summary_df.iterrows():
        judge_col = row['judge_column']
        n_pairs = row['n_file_pairs']
        
        print(f"\n{judge_col.upper()} (averaged over {n_pairs} file pairs):")
        print("-" * 70)
        
        if 'pearson_r_mean' in row:
            print(f"  Correlation Metrics:")
            print(f"    Pearson r:     {row['pearson_r_mean']:.4f} ± {row.get('pearson_r_std', 0):.4f}")
            print(f"    Spearman ρ:    {row['spearman_r_mean']:.4f} ± {row.get('spearman_r_std', 0):.4f}")
            print(f"    ICC:           {row['icc_mean']:.4f} ± {row.get('icc_std', 0):.4f}")
        
        if 'mae_mean' in row:
            print(f"\n  Error Metrics:")
            print(f"    MAE:           {row['mae_mean']:.4f} ± {row.get('mae_std', 0):.4f}")
            print(f"    RMSE:          {row['rmse_mean']:.4f} ± {row.get('rmse_std', 0):.4f}")
        
        if 'agreement_within_10_mean' in row:
            print(f"\n  Agreement within threshold:")
            print(f"    Within ±5:     {row['agreement_within_5_mean']:.2f}% ± {row.get('agreement_within_5_std', 0):.2f}%")
            print(f"    Within ±10:    {row['agreement_within_10_mean']:.2f}% ± {row.get('agreement_within_10_std', 0):.2f}%")
            print(f"    Within ±20:    {row['agreement_within_20_mean']:.2f}% ± {row.get('agreement_within_20_std', 0):.2f}%")
        
        if 'cohen_kappa_mean' in row and pd.notna(row['cohen_kappa_mean']):
            print(f"\n  Cohen's Kappa: {row['cohen_kappa_mean']:.4f} ± {row.get('cohen_kappa_std', 0):.4f}")


def main(eval_result_dir="./eval_result", output="iaa_summary.csv", detailed_output="iaa_detailed.csv"):
    """
    Batch compute IAA for all rejudged files vs originals.
    
    Args:
        eval_result_dir: Root directory containing eval result CSV files
        output: Path to save summary statistics CSV
        detailed_output: Path to save detailed results for all file pairs
    """
    logging.basicConfig(
        level=logging.INFO,
        format='%(asctime)s - %(levelname)s - %(message)s'
    )
    
    logging.info(f"Searching for rejudged files in: {eval_result_dir}")
    
    # Find all pairs
    pairs = find_rejudged_pairs(eval_result_dir)
    
    if not pairs:
        logging.error("No rejudged/original pairs found!")
        return
    
    logging.info(f"\nFound {len(pairs)} file pairs to process\n")
    
    # Compute IAA for each pair
    all_results = []
    for original_path, rejudged_path in pairs:
        results = compute_iaa_for_pair(original_path, rejudged_path)
        if results:
            all_results.append(results)
    
    if not all_results:
        logging.error("No results computed!")
        return
    
    logging.info(f"\nSuccessfully computed metrics for {len(all_results)} pairs")
    
    # Aggregate metrics
    detailed_df, summary_df = aggregate_metrics(all_results)
    
    # Save detailed results
    if detailed_df is not None:
        detailed_df.to_csv(detailed_output, index=False)
        logging.info(f"\nDetailed results saved to: {detailed_output}")
    
    # Save summary
    if summary_df is not None:
        summary_df.to_csv(output, index=False)
        logging.info(f"Summary statistics saved to: {output}")
        
        # Print summary
        print_summary(summary_df)
    
    return summary_df


if __name__ == "__main__":
    import fire
    fire.Fire(main)


