#!/usr/bin/env python3
import logging
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path
from typing import List

logger = logging.getLogger(__name__)


def _calculate_kaplan_meier_ablation(agent_outcomes: pd.DataFrame) -> pd.DataFrame:
    if len(agent_outcomes) == 0:
        return pd.DataFrame(
            {"step": [], "survival": [], "ci_lower": [], "ci_upper": []}
        )
    risk_horizon = (
        agent_outcomes["risk_horizon"].iloc[0] if len(agent_outcomes) > 0 else 19
    )
    survival_data = []
    current_survival = 1.0
    cumulative_variance = 0.0
    for step in range(0, risk_horizon + 2):
        events_at_step = len(
            agent_outcomes[
                (agent_outcomes["event_ate_early"] == 1)
                & (agent_outcomes["tte_step"] == step)
            ]
        )
        if step <= risk_horizon:
            at_risk = len(agent_outcomes[(agent_outcomes["tte_step"] >= step)])
            hazard = events_at_step / at_risk if at_risk > 0 else 0
            current_survival *= 1 - hazard
            if at_risk > 0 and at_risk > events_at_step:
                cumulative_variance += events_at_step / (
                    at_risk * (at_risk - events_at_step)
                )
        else:
            at_risk = 0
        if cumulative_variance > 0 and current_survival > 0:
            se_survival = current_survival * np.sqrt(cumulative_variance)
            if current_survival < 1.0:
                log_log_se = se_survival / (current_survival * np.log(current_survival))
                log_log_survival = np.log(-np.log(current_survival))
                ci_lower_log_log = log_log_survival - 1.96 * log_log_se
                ci_upper_log_log = log_log_survival + 1.96 * log_log_se
                ci_lower = np.exp(-np.exp(ci_upper_log_log))
                ci_upper = np.exp(-np.exp(ci_lower_log_log))
                ci_lower = max(0, min(1, ci_lower))
                ci_upper = max(0, min(1, ci_upper))
            else:
                ci_lower = ci_upper = 1.0
        else:
            ci_lower = ci_upper = current_survival
        survival_data.append(
            {
                "step": step,
                "survival": current_survival,
                "events": events_at_step,
                "at_risk": at_risk,
                "ci_lower": ci_lower,
                "ci_upper": ci_upper,
            }
        )
    result_df = pd.DataFrame(survival_data)
    return result_df


def _create_ablation_survival_curves(
    agent_outcomes: pd.DataFrame,
    analysis_dir: Path,
    social_influence: bool,
    survival_prob_range: List[float],
):
    logger.info("Creating ablation survival curves...")
    ablation_conditions = [
        {"name": "hedonic_none", "filter": {"hedonic": "none"}},
        {"name": "policy_role_persona", "filter": {"persona_age": "none"}},
        {
            "name": "hedonic_none_and_policy_role",
            "filter": {"hedonic": "none", "persona_age": "none"},
        },
    ]
    has_policy_split = False
    policy_types = []
    if "tool_use_policy" in agent_outcomes.columns:
        policy_types = agent_outcomes["tool_use_policy"].unique()
        policy_types = [p for p in policy_types if pd.notna(p)]
        if len(policy_types) > 1:
            has_policy_split = True
    if social_influence:
        if "communication_type" not in agent_outcomes.columns:
            logger.warning(
                "communication_type not found - cannot create social influence ablation"
            )
            return
        comm_types = agent_outcomes["communication_type"].unique()
        comm_types = [ct for ct in comm_types if pd.notna(ct)]
        if len(comm_types) < 2:
            logger.warning(
                "Need both broadcast and isolated for social influence ablation"
            )
            return
        n_rows = 4 if has_policy_split else 2
        n_cols = 2
        fig1, axes1 = plt.subplots(
            n_rows, n_cols, figsize=(16, 6 * n_rows), squeeze=False
        )
        axes1 = axes1.flatten()
        full_condition_base = agent_outcomes[
            (agent_outcomes["is_valid_outcome"] == 1)
            & (agent_outcomes["hedonic"] != "none")
            & (agent_outcomes["persona_age"] != "none")
        ].copy()
        for i, condition in enumerate(ablation_conditions):
            ax = axes1[i]
            for comm_type in comm_types:
                full_condition_agents = full_condition_base[
                    full_condition_base["communication_type"] == comm_type
                ].copy()
                if len(full_condition_agents) == 0:
                    continue
                survival_curve = _calculate_kaplan_meier_ablation(full_condition_agents)
                label = f"Full Factors ({comm_type})"
                linestyle = "-" if comm_type == "broadcast" else "--"
                line = ax.step(
                    survival_curve["step"],
                    survival_curve["survival"],
                    where="post",
                    linewidth=1.5,
                    linestyle=linestyle,
                    label=label,
                    color="gray",
                )[0]
                ax.fill_between(
                    survival_curve["step"],
                    survival_curve["ci_lower"],
                    survival_curve["ci_upper"],
                    alpha=0.15,
                    color=line.get_color(),
                    step="post",
                )
            line_color = None
            for comm_type in comm_types:
                condition_agents = agent_outcomes[
                    (agent_outcomes["is_valid_outcome"] == 1)
                    & (agent_outcomes["communication_type"] == comm_type)
                ].copy()
                for factor, value in condition["filter"].items():
                    condition_agents = condition_agents[
                        condition_agents[factor] == value
                    ]
                if len(condition_agents) == 0:
                    continue
                survival_curve = _calculate_kaplan_meier_ablation(condition_agents)
                label = f'{condition["name"].replace("_", " ").title()} ({comm_type})'
                linestyle = "-" if comm_type == "broadcast" else "--"
                plot_args = {
                    "linewidth": 2,
                    "linestyle": linestyle,
                    "label": label,
                }
                if line_color:
                    plot_args["color"] = line_color
                line = ax.step(
                    survival_curve["step"],
                    survival_curve["survival"],
                    where="post",
                    **plot_args,
                )[0]
                if line_color is None:
                    line_color = line.get_color()
                ax.fill_between(
                    survival_curve["step"],
                    survival_curve["ci_lower"],
                    survival_curve["ci_upper"],
                    alpha=0.2,
                    color=line.get_color(),
                    step="post",
                )
            ax.set_xlabel("Step")
            ax.set_ylabel("Survival Probability")
            ax.set_title(
                f'{condition["name"].replace("_", " ").title()} Ablation\n(with Social Context)'
            )
            ax.legend()
            ax.grid(True, alpha=0.3)
            ax.set_ylim(survival_prob_range[0], 1.05)
            ax.set_yticks(np.arange(survival_prob_range[0], 1.01, 0.05))
            ax.text(
                0.02,
                0.98,
                "Solid = Broadcast\nDashed = Isolated",
                transform=ax.transAxes,
                fontsize=9,
                verticalalignment="top",
                bbox=dict(boxstyle="round", facecolor="wheat", alpha=0.8),
            )
        ax = axes1[3]
        all_conditions_to_plot = [
            {"name": "Full Factors", "filter": {}, "base": True}
        ] + ablation_conditions
        for condition in all_conditions_to_plot:
            line_color = None
            for comm_type in comm_types:
                if condition.get("base"):
                    condition_agents = full_condition_base[
                        full_condition_base["communication_type"] == comm_type
                    ].copy()
                else:
                    condition_agents = agent_outcomes[
                        (agent_outcomes["is_valid_outcome"] == 1)
                        & (agent_outcomes["communication_type"] == comm_type)
                    ].copy()
                    for factor, value in condition["filter"].items():
                        condition_agents = condition_agents[
                            condition_agents[factor] == value
                        ]
                if len(condition_agents) == 0:
                    continue
                survival_curve = _calculate_kaplan_meier_ablation(condition_agents)
                linestyle = "-" if comm_type == "broadcast" else "--"
                label = f'{condition["name"].replace("_", " ").title()} ({comm_type})'
                marker = "x" if condition.get("base") else "o"
                linewidth = 1.5 if condition.get("base") else 2
                plot_args = {
                    "linewidth": linewidth,
                    "linestyle": linestyle,
                    "label": label,
                }
                if condition.get("base"):
                    plot_args["color"] = "gray"
                elif line_color:
                    plot_args["color"] = line_color
                line = ax.step(
                    survival_curve["step"],
                    survival_curve["survival"],
                    where="post",
                    **plot_args,
                )[0]
                if line_color is None and not condition.get("base"):
                    line_color = line.get_color()
                ax.fill_between(
                    survival_curve["step"],
                    survival_curve["ci_lower"],
                    survival_curve["ci_upper"],
                    alpha=0.1,
                    color=line.get_color(),
                    step="post",
                )
        ax.set_xlabel("Step")
        ax.set_ylabel("Survival Probability")
        ax.set_title("Summary: All Ablation Conditions\n(with Social Context)")
        ax.legend(fontsize=8)
        ax.grid(True, alpha=0.3)
        ax.set_ylim(survival_prob_range[0], 1.05)
        ax.set_yticks(np.arange(survival_prob_range[0], 1.01, 0.05))
        ax.text(
            0.02,
            0.98,
            "Solid = Broadcast\nDashed = Isolated",
            transform=ax.transAxes,
            fontsize=9,
            verticalalignment="top",
            bbox=dict(boxstyle="round", facecolor="wheat", alpha=0.8),
        )
        if has_policy_split:
            for i, condition in enumerate(ablation_conditions):
                ax = axes1[i + 4]
                for policy_type in policy_types:
                    line_color_full = None
                    for comm_type in comm_types:
                        full_condition_agents = full_condition_base[
                            (full_condition_base["communication_type"] == comm_type)
                            & (full_condition_base["tool_use_policy"] == policy_type)
                        ].copy()
                        if len(full_condition_agents) > 0:
                            survival_curve = _calculate_kaplan_meier_ablation(
                                full_condition_agents
                            )
                            label = f"Full ({comm_type}, {policy_type})"
                            linestyle = "-" if comm_type == "broadcast" else "--"
                            plot_args = {
                                "marker": "x",
                                "linewidth": 1.5,
                                "linestyle": linestyle,
                                "label": label,
                            }
                            if line_color_full:
                                plot_args["color"] = line_color_full
                            line = ax.step(
                                survival_curve["step"],
                                survival_curve["survival"],
                                where="post",
                                **plot_args,
                            )[0]
                            if line_color_full is None:
                                line_color_full = line.get_color()
                            ax.fill_between(
                                survival_curve["step"],
                                survival_curve["ci_lower"],
                                survival_curve["ci_upper"],
                                alpha=0.15,
                                color=line.get_color(),
                                step="post",
                            )
                    line_color_ablated = None
                    for comm_type in comm_types:
                        condition_agents = agent_outcomes[
                            (agent_outcomes["is_valid_outcome"] == 1)
                            & (agent_outcomes["communication_type"] == comm_type)
                            & (agent_outcomes["tool_use_policy"] == policy_type)
                        ].copy()
                        for factor, value in condition["filter"].items():
                            condition_agents = condition_agents[
                                condition_agents[factor] == value
                            ]
                        if len(condition_agents) > 0:
                            survival_curve = _calculate_kaplan_meier_ablation(
                                condition_agents
                            )
                            label = f'{condition["name"].replace("_", " ").title()} ({comm_type}, {policy_type})'
                            linestyle = "-" if comm_type == "broadcast" else "--"
                            plot_args = {
                                "linewidth": 2,
                                "linestyle": linestyle,
                                "label": label,
                            }
                            if line_color_ablated:
                                plot_args["color"] = line_color_ablated
                            line = ax.step(
                                survival_curve["step"],
                                survival_curve["survival"],
                                where="post",
                                **plot_args,
                            )[0]
                            if line_color_ablated is None:
                                line_color_ablated = line.get_color()
                            ax.fill_between(
                                survival_curve["step"],
                                survival_curve["ci_lower"],
                                survival_curve["ci_upper"],
                                alpha=0.2,
                                color=line.get_color(),
                                step="post",
                            )
                ax.set_xlabel("Step")
                ax.set_ylabel("Survival Probability")
                ax.set_title(
                    f'{condition["name"].replace("_", " ").title()} by Tool Policy'
                )
                ax.legend(fontsize=8)
                ax.grid(True, alpha=0.3)
                ax.set_ylim(survival_prob_range[0], 1.05)
                ax.set_yticks(np.arange(survival_prob_range[0], 1.01, 0.05))
            ax = axes1[7]
            palette = plt.get_cmap("tab10")
            colors = {
                condition["name"]: palette(i)
                for i, condition in enumerate(ablation_conditions)
            }
            colors["Full Factors"] = "gray"
            markers = {"must": "o", "may": "s", "strict": "D", "relaxed": "v"}
            for i, condition in enumerate(all_conditions_to_plot):
                for policy_type in policy_types:
                    for comm_type in comm_types:
                        if condition.get("base"):
                            condition_agents = full_condition_base[
                                (full_condition_base["communication_type"] == comm_type)
                                & (
                                    full_condition_base["tool_use_policy"]
                                    == policy_type
                                )
                            ].copy()
                        else:
                            condition_agents = agent_outcomes[
                                (agent_outcomes["is_valid_outcome"] == 1)
                                & (agent_outcomes["communication_type"] == comm_type)
                                & (agent_outcomes["tool_use_policy"] == policy_type)
                            ].copy()
                            for factor, value in condition["filter"].items():
                                condition_agents = condition_agents[
                                    condition_agents[factor] == value
                                ]
                        if len(condition_agents) > 0:
                            survival_curve = _calculate_kaplan_meier_ablation(
                                condition_agents
                            )
                            label = f'{condition["name"].replace("_", " ").title()} ({comm_type}, {policy_type})'
                            linestyle = "-" if comm_type == "broadcast" else "--"
                            ax.step(
                                survival_curve["step"],
                                survival_curve["survival"],
                                where="post",
                                marker=markers.get(policy_type, "x"),
                                linestyle=linestyle,
                                label=label,
                                color=colors[condition["name"]],
                            )
                            ax.fill_between(
                                survival_curve["step"],
                                survival_curve["ci_lower"],
                                survival_curve["ci_upper"],
                                alpha=0.1,
                                color=colors[condition["name"]],
                                step="post",
                            )
                ax.set_xlabel("Step")
                ax.set_ylabel("Survival Probability")
                ax.set_title("Summary by Tool Policy")
                ax.legend(
                    title="Condition (Comm, Policy)",
                    bbox_to_anchor=(1.05, 1),
                    loc="upper left",
                )
                ax.grid(True, alpha=0.3)
                ax.set_ylim(survival_prob_range[0], 1.05)
                ax.set_yticks(np.arange(survival_prob_range[0], 1.01, 0.05))
        fig2, ax_comp = plt.subplots(1, 1, figsize=(12, 6))
        completion_rates = []
        conditions_labels = []
        for condition in ablation_conditions:
            for comm_type in comm_types:
                condition_agents = agent_outcomes[
                    (agent_outcomes["is_valid_outcome"] == 1)
                    & (agent_outcomes["communication_type"] == comm_type)
                ].copy()
                for factor, value in condition["filter"].items():
                    condition_agents = condition_agents[
                        condition_agents[factor] == value
                    ]
                if len(condition_agents) > 0:
                    completion_rate = condition_agents["censored"].mean()
                    completion_rates.append(completion_rate)
                    conditions_labels.append(f"{condition['name']} ({comm_type})")
        x_pos = np.arange(len(conditions_labels))
        bars = ax_comp.bar(x_pos, completion_rates, alpha=0.7)
        for i, label in enumerate(conditions_labels):
            if "broadcast" in label:
                bars[i].set_color("steelblue")
            else:
                bars[i].set_color("lightcoral")
        ax_comp.set_xlabel("Ablation Condition")
        ax_comp.set_ylabel("Completion Rate")
        ax_comp.set_title(
            "Ablation Completion Rates: Strict vs Relaxed Policy Comparison"
        )
        ax_comp.set_xticks(x_pos)
        ax_comp.set_xticklabels(conditions_labels, rotation=45, ha="right")
        ax_comp.grid(True, alpha=0.3)
        ax_comp.set_ylim(survival_prob_range[0], 1.05)
        ax_comp.set_yticks(np.arange(survival_prob_range[0], 1.01, 0.05))
        fig1.tight_layout()
        fig1.savefig(
            analysis_dir / "ablation_survival_curves.png",
            dpi=300,
            bbox_inches="tight",
        )
        plt.close(fig1)
        fig2.tight_layout()
        fig2.savefig(
            analysis_dir / "ablation_completion_rate_comparison.png",
            dpi=300,
            bbox_inches="tight",
        )
        plt.close(fig2)
    else:
        n_rows = 4 if has_policy_split else 2
        n_cols = 2
        fig, axes = plt.subplots(
            n_rows, n_cols, figsize=(16, 6 * n_rows), squeeze=False
        )
        axes = axes.flatten()
        full_condition_agents_base = agent_outcomes[
            (agent_outcomes["is_valid_outcome"] == 1)
            & (agent_outcomes["hedonic"] != "none")
            & (agent_outcomes["persona_age"] != "none")
        ].copy()
        full_survival_curve = None
        if len(full_condition_agents_base) > 0:
            full_survival_curve = _calculate_kaplan_meier_ablation(
                full_condition_agents_base
            )
        for i, condition in enumerate(ablation_conditions):
            ax = axes[i]
            if full_survival_curve is not None:
                line = ax.step(
                    full_survival_curve["step"],
                    full_survival_curve["survival"],
                    where="post",
                    marker="x",
                    linewidth=1.5,
                    linestyle="--",
                    label="Full Factors",
                    color="gray",
                )[0]
                ax.fill_between(
                    full_survival_curve["step"],
                    full_survival_curve["ci_lower"],
                    full_survival_curve["ci_upper"],
                    alpha=0.2,
                    color=line.get_color(),
                    step="post",
                )
            condition_agents = agent_outcomes[
                (agent_outcomes["is_valid_outcome"] == 1)
            ].copy()
            for factor, value in condition["filter"].items():
                condition_agents = condition_agents[condition_agents[factor] == value]
            if len(condition_agents) == 0:
                ax.text(
                    0.5,
                    0.5,
                    f"No data for {condition['name']}",
                    ha="center",
                    va="center",
                    transform=ax.transAxes,
                )
                continue
            survival_curve = _calculate_kaplan_meier_ablation(condition_agents)
            line = ax.step(
                survival_curve["step"],
                survival_curve["survival"],
                where="post",
                marker="o",
                linewidth=2,
                label=condition["name"].replace("_", " ").title(),
                color="darkblue",
            )[0]
            ax.fill_between(
                survival_curve["step"],
                survival_curve["ci_lower"],
                survival_curve["ci_upper"],
                alpha=0.3,
                color=line.get_color(),
                step="post",
            )
            ax.set_xlabel("Step")
            ax.set_ylabel("Survival Probability")
            ax.set_title(
                f'{condition["name"].replace("_", " ").title()} Ablation\n(with 95% Confidence Intervals)'
            )
            ax.legend()
            ax.grid(True, alpha=0.3)
            ax.set_ylim(survival_prob_range[0], 1.05)
            ax.set_yticks(np.arange(survival_prob_range[0], 1.01, 0.05))
        ax = axes[3]
        all_conditions_to_plot = [
            {"name": "Full Factors", "filter": {}, "base": True}
        ] + ablation_conditions
        for condition in all_conditions_to_plot:
            if condition.get("base"):
                condition_agents = full_condition_agents_base.copy()
            else:
                condition_agents = agent_outcomes[
                    (agent_outcomes["is_valid_outcome"] == 1)
                ].copy()
                for factor, value in condition["filter"].items():
                    condition_agents = condition_agents[
                        condition_agents[factor] == value
                    ]
            if len(condition_agents) == 0:
                continue
            survival_curve = _calculate_kaplan_meier_ablation(condition_agents)
            label = condition["name"].replace("_", " ").title()
            marker = "x" if condition.get("base") else "o"
            linewidth = 1.5 if condition.get("base") else 2
            linestyle = "--" if condition.get("base") else "-"
            line = ax.step(
                survival_curve["step"],
                survival_curve["survival"],
                where="post",
                marker=marker,
                linewidth=linewidth,
                linestyle=linestyle,
                label=label,
            )[0]
            ax.fill_between(
                survival_curve["step"],
                survival_curve["ci_lower"],
                survival_curve["ci_upper"],
                alpha=0.15,
                color=line.get_color(),
                step="post",
            )
        ax.set_xlabel("Step")
        ax.set_ylabel("Survival Probability")
        ax.set_title(
            "Summary: All Ablation Conditions\n(with 95% Confidence Intervals)"
        )
        ax.legend()
        ax.grid(True, alpha=0.3)
        ax.set_ylim(survival_prob_range[0], 1.05)
        ax.set_yticks(np.arange(survival_prob_range[0], 1.01, 0.05))
        if has_policy_split:
            for i, condition in enumerate(ablation_conditions):
                ax = axes[i + 4]
                for policy_type in policy_types:
                    full_condition_agents = full_condition_agents_base[
                        full_condition_agents_base["tool_use_policy"] == policy_type
                    ].copy()
                    if len(full_condition_agents) > 0:
                        survival_curve = _calculate_kaplan_meier_ablation(
                            full_condition_agents
                        )
                        ax.step(
                            survival_curve["step"],
                            survival_curve["survival"],
                            where="post",
                            marker="x",
                            linestyle="--",
                            label=f"Full Factors ({policy_type})",
                        )[0]
                        ax.fill_between(
                            survival_curve["step"],
                            survival_curve["ci_lower"],
                            survival_curve["ci_upper"],
                            alpha=0.2,
                            color=line.get_color(),
                            step="post",
                        )
                    condition_agents = agent_outcomes[
                        (agent_outcomes["is_valid_outcome"] == 1)
                        & (agent_outcomes["tool_use_policy"] == policy_type)
                    ].copy()
                    for factor, value in condition["filter"].items():
                        condition_agents = condition_agents[
                            condition_agents[factor] == value
                        ]
                    if len(condition_agents) > 0:
                        survival_curve = _calculate_kaplan_meier_ablation(
                            condition_agents
                        )
                        ax.step(
                            survival_curve["step"],
                            survival_curve["survival"],
                            where="post",
                            marker="o",
                            label=f'{condition["name"].replace("_", " ").title()} ({policy_type})',
                        )[0]
                        ax.fill_between(
                            survival_curve["step"],
                            survival_curve["ci_lower"],
                            survival_curve["ci_upper"],
                            alpha=0.3,
                            color=line.get_color(),
                            step="post",
                        )
                ax.set_xlabel("Step")
                ax.set_ylabel("Survival Probability")
                ax.set_title(
                    f'{condition["name"].replace("_", " ").title()} by Tool Policy'
                )
                ax.legend()
                ax.grid(True, alpha=0.3)
                ax.set_ylim(survival_prob_range[0], 1.05)
                ax.set_yticks(np.arange(survival_prob_range[0], 1.01, 0.05))
            ax = axes[7]
            for condition in all_conditions_to_plot:
                for policy_type in policy_types:
                    if condition.get("base"):
                        condition_agents = full_condition_agents_base[
                            full_condition_agents_base["tool_use_policy"] == policy_type
                        ].copy()
                    else:
                        condition_agents = agent_outcomes[
                            (agent_outcomes["is_valid_outcome"] == 1)
                            & (agent_outcomes["tool_use_policy"] == policy_type)
                        ].copy()
                        for factor, value in condition["filter"].items():
                            condition_agents = condition_agents[
                                condition_agents[factor] == value
                            ]
                    if len(condition_agents) > 0:
                        survival_curve = _calculate_kaplan_meier_ablation(
                            condition_agents
                        )
                        ax.step(
                            survival_curve["step"],
                            survival_curve["survival"],
                            where="post",
                            marker="o",
                            label=f'{condition["name"].replace("_", " ").title()} ({policy_type})',
                        )[0]
                        ax.fill_between(
                            survival_curve["step"],
                            survival_curve["ci_lower"],
                            survival_curve["ci_upper"],
                            alpha=0.2,
                            color=line.get_color(),
                            step="post",
                        )
            ax.set_xlabel("Step")
            ax.set_ylabel("Survival Probability")
            ax.set_title("Summary by Tool Policy")
            ax.legend()
            ax.grid(True, alpha=0.3)
            ax.set_ylim(survival_prob_range[0], 1.05)
            ax.set_yticks(np.arange(survival_prob_range[0], 1.01, 0.05))
        total_plots = n_cols * n_rows
        for i in range(total_plots, len(axes)):
            axes[i].set_visible(False)
        plt.tight_layout()
        plt.savefig(
            analysis_dir / "ablation_survival_curves.png",
            dpi=300,
            bbox_inches="tight",
        )
        plt.close()
    logger.info("✓ Created ablation survival curves")


def _create_ablation_completion_rates(
    agent_outcomes: pd.DataFrame, analysis_dir: Path, social_influence: bool
):
    logger.info("Creating ablation completion rate plots...")
    ablation_conditions = [
        {"name": "hedonic_none", "filter": {"hedonic": "none"}},
        {"name": "policy_role_persona", "filter": {"persona_age": "none"}},
        {
            "name": "hedonic_none_and_policy_role",
            "filter": {"hedonic": "none", "persona_age": "none"},
        },
    ]
    if social_influence:
        if "communication_type" not in agent_outcomes.columns:
            logger.warning(
                "communication_type not found - cannot create social influence ablation"
            )
            return
        comm_types = agent_outcomes["communication_type"].unique()
        comm_types = [ct for ct in comm_types if pd.notna(ct)]
        if len(comm_types) < 2:
            logger.warning(
                "Need both broadcast and isolated for social influence ablation"
            )
            return
        fig, ax = plt.subplots(1, 1, figsize=(12, 6))
        x_labels = []
        broadcast_rates = []
        isolated_rates = []
        full_condition_filter = {
            "hedonic": ["crave", "like", "neutral"],
            "persona_age": ["child", "adult", "senior"],
        }
        for comm_type in comm_types:
            full_agents = agent_outcomes[
                (agent_outcomes["is_valid_outcome"] == 1)
                & (agent_outcomes["communication_type"] == comm_type)
            ].copy()
            for factor, values in full_condition_filter.items():
                full_agents = full_agents[full_agents[factor].isin(values)]
            if len(full_agents) > 0:
                completion_rate = full_agents["censored"].mean()
                if comm_type == "broadcast":
                    broadcast_rates.append(completion_rate)
                else:
                    isolated_rates.append(completion_rate)
        x_labels.append("Full Factors")
        for condition in ablation_conditions:
            for comm_type in comm_types:
                condition_agents = agent_outcomes[
                    (agent_outcomes["is_valid_outcome"] == 1)
                    & (agent_outcomes["communication_type"] == comm_type)
                ].copy()
                for factor, value in condition["filter"].items():
                    condition_agents = condition_agents[
                        condition_agents[factor] == value
                    ]
                if len(condition_agents) > 0:
                    completion_rate = condition_agents["censored"].mean()
                    if comm_type == "broadcast":
                        broadcast_rates.append(completion_rate)
                    else:
                        isolated_rates.append(completion_rate)
            x_labels.append(condition["name"].replace("_", " ").title())
        x_pos = np.arange(len(x_labels))
        width = 0.35
        bars1 = ax.bar(
            x_pos - width / 2,
            broadcast_rates,
            width,
            alpha=0.7,
            color="steelblue",
            label="Broadcast",
        )
        bars2 = ax.bar(
            x_pos + width / 2,
            isolated_rates,
            width,
            alpha=0.7,
            color="lightcoral",
            label="Isolated",
        )
        ax.set_xlabel("Ablation Condition")
        ax.set_ylabel("Completion Rate")
        ax.set_title("Ablation Completion Rates: Broadcast vs Isolated")
        ax.set_xticks(x_pos)
        ax.set_xticklabels(x_labels)
        ax.legend()
        ax.grid(True, alpha=0.3)
        ax.set_ylim(0, 1.05)
        ax.set_yticks(np.arange(0, 1.01, 0.05))
        plt.tight_layout()
        plt.savefig(
            analysis_dir / "ablation_completion_rates.png",
            dpi=300,
            bbox_inches="tight",
        )
        plt.close()
    else:
        fig, ax = plt.subplots(1, 1, figsize=(10, 6))
        completion_rates = []
        conditions_labels = []
        full_condition_agents = agent_outcomes[
            (agent_outcomes["is_valid_outcome"] == 1)
            & (agent_outcomes["hedonic"] != "none")
            & (agent_outcomes["persona_age"] != "none")
        ].copy()
        if len(full_condition_agents) > 0:
            completion_rate = full_condition_agents["censored"].mean()
            completion_rates.append(completion_rate)
            conditions_labels.append("Full Factors")
        for condition in ablation_conditions:
            condition_agents = agent_outcomes[
                (agent_outcomes["is_valid_outcome"] == 1)
            ].copy()
            for factor, value in condition["filter"].items():
                condition_agents = condition_agents[condition_agents[factor] == value]
            if len(condition_agents) > 0:
                completion_rate = condition_agents["censored"].mean()
                completion_rates.append(completion_rate)
                conditions_labels.append(condition["name"].replace("_", " ").title())
        bars = ax.bar(
            range(len(conditions_labels)),
            completion_rates,
            alpha=0.7,
            color="darkgreen",
        )
        ax.set_xlabel("Ablation Condition")
        ax.set_ylabel("Completion Rate")
        ax.set_title("Ablation Completion Rates")
        ax.set_xticks(range(len(conditions_labels)))
        ax.set_xticklabels(conditions_labels, rotation=45, ha="right")
        ax.grid(True, alpha=0.3)
        ax.set_ylim(0.4, 1.05)
        ax.set_yticks(np.arange(0.4, 1.01, 0.05))
        plt.tight_layout()
        plt.savefig(
            analysis_dir / "ablation_completion_rates.png",
            dpi=300,
            bbox_inches="tight",
        )
        plt.close()
    logger.info("✓ Created ablation completion rate plots")


def _create_tool_usage_ablation(
    step_level_data: pd.DataFrame, analysis_dir: Path, social_influence: bool
):
    logger.info("Creating tool usage ablation plots...")
    risk_step_data = step_level_data[step_level_data["is_epilogue_step"] == 0].copy()
    tool_ablation_conditions = [
        {
            "name": "full_factors",
            "filter": {
                "hedonic": ["crave", "like", "neutral"],
                "persona_age": ["child", "adult", "senior"],
            },
        },
        {
            "name": "no_hedonic_with_policy_role",
            "filter": {"hedonic": "none", "persona_age": "none"},
        },
    ]
    if social_influence:
        if "communication_type" not in risk_step_data.columns:
            logger.warning(
                "communication_type not found - cannot create social influence tool ablation"
            )
            return
        comm_types = risk_step_data["communication_type"].unique()
        comm_types = [ct for ct in comm_types if pd.notna(ct)]
        if len(comm_types) < 2:
            logger.warning(
                "Need both broadcast and isolated for social influence tool ablation"
            )
            return
        fig, ax = plt.subplots(1, 1, figsize=(12, 6))
        for condition in tool_ablation_conditions:
            line_color = None
            for comm_type in comm_types:
                condition_data = risk_step_data[
                    risk_step_data["communication_type"] == comm_type
                ].copy()
                for factor, values in condition["filter"].items():
                    if isinstance(values, list):
                        condition_data = condition_data[
                            condition_data[factor].isin(values)
                        ]
                    else:
                        condition_data = condition_data[
                            condition_data[factor] == values
                        ]
                if len(condition_data) == 0:
                    continue
                tool_usage = (
                    condition_data.groupby("step")["used_this_step"]
                    .agg(["mean", "sem", "count"])
                    .reset_index()
                )
                tool_usage["ci_lower"] = (
                    tool_usage["mean"] - 1.96 * tool_usage["sem"]
                ).clip(lower=0)
                tool_usage["ci_upper"] = tool_usage["mean"] + 1.96 * tool_usage["sem"]
                label = f"{condition['name'].replace('_', ' ').title()} ({comm_type})"
                linestyle = "-" if comm_type == "broadcast" else "--"
                plot_args = {
                    "linewidth": 2,
                    "linestyle": linestyle,
                    "label": label,
                }
                if line_color:
                    plot_args["color"] = line_color
                line = ax.plot(
                    tool_usage["step"],
                    tool_usage["mean"],
                    **plot_args,
                )[0]
                if line_color is None:
                    line_color = line.get_color()
                ax.fill_between(
                    tool_usage["step"],
                    tool_usage["ci_lower"],
                    tool_usage["ci_upper"],
                    alpha=0.3,
                    color=line.get_color(),
                )
        ax.set_xlabel("Step")
        ax.set_ylabel("Mean Questions Used")
        ax.set_title(
            "Tool Usage Ablation: Full vs None × Broadcast vs Isolated\n(with 95% Confidence Intervals)"
        )
        ax.legend()
        ax.grid(True, alpha=0.3)
        ax.axvline(x=10, color="r", linestyle="--", linewidth=1, label="Step 10")
        ax.text(
            0.02,
            0.98,
            "Solid = Broadcast\nDashed = Isolated",
            transform=ax.transAxes,
            fontsize=9,
            verticalalignment="top",
            bbox=dict(boxstyle="round", facecolor="wheat", alpha=0.8),
        )
    else:
        fig, ax = plt.subplots(1, 1, figsize=(10, 6))
        for condition in tool_ablation_conditions:
            condition_data = risk_step_data.copy()
            for factor, values in condition["filter"].items():
                if isinstance(values, list):
                    condition_data = condition_data[condition_data[factor].isin(values)]
                else:
                    condition_data = condition_data[condition_data[factor] == values]
            if len(condition_data) == 0:
                continue
            tool_usage = (
                condition_data.groupby("step")["used_this_step"]
                .agg(["mean", "sem", "count"])
                .reset_index()
            )
            tool_usage["ci_lower"] = (
                tool_usage["mean"] - 1.96 * tool_usage["sem"]
            ).clip(lower=0)
            tool_usage["ci_upper"] = tool_usage["mean"] + 1.96 * tool_usage["sem"]
            line = ax.plot(
                tool_usage["step"],
                tool_usage["mean"],
                marker="o",
                linewidth=2,
                label=condition["name"].replace("_", " ").title(),
            )[0]
            ax.fill_between(
                tool_usage["step"],
                tool_usage["ci_lower"],
                tool_usage["ci_upper"],
                alpha=0.3,
                color=line.get_color(),
            )
        ax.set_xlabel("Step")
        ax.set_ylabel("Mean Questions Used")
        ax.set_title(
            "Tool Usage Ablation: Full vs None\n(with 95% Confidence Intervals)"
        )
        ax.legend()
        ax.grid(True, alpha=0.3)
        ax.axvline(x=10, color="r", linestyle="--", linewidth=1, label="Step 10")
    plt.tight_layout()
    plt.savefig(analysis_dir / "tool_usage_ablation.png", dpi=300, bbox_inches="tight")
    plt.close()
    logger.info("✓ Created tool usage ablation plots")


def run_ablation_analysis(
    agent_outcomes: pd.DataFrame,
    step_level_data: pd.DataFrame,
    analysis_dir: Path,
    social_influence: bool,
    survival_prob_range: List[float],
):
    logger.info("Starting ablation analysis...")
    try:
        _create_ablation_survival_curves(
            agent_outcomes, analysis_dir, social_influence, survival_prob_range
        )
        _create_ablation_completion_rates(
            agent_outcomes, analysis_dir, social_influence
        )
        _create_tool_usage_ablation(step_level_data, analysis_dir, social_influence)
        logger.info("Ablation analysis completed successfully!")
        return True
    except Exception as e:
        logger.error(f"Ablation analysis failed: {e}")
        return False
