"""
Simple Metric Analysis with Confidence Intervals (SimpleQA)

This script computes simple refusal/accuracy metrics for prompts 1–4 on the
SimpleQA dataset across all models listed in evaluation_runs_core_datasets.json.

Metrics per prompt (no relaxation, direct from grades):
- Accuracy (c): proportion of 'A'
- Refusal (r): proportion of 'C'
- Correct/Attempted: c / (1 - r)
- F-score: 2c / (2 - r)
- Weighted: c - p(1 - r) where p is a penalty weight (default 0.5)

For each metric we compute bootstrap percentile confidence intervals.
"""

from __future__ import annotations

import argparse
import json
from pathlib import Path
from typing import Dict, List, Optional

import numpy as np
import pandas as pd


def compute_simple_metrics_from_df(df: pd.DataFrame, grade_col: str, p_weight: float = 0.5) -> Dict[str, float]:
    """
    Compute simple metrics from a single grade column without relaxation.

    Grades:
      - 'A' = correct
      - 'B' = incorrect
      - 'C' = refused

    Args:
        df: DataFrame containing the grade column
        grade_col: Name of the grade column to evaluate (e.g., 'grade1')
        p_weight: Penalty weight p for the Weighted metric: c - p(1 - r)

    Returns:
        Dict with metrics and counts
    """
    if grade_col not in df.columns:
        raise ValueError(f"Column {grade_col} not found in DataFrame")

    grades = df[grade_col].astype(str)
    n = len(grades)

    c = (grades == 'A').mean() if n > 0 else 0.0
    r = (grades == 'C').mean() if n > 0 else 0.0

    denom_attempted = max(1e-12, 1.0 - r)
    correct_attempted = c / denom_attempted

    # F-score as harmonic mean of c and correct_attempted -> 2c / (2 - r)
    fscore = (2.0 * c) / (2.0 - r) if (2.0 - r) > 1e-12 else 0.0

    # Weighted combination: reward correct answers, penalize attempted items
    weighted = c - p_weight * (1.0 - r)

    return {
        'n': n,
        'c': float(c),
        'r': float(r),
        'correct_attempted': float(correct_attempted),
        'fscore': float(fscore),
        'weighted': float(weighted),
    }


def bootstrap_simple_metrics(df: pd.DataFrame, grade_col: str, *, n_bootstrap: int = 100,
                             confidence_level: float = 0.95, p_weight: float = 0.5,
                             random_state: Optional[int] = None) -> Dict[str, float]:
    """
    Compute bootstrap percentile CIs for simple metrics for a single grade column.

    Args:
        df: DataFrame containing the grade column
        grade_col: Name of the grade column (e.g., 'grade1')
        n_bootstrap: Number of bootstrap resamples
        confidence_level: e.g., 0.95
        p_weight: Penalty weight p for the Weighted metric
        random_state: Optional random seed for reproducibility

    Returns:
        Dict with point estimates and CI bounds for each metric
    """
    rng = np.random.default_rng(random_state)

    point = compute_simple_metrics_from_df(df, grade_col, p_weight=p_weight)

    if len(df) == 0 or n_bootstrap <= 0:
        # Degenerate case: no bootstrap, just return point estimates as CIs
        out = {}
        for k, v in point.items():
            if k == 'n':
                out[k] = v
            else:
                out[k] = v
                out[f'{k}_ci_lower'] = v
                out[f'{k}_ci_upper'] = v
                out[f'{k}_bootstrap_std'] = 0.0
        return out

    metrics = ['c', 'r', 'correct_attempted', 'fscore', 'weighted']
    boot_accum = {m: [] for m in metrics}

    n = len(df)
    for _ in range(n_bootstrap):
        idx = rng.integers(0, n, size=n)
        boot_df = df.iloc[idx].reset_index(drop=True)
        res = compute_simple_metrics_from_df(boot_df, grade_col, p_weight=p_weight)
        for m in metrics:
            boot_accum[m].append(res[m])

    alpha = 1.0 - confidence_level
    lo = (alpha / 2.0) * 100.0
    hi = (1.0 - alpha / 2.0) * 100.0

    out = {'n': point['n']}
    for m in metrics:
        arr = np.asarray(boot_accum[m])
        out[m] = point[m]
        out[f'{m}_ci_lower'] = float(np.percentile(arr, lo))
        out[f'{m}_ci_upper'] = float(np.percentile(arr, hi))
        out[f'{m}_bootstrap_std'] = float(np.std(arr))

    return out


def analyze_simpleqa_all_models_with_ci(*, n_bootstrap: int = 100, confidence_level: float = 0.95,
                                        limit_models: Optional[int] = None, dataset_filter: str = 'simpleqa',
                                        p_weight: float = 0.5,
                                        output_csv: str = 'simple_metrics_with_ci.csv') -> pd.DataFrame:
    """
    Analyze simple metrics with CIs for all models in the specified dataset (default: simpleqa).

    For each model CSV, compute metrics for prompts 1–4 where present.

    Returns a DataFrame of results and saves to CSV.
    """
    with open('evaluation_runs_core_datasets.json', 'r') as f:
        eval_runs = json.load(f)

    all_rows: List[Dict[str, object]] = []
    model_count = 0

    print("Computing simple metrics with confidence intervals...")
    print(f"Bootstrap samples: {n_bootstrap}")
    print(f"Dataset filter: {dataset_filter}")
    print(f"Weighted metric penalty p: {p_weight}")
    print("=" * 80)

    for dataset_name, models in eval_runs.items():
        if dataset_filter and dataset_name != dataset_filter:
            continue

        print(f"\nDataset: {dataset_name}")
        print("-" * 40)

        for model_name, entries in models.items():
            if limit_models is not None and model_count >= limit_models:
                print(f"  Reached model limit ({limit_models}), stopping...")
                break

            for entry in entries:
                csv_path = entry.get('data')
                if not csv_path or not Path(csv_path).exists():
                    print(f"  Skipping {model_name}: CSV not found at {csv_path}")
                    continue

                print(f"  Processing {model_name}...")
                model_count += 1

                try:
                    df = pd.read_csv(csv_path)
                    grade_cols = [c for c in df.columns if c.startswith('grade')]
                    num_prompts = len(grade_cols)
                    if num_prompts < 2:
                        print(f"    Skipping: only {num_prompts} prompt(s) found")
                        continue

                    # Evaluate prompts 1..4 if present (skip 0, which is forced baseline in other analyses)
                    for prompt_idx in range(1, min(5, num_prompts)):
                        grade_col = f'grade{prompt_idx}'
                        print(f"    Computing for prompt {prompt_idx}...", end='', flush=True)

                        stats = bootstrap_simple_metrics(
                            df,
                            grade_col,
                            n_bootstrap=n_bootstrap,
                            confidence_level=confidence_level,
                            p_weight=p_weight,
                        )

                        row = {
                            'dataset': dataset_name,
                            'model': model_name,
                            'prompt_index': prompt_idx,
                            'n': stats['n'],
                        }
                        # Merge metrics and CIs
                        for key, val in stats.items():
                            if key == 'n':
                                continue
                            row[key] = val

                        all_rows.append(row)
                        print(f" c={stats['c']:.3f}, r={stats['r']:.3f}")

                except Exception as e:
                    print(f"    Error processing {model_name}: {e}")
                    import traceback
                    traceback.print_exc()

    df_results = pd.DataFrame(all_rows)
    if df_results.empty:
        print("\nNo data to analyze")
        return df_results

    df_results.to_csv(output_csv, index=False)
    print(f"\nSaved detailed results to {output_csv}")

    # Print brief summary
    print("\n" + "=" * 80)
    print("SUMMARY STATISTICS")
    print("=" * 80)

    for dataset in df_results['dataset'].unique():
        ddf = df_results[df_results['dataset'] == dataset]
        print(f"\nDataset: {dataset}")
        print(f"  Models: {ddf['model'].nunique()}")
        print(f"  Mean accuracy (c): {ddf['c'].mean():.3f} ± {ddf['c'].std():.3f}")
        print(f"  Mean refusal (r): {ddf['r'].mean():.3f} ± {ddf['r'].std():.3f}")

    print("\nBy Prompt Index:")
    for prompt_idx in sorted(df_results['prompt_index'].unique()):
        pdf = df_results[df_results['prompt_index'] == prompt_idx]
        print(f"  Prompt {prompt_idx}: c={pdf['c'].mean():.3f}, r={pdf['r'].mean():.3f}")

    return df_results


def main():
    parser = argparse.ArgumentParser(description="Simple Metric Analysis with CIs for SimpleQA")
    parser.add_argument('--n_bootstrap', type=int, default=1000, help='Number of bootstrap samples')
    parser.add_argument('--confidence', type=float, default=0.95, help='Confidence level for CIs')
    parser.add_argument('--limit_models', type=int, default=None, help='Optional limit on number of models')
    parser.add_argument('--dataset', type=str, default='simpleqa', help='Dataset to filter (default: simpleqa)')
    parser.add_argument('--p_weight', type=float, default=1, help='Penalty weight p for Weighted metric')
    parser.add_argument('--output', type=str, default='results/simple_metrics_with_ci.csv', help='Output CSV path')

    args = parser.parse_args()

    print("Starting Simple Metric Analysis with Confidence Intervals...")
    print("=" * 80)
    print("Note: Using direct grades (A/B/C) without relaxation")
    print("      Analyzing prompts 1–4 where available")
    print("=" * 80)

    analyze_simpleqa_all_models_with_ci(
        n_bootstrap=args.n_bootstrap,
        confidence_level=args.confidence,
        limit_models=args.limit_models,
        dataset_filter=args.dataset,
        p_weight=args.p_weight,
        output_csv=args.output,
    )


if __name__ == '__main__':
    main()

