"""
Refusal Index Analysis with Confidence Intervals

This script computes the refusal index using tetrachoric correlation with bootstrap
confidence intervals. It properly handles score relaxation and uses prompt 0 
(forced answering) as the baseline for determining correctness.
"""

import json
import tempfile
import numpy as np
from scipy.stats import norm, multivariate_normal
from scipy.optimize import minimize_scalar
from pathlib import Path
import pandas as pd
from typing import Dict, List, Tuple, Optional
import matplotlib.pyplot as plt
import seaborn as sns
from collections import defaultdict, Counter
from tqdm import tqdm


def compute_bivariate_normal_tail_prob(tau_r: float, tau_w: float, rho: float) -> float:
    """
    Compute P(Z_R > tau_r, Z_W > tau_w) for bivariate standard normal with correlation rho.
    This is the survival function: bar(Phi_2)(tau_r, tau_w; rho)
    """
    if abs(rho) >= 1 - 1e-12:
        # Handle edge cases for perfect correlation
        if rho > 0:  # rho = 1
            tau_max = max(tau_r, tau_w)
            return 1 - norm.cdf(tau_max)
        else:  # rho = -1
            # For perfect negative correlation, the event is impossible if tau_r + tau_w > 0
            if tau_r + tau_w > 0:
                return 0
            else:
                return 1 - norm.cdf(tau_r)
    
    # Use multivariate normal CDF
    mean = [0, 0]
    cov = [[1, rho], [rho, 1]]
    
    # We need P(Z_R > tau_r, Z_W > tau_w)
    # This equals 1 - P(Z_R <= tau_r) - P(Z_W <= tau_w) + P(Z_R <= tau_r, Z_W <= tau_w)
    p_r_le = norm.cdf(tau_r)
    p_w_le = norm.cdf(tau_w)
    
    # Compute P(Z_R <= tau_r, Z_W <= tau_w)
    rv = multivariate_normal(mean=mean, cov=cov)
    p_both_le = rv.cdf([tau_r, tau_w])
    
    return 1 - p_r_le - p_w_le + p_both_le


def tetrachoric_log_likelihood(rho: float, r: float, mu: float, n_counts: Dict[Tuple[int, int], int]) -> float:
    """
    Compute the tetrachoric log-likelihood for given correlation rho.
    
    Args:
        rho: Correlation parameter in (-1, 1)
        r: Empirical refusal rate P(R=1)
        mu: Overall wrong rate P(W=1)
        n_counts: Dictionary with counts n_{ab} for (R=a, W=b)
    
    Returns:
        Log-likelihood value
    """
    # Compute thresholds
    tau_r = norm.ppf(1 - r)
    tau_w = norm.ppf(1 - mu)
    
    # Compute cell probabilities
    p_11 = compute_bivariate_normal_tail_prob(tau_r, tau_w, rho)
    p_10 = r - p_11
    p_01 = mu - p_11
    p_00 = 1 - r - mu + p_11
    
    # Ensure probabilities are valid
    eps = 1e-10
    p_11 = max(eps, min(1 - eps, p_11))
    p_10 = max(eps, min(1 - eps, p_10))
    p_01 = max(eps, min(1 - eps, p_01))
    p_00 = max(eps, min(1 - eps, p_00))
    
    # Compute log-likelihood
    ll = 0
    for (a, b), count in n_counts.items():
        if count > 0:
            if a == 1 and b == 1:
                ll += count * np.log(p_11)
            elif a == 1 and b == 0:
                ll += count * np.log(p_10)
            elif a == 0 and b == 1:
                ll += count * np.log(p_01)
            elif a == 0 and b == 0:
                ll += count * np.log(p_00)
    
    return ll


def estimate_refusal_index(r: float, mu: float, mu_a: float, mu_r: float, n_total: int) -> float:
    """
    Estimate the refusal index (correlation rho) using tetrachoric correlation.
    
    Args:
        r: Refusal rate
        mu: Overall wrong rate (with no refusal)
        mu_a: Wrong rate on attempted items
        mu_r: Wrong rate on refused items (measured in forced answering)
        n_total: Total number of items
    
    Returns:
        Estimated correlation rho (refusal index)
    """
    # Handle edge cases
    if r == 0 or r == 1:
        return 0.0
    
    # Compute counts for the 2x2 contingency table
    n_r = int(n_total * r)  # Number of refused items
    n_a = n_total - n_r     # Number of attempted items
    
    # For refused items (R=1): wrong rate is mu_r
    n_11 = int(n_r * mu_r)  # Refused and wrong
    n_10 = n_r - n_11       # Refused and correct
    
    # For attempted items (R=0): wrong rate is mu_a
    n_01 = int(n_a * mu_a)  # Attempted and wrong
    n_00 = n_a - n_01       # Attempted and correct
    
    n_counts = {
        (1, 1): n_11,
        (1, 0): n_10,
        (0, 1): n_01,
        (0, 0): n_00
    }
    
    # Optimize to find maximum likelihood estimate
    def neg_ll(rho):
        return -tetrachoric_log_likelihood(rho, r, mu, n_counts)
    
    result = minimize_scalar(neg_ll, bounds=(-0.999, 0.999), method='bounded')
    return result.x


def apply_score_relaxation(df: pd.DataFrame, prompt_idx: int, num_prompts: int) -> pd.Series:
    """
    Apply score relaxation for grades at a specific prompt index.
    This follows the logic from evaluate.py.
    
    For grade 'C' (refusal), find the majority of non-C grades from other prompts.
    If all grades are C, set to 'B'.
    
    Returns:
        Series of relaxed grades (A or B only)
    """
    relaxed = df[f'grade{prompt_idx}'].copy()
    
    # For rows where grade is C, find majority of non-C grades
    c_mask = df[f'grade{prompt_idx}'] == 'C'
    if c_mask.any():
        # Get all non-C grades for each row
        grade_cols = [f'grade{j}' for j in range(num_prompts)]
        non_c_grades = df.loc[c_mask, grade_cols].apply(
            lambda row: [grade for grade in row if grade != 'C'], axis=1
        )
        
        # Find majority grade for each row
        for idx in non_c_grades.index:
            if len(non_c_grades[idx]) > 0:
                # Get the most common non-C grade
                majority_grade = Counter(non_c_grades[idx]).most_common(1)[0][0]
                relaxed.loc[idx] = majority_grade
    
    # For rows where all grades are C, set to B
    all_c_mask = (df[[f'grade{k}' for k in range(num_prompts)]] == 'C').all(axis=1)
    relaxed.loc[all_c_mask] = 'B'
    
    return relaxed


def compute_refusal_index_from_csv(csv_path: str, prompt_idx: int, num_prompts: int = 5) -> Dict:
    """
    Compute refusal index from a CSV file for a specific prompt.
    
    Args:
        csv_path: Path to the CSV file with evaluation results
        prompt_idx: Index of the prompt to analyze (1-4)
        num_prompts: Total number of prompts (default 5)
    
    Returns:
        Dictionary with refusal index and related metrics
    """
    df = pd.read_csv(csv_path)
    
    # Apply score relaxation to get true labels
    # Prompt 0 is the forced-answering baseline
    baseline_relaxed = apply_score_relaxation(df, 0, num_prompts)
    
    # Get the overall wrong rate from baseline (mu)
    mu = (baseline_relaxed == 'B').mean()  # Wrong rate when forced to answer
    
    # Get grades for the target prompt
    grades = df[f'grade{prompt_idx}']
    
    # Calculate refusal rate
    r = (grades == 'C').mean()
    
    # For attempted items (non-C grades)
    attempted_mask = grades != 'C'
    if attempted_mask.any():
        # Use baseline to determine correctness for attempted items
        # An item is wrong if baseline says B
        mu_a = (baseline_relaxed[attempted_mask] == 'B').mean()
    else:
        mu_a = mu
    
    # For refused items (C grades)
    refused_mask = grades == 'C'
    if refused_mask.any():
        # Use baseline to determine what would have been correct
        mu_r = (baseline_relaxed[refused_mask] == 'B').mean()
    else:
        mu_r = mu
    
    # Compute refusal index
    n_total = len(df)
    refusal_idx = estimate_refusal_index(r, mu, mu_a, mu_r, n_total)
    
    return {
        'refusal_index': refusal_idx,
        'refusal_rate': r,
        'mu': mu,
        'mu_a': mu_a,
        'mu_r': mu_r,
        'n_total': n_total
    }


def bootstrap_refusal_index(csv_path: str, prompt_idx: int, n_bootstrap: int = 10, 
                           confidence_level: float = 0.95, num_prompts: int = 5,
                           show_progress: bool = False) -> Dict:
    """
    Compute bootstrap confidence intervals for the refusal index.
    
    Args:
        csv_path: Path to the CSV file
        prompt_idx: Index of the prompt to analyze
        n_bootstrap: Number of bootstrap samples
        confidence_level: Confidence level for CI (default 0.95)
        num_prompts: Total number of prompts
    
    Returns:
        Dictionary with point estimate and confidence intervals
    """
    df = pd.read_csv(csv_path)
    n_samples = len(df)
    
    # Compute point estimate
    point_estimate = compute_refusal_index_from_csv(csv_path, prompt_idx, num_prompts)
    
    # Bootstrap
    bootstrap_estimates = []
    for _ in range(n_bootstrap):
        # Resample with replacement
        bootstrap_indices = np.random.choice(n_samples, size=n_samples, replace=True)
        bootstrap_df = df.iloc[bootstrap_indices].reset_index(drop=True)
        
        # Save to a secure temporary CSV and compute refusal index
        with tempfile.NamedTemporaryFile(mode='w+', suffix='.csv', delete=True) as tmp:
            bootstrap_df.to_csv(tmp.name, index=False)
            try:
                bootstrap_result = compute_refusal_index_from_csv(tmp.name, prompt_idx, num_prompts)
                bootstrap_estimates.append(bootstrap_result['refusal_index'])
            except Exception:
                # Skip failed bootstrap samples
                continue
    
    # Compute confidence intervals
    alpha = 1 - confidence_level
    lower_percentile = (alpha / 2) * 100
    upper_percentile = (1 - alpha / 2) * 100
    
    if bootstrap_estimates:
        ci_lower = np.percentile(bootstrap_estimates, lower_percentile)
        ci_upper = np.percentile(bootstrap_estimates, upper_percentile)
    else:
        # If no valid bootstrap estimates, use point estimate
        ci_lower = point_estimate['refusal_index']
        ci_upper = point_estimate['refusal_index']
    
    return {
        'refusal_index': point_estimate['refusal_index'],
        'ci_lower': ci_lower,
        'ci_upper': ci_upper,
        'bootstrap_std': np.std(bootstrap_estimates),
        'refusal_rate': point_estimate['refusal_rate'],
        'mu': point_estimate['mu'],
        'mu_a': point_estimate['mu_a'],
        'mu_r': point_estimate['mu_r']
    }


def analyze_all_models_with_ci(n_bootstrap: int = 1000, confidence_level: float = 0.95, 
                              limit_models: int = None, dataset_filter: str = "simpleqa"):
    """
    Analyze refusal indices with confidence intervals for all models and datasets.
    
    Args:
        n_bootstrap: Number of bootstrap samples (default 1000)
        confidence_level: Confidence level for CI
        limit_models: Limit number of models to process (None for all)
        dataset_filter: Only process this dataset (default "simpleqa")
    """
    # Load evaluation runs configuration
    with open('evaluation_runs_core_datasets.json', 'r') as f:
        eval_runs = json.load(f)
    
    all_results = []
    model_count = 0
    
    print("Computing refusal indices with confidence intervals...")
    print(f"Bootstrap samples: {n_bootstrap}")
    print(f"Dataset filter: {dataset_filter}")
    print("=" * 80)
    
    # Process each dataset and model
    for dataset_name, models in eval_runs.items():
        # Skip if not the target dataset
        if dataset_filter and dataset_name != dataset_filter:
            continue
        print(f"\nDataset: {dataset_name}")
        print("-" * 40)
        
        for model_name, entries in models.items():
            if limit_models and model_count >= limit_models:
                print(f"  Reached model limit ({limit_models}), stopping...")
                break
                
            for entry in entries:
                csv_path = entry.get('data')
                
                if not csv_path or not Path(csv_path).exists():
                    print(f"  Skipping {model_name}: CSV not found at {csv_path}")
                    continue
                
                print(f"  Processing {model_name}...")
                model_count += 1
                
                try:
                    # Load CSV to check number of grade columns
                    df = pd.read_csv(csv_path)
                    grade_cols = [col for col in df.columns if col.startswith('grade')]
                    num_prompts = len(grade_cols)
                    
                    if num_prompts < 2:
                        print(f"    Skipping: only {num_prompts} prompts found")
                        continue
                    
                    # Only analyze prompts 1-4 (skip prompt 0 which is baseline)
                    for prompt_idx in range(1, min(5, num_prompts)):
                        print(f"    Computing for prompt {prompt_idx}...", end='', flush=True)
                        result = bootstrap_refusal_index(
                            csv_path, 
                            prompt_idx, 
                            n_bootstrap=n_bootstrap,
                            confidence_level=confidence_level,
                            num_prompts=num_prompts
                        )
                        
                        all_results.append({
                            'dataset': dataset_name,
                            'model': model_name,
                            'prompt_index': prompt_idx,
                            'refusal_index': result['refusal_index'],
                            'ci_lower': result['ci_lower'],
                            'ci_upper': result['ci_upper'],
                            'ci_width': result['ci_upper'] - result['ci_lower'],
                            'bootstrap_std': result['bootstrap_std'],
                            'refusal_rate': result['refusal_rate'],
                            'mu': result['mu'],
                            'mu_a': result['mu_a'],
                            'mu_r': result['mu_r']
                        })
                        
                        print(f" RI={result['refusal_index']:.3f} "
                              f"[{result['ci_lower']:.3f}, {result['ci_upper']:.3f}], "
                              f"r={result['refusal_rate']:.3f}")
                
                except Exception as e:
                    print(f"    Error processing {model_name}: {e}")
                    import traceback
                    traceback.print_exc()
    
    # Convert to DataFrame for analysis
    df_results = pd.DataFrame(all_results)
    
    if df_results.empty:
        print("\nNo data to analyze")
        return None
    
    # Save detailed results
    output_file = 'refusal_index_with_ci.csv'
    df_results.to_csv(output_file, index=False)
    print(f"\n\nDetailed results saved to {output_file}")
    
    # Print summary statistics
    print("\n" + "=" * 80)
    print("SUMMARY STATISTICS")
    print("=" * 80)
    
    # By dataset
    for dataset in df_results['dataset'].unique():
        dataset_df = df_results[df_results['dataset'] == dataset]
        print(f"\nDataset: {dataset}")
        print(f"  Number of models: {dataset_df['model'].nunique()}")
        print(f"  Mean refusal index: {dataset_df['refusal_index'].mean():.3f} ± {dataset_df['refusal_index'].std():.3f}")
        print(f"  Mean CI width: {dataset_df['ci_width'].mean():.3f}")
        print(f"  Mean refusal rate: {dataset_df['refusal_rate'].mean():.3f}")
    
    # By prompt
    print("\nBy Prompt Index:")
    for prompt_idx in sorted(df_results['prompt_index'].unique()):
        prompt_df = df_results[df_results['prompt_index'] == prompt_idx]
        print(f"  Prompt {prompt_idx}:")
        print(f"    Mean refusal index: {prompt_df['refusal_index'].mean():.3f} ± {prompt_df['refusal_index'].std():.3f}")
        print(f"    Mean CI width: {prompt_df['ci_width'].mean():.3f}")
    
    return df_results


def plot_refusal_indices_with_ci(df_results: pd.DataFrame):
    """
    Create visualization plots for refusal indices with confidence intervals.
    """
    if df_results is None or df_results.empty:
        print("No data to plot")
        return
    
    fig, axes = plt.subplots(2, 2, figsize=(14, 10))
    
    # 1. Refusal index by model with CI (for simpleqa dataset)
    ax = axes[0, 0]
    simpleqa_df = df_results[df_results['dataset'] == 'simpleqa'].copy()
    if not simpleqa_df.empty:
        # Group by model and compute mean with CI
        model_summary = simpleqa_df.groupby('model').agg({
            'refusal_index': 'mean',
            'ci_lower': 'mean',
            'ci_upper': 'mean'
        }).reset_index()
        
        x = np.arange(len(model_summary))
        # Calculate error bars (ensure non-negative)
        yerr_lower = np.maximum(0, model_summary['refusal_index'] - model_summary['ci_lower'])
        yerr_upper = np.maximum(0, model_summary['ci_upper'] - model_summary['refusal_index'])
        ax.errorbar(x, model_summary['refusal_index'], 
                   yerr=[yerr_lower, yerr_upper],
                   fmt='o', capsize=5, capthick=2, markersize=8)
        ax.set_xticks(x)
        ax.set_xticklabels(model_summary['model'], rotation=45, ha='right')
        ax.set_xlabel('Model')
        ax.set_ylabel('Refusal Index')
        ax.set_title('Refusal Index by Model (SimpleQA)')
        ax.grid(True, alpha=0.3)
        ax.axhline(y=0, color='red', linestyle='--', alpha=0.5)
    
    # 2. Refusal index by prompt across all models
    ax = axes[0, 1]
    prompt_summary = df_results.groupby('prompt_index').agg({
        'refusal_index': ['mean', 'std'],
        'ci_width': 'mean'
    }).reset_index()
    prompt_summary.columns = ['prompt_index', 'mean_ri', 'std_ri', 'mean_ci_width']
    
    ax.bar(prompt_summary['prompt_index'], prompt_summary['mean_ri'], 
           yerr=prompt_summary['std_ri'], capsize=5, alpha=0.7)
    ax.set_xlabel('Prompt Index')
    ax.set_ylabel('Mean Refusal Index')
    ax.set_title('Mean Refusal Index by Prompt')
    ax.grid(True, alpha=0.3)
    
    # 3. CI width distribution
    ax = axes[1, 0]
    ax.hist(df_results['ci_width'], bins=30, edgecolor='black', alpha=0.7)
    ax.set_xlabel('Confidence Interval Width')
    ax.set_ylabel('Count')
    ax.set_title('Distribution of CI Widths')
    ax.axvline(x=df_results['ci_width'].mean(), color='red', 
              linestyle='--', label=f'Mean: {df_results["ci_width"].mean():.3f}')
    ax.legend()
    ax.grid(True, alpha=0.3)
    
    # 4. Refusal index vs refusal rate with CI
    ax = axes[1, 1]
    for dataset in df_results['dataset'].unique():
        dataset_df = df_results[df_results['dataset'] == dataset]
        ax.scatter(dataset_df['refusal_rate'], dataset_df['refusal_index'], 
                  alpha=0.5, label=dataset, s=50)
    ax.set_xlabel('Refusal Rate')
    ax.set_ylabel('Refusal Index')
    ax.set_title('Refusal Index vs Refusal Rate')
    ax.legend()
    ax.grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.savefig('refusal_index_with_ci.png', dpi=150, bbox_inches='tight')
    plt.show()
    
    print("Plots saved to refusal_index_with_ci.png")


def main(test_mode=False):
    """
    Main analysis function.
    
    Args:
        test_mode: If True, run with limited samples for testing
    """
    print("Starting Refusal Index Analysis with Confidence Intervals...")
    print("=" * 80)
    print("Note: Using prompt 0 as forced-answering baseline for determining correctness")
    print("      Analyzing prompts 1-4 for refusal patterns")
    print("=" * 80)
    
    if test_mode:
        print("\n*** RUNNING IN TEST MODE (limited samples) ***\n")
        # Run analysis with limited bootstrap CI for testing
        df_results = analyze_all_models_with_ci(n_bootstrap=10, confidence_level=0.95, limit_models=2)
    else:
        # Run full analysis with bootstrap CI
        df_results = analyze_all_models_with_ci(n_bootstrap=1000, confidence_level=0.95)
    
    # Create visualizations
    if df_results is not None and not df_results.empty:
        plot_refusal_indices_with_ci(df_results)
    
    print("\n" + "=" * 80)
    print("Analysis complete!")


if __name__ == "__main__":
    main()
