#!/usr/bin/env python3
"""
Visualize and analyze the confidence-based filling experiment results.
"""

import json
import numpy as np
import matplotlib.pyplot as plt
import os
import sys
from pathlib import Path

def load_results(results_file):
    """Load results from JSON file."""
    with open(results_file, 'r') as f:
        return json.load(f)

def analyze_results(results):
    """Analyze and print statistics from results."""
    print("\n" + "="*60)
    print("ANALYSIS RESULTS")
    print("="*60)
    
    # Collect all confidence data
    all_initial = []
    all_filled = []
    all_changes = []
    
    for sample in results:
        initial = sample['bottom_half_confidences_initial']
        all_initial.extend(initial)
        
        if sample['bottom_half_confidences_filled']:
            filled = sample['bottom_half_confidences_filled']
            changes = sample['confidence_changes']
            all_filled.extend(filled)
            all_changes.extend(changes)
    
    print(f"\nDataset Statistics:")
    print(f"  Total samples analyzed: {len(results)}")
    print(f"  Total bottom-half tokens: {len(all_initial)}")
    
    print(f"\nInitial Bottom-Half Confidences:")
    print(f"  Mean: {np.mean(all_initial):.4f}")
    print(f"  Median: {np.median(all_initial):.4f}")
    print(f"  Std: {np.std(all_initial):.4f}")
    print(f"  Min: {np.min(all_initial):.4f}")
    print(f"  Max: {np.max(all_initial):.4f}")
    
    if all_filled:
        print(f"\nAfter Filling Top Half:")
        print(f"  Mean confidence: {np.mean(all_filled):.4f}")
        print(f"  Median confidence: {np.median(all_filled):.4f}")
        print(f"  Std confidence: {np.std(all_filled):.4f}")
        
        print(f"\nConfidence Changes:")
        print(f"  Mean change: {np.mean(all_changes):.4f}")
        print(f"  Median change: {np.median(all_changes):.4f}")
        print(f"  Std change: {np.std(all_changes):.4f}")
        
        improved = sum(1 for c in all_changes if c > 0)
        degraded = sum(1 for c in all_changes if c < 0)
        unchanged = sum(1 for c in all_changes if c == 0)
        
        print(f"\nToken Classification:")
        print(f"  Improved: {improved} ({100*improved/len(all_changes):.1f}%)")
        print(f"  Degraded: {degraded} ({100*degraded/len(all_changes):.1f}%)")
        print(f"  Unchanged: {unchanged} ({100*unchanged/len(all_changes):.1f}%)")
        
        # Quartile analysis
        print(f"\nChange Distribution (Quartiles):")
        q1, q2, q3 = np.percentile(all_changes, [25, 50, 75])
        print(f"  Q1 (25%): {q1:.4f}")
        print(f"  Q2 (50%): {q2:.4f}")
        print(f"  Q3 (75%): {q3:.4f}")
    
    print("="*60 + "\n")
    
    return all_initial, all_filled, all_changes

def visualize_results(results, output_dir):
    """Create visualizations of the results."""
    all_initial, all_filled, all_changes = analyze_results(results)
    
    if not all_filled:
        print("No filling data available. Skipping visualization.")
        return
    
    # Create figure with 2 subplots
    fig, axes = plt.subplots(1, 2, figsize=(16, 6))
    fig.suptitle('Confidence-Based Filling Analysis', fontsize=16, fontweight='bold')
    
    # Plot 1: Change magnitude by category
    ax1 = axes[0]
    improved = [c for c in all_changes if c > 0]
    degraded = [c for c in all_changes if c < 0]
    unchanged = [c for c in all_changes if c == 0]
    
    categories = ['Confidence\nIncreased', 'Confidence\nDecreased', 'No Change']
    counts = [len(improved), len(degraded), len(unchanged)]
    colors_cat = ['green', 'red', 'gray']
    
    bars = ax1.bar(categories, counts, color=colors_cat, alpha=0.7, edgecolor='black', linewidth=2)
    ax1.set_ylabel('Number of Positions', fontsize=14)
    ax1.set_title('How Many Positions Changed Confidence?', fontsize=15, fontweight='bold')
    ax1.grid(alpha=0.3, axis='y')
    
    # Add count and percentage labels on bars
    for bar, count in zip(bars, counts):
        height = bar.get_height()
        percentage = 100 * count / len(all_changes)
        ax1.text(bar.get_x() + bar.get_width()/2., height,
                f'{count}\n({percentage:.1f}%)',
                ha='center', va='bottom', fontweight='bold', fontsize=12)
    
    # Plot 2: Magnitude of changes (absolute values)
    ax2 = axes[1]
    abs_changes = [abs(c) for c in all_changes]
    
    # Create bins for magnitude ranges
    bins = [0, 0.01, 0.05, 0.1, 0.2, 0.5, 1.0]
    bin_labels = ['0-0.01', '0.01-0.05', '0.05-0.1', '0.1-0.2', '0.2-0.5', '0.5-1.0']
    hist, _ = np.histogram(abs_changes, bins=bins)
    
    bars_mag = ax2.bar(bin_labels, hist, color='purple', alpha=0.7, edgecolor='black', linewidth=2)
    ax2.set_xlabel('Magnitude of Confidence Change (Absolute)', fontsize=14)
    ax2.set_ylabel('Number of Positions', fontsize=14)
    ax2.set_title('By How Much Did Position Confidences Change?', fontsize=15, fontweight='bold')
    ax2.grid(alpha=0.3, axis='y')
    ax2.tick_params(axis='x', rotation=45)
    
    # Add count and percentage labels on bars
    for bar, count in zip(bars_mag, hist):
        height = bar.get_height()
        if count > 0:
            percentage = 100 * count / len(abs_changes)
            ax2.text(bar.get_x() + bar.get_width()/2., height,
                    f'{count}\n({percentage:.1f}%)',
                    ha='center', va='bottom', fontsize=10, fontweight='bold')
    
    plt.tight_layout()
    
    # Save figure
    output_path = os.path.join(output_dir, 'confidence_analysis_visualization.png')
    plt.savefig(output_path, dpi=300, bbox_inches='tight')
    print(f"\nVisualization saved to: {output_path}")
    plt.close()

def main():
    if len(sys.argv) < 2:
        print("Usage: python visualize_conf_fill.py <results_json_file>")
        print("\nExample:")
        print("  python visualize_conf_fill.py ./conf_fill_analysis/conf_fill_t0.5_*/confidence_analysis.json")
        sys.exit(1)
    
    results_file = sys.argv[1]
    
    if not os.path.exists(results_file):
        print(f"Error: File not found: {results_file}")
        sys.exit(1)
    
    print(f"Loading results from: {results_file}")
    results = load_results(results_file)
    
    output_dir = os.path.dirname(results_file)
    visualize_results(results, output_dir)

if __name__ == "__main__":
    main()
