"""
Generate calibration comparison figure: Baseline (V_0) vs Dynamic (V_τ)

This script compares the calibration of:
1. Baseline method: value at t=0 (V_0)
2. Dynamic abstention: value at abstention time (V_τ)

For dynamic abstention, V_τ is the value when the trajectory first drops below
the threshold. Since V_τ ≈ T by construction (abstention occurs at first crossing),
we collect (V_τ, correctness) pairs across multiple thresholds.
"""

import os
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np

plt.rcParams.update({'font.size': 12})

trajectory_files = {
    "./trajectory_values_gsm8k_qwen.csv": "gsm8k_qwen",
    "./trajectory_values_gsm8k_phi3.csv": "gsm8k_phi3",
    "./trajectory_values_olympiadMath_phi3.csv": "olympiadMath_phi3",
    "./trajectory_values_olympiadMath_qwen.csv": "olympiadMath_qwen",
}

DATASET_TITLE_MAP = {
    'gsm8k_qwen': 'GSM8K (Qwen)',
    'gsm8k_phi3': 'GSM8K (Phi-3)',
    'olympiadMath_qwen': 'OlympiadBench (Qwen)',
    'olympiadMath_phi3': 'OlympiadBench (Phi-3)',
}

output_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), "output_plots")
os.makedirs(output_dir, exist_ok=True)


def get_abstention_time_value(trajectory_values, threshold):
    """
    Get the value at the abstention time for a given trajectory and threshold.
    
    Abstention time τ = min{t : V_t < threshold}
    Returns V_τ, the value at the first time it drops below threshold.
    Returns None if the trajectory never drops below threshold (no abstention).
    """
    for v in trajectory_values:
        if v < threshold:
            return v
    return None


def compute_calibration(values, correctness, n_bins=10):
    """
    Compute calibration data: bin predictions and compute actual accuracy per bin.
    
    Args:
        values: array of predicted probabilities
        correctness: array of binary correctness labels (1 = correct)
        n_bins: number of bins
        
    Returns:
        DataFrame with predicted_prob, actual_accuracy, count per bin
    """
    values = np.array(values)
    correctness = np.array(correctness)
    
    bins = np.linspace(0, 1, n_bins + 1)
    bin_indices = np.digitize(values, bins) - 1
    bin_indices = np.clip(bin_indices, 0, n_bins - 1)
    
    calibration_data = []
    for bin_idx in range(n_bins):
        mask = bin_indices == bin_idx
        if mask.sum() > 0:
            calibration_data.append({
                'bin_idx': bin_idx,
                'predicted_prob': values[mask].mean(),
                'actual_accuracy': correctness[mask].mean(),
                'count': mask.sum()
            })
    return pd.DataFrame(calibration_data)


def get_baseline_calibration_data(results):
    """
    Get calibration data for baseline method (V_0).
    
    Args:
        results: DataFrame with trajectory data
        
    Returns:
        values: list of V_0 values
        correctness: list of correctness labels
    """
    baseline_results = results[results['model'] == 'baseline'].copy()
    if len(baseline_results) == 0:
        return [], []
    
    values = baseline_results['trajectory_value'].values
    correctness = (1 - baseline_results['should_abstain_label']).values
    
    return values, correctness


def get_dynamic_calibration_data(results, n_thresholds=50):
    """
    Get calibration data for dynamic method (V_τ at abstention time).
    
    For each threshold T, we collect (V_τ, correctness) pairs for samples
    that abstain at that threshold. Since V_τ ≈ T by construction, this
    gives us calibration data across the range of possible thresholds.
    
    Args:
        results: DataFrame with trajectory data for 'full' model
        n_thresholds: number of thresholds to sample
        
    Returns:
        values: list of V_τ values
        correctness: list of correctness labels
    """
    full_results = results[results['model'] == 'full'].copy()
    if len(full_results) == 0:
        return [], []
    
    # Build trajectory data structures
    sample_trajectories = {}
    sample_correctness = {}
    
    for sample_index, group in full_results.groupby('sample_index'):
        group_sorted = group.sort_values('output_length')
        sample_trajectories[sample_index] = group_sorted['trajectory_value'].values
        sample_correctness[sample_index] = 1 - group_sorted['should_abstain_label'].values[0]
    
    # Get minimum value per sample to determine threshold range
    all_min_values = [traj.min() for traj in sample_trajectories.values()]
    
    # Sample thresholds across the range
    thresholds = np.linspace(
        np.percentile(all_min_values, 2),
        np.percentile(all_min_values, 98),
        n_thresholds
    )
    
    # Collect (V_τ, correctness) pairs across thresholds
    abstention_values = []
    correctness_labels = []
    
    for T in thresholds:
        for sample_index, traj_values in sample_trajectories.items():
            v_tau = get_abstention_time_value(traj_values, T)
            if v_tau is not None:
                abstention_values.append(v_tau)
                correctness_labels.append(sample_correctness[sample_index])
    
    return abstention_values, correctness_labels


def plot_calibration_comparison_matrix(all_data, output_dir):
    """
    Plot calibration comparison between baseline (V_0) and dynamic (V_τ).
    Creates a 2x2 matrix where each subplot compares both methods for one dataset.
    """
    datasets = sorted(all_data.keys())
    fig, axes = plt.subplots(2, 2, figsize=(14, 12), constrained_layout=True)
    axes = axes.flatten()
    
    colors = {
        'baseline': '#2ca02c',  # green
        'dynamic': '#d62728',   # red
    }
    
    for i, dataset in enumerate(datasets):
        if i >= 4:
            break
        ax = axes[i]
        results = all_data[dataset]
        
        # Get baseline calibration (V_0)
        baseline_values, baseline_correct = get_baseline_calibration_data(results)
        if len(baseline_values) > 0:
            cal_baseline = compute_calibration(baseline_values, baseline_correct, n_bins=10)
            if len(cal_baseline) > 0:
                ax.plot(
                    cal_baseline['predicted_prob'],
                    cal_baseline['actual_accuracy'],
                    'o-',
                    color=colors['baseline'],
                    markersize=8,
                    linewidth=2,
                    label=r'Baseline ($V_0$)'
                )
        
        # Get dynamic calibration (V_τ)
        dynamic_values, dynamic_correct = get_dynamic_calibration_data(results, n_thresholds=50)
        if len(dynamic_values) > 0:
            cal_dynamic = compute_calibration(dynamic_values, dynamic_correct, n_bins=10)
            if len(cal_dynamic) > 0:
                ax.plot(
                    cal_dynamic['predicted_prob'],
                    cal_dynamic['actual_accuracy'],
                    's-',
                    color=colors['dynamic'],
                    markersize=8,
                    linewidth=2,
                    label=r'Dynamic ($V_\tau$)'
                )
        
        # Perfect calibration line
        ax.plot([0, 1], [0, 1], 'k--', label='Perfect', alpha=0.7, linewidth=1.5)
        
        ax.set_title(DATASET_TITLE_MAP[dataset], fontsize=16, fontweight='bold', pad=10)
        ax.set_xlabel("Predicted Probability", fontsize=14)
        ax.set_ylabel("Actual Accuracy", fontsize=14)
        ax.grid(True, alpha=0.3)
        ax.set_xlim(-0.05, 1.05)
        ax.set_ylim(-0.05, 1.05)
        ax.set_aspect('equal')
        ax.legend(fontsize=11, loc='lower right')
        ax.tick_params(axis='both', which='major', labelsize=11)
    
    # Hide unused axes
    for j in range(i + 1, len(axes)):
        axes[j].axis('off')
    
    save_path = os.path.join(output_dir, "calibration_comparison_matrix.png")
    plt.savefig(save_path, dpi=300, bbox_inches='tight')
    print(f"Saved calibration comparison matrix: {save_path}")
    plt.close()


def print_calibration_summary(all_data):
    """Print summary statistics for calibration comparison."""
    print("\nCalibration Summary (Expected Calibration Error)")
    print("=" * 60)
    
    for dataset, results in sorted(all_data.items()):
        print(f"\n{dataset}:")
        
        # Baseline ECE
        baseline_values, baseline_correct = get_baseline_calibration_data(results)
        if len(baseline_values) > 0:
            cal_baseline = compute_calibration(baseline_values, baseline_correct, n_bins=10)
            if len(cal_baseline) > 0:
                ece_baseline = np.average(
                    np.abs(cal_baseline['predicted_prob'] - cal_baseline['actual_accuracy']),
                    weights=cal_baseline['count']
                )
                print(f"  Baseline (V_0):  ECE = {ece_baseline:.4f}")
        
        # Dynamic ECE
        dynamic_values, dynamic_correct = get_dynamic_calibration_data(results, n_thresholds=50)
        if len(dynamic_values) > 0:
            cal_dynamic = compute_calibration(dynamic_values, dynamic_correct, n_bins=10)
            if len(cal_dynamic) > 0:
                ece_dynamic = np.average(
                    np.abs(cal_dynamic['predicted_prob'] - cal_dynamic['actual_accuracy']),
                    weights=cal_dynamic['count']
                )
                print(f"  Dynamic (V_τ):   ECE = {ece_dynamic:.4f}")


# Main Execution
if __name__ == "__main__":
    print("Loading data for calibration comparison...")
    all_data_buffer = {}

    for data_path, dataset_model_string in trajectory_files.items():
        if not os.path.exists(data_path):
            print(f"  Skipping {data_path} (not found)")
            continue
        results = pd.read_csv(data_path, low_memory=False)
        results['trajectory_value'] = pd.to_numeric(results['trajectory_value'], errors='coerce')
        results['should_abstain_label'] = pd.to_numeric(results['should_abstain_label'], errors='coerce')
        results['output_length'] = pd.to_numeric(results['output_length'], errors='coerce')
        results = results.dropna(subset=['trajectory_value', 'should_abstain_label'])
        all_data_buffer[dataset_model_string] = results
        print(f"  Loaded {dataset_model_string}: {len(results)} rows")

    if len(all_data_buffer) == 0:
        print("No data files found. Exiting.")
        exit(1)

    print("\nGenerating calibration comparison plot...")
    plot_calibration_comparison_matrix(all_data_buffer, output_dir)
    
    print_calibration_summary(all_data_buffer)

    print(f"\n✓ Calibration comparison plot saved to: {output_dir}")
