#!/usr/bin/env python3
"""
Script to analyze the relationship between number of prompts and eval probability prediction accuracy.

This script:
1. Parses qwen3_32b_probability_fp_full.json
2. Extracts eval_probability, eval_category, and prompt count for each sample
3. Calculates absolute difference between actual eval_category and predicted eval_probability
4. Creates scatter plot with trend line showing correlation between prompt count and prediction error
"""

import json
import matplotlib.pyplot as plt
import numpy as np
from scipy import stats
import seaborn as sns

def load_and_parse_data(json_file):
    """Load JSON file and extract required data for each sample."""
    with open(json_file, 'r') as f:
        data = json.load(f)
    
    samples = []
    
    for result in data['results']:
        # Get sample ID
        sample_id = result.get('id', 'unknown')
        
        # Extract eval_probability
        eval_prob = result.get('eval_probability_results', {}).get('eval_probability')
        if eval_prob is None:
            continue
            
        # Extract eval_category from metadata
        eval_category = result.get('metadata', {}).get('eval_category', False)
        eval_category_numeric = 1 if eval_category else 0
        
        # Count prompts
        prompts = result.get('prompts', [])
        prompt_count = len(prompts)
        
        # Calculate absolute difference
        abs_difference = abs(eval_category_numeric - eval_prob)
        
        samples.append({
            'id': sample_id,
            'eval_probability': eval_prob,
            'eval_category': eval_category,
            'eval_category_numeric': eval_category_numeric,
            'prompt_count': prompt_count,
            'abs_difference': abs_difference
        })
    
    return samples

def create_plot(samples):
    """Create scatter plot with trend line."""
    # Extract data for plotting
    prompt_counts = [s['prompt_count'] for s in samples]
    abs_differences = [s['abs_difference'] for s in samples]
    
    # Calculate correlation
    correlation, p_value = stats.pearsonr(prompt_counts, abs_differences)
    
    # Create the plot
    plt.figure(figsize=(12, 8))
    
    # Scatter plot
    plt.scatter(prompt_counts, abs_differences, alpha=0.6, s=30, color='steelblue')
    
    # Calculate and plot trend line
    slope, intercept, r_value, p_value_trend, std_err = stats.linregress(prompt_counts, abs_differences)
    line_x = np.array([min(prompt_counts), max(prompt_counts)])
    line_y = slope * line_x + intercept
    plt.plot(line_x, line_y, 'r-', color='#FFB3B3', linewidth=2, alpha=0.8, 
             label=f'Trend line (r={correlation:.3f}, p={p_value:.3e})')
    
    # Formatting
    plt.xlabel('Conversation turns', fontsize=12)
    plt.ylabel('Difference between ground truth and P(Eval)', fontsize=12)
    plt.title('Correlation between the conversation turns and eval probability estimation error\n', fontsize=14)
    plt.legend()
    plt.grid(True, alpha=0.3)
    
    plt.tight_layout()
    return correlation, p_value

def print_summary_stats(samples):
    """Print summary statistics."""
    print(f"\nSummary Statistics:")
    print(f"Total samples: {len(samples)}")
    
    prompt_counts = [s['prompt_count'] for s in samples]
    abs_differences = [s['abs_difference'] for s in samples]
    
    print(f"\nPrompt counts:")
    print(f"  Range: {min(prompt_counts)} - {max(prompt_counts)}")
    print(f"  Mean: {np.mean(prompt_counts):.2f}")
    print(f"  Median: {np.median(prompt_counts):.2f}")
    
    print(f"\nAbsolute differences (prediction errors):")
    print(f"  Range: {min(abs_differences):.3f} - {max(abs_differences):.3f}")
    print(f"  Mean: {np.mean(abs_differences):.3f}")
    print(f"  Median: {np.median(abs_differences):.3f}")
    
    # Count by eval category
    eval_samples = [s for s in samples if s['eval_category']]
    real_samples = [s for s in samples if not s['eval_category']]
    
    print(f"\nBreakdown by actual category:")
    print(f"  Eval samples: {len(eval_samples)} ({len(eval_samples)/len(samples)*100:.1f}%)")
    print(f"  Real samples: {len(real_samples)} ({len(real_samples)/len(samples)*100:.1f}%)")
    
    if eval_samples:
        eval_errors = [s['abs_difference'] for s in eval_samples]
        print(f"  Mean error on eval samples: {np.mean(eval_errors):.3f}")
    
    if real_samples:
        real_errors = [s['abs_difference'] for s in real_samples]
        print(f"  Mean error on real samples: {np.mean(real_errors):.3f}")

def main():
    json_file = '/Users/igorivanov/Cursor_projects/eval_awareness/eval_results/qwen3_32b_probability_fp_full.json'
    
    print("Loading and parsing data...")
    samples = load_and_parse_data(json_file)
    
    if not samples:
        print("No valid samples found!")
        return
    
    print(f"Successfully parsed {len(samples)} samples")
    
    # Print summary statistics
    print_summary_stats(samples)
    
    # Create plot
    print("\nCreating plot...")
    correlation, p_value = create_plot(samples)
    
    # Save plot
    output_file = '/Users/igorivanov/Cursor_projects/eval_awareness/plots/eval_probability_vs_prompt_count.png'
    plt.savefig(output_file, dpi=300, bbox_inches='tight')
    print(f"Plot saved to: {output_file}")
    
    # Show plot
    plt.show()
    
    print(f"\nCorrelation analysis:")
    print(f"Pearson correlation coefficient: {correlation:.4f}")
    print(f"P-value: {p_value:.4e}")
    
    if p_value < 0.001:
        significance = "highly significant (p < 0.001)"
    elif p_value < 0.01:
        significance = "significant (p < 0.01)"
    elif p_value < 0.05:
        significance = "marginally significant (p < 0.05)"
    else:
        significance = "not significant (p >= 0.05)"
    
    print(f"Statistical significance: {significance}")
    
    # Interpretation
    if abs(correlation) > 0.5:
        strength = "strong"
    elif abs(correlation) > 0.3:
        strength = "moderate"
    elif abs(correlation) > 0.1:
        strength = "weak"
    else:
        strength = "very weak"
    
    direction = "positive" if correlation > 0 else "negative"
    print(f"Correlation strength: {strength} {direction} correlation")

if __name__ == "__main__":
    main()
