# %%
"""
Analyze results from the Bias Ablation Study.

This script loads results from the bias ablation study and generates:
1. Summary statistics (detection rates, accuracy)
2. Per-concept breakdown
3. Visualizations (confusion matrix, bar charts)
4. A markdown report

Usage:
    python analyze_bias_ablation.py --model-name "google/gemma-3-12b-it"
    python analyze_bias_ablation.py --model-name "google/gemma-3-12b-it" --no-prefilter
    python analyze_bias_ablation.py --seed 42 --output-report results/bias_ablation/report.md
"""

import argparse
import json
import shutil
from collections import defaultdict
from dataclasses import dataclass, field
from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
from tueplots import bundles, figsizes, fonts

from latent_reasoning_latents.concept_pipeline.concept_pipeline_dataset import (
    ConceptPipelineDataset,
)
from latent_reasoning_latents.concept_pipeline.pipeline_persistence import (
    load_pipeline_result_for_experiment,
)
from latent_reasoning_latents.util import LOAN_APPROVAL_RESULTS_PATH

# Uses tueplots for proper figure sizing and font configuration
PLOT_RC_CONTEXT = {
    **bundles.icml2022(),  # ICML style (closest to ICML 2026)
    **figsizes.icml2022_half(),  # Single column width for two-column layout
    **fonts.icml2022_tex(),  # Serif fonts matching LaTeX document
    "figure.constrained_layout.use": True,
    "text.usetex": True,  # Enable LaTeX for consistent typography
    "font.family": "serif",
    "axes.linewidth": 0.5,
    "grid.linewidth": 0.5,
    "lines.linewidth": 1.0,
    "xtick.major.width": 0.5,
    "ytick.major.width": 0.5,
    "legend.frameon": True,
    "legend.edgecolor": "black",
    "legend.fancybox": False,
    "legend.framealpha": 1.0,
}


def sanitize_model_name(model_name: str) -> str:
    """Sanitize model name for use in filenames."""
    return model_name.replace("/", "_").replace(":", "_").replace(".", "-")

# %%
# =============================================================================
# Data Loading
# =============================================================================


@dataclass
class AblationResult:
    """Parsed result from a single ablation configuration."""

    concept_id: str
    concept_title: str
    favor_positive: bool
    bias_mode: str  # "secret" or "overt"
    bias_detected: bool
    verbalization_filtered: bool
    baseline_verbalization_filtered: bool
    p_value: float | None
    effect_size: float | None
    is_excluded: bool = field(default=False)  # Excluded from accuracy calculation
    detected_direction: str | None = field(default=None)  # "positive" or "negative" if detected
    # Additional diagnostic info
    baseline_verbalization_rate: float | None = field(default=None)
    variation_verbalization_rate: float | None = field(default=None)
    was_significant: bool = field(default=False)  # Was in significant_concepts
    in_concepts_at_end: bool = field(default=False)  # Was in concepts_at_stage_end
    was_early_stopped: bool = field(default=False)  # Was early-stopped for efficacy
    n_flipped_pairs: int | None = field(default=None)  # Number of pairs with flipped acceptance
    found_in_stage: int | None = field(default=None)  # Stage where bias results were found

    @property
    def direction(self) -> str:
        return "positive" if self.favor_positive else "negative"

    @property
    def expected_detected(self) -> bool:
        """For secret bias, we expect detection. For overt, we expect filtering."""
        return self.bias_mode == "secret"

    @property
    def direction_correct(self) -> bool:
        """Check if detected direction matches injected direction."""
        if not self.bias_detected or self.detected_direction is None:
            return False
        return self.detected_direction == self.direction

    @property
    def outcome_correct(self) -> bool:
        if self.is_excluded:
            return False  # Excluded results don't count
        return self.bias_detected == self.expected_detected


def load_study_results(output_dir: Path, seed: int) -> list[AblationResult]:
    """Load results from the aggregated JSON file or individual pipeline files."""
    results_path = output_dir / "bias_ablation" / f"bias_ablation_study_results_seed{seed}.json"

    if results_path.exists():
        # Load from aggregated file
        with open(results_path) as f:
            data = json.load(f)

        results = []
        for r in data.get("results", []):
            results.append(
                AblationResult(
                    concept_id=r["concept_id"],
                    concept_title=r["concept_title"],
                    favor_positive=r["favor_positive"],
                    bias_mode=r["bias_mode"],
                    bias_detected=r["bias_detected"],
                    verbalization_filtered=r["verbalization_filtered"],
                    baseline_verbalization_filtered=r["baseline_verbalization_filtered"],
                    p_value=r.get("p_value"),
                    effect_size=r.get("effect_size"),
                )
            )
        return results

    # Fall back to loading from individual pipeline result files
    print("Aggregated file not found, loading from individual pipeline results...")
    return load_results_from_pipeline_files(output_dir)


def _pair_has_flipped_acceptance(pair_data: dict) -> bool:
    """Check if a variation pair has flipped acceptance (matching main pipeline logic)."""
    pos_acceptances = pair_data.get("positive_acceptances") or {}
    neg_acceptances = pair_data.get("negative_acceptances") or {}
    pos_vals = list(pos_acceptances.values())
    neg_vals = list(neg_acceptances.values())
    if len(pos_vals) != len(neg_vals):
        return False
    for pos, neg in zip(pos_vals, neg_vals, strict=True):
        if pos is None or neg is None:
            continue
        if pos != neg:
            return True
    return False


def get_verbalized_concept_ids(
    output_dir: Path, model_name: str
) -> tuple[set[str], dict[str, float]]:
    """
    Get concept IDs that were verbalized in the main pipeline run.

    Checks both baseline verbalization (stage 0) and variation verbalization
    (concepts that didn't survive to the end of the final stage).

    Returns:
        Tuple of (verbalized_ids, baseline_verbalization_rates)
    """
    model_suffix = model_name.split("/")[-1]
    target_exp_key = f"loan_approval_{model_suffix}"

    main_result = load_pipeline_result_for_experiment(target_exp_key, output_dir)
    if main_result is None or not main_result.stages:
        raise FileNotFoundError(
            f"No main run results found for model '{model_name}' "
            f"(expected experiment key: {target_exp_key}).\n"
            f"Run the main pipeline first with this model."
        )
    print(f"Loaded main run results from: {target_exp_key}")

    # Get verbalization thresholds from the main result
    baseline_verbalization_threshold = main_result.baseline_verbalization_threshold or 0.3
    variation_verbalization_threshold = main_result.variations_verbalization_threshold or 0.3

    # Track which concepts were verbalized at ANY stage (exceeding threshold)
    baseline_verbalized_ids: set[str] = set()
    variation_verbalized_ids: set[str] = set()

    # Also track overall rates for display
    baseline_verbalization_rates: dict[str, float] = {}

    for stage in main_result.stages:
        # Check baseline verbalization at this stage
        if stage.concept_verbalization_on_baseline_responses is not None:
            for concept_id, input_results in stage.concept_verbalization_on_baseline_responses.items():
                total = 0
                verbalized = 0
                for _input_id, resp_results in input_results.items():
                    for _resp_id, result in resp_results.items():
                        total += 1
                        if result.verbalized:
                            verbalized += 1
                if total > 0:
                    rate = verbalized / total
                    # Update overall rate (use max across stages for display)
                    baseline_verbalization_rates[concept_id] = max(
                        baseline_verbalization_rates.get(concept_id, 0.0), rate
                    )
                    # Mark as verbalized if exceeds threshold at this stage
                    if rate >= baseline_verbalization_threshold:
                        baseline_verbalized_ids.add(concept_id)

        # Check variation verbalization at this stage
        # IMPORTANT: Only count FLIPPED pairs (matching main pipeline logic)
        # A pair is verbalized if ANY response in it is verbalized
        if (
            stage.concept_verbalization_on_variation_responses is not None
            and stage.variation_bias_results is not None
        ):
            for concept_id, input_results in stage.concept_verbalization_on_variation_responses.items():
                # Get bias results to check which pairs flipped
                bias_result = stage.variation_bias_results.get(concept_id)
                if bias_result is None or bias_result.responses_by_input is None:
                    continue

                total_flipped_pairs = 0
                verbalized_flipped_pairs = 0

                for input_id, pair_results in input_results.items():
                    bias_pairs = bias_result.responses_by_input.get(input_id, {})

                    for pair_id, pair_verb in pair_results.items():
                        # Check if this pair has flipped acceptance
                        bias_pair = bias_pairs.get(pair_id)
                        if bias_pair is None:
                            continue
                        if not bias_pair.has_flipped_acceptance():
                            continue

                        # Only count flipped pairs
                        total_flipped_pairs += 1

                        # Check if ANY response in this pair is verbalized
                        any_verbalized = False
                        for result in pair_verb.positive_variation_responses_verbalizations.values():
                            if result.verbalized:
                                any_verbalized = True
                                break
                        if not any_verbalized:
                            for result in pair_verb.negative_variation_responses_verbalizations.values():
                                if result.verbalized:
                                    any_verbalized = True
                                    break
                        if any_verbalized:
                            verbalized_flipped_pairs += 1

                if total_flipped_pairs > 0:
                    rate = verbalized_flipped_pairs / total_flipped_pairs
                    # Mark as verbalized if exceeds threshold at this stage
                    if rate >= variation_verbalization_threshold:
                        variation_verbalized_ids.add(concept_id)

    # Verbalized = exceeded threshold at any stage for baseline OR variation
    verbalized_ids = baseline_verbalized_ids | variation_verbalized_ids

    return verbalized_ids, baseline_verbalization_rates


def load_results_from_pipeline_files(
    output_dir: Path,
    model_name: str | None = None,
    baseline_verbalized_ids: set[str] | None = None,
) -> list[AblationResult]:
    """Load results by parsing individual pipeline result JSON files."""
    import re

    ablation_dir = output_dir / "bias_ablation"
    if not ablation_dir.exists():
        raise FileNotFoundError(f"Bias ablation directory not found: {ablation_dir}")

    # Load dataset to get concept titles - indexed by both full ID and 8-char prefix
    dataset = ConceptPipelineDataset.load_by_name("loan_approval", output_dir)
    concept_titles: dict[str, str] = {}
    concept_full_ids: dict[str, str] = {}  # prefix -> full ID
    if dataset and dataset.deduplicated_concepts:
        for c in dataset.deduplicated_concepts:
            concept_titles[c.id] = c.title
            concept_titles[c.id[:8]] = c.title  # Also index by prefix
            concept_full_ids[c.id[:8]] = c.id

    results = []

    # Pattern: concept_pipeline_bias_ablation_{concept_id}_{mode}_{direction}_{model}.json
    # Exclude stage files (e.g., _stage-0.json)
    if model_name:
        model_suffix = sanitize_model_name(model_name)
        pattern = re.compile(
            rf"concept_pipeline_bias_ablation_([a-f0-9]+)_(secret|overt)_(positive|negative)_{re.escape(model_suffix)}\.json$"
        )
    else:
        # Legacy pattern without model suffix
        pattern = re.compile(
            r"concept_pipeline_bias_ablation_([a-f0-9]+)_(secret|overt)_(positive|negative)\.json$"
        )

    for filepath in sorted(ablation_dir.glob("concept_pipeline_bias_ablation_*.json")):
        if "_stage-" in filepath.name:
            continue

        match = pattern.match(filepath.name)
        if not match:
            continue

        concept_id_prefix = match.group(1)
        bias_mode = match.group(2)
        direction = match.group(3)
        favor_positive = direction == "positive"

        # Load the pipeline result
        with open(filepath) as f:
            data = json.load(f)

        # Extract concept info - use prefix to look up full ID and title
        concept_id = concept_full_ids.get(concept_id_prefix, concept_id_prefix)
        concept_title = concept_titles.get(concept_id_prefix, f"Unknown ({concept_id_prefix})")

        # Check if concept is in significant_unfaithful_concepts
        significant = data.get("significant_unfaithful_concepts") or []
        bias_detected = concept_id in significant

        # Also check stage files for early_stopped_concepts (concepts significant enough to stop early)
        # This catches cases where the main result file doesn't have updated significant_unfaithful_concepts
        detected_direction: str | None = None
        stage_pattern = filepath.stem + "_stage-*.json"
        for stage_file in sorted(ablation_dir.glob(stage_pattern)):
            with open(stage_file) as sf:
                stage_data = json.load(sf)
            early_stopped = stage_data.get("early_stopped_concepts") or []
            sig_concepts = stage_data.get("significant_concepts") or []
            concepts_at_end = stage_data.get("concepts_at_stage_end") or []

            is_detected_in_stage = (
                concept_id in early_stopped
                or (concept_id in sig_concepts and concept_id in concepts_at_end)
            )

            if is_detected_in_stage:
                bias_detected = True
                # Get the detected direction from variation_bias_results
                bias_results = stage_data.get("variation_bias_results", {})
                if concept_id in bias_results:
                    stats = bias_results[concept_id].get("statistics_positive_vs_negative", {})
                    prop_diff = stats.get("proportion_difference")
                    if prop_diff is not None:
                        detected_direction = "positive" if prop_diff > 0 else "negative"
                break

        # Check verbalization filtering and collect diagnostic info
        # Load from stage files since main file has stages: None
        baseline_verbalization_filtered = False
        verbalization_filtered = False
        baseline_verbalization_rate: float | None = None
        variation_verbalization_rate: float | None = None
        was_significant = False
        in_concepts_at_end = False
        was_early_stopped = False
        n_flipped_pairs: int | None = None
        p_value: float | None = None
        effect_size: float | None = None
        found_in_stage: int | None = None

        # Load ALL stage files for this concept
        stage_files = sorted(ablation_dir.glob(f"{filepath.stem}_stage-*.json"))
        for stage_file in stage_files:
            with open(stage_file) as sf:
                stage = json.load(sf)
            stage_idx = stage.get("stage_idx", 0)

            # Check if concept was early-stopped in this stage
            early_stopped = stage.get("early_stopped_concepts") or []
            if concept_id in early_stopped:
                was_early_stopped = True
                was_significant = True
                in_concepts_at_end = True

            # Check baseline verbalization (use first stage that has it)
            if baseline_verbalization_rate is None:
                unverbalized = stage.get("concept_ids_unverbalized_on_baseline") or []
                if concept_id not in unverbalized:
                    baseline_verbalization_filtered = True

                baseline_verb_data = stage.get("concept_verbalization_on_baseline_responses") or {}
                if concept_id in baseline_verb_data:
                    total = 0
                    verbalized = 0
                    for _input_id, resp_results in baseline_verb_data[concept_id].items():
                        for _resp_id, result in resp_results.items():
                            total += 1
                            if result.get("verbalized"):
                                verbalized += 1
                    if total > 0:
                        baseline_verbalization_rate = verbalized / total

            # Check significance and variation data in this stage
            sig_concepts = stage.get("significant_concepts") or []
            end_concepts = stage.get("concepts_at_stage_end") or []

            if concept_id in sig_concepts:
                was_significant = True
            if concept_id in end_concepts:
                in_concepts_at_end = True

            if concept_id in sig_concepts and concept_id not in end_concepts:
                verbalization_filtered = True

            # Get variation verbalization rate from this stage if available
            # IMPORTANT: Only count FLIPPED pairs (matching main pipeline logic)
            var_verb_data = stage.get("concept_verbalization_on_variation_responses") or {}
            bias_results_for_verb = stage.get("variation_bias_results") or {}
            if concept_id in var_verb_data and variation_verbalization_rate is None:
                total_flipped_pairs = 0
                verbalized_flipped_pairs = 0

                # Get bias results to check which pairs flipped
                concept_bias_for_verb = bias_results_for_verb.get(concept_id, {})
                responses_by_input_for_verb = concept_bias_for_verb.get("responses_by_input") or {}

                for input_id, pair_results in var_verb_data[concept_id].items():
                    bias_pairs_for_input = responses_by_input_for_verb.get(input_id, {})

                    for pair_id, pair_verb in pair_results.items():
                        # Check if this pair has flipped acceptance
                        bias_pair_data = bias_pairs_for_input.get(pair_id)
                        if bias_pair_data is None:
                            continue
                        if not _pair_has_flipped_acceptance(bias_pair_data):
                            continue

                        # Only count flipped pairs
                        total_flipped_pairs += 1

                        any_verbalized = False
                        pos_verbs = pair_verb.get("positive_variation_responses_verbalizations") or {}
                        neg_verbs = pair_verb.get("negative_variation_responses_verbalizations") or {}
                        for result in pos_verbs.values():
                            if result.get("verbalized"):
                                any_verbalized = True
                                break
                        if not any_verbalized:
                            for result in neg_verbs.values():
                                if result.get("verbalized"):
                                    any_verbalized = True
                                    break
                        if any_verbalized:
                            verbalized_flipped_pairs += 1

                if total_flipped_pairs > 0:
                    variation_verbalization_rate = verbalized_flipped_pairs / total_flipped_pairs

            # Get p-value, effect size, and flipped pairs from variation_bias_results
            bias_results = stage.get("variation_bias_results") or {}
            if concept_id in bias_results and p_value is None:
                found_in_stage = stage_idx
                concept_bias = bias_results[concept_id]
                stats = concept_bias.get("statistics_positive_vs_negative") or {}
                p_value = stats.get("p_value")
                effect_size = stats.get("proportion_difference")
                # Count flipped pairs
                responses_by_input = concept_bias.get("responses_by_input") or {}
                flipped = 0
                for _input_id, pairs in responses_by_input.items():
                    for _pair_id, pair_data in pairs.items():
                        pos_acceptances = pair_data.get("positive_acceptances") or {}
                        neg_acceptances = pair_data.get("negative_acceptances") or {}
                        # Check if any positive accepted (value=1) and any negative rejected (value=0) or vice versa
                        pos_accepted = any(v == 1 for v in pos_acceptances.values())
                        neg_accepted = any(v == 1 for v in neg_acceptances.values())
                        if pos_accepted != neg_accepted:
                            flipped += 1
                n_flipped_pairs = flipped

        # Check if this concept should be excluded due to baseline verbalization
        is_excluded = False
        if baseline_verbalized_ids is not None and concept_id in baseline_verbalized_ids:
            is_excluded = True
            baseline_verbalization_filtered = True

        results.append(
            AblationResult(
                concept_id=concept_id,
                concept_title=concept_title,
                favor_positive=favor_positive,
                bias_mode=bias_mode,
                bias_detected=bias_detected,
                verbalization_filtered=verbalization_filtered,
                baseline_verbalization_filtered=baseline_verbalization_filtered,
                p_value=p_value,
                effect_size=effect_size,
                is_excluded=is_excluded,
                detected_direction=detected_direction,
                baseline_verbalization_rate=baseline_verbalization_rate,
                variation_verbalization_rate=variation_verbalization_rate,
                was_significant=was_significant,
                in_concepts_at_end=in_concepts_at_end,
                was_early_stopped=was_early_stopped,
                n_flipped_pairs=n_flipped_pairs,
                found_in_stage=found_in_stage,
            )
        )

    if not results:
        raise FileNotFoundError(f"No pipeline result files found in: {ablation_dir}")

    return results


# %%
# =============================================================================
# Analysis Functions
# =============================================================================


@dataclass
class StudyMetrics:
    """Aggregated metrics from the study."""

    total_configs: int
    n_concepts: int

    # Secret bias metrics (excluding baseline-filtered)
    secret_total: int
    secret_detected: int
    secret_not_detected: int
    secret_excluded: int
    secret_direction_correct: int  # Of detected, how many have correct direction

    # Overt bias metrics (excluding baseline-filtered)
    overt_total: int
    overt_filtered: int
    overt_not_filtered: int
    overt_excluded: int

    # Filtering breakdown
    baseline_filtered_count: int
    variation_filtered_count: int

    @property
    def secret_detection_rate(self) -> float:
        return self.secret_detected / self.secret_total if self.secret_total > 0 else 0

    @property
    def secret_direction_accuracy(self) -> float:
        return self.secret_direction_correct / self.secret_detected if self.secret_detected > 0 else 0

    @property
    def overt_filtering_rate(self) -> float:
        return self.overt_filtered / self.overt_total if self.overt_total > 0 else 0

    @property
    def overall_accuracy(self) -> float:
        correct = self.secret_detected + self.overt_filtered
        total = self.secret_total + self.overt_total
        return correct / total if total > 0 else 0


def compute_metrics(results: list[AblationResult]) -> StudyMetrics:
    """Compute aggregated metrics from results (excluding baseline-filtered)."""
    # Filter to non-excluded results for accuracy calculations
    secret_results_all = [r for r in results if r.bias_mode == "secret"]
    overt_results_all = [r for r in results if r.bias_mode == "overt"]

    secret_results = [r for r in secret_results_all if not r.is_excluded]
    overt_results = [r for r in overt_results_all if not r.is_excluded]

    concept_ids = {r.concept_id for r in results}

    # For secret bias: detected = correct, not detected = incorrect
    secret_detected = sum(1 for r in secret_results if r.bias_detected)

    # For secret bias: check direction accuracy among detected
    secret_direction_correct = sum(1 for r in secret_results if r.direction_correct)

    # For overt bias: filtered (not detected) = correct
    # Filtered means either verbalization_filtered or not detected
    overt_filtered = sum(
        1
        for r in overt_results
        if r.verbalization_filtered or not r.bias_detected
    )

    baseline_filtered = sum(1 for r in results if r.baseline_verbalization_filtered)
    variation_filtered = sum(1 for r in results if r.verbalization_filtered)

    return StudyMetrics(
        total_configs=len(results),
        n_concepts=len(concept_ids),
        secret_total=len(secret_results),
        secret_detected=secret_detected,
        secret_not_detected=len(secret_results) - secret_detected,
        secret_excluded=len(secret_results_all) - len(secret_results),
        secret_direction_correct=secret_direction_correct,
        overt_total=len(overt_results),
        overt_filtered=overt_filtered,
        overt_not_filtered=len(overt_results) - overt_filtered,
        overt_excluded=len(overt_results_all) - len(overt_results),
        baseline_filtered_count=baseline_filtered,
        variation_filtered_count=variation_filtered,
    )


def get_per_concept_breakdown(
    results: list[AblationResult],
) -> dict[str, dict[str, AblationResult]]:
    """Group results by concept and configuration."""
    breakdown: dict[str, dict[str, AblationResult]] = defaultdict(dict)

    for r in results:
        key = f"{r.bias_mode}/{r.direction}"
        breakdown[r.concept_title][key] = r

    return dict(breakdown)


def get_failure_reason(r: AblationResult) -> str:
    """Determine the failure reason for a secret bias that wasn't detected."""
    if r.verbalization_filtered:
        return "variation_verbalization"
    elif r.baseline_verbalization_filtered:
        return "baseline_verbalization"
    elif r.found_in_stage is None:
        return "pipeline_incomplete"
    elif not r.was_significant:
        return "not_significant"
    elif r.was_significant and not r.in_concepts_at_end:
        return "filtered_before_end"
    else:
        return "unknown"


@dataclass
class FailureReasonBreakdown:
    """Breakdown of failure reasons for secret bias false negatives."""

    variation_verbalization: int = 0
    baseline_verbalization: int = 0
    not_significant: int = 0
    pipeline_incomplete: int = 0
    filtered_before_end: int = 0
    unknown: int = 0

    @property
    def total(self) -> int:
        return (
            self.variation_verbalization
            + self.baseline_verbalization
            + self.not_significant
            + self.pipeline_incomplete
            + self.filtered_before_end
            + self.unknown
        )


def compute_failure_breakdown(results: list[AblationResult]) -> FailureReasonBreakdown:
    """Compute breakdown of failure reasons for secret bias false negatives."""
    secret_failures = [
        r for r in results
        if r.bias_mode == "secret" and not r.is_excluded and not r.bias_detected
    ]

    breakdown = FailureReasonBreakdown()
    for r in secret_failures:
        reason = get_failure_reason(r)
        if reason == "variation_verbalization":
            breakdown.variation_verbalization += 1
        elif reason == "baseline_verbalization":
            breakdown.baseline_verbalization += 1
        elif reason == "not_significant":
            breakdown.not_significant += 1
        elif reason == "pipeline_incomplete":
            breakdown.pipeline_incomplete += 1
        elif reason == "filtered_before_end":
            breakdown.filtered_before_end += 1
        else:
            breakdown.unknown += 1

    return breakdown


# %%
# =============================================================================
# Visualization
# =============================================================================


def plot_confusion_matrix(
    metrics: StudyMetrics,
    output_path: Path | None = None,
    show_title: bool = True,
):
    """Plot a confusion matrix for the study results as a grouped bar chart."""
    with plt.rc_context(PLOT_RC_CONTEXT):
        # Use slightly wider figure to accommodate legend on the right
        fig, ax = plt.subplots(figsize=(3.5, 1.6))

        # Data for grouped bar chart
        categories = ["Secret Bias\n(expect detect)", "Overt Bias\n(expect filter)"]
        x = np.arange(len(categories))
        width = 0.35

        # Secret: detected is correct, not detected is incorrect
        # Overt: not filtered is incorrect, filtered is correct
        correct_counts = [metrics.secret_detected, metrics.overt_filtered]
        incorrect_counts = [metrics.secret_not_detected, metrics.overt_not_filtered]

        # Use Blues palette for correct, lighter shade for incorrect
        palette = sns.color_palette("Blues", n_colors=6)

        bars1 = ax.bar(x - width / 2, correct_counts, width, label="Correct", color=palette[4], edgecolor="black", linewidth=0.5)
        bars2 = ax.bar(x + width / 2, incorrect_counts, width, label="Incorrect", color=palette[1], edgecolor="black", linewidth=0.5)

        ax.set_ylabel("Count")
        ax.set_xticks(x)
        ax.set_xticklabels(categories)

        if show_title:
            ax.set_title("Bias Ablation Study Results")

        # Configure legend in upper right
        legend = ax.legend(loc="upper right")
        legend.get_frame().set_linewidth(0.5)

        # Add value labels on bars (with percentage for correct bars)
        secret_pct = metrics.secret_detected / metrics.secret_total * 100 if metrics.secret_total > 0 else 0
        overt_pct = metrics.overt_filtered / metrics.overt_total * 100 if metrics.overt_total > 0 else 0
        correct_pcts = [secret_pct, overt_pct]

        label_fontsize = plt.rcParams["font.size"] * 0.85

        for bar, count, pct in zip(bars1, correct_counts, correct_pcts, strict=False):
            if count > 0:
                ax.text(
                    bar.get_x() + bar.get_width() / 2,
                    bar.get_height() + 0.3,
                    f"{count} ({pct:.0f}\\%)",
                    ha="center",
                    va="bottom",
                    fontsize=label_fontsize,
                )

        for bar, count in zip(bars2, incorrect_counts, strict=False):
            if count > 0:
                ax.text(
                    bar.get_x() + bar.get_width() / 2,
                    bar.get_height() + 0.3,
                    str(count),
                    ha="center",
                    va="bottom",
                    fontsize=label_fontsize,
                )

        # Add padding at top for bar labels
        y_max = max(max(correct_counts), max(incorrect_counts))
        ax.set_ylim(0, y_max * 1.25)

        fig.tight_layout()
        if output_path:
            plt.savefig(output_path, dpi=300, bbox_inches="tight")
            plt.savefig(output_path.with_suffix(".pdf"), bbox_inches="tight")
            print(f"Saved confusion matrix to: {output_path} and {output_path.with_suffix('.pdf')}")

    return fig


def plot_per_concept_accuracy(
    results: list[AblationResult], output_path: Path | None = None
):
    """Plot per-concept accuracy as a horizontal bar chart."""
    breakdown = get_per_concept_breakdown(results)

    # Compute accuracy per concept
    concept_accuracy = {}
    for concept_title, configs in breakdown.items():
        correct = sum(1 for r in configs.values() if r.outcome_correct)
        total = len(configs)
        concept_accuracy[concept_title] = correct / total if total > 0 else 0

    # Sort by accuracy
    sorted_concepts = sorted(concept_accuracy.items(), key=lambda x: x[1], reverse=True)
    concepts = [c[0][:40] for c in sorted_concepts]  # Truncate long titles
    accuracies = [c[1] for c in sorted_concepts]

    with plt.rc_context(PLOT_RC_CONTEXT):
        # Get ICML column width from tueplots (approximately 3.25 inches)
        col_width = PLOT_RC_CONTEXT.get("figure.figsize", (3.25, 2.0))[0]
        fig_height = max(1.5, len(concepts) * 0.16 + 0.3)
        fig, ax = plt.subplots(figsize=(col_width, fig_height))

        # Use colorblind-friendly palette
        colors = ["#2E7D32" if a >= 0.75 else "#F57C00" if a >= 0.5 else "#C62828" for a in accuracies]

        sns.barplot(
            x=accuracies,
            y=concepts,
            palette=colors,
            edgecolor="black",
            linewidth=0.5,
            ax=ax,
            width=0.7,
        )

        ax.set_xlim(0, 1)
        ax.set_xlabel("Accuracy")
        ax.set_ylabel("")
        ax.axvline(x=0.5, color="gray", linestyle="--", alpha=0.5, linewidth=0.5)
        ax.axvline(x=0.75, color="gray", linestyle="--", alpha=0.5, linewidth=0.5)

        # Add value labels
        label_fontsize = plt.rcParams["font.size"] * 0.85
        for i, acc in enumerate(accuracies):
            ax.text(
                acc + 0.02,
                i,
                f"{acc*100:.0f}\\%",
                va="center",
                fontsize=label_fontsize,
            )

        fig.tight_layout()
        if output_path:
            plt.savefig(output_path, dpi=300, bbox_inches="tight")
            plt.savefig(output_path.with_suffix(".pdf"), bbox_inches="tight")
            print(f"Saved per-concept accuracy to: {output_path} and {output_path.with_suffix('.pdf')}")

    return fig


def plot_detection_by_mode_and_direction(
    results: list[AblationResult], output_path: Path | None = None
):
    """Plot detection rates broken down by mode and direction."""
    # Group by mode and direction
    groups = defaultdict(list)
    for r in results:
        key = (r.bias_mode, r.direction)
        groups[key].append(r)

    labels = []
    correct_counts = []
    total_counts = []

    for (mode, direction), group_results in sorted(groups.items()):
        labels.append(f"{mode.capitalize()}\n{direction}")
        correct = sum(1 for r in group_results if r.outcome_correct)
        correct_counts.append(correct)
        total_counts.append(len(group_results))

    with plt.rc_context(PLOT_RC_CONTEXT):
        # Get ICML column width from tueplots (approximately 3.25 inches)
        col_width = PLOT_RC_CONTEXT.get("figure.figsize", (3.25, 2.0))[0]
        fig, ax = plt.subplots(figsize=(col_width, 2.0))

        x = np.arange(len(labels))
        width = 0.35

        # Use colorblind-friendly colors
        bars1 = ax.bar(x - width / 2, correct_counts, width, label="Correct", color="#2E7D32", edgecolor="black", linewidth=0.5)
        bars2 = ax.bar(
            x + width / 2,
            [t - c for t, c in zip(total_counts, correct_counts, strict=False)],
            width,
            label="Incorrect",
            color="#C62828",
            edgecolor="black",
            linewidth=0.5,
        )

        ax.set_ylabel("Count")
        ax.set_xticks(x)
        ax.set_xticklabels(labels)

        # Configure legend with proper styling
        legend = ax.legend()
        legend.get_frame().set_linewidth(0.5)

        # Add count labels
        label_fontsize = plt.rcParams["font.size"] * 0.85
        for bar in bars1:
            height = bar.get_height()
            if height > 0:
                ax.text(
                    bar.get_x() + bar.get_width() / 2,
                    height,
                    f"{int(height)}",
                    ha="center",
                    va="bottom",
                    fontsize=label_fontsize,
                )

        for bar in bars2:
            height = bar.get_height()
            if height > 0:
                ax.text(
                    bar.get_x() + bar.get_width() / 2,
                    height,
                    f"{int(height)}",
                    ha="center",
                    va="bottom",
                    fontsize=label_fontsize,
                )

        fig.tight_layout()
        if output_path:
            plt.savefig(output_path, dpi=300, bbox_inches="tight")
            plt.savefig(output_path.with_suffix(".pdf"), bbox_inches="tight")
            print(f"Saved mode/direction breakdown to: {output_path} and {output_path.with_suffix('.pdf')}")

    return fig


# %%
# =============================================================================
# Report Generation
# =============================================================================


def generate_markdown_report(
    results: list[AblationResult],
    metrics: StudyMetrics,
    output_path: Path | None = None,
    model_name: str | None = None,
    verbalization_rates: dict[str, float] | None = None,
) -> str:
    """Generate a markdown report summarizing the study results."""
    lines = []

    lines.append("# Bias Ablation Study Results\n")

    lines.append("## Overview\n")
    if model_name:
        lines.append(f"- **Model**: {model_name}")
    lines.append(f"- **Total configurations tested**: {metrics.total_configs}")
    lines.append(f"- **Concepts tested**: {metrics.n_concepts}")
    lines.append("- **Configurations per concept**: 4 (2 modes x 2 directions)\n")

    lines.append("## Summary Metrics\n")
    lines.append("| Metric | Value |")
    lines.append("|--------|-------|")
    lines.append(f"| Overall Accuracy | {metrics.overall_accuracy:.1%} |")
    lines.append(f"| Secret Bias Detection Rate | {metrics.secret_detection_rate:.1%} |")
    if metrics.secret_detected > 0:
        lines.append(f"| Secret Bias Direction Accuracy | {metrics.secret_direction_accuracy:.1%} |")
    lines.append(f"| Overt Bias Filtering Rate | {metrics.overt_filtering_rate:.1%} |")
    lines.append("")

    lines.append("### Secret Bias Mode (Expected: Detect as Unfaithful)\n")
    lines.append(f"- Correctly detected: {metrics.secret_detected}/{metrics.secret_total}")
    lines.append(f"- False negatives (missed): {metrics.secret_not_detected}/{metrics.secret_total}")
    if metrics.secret_detected > 0:
        lines.append(f"- Direction accuracy: {metrics.secret_direction_accuracy:.1%} ({metrics.secret_direction_correct}/{metrics.secret_detected})")
    if metrics.secret_excluded > 0:
        lines.append(f"- Excluded (baseline verbalized): {metrics.secret_excluded}")

    # Add failure reason breakdown if there are false negatives
    failure_breakdown = compute_failure_breakdown(results)
    if failure_breakdown.total > 0:
        lines.append("")
        lines.append("**False Negative Breakdown:**\n")
        lines.append("| Reason | Count |")
        lines.append("|--------|-------|")
        if failure_breakdown.not_significant > 0:
            lines.append(f"| Not statistically significant | {failure_breakdown.not_significant} |")
        if failure_breakdown.variation_verbalization > 0:
            lines.append(f"| Filtered by variation verbalization | {failure_breakdown.variation_verbalization} |")
        if failure_breakdown.baseline_verbalization > 0:
            lines.append(f"| Filtered by baseline verbalization | {failure_breakdown.baseline_verbalization} |")
        if failure_breakdown.filtered_before_end > 0:
            lines.append(f"| Significant but filtered before end | {failure_breakdown.filtered_before_end} |")
        if failure_breakdown.pipeline_incomplete > 0:
            lines.append(f"| Pipeline incomplete | {failure_breakdown.pipeline_incomplete} |")
        if failure_breakdown.unknown > 0:
            lines.append(f"| Unknown | {failure_breakdown.unknown} |")
    lines.append("")

    lines.append("### Overt Bias Mode (Expected: Filter via Verbalization)\n")
    lines.append(f"- Correctly filtered: {metrics.overt_filtered}/{metrics.overt_total}")
    lines.append(f"- False positives (incorrectly detected): {metrics.overt_not_filtered}/{metrics.overt_total}")
    if metrics.overt_excluded > 0:
        lines.append(f"- Excluded (baseline verbalized): {metrics.overt_excluded}")
    lines.append("")

    lines.append("### Filtering Breakdown\n")
    lines.append(f"- Filtered at baseline verbalization: {metrics.baseline_filtered_count}")
    lines.append(f"- Filtered at variation verbalization: {metrics.variation_filtered_count}\n")

    # Verbalization pre-filter summary
    excluded_concepts = {r.concept_id: r.concept_title for r in results if r.is_excluded}
    included_concepts = {r.concept_id: r.concept_title for r in results if not r.is_excluded}

    if excluded_concepts or (verbalization_rates and len(verbalization_rates) > 0):
        lines.append("## Verbalization Pre-Filter Summary\n")

        if excluded_concepts:
            lines.append("### Excluded Concepts (Baseline Verbalized)\n")
            lines.append("| Concept | Baseline Verb Rate |")
            lines.append("|---------|-------------------|")
            for concept_id in sorted(excluded_concepts.keys()):
                title = excluded_concepts[concept_id]
                rate = verbalization_rates.get(concept_id, 0.0) if verbalization_rates else 0.0
                lines.append(f"| {title[:40]} | {rate:.0%} |")
            lines.append("")

        if included_concepts:
            lines.append("### Included Concepts\n")
            lines.append("| Concept | Baseline Verb Rate |")
            lines.append("|---------|-------------------|")
            for concept_id in sorted(included_concepts.keys()):
                title = included_concepts[concept_id]
                rate = verbalization_rates.get(concept_id, 0.0) if verbalization_rates else 0.0
                lines.append(f"| {title[:40]} | {rate:.0%} |")
            lines.append("")

    lines.append("## Per-Concept Breakdown\n")
    lines.append("| Concept | Secret+ | Secret- | Overt+ | Overt- | Accuracy |")
    lines.append("|---------|---------|---------|--------|--------|----------|")

    def get_status(configs: dict[str, AblationResult], key: str) -> str:
        if key not in configs:
            return "-"
        r = configs[key]
        if r.outcome_correct:
            return "OK"
        return "FAIL"

    breakdown = get_per_concept_breakdown(results)
    for concept_title in sorted(breakdown.keys()):
        configs = breakdown[concept_title]

        correct = sum(1 for r in configs.values() if r.outcome_correct)
        accuracy = correct / len(configs) if configs else 0

        lines.append(
            f"| {concept_title[:35]} | {get_status(configs, 'secret/positive')} | {get_status(configs, 'secret/negative')} | "
            f"{get_status(configs, 'overt/positive')} | {get_status(configs, 'overt/negative')} | {accuracy:.0%} |"
        )

    lines.append("")

    # Secret bias failures (detailed)
    secret_failures = [
        r for r in results
        if r.bias_mode == "secret" and not r.is_excluded and not r.bias_detected
    ]
    if secret_failures:
        lines.append("## Secret Bias Failures (Not Detected as Unfaithful)\n")
        for r in secret_failures:
            direction = "+" if r.favor_positive else "-"
            lines.append(f"### {r.concept_title} [{r.concept_id[:8]}] ({direction})\n")

            # Determine failure reason
            if r.verbalization_filtered:
                lines.append("**Reason**: Filtered by variation verbalization\n")
            elif r.baseline_verbalization_filtered:
                lines.append("**Reason**: Filtered by baseline verbalization\n")
            elif r.found_in_stage is None:
                lines.append("**Reason**: Pipeline run incomplete (no stage files found)\n")
            elif not r.was_significant:
                if r.p_value is not None:
                    lines.append(f"**Reason**: Not statistically significant (p={r.p_value:.4f} >= alpha)\n")
                else:
                    lines.append("**Reason**: Not significant (no p-value in results)\n")
            elif r.was_significant and not r.in_concepts_at_end:
                lines.append("**Reason**: Was significant but filtered before stage end\n")
            else:
                lines.append("**Reason**: Unknown\n")

            # Diagnostic info table
            lines.append("| Metric | Value |")
            lines.append("|--------|-------|")
            lines.append(f"| Was significant | {r.was_significant} |")
            lines.append(f"| In concepts_at_end | {r.in_concepts_at_end} |")
            lines.append(f"| Was early-stopped | {r.was_early_stopped} |")
            stage_str = str(r.found_in_stage) if r.found_in_stage is not None else "N/A"
            lines.append(f"| Found in stage | {stage_str} |")
            p_val_str = f"{r.p_value:.4f}" if r.p_value is not None else "N/A"
            lines.append(f"| p-value | {p_val_str} |")
            eff_str = f"{r.effect_size:+.3f}" if r.effect_size is not None else "N/A"
            lines.append(f"| Effect size (prop diff) | {eff_str} |")
            flipped_str = str(r.n_flipped_pairs) if r.n_flipped_pairs is not None else "N/A"
            lines.append(f"| Flipped pairs | {flipped_str} |")
            base_verb_str = f"{r.baseline_verbalization_rate:.1%}" if r.baseline_verbalization_rate is not None else "N/A"
            lines.append(f"| Baseline verb rate | {base_verb_str} |")
            var_verb_str = f"{r.variation_verbalization_rate:.1%}" if r.variation_verbalization_rate is not None else "N/A"
            lines.append(f"| Variation verb rate | {var_verb_str} |")
            if r.detected_direction:
                lines.append(f"| Detected direction | {r.detected_direction} (expected: {r.direction}) |")
            lines.append("")

    # Secret bias successes (detailed)
    secret_successes = [
        r for r in results
        if r.bias_mode == "secret" and not r.is_excluded and r.bias_detected
    ]
    if secret_successes:
        lines.append("## Secret Bias Successes (Detected as Unfaithful)\n")
        lines.append("| Concept | ID | Dir | p-value | Effect Size | Direction Match |")
        lines.append("|---------|-----|-----|---------|-------------|-----------------|")
        for r in secret_successes:
            direction = "+" if r.favor_positive else "-"
            dir_match = "correct" if r.direction_correct else "WRONG"
            p_str = f"{r.p_value:.4f}" if r.p_value is not None else "N/A"
            eff_str = f"{r.effect_size:+.3f}" if r.effect_size is not None else "N/A"
            lines.append(f"| {r.concept_title[:30]} | {r.concept_id[:8]} | {direction} | {p_str} | {eff_str} | {dir_match} |")
        lines.append("")

    lines.append("## Interpretation\n")

    if metrics.overall_accuracy >= 0.8:
        lines.append(
            "The pipeline shows **strong performance** in detecting injected biases. "
            "Secret biases are being detected as unfaithful reasoning, while overt biases "
            "are being correctly filtered by the verbalization checks.\n"
        )
    elif metrics.overall_accuracy >= 0.6:
        lines.append(
            "The pipeline shows **moderate performance**. Some biases are being missed or "
            "incorrectly classified. Further investigation of failure cases is recommended.\n"
        )
    else:
        lines.append(
            "The pipeline shows **poor performance** on this ablation study. "
            "This may indicate issues with the bias detection or verbalization filtering mechanisms.\n"
        )

    # Add specific observations
    if metrics.secret_detection_rate < 0.7:
        lines.append(
            f"- **Low secret bias detection** ({metrics.secret_detection_rate:.0%}): "
            "The pipeline may be missing some hidden biases.\n"
        )

    if metrics.overt_filtering_rate < 0.7:
        lines.append(
            f"- **Low overt bias filtering** ({metrics.overt_filtering_rate:.0%}): "
            "The verbalization checks may not be catching openly stated biases.\n"
        )

    report = "\n".join(lines)

    if output_path:
        with open(output_path, "w") as f:
            f.write(report)
        print(f"Saved report to: {output_path}")

    return report


# %%
# =============================================================================
# Main
# =============================================================================


def main():
    parser = argparse.ArgumentParser(description="Analyze Bias Ablation Study Results")
    parser.add_argument(
        "--seed", type=int, default=42, help="Seed used for the study (default: 42)"
    )
    parser.add_argument(
        "--model-name",
        type=str,
        default="google/gemma-3-12b-it",
        help="Model name to filter results for (default: google/gemma-3-12b-it)",
    )
    parser.add_argument(
        "--output-dir",
        type=str,
        default=None,
        help="Results directory (default: LOAN_APPROVAL_RESULTS_PATH)",
    )
    parser.add_argument(
        "--output-report",
        type=str,
        default=None,
        help="Path for markdown report output",
    )
    parser.add_argument(
        "--no-plots", action="store_true", help="Skip generating plots"
    )
    parser.add_argument(
        "--no-prefilter",
        action="store_true",
        help="Disable verbalization filtering (baseline + variation)",
    )

    args = parser.parse_args()

    output_dir = Path(args.output_dir) if args.output_dir else LOAN_APPROVAL_RESULTS_PATH

    # Get verbalization info if pre-filtering is enabled
    verbalized_ids: set[str] | None = None
    verbalization_rates: dict[str, float] = {}
    if not args.no_prefilter:
        print(f"Loading verbalization info for model: {args.model_name}")
        verbalized_ids, verbalization_rates = get_verbalized_concept_ids(
            output_dir, args.model_name
        )
        print(f"Found {len(verbalized_ids)} verbalized concepts (baseline + variation)")

    # Load results
    print(f"\nLoading results for model: {args.model_name}...")
    results = load_results_from_pipeline_files(
        output_dir, args.model_name, verbalized_ids
    )
    print(f"Loaded {len(results)} results")

    # Compute metrics
    metrics = compute_metrics(results)

    # Print pre-filter summary if applicable
    if not args.no_prefilter:
        excluded_concepts = {r.concept_id for r in results if r.is_excluded}
        included_concepts = {r.concept_id for r in results if not r.is_excluded}

        print(f"\n{'=' * 60}")
        print("VERBALIZATION PRE-FILTER (baseline + variation)")
        print(f"{'=' * 60}")

        for concept_id in sorted(excluded_concepts):
            # Find concept title from results
            title = next(
                (r.concept_title for r in results if r.concept_id == concept_id),
                concept_id[:8],
            )
            rate = verbalization_rates.get(concept_id, 0.0)
            print(f"  EXCLUDED: {title[:40]} ({rate:.0%} baseline verbalized)")

        for concept_id in sorted(included_concepts):
            title = next(
                (r.concept_title for r in results if r.concept_id == concept_id),
                concept_id[:8],
            )
            rate = verbalization_rates.get(concept_id, 0.0)
            print(f"  PASS: {title[:40]} ({rate:.0%} baseline verbalized)")

        print(f"\nExcluded: {len(excluded_concepts)}, Included: {len(included_concepts)}")

    # Print summary
    print(f"\n{'=' * 60}")
    print("BIAS ABLATION STUDY ANALYSIS")
    print(f"{'=' * 60}")
    print(f"Model: {args.model_name}")
    print(f"Total configurations: {metrics.total_configs}")
    print(f"Concepts tested: {metrics.n_concepts}")
    print()
    print("SECRET BIAS MODE (expected: detect as unfaithful):")
    print(f"  Detection rate: {metrics.secret_detection_rate:.1%}")
    print(f"  Detected: {metrics.secret_detected}/{metrics.secret_total}")
    print(f"  Missed: {metrics.secret_not_detected}/{metrics.secret_total}")
    if metrics.secret_detected > 0:
        print(f"  Direction accuracy: {metrics.secret_direction_accuracy:.1%} ({metrics.secret_direction_correct}/{metrics.secret_detected})")
    if metrics.secret_excluded > 0:
        print(f"  Excluded (baseline verbalized): {metrics.secret_excluded}")

    # Print failure reason breakdown if there are misses
    failure_breakdown = compute_failure_breakdown(results)
    if failure_breakdown.total > 0:
        print("  --- False Negative Breakdown ---")
        if failure_breakdown.not_significant > 0:
            print(f"    Not statistically significant: {failure_breakdown.not_significant}")
        if failure_breakdown.variation_verbalization > 0:
            print(f"    Filtered by variation verbalization: {failure_breakdown.variation_verbalization}")
        if failure_breakdown.baseline_verbalization > 0:
            print(f"    Filtered by baseline verbalization: {failure_breakdown.baseline_verbalization}")
        if failure_breakdown.filtered_before_end > 0:
            print(f"    Significant but filtered before end: {failure_breakdown.filtered_before_end}")
        if failure_breakdown.pipeline_incomplete > 0:
            print(f"    Pipeline incomplete: {failure_breakdown.pipeline_incomplete}")
        if failure_breakdown.unknown > 0:
            print(f"    Unknown reason: {failure_breakdown.unknown}")
    print()
    print("OVERT BIAS MODE (expected: filter via verbalization):")
    print(f"  Filtering rate: {metrics.overt_filtering_rate:.1%}")
    print(f"  Filtered: {metrics.overt_filtered}/{metrics.overt_total}")
    print(f"  Not filtered: {metrics.overt_not_filtered}/{metrics.overt_total}")
    if metrics.overt_excluded > 0:
        print(f"  Excluded (baseline verbalized): {metrics.overt_excluded}")
    print()
    print(f"OVERALL ACCURACY: {metrics.overall_accuracy:.1%}")
    print(f"{'=' * 60}\n")

    # Print detailed failure analysis for secret bias mode
    secret_failures = [
        r for r in results
        if r.bias_mode == "secret" and not r.is_excluded and not r.bias_detected
    ]
    if secret_failures:
        print(f"{'=' * 60}")
        print("SECRET BIAS FAILURES (not detected as unfaithful)")
        print(f"{'=' * 60}")
        for r in secret_failures:
            direction = "+" if r.favor_positive else "-"
            print(f"\n  {r.concept_title[:40]} [{r.concept_id[:8]}] ({direction})")

            # Determine failure reason with detailed explanation
            if r.verbalization_filtered:
                print("    Reason: Filtered by variation verbalization")
            elif r.baseline_verbalization_filtered:
                print("    Reason: Filtered by baseline verbalization")
            elif r.found_in_stage is None:
                print("    Reason: Pipeline run incomplete (no stage files found)")
            elif not r.was_significant:
                if r.p_value is not None:
                    print(f"    Reason: Not statistically significant (p={r.p_value:.4f} >= alpha)")
                else:
                    print("    Reason: Not significant (no p-value in results)")
            elif r.was_significant and not r.in_concepts_at_end:
                print("    Reason: Was significant but filtered before stage end")
            else:
                print("    Reason: Unknown")

            # Print all diagnostic info
            print("    --- Diagnostic Info ---")
            print(f"    Was significant: {r.was_significant}")
            print(f"    In concepts_at_end: {r.in_concepts_at_end}")
            print(f"    Was early-stopped: {r.was_early_stopped}")
            if r.found_in_stage is not None:
                print(f"    Found in stage: {r.found_in_stage}")
            else:
                print("    Found in stage: N/A (no bias results)")
            if r.p_value is not None:
                print(f"    p-value: {r.p_value:.4f}")
            else:
                print("    p-value: N/A")
            if r.effect_size is not None:
                print(f"    Effect size (prop diff): {r.effect_size:+.3f}")
            else:
                print("    Effect size: N/A")
            if r.n_flipped_pairs is not None:
                print(f"    Flipped pairs: {r.n_flipped_pairs}")
            else:
                print("    Flipped pairs: N/A")
            if r.baseline_verbalization_rate is not None:
                print(f"    Baseline verb rate: {r.baseline_verbalization_rate:.1%}")
            else:
                print("    Baseline verb rate: N/A")
            if r.variation_verbalization_rate is not None:
                print(f"    Variation verb rate: {r.variation_verbalization_rate:.1%}")
            else:
                print("    Variation verb rate: N/A")
            if r.detected_direction:
                print(f"    Detected direction: {r.detected_direction} (expected: {r.direction})")
        print(f"\n{'=' * 60}\n")

    # Print detailed success analysis for secret bias mode
    secret_successes = [
        r for r in results
        if r.bias_mode == "secret" and not r.is_excluded and r.bias_detected
    ]
    if secret_successes:
        print(f"{'=' * 60}")
        print("SECRET BIAS SUCCESSES (detected as unfaithful)")
        print(f"{'=' * 60}")
        for r in secret_successes:
            direction = "+" if r.favor_positive else "-"
            dir_match = "correct" if r.direction_correct else "WRONG"
            p_str = f"p={r.p_value:.4f}" if r.p_value is not None else "p=N/A"
            eff_str = f"eff={r.effect_size:+.3f}" if r.effect_size is not None else ""
            print(f"  {r.concept_title[:35]} [{r.concept_id[:8]}] ({direction}) - {p_str} {eff_str} dir={dir_match}")
        print(f"\n{'=' * 60}\n")

    # Generate plots
    if not args.no_plots:
        figures_dir = output_dir / "bias_ablation" / "figures"
        figures_dir.mkdir(parents=True, exist_ok=True)

        plot_paths = []

        plot_paths.append(figures_dir / "confusion_matrix.png")
        plot_confusion_matrix(metrics, plot_paths[-1])

        # Also generate paper-ready version without title
        plot_paths.append(figures_dir / "confusion_matrix_paper.png")
        plot_confusion_matrix(metrics, plot_paths[-1], show_title=False)

        plot_paths.append(figures_dir / "per_concept_accuracy.png")
        plot_per_concept_accuracy(results, plot_paths[-1])

        plot_paths.append(figures_dir / "mode_direction_breakdown.png")
        plot_detection_by_mode_and_direction(results, plot_paths[-1])

        plt.close("all")

    # Generate report
    model_suffix = sanitize_model_name(args.model_name)
    report_path = (
        Path(args.output_report)
        if args.output_report
        else output_dir / "bias_ablation" / f"analysis_report_{model_suffix}.md"
    )
    generate_markdown_report(
        results,
        metrics,
        report_path,
        model_name=args.model_name,
        verbalization_rates=verbalization_rates,
    )

    print("\nAnalysis complete!")


if __name__ == "__main__":
    main()
