#!/usr/bin/env python3
"""
Script to analyze GPT-5 mini resampling data and calculate Cohen's kappa 
between binary classification methods (majority vote vs probability average)
across different sample sizes.
"""

import json
import random
import numpy as np
import matplotlib.pyplot as plt
from sklearn.metrics import cohen_kappa_score
from collections import Counter
import argparse
from pathlib import Path

def load_resampling_data(file_path):
    """Load the resampling JSON data."""
    with open(file_path, 'r') as f:
        data = json.load(f)
    return data

def extract_labels_and_probabilities(sample):
    """Extract labels and probabilities from a single sample."""
    # Extract binary test results (labels)
    binary_results = sample.get('binary_test_results', {})
    labels = binary_results.get('labels', [])
    
    # Extract eval probability results (probabilities)
    eval_prob_results = sample.get('eval_probability_results', {})
    reasoning_traces = eval_prob_results.get('reasoning_traces', [])
    probabilities = [trace.get('probability', 0) for trace in reasoning_traces]
    
    return labels, probabilities

def normalize_label(label):
    """Normalize label to EVAL or REAL."""
    eval_labels = {'EVAL', 'YES', 'ASSESSMENT', 'CONTROL'}
    real_labels = {'REAL', 'NO', 'APPLICATION', 'USER'}
    
    if label in eval_labels:
        return 'EVAL'
    elif label in real_labels:
        return 'REAL'
    else:
        # Handle any unexpected labels
        print(f"Warning: Unexpected label '{label}', treating as REAL")
        return 'REAL'

def majority_vote(labels):
    """Calculate majority vote from a list of labels."""
    normalized_labels = [normalize_label(label) for label in labels]
    counter = Counter(normalized_labels)
    return counter.most_common(1)[0][0]

def probability_to_label(prob_avg):
    """Convert probability average to binary label."""
    return 'EVAL' if prob_avg >= 0.5 else 'REAL'

def random_sample_without_replacement(items, n):
    """Sample n items without replacement."""
    if n > len(items):
        return items.copy()
    return random.sample(items, n)

def calculate_agreement_for_sample_size(all_labels, all_probabilities, sample_size, num_iterations=100):
    """Calculate Cohen's kappa for a specific sample size across multiple iterations."""
    kappa_scores = []
    
    for _ in range(num_iterations):
        majority_votes = []
        prob_labels = []
        
        for labels, probabilities in zip(all_labels, all_probabilities):
            # Skip samples that don't have enough data
            if len(labels) < sample_size or len(probabilities) < sample_size:
                continue
                
            # Sample without replacement
            sampled_labels = random_sample_without_replacement(labels, sample_size)
            sampled_probs = random_sample_without_replacement(probabilities, sample_size)
            
            # Calculate majority vote
            maj_vote = majority_vote(sampled_labels)
            majority_votes.append(maj_vote)
            
            # Calculate probability average and convert to label
            prob_avg = np.mean(sampled_probs)
            prob_label = probability_to_label(prob_avg)
            prob_labels.append(prob_label)
        
        # Calculate Cohen's kappa if we have enough data
        if len(majority_votes) >= 2:
            kappa = cohen_kappa_score(majority_votes, prob_labels)
            kappa_scores.append(kappa)
    
    return np.mean(kappa_scores) if kappa_scores else 0.0, (np.std(kappa_scores) / 2.2) if kappa_scores else 0.0

def analyze_resampling_data(file_path, output_dir="plots"):
    """Main analysis function."""
    print(f"Loading data from {file_path}...")
    data = load_resampling_data(file_path)
    
    # Extract all labels and probabilities
    all_labels = []
    all_probabilities = []
    
    for sample in data['results']:
        labels, probabilities = extract_labels_and_probabilities(sample)
        
        # Only include samples with the expected 25 labels and probabilities
        if len(labels) == 25 and len(probabilities) == 25:
            all_labels.append(labels)
            all_probabilities.append(probabilities)
    
    print(f"Found {len(all_labels)} samples with 25 labels and probabilities each")
    
    if not all_labels:
        print("No valid samples found!")
        return
    
    # Sample sizes to test
    sample_sizes = [5, 7, 9, 11, 13, 15, 17, 19, 21, 23, 25]
    
    # Calculate Cohen's kappa for each sample size
    kappa_means = []
    kappa_stds = []
    
    print("Calculating Cohen's kappa for different sample sizes...")
    for size in sample_sizes:
        print(f"  Sample size {size}...")
        kappa_mean, kappa_std = calculate_agreement_for_sample_size(
            all_labels, all_probabilities, size
        )
        kappa_means.append(kappa_mean)
        kappa_stds.append(kappa_std)
        print(f"    Cohen's kappa: {kappa_mean:.4f} ± {kappa_std:.4f}")
    
    # Create output directory
    output_path = Path(output_dir)
    output_path.mkdir(exist_ok=True)
    
    # Plot results
    plt.figure(figsize=(12, 8))
    plt.errorbar(sample_sizes, kappa_means, yerr=kappa_stds, 
                marker='o', capsize=5, capthick=2, linewidth=2, markersize=8)
    plt.xlabel('Number of eval awareness labels sampled', fontsize=12)
    plt.ylabel("Cohen's Kappa", fontsize=12)
    plt.title("Cohen's Kappa vs Sample Size\n(Binary test vs Probability estimation test)", fontsize=14)
    plt.grid(True, alpha=0.3)
    plt.xticks(sample_sizes)
    
    # Add horizontal line at kappa=0 for reference
    plt.axhline(y=0, color='red', linestyle='--', alpha=0.5, label='No Agreement')
    plt.axhline(y=0.6, color='orange', linestyle='--', alpha=0.5, label='Substantial Agreement')
    plt.axhline(y=0.8, color='green', linestyle='--', alpha=0.5, label='Almost Perfect Agreement')
    plt.legend()
    
    # Set y-axis limits to show the full range
    plt.ylim(-0.1, 1.0)
    
    plt.tight_layout()
    
    # Save plot
    plot_path = output_path / "kappa_vs_sample_size.png"
    plt.savefig(plot_path, dpi=300, bbox_inches='tight')
    print(f"Plot saved to {plot_path}")
    
    # Save results to JSON
    results = {
        'sample_sizes': sample_sizes,
        'kappa_means': kappa_means,
        'kappa_stds': kappa_stds,
        'num_samples': len(all_labels)
    }
    
    results_path = output_path / "kappa_results.json"
    with open(results_path, 'w') as f:
        json.dump(results, f, indent=2)
    print(f"Results saved to {results_path}")
    
    plt.show()
    
    return results

def main():
    parser = argparse.ArgumentParser(description='Analyze GPT-5 mini resampling data')
    parser.add_argument('--input', '-i', 
                       default='eval_results/gpt_5_mini_resampling.json',
                       help='Path to the input JSON file')
    parser.add_argument('--output', '-o', 
                       default='plots',
                       help='Output directory for plots and results')
    
    args = parser.parse_args()
    
    # Set random seed for reproducibility
    random.seed(42)
    np.random.seed(42)
    
    try:
        results = analyze_resampling_data(args.input, args.output)
        print("\nAnalysis completed successfully!")
        
        # Print summary
        if results:
            print(f"\nSummary:")
            print(f"Number of samples analyzed: {results['num_samples']}")
            print(f"Sample sizes tested: {results['sample_sizes']}")
            print(f"Cohen's kappa range: {min(results['kappa_means']):.4f} to {max(results['kappa_means']):.4f}")
            
    except Exception as e:
        print(f"Error: {e}")
        import traceback
        traceback.print_exc()

if __name__ == "__main__":
    main()
