import pandas as pd
import matplotlib.pyplot as plt
import logging
from pathlib import Path
import numpy as np
from typing import List

logger = logging.getLogger(__name__)


def _calculate_kaplan_meier(agent_outcomes: pd.DataFrame) -> pd.DataFrame:
    if len(agent_outcomes) == 0:
        return pd.DataFrame(
            {"step": [], "survival": [], "ci_lower": [], "ci_upper": []}
        )
    risk_horizon = (
        agent_outcomes["risk_horizon"].iloc[0] if len(agent_outcomes) > 0 else 4
    )
    survival_data = []
    current_survival = 1.0
    cumulative_variance = 0.0
    for step in range(0, risk_horizon + 2):
        events_at_step = len(
            agent_outcomes[
                (agent_outcomes["event_ate_early"] == 1)
                & (agent_outcomes["tte_step"] == step)
            ]
        )
        if step <= risk_horizon:
            at_risk = len(agent_outcomes[(agent_outcomes["tte_step"] >= step)])
            hazard = events_at_step / at_risk if at_risk > 0 else 0
            current_survival *= 1 - hazard
            if at_risk > 0 and at_risk > events_at_step:
                cumulative_variance += events_at_step / (
                    at_risk * (at_risk - events_at_step)
                )
        else:
            at_risk = 0
        if cumulative_variance > 0 and current_survival > 0:
            se_survival = current_survival * np.sqrt(cumulative_variance)
            if current_survival < 1.0:
                log_log_se = se_survival / (current_survival * np.log(current_survival))
                log_log_survival = np.log(-np.log(current_survival))
                ci_lower_log_log = log_log_survival - 1.96 * log_log_se
                ci_upper_log_log = log_log_survival + 1.96 * log_log_se
                ci_lower = np.exp(-np.exp(ci_upper_log_log))
                ci_upper = np.exp(-np.exp(ci_lower_log_log))
                ci_lower = max(0, min(1, ci_lower))
                ci_upper = max(0, min(1, ci_upper))
            else:
                ci_lower = ci_upper = 1.0
        else:
            ci_lower = ci_upper = current_survival
        survival_data.append(
            {
                "step": step,
                "survival": current_survival,
                "events": events_at_step,
                "at_risk": at_risk,
                "ci_lower": ci_lower,
                "ci_upper": ci_upper,
            }
        )
    result_df = pd.DataFrame(survival_data)
    if len(result_df) > 1:
        s = result_df["survival"].to_numpy()
        if not np.all(np.diff(s) <= 1e-12):
            logger.warning("KM survival curve is not monotonic - potential data issue")
    return result_df


def create_kaplan_meier_plots(
    agent_outcomes: pd.DataFrame,
    analysis_dir: Path,
    social_influence: bool,
    survival_prob_range: List[float],
):
    logger.info("Creating Kaplan-Meier survival plots...")
    all_factors = [
        "communication_type",
        "tool_use_policy",
        "visible_question_budget",
        "hedonic",
        "persona_age",
        "distraction",
        "reward_visibility",
        "agent_architecture",
        "model_id",
    ]
    if social_influence:
        factors = [f for f in all_factors if f != "communication_type"]
        available_factors = [f for f in factors if f in agent_outcomes.columns]
    else:
        factors = all_factors
        available_factors = [f for f in factors if f in agent_outcomes.columns]
    if not available_factors:
        logger.warning("No factors available for Kaplan-Meier plots")
        return
    n_factors = len(available_factors)
    n_cols = min(2, n_factors)
    n_rows = (n_factors + n_cols - 1) // n_cols
    fig, axes = plt.subplots(n_rows, n_cols, figsize=(8 * n_cols, 6 * n_rows))
    if n_factors == 1:
        axes = [axes]
    elif n_rows == 1:
        axes = axes.flatten() if n_factors > 1 else [axes]
    else:
        axes = axes.flatten()
    for i in range(n_factors, len(axes)):
        axes[i].set_visible(False)
    for i, factor in enumerate(available_factors):
        ax = axes[i]
        factor_levels = agent_outcomes[factor].unique()
        factor_levels = [level for level in factor_levels if pd.notna(level)]
        if social_influence and "communication_type" in agent_outcomes.columns:
            comm_types = agent_outcomes["communication_type"].unique()
            comm_types = [ct for ct in comm_types if pd.notna(ct)]
            for factor_level in factor_levels:
                line_color = None
                for comm_type in comm_types:
                    condition_agents = agent_outcomes[
                        (agent_outcomes[factor] == factor_level)
                        & (agent_outcomes["communication_type"] == comm_type)
                        & (agent_outcomes["is_valid_outcome"] == 1)
                    ]
                    if len(condition_agents) == 0:
                        continue
                    survival_curve = _calculate_kaplan_meier(condition_agents)
                    label = f"{factor_level} ({comm_type})"
                    linestyle = "-" if comm_type == "broadcast" else "--"
                    plot_args = {
                        "marker": "o",
                        "linewidth": 2,
                        "linestyle": linestyle,
                        "label": label,
                    }
                    if line_color:
                        plot_args["color"] = line_color
                    line = ax.step(
                        survival_curve["step"],
                        survival_curve["survival"],
                        where="post",
                        **plot_args,
                    )[0]
                    if line_color is None:
                        line_color = line.get_color()
                    ax.fill_between(
                        survival_curve["step"],
                        survival_curve["ci_lower"],
                        survival_curve["ci_upper"],
                        alpha=0.2,
                        color=line.get_color(),
                        step="post",
                    )
        else:
            for level in factor_levels:
                condition_agents = agent_outcomes[
                    (agent_outcomes[factor] == level)
                    & (agent_outcomes["is_valid_outcome"] == 1)
                ]
                if len(condition_agents) == 0:
                    continue
                survival_curve = _calculate_kaplan_meier(condition_agents)
                line = ax.step(
                    survival_curve["step"],
                    survival_curve["survival"],
                    where="post",
                    linewidth=2,
                    label=str(level),
                )[0]
                ax.fill_between(
                    survival_curve["step"],
                    survival_curve["ci_lower"],
                    survival_curve["ci_upper"],
                    alpha=0.2,
                    color=line.get_color(),
                    step="post",
                )
        ax.set_xlabel("Step")
        ax.set_ylabel("Survival Probability")
        title_suffix = " (with Social Context)" if social_influence else ""
        ax.set_title(
            f'Survival by {factor.replace("_", " ").title()}{title_suffix}\n(with 95% Confidence Intervals)'
        )
        if social_influence:
            legend = ax.legend()
            ax.text(
                0.02,
                0.98,
                "Solid = Broadcast\nDashed = Isolated",
                transform=ax.transAxes,
                fontsize=9,
                verticalalignment="top",
                bbox=dict(boxstyle="round", facecolor="wheat", alpha=0.8),
            )
        else:
            ax.legend()
        ax.grid(True, alpha=0.3)
        ax.set_ylim(survival_prob_range[0], 1.05)
        ax.set_yticks(np.arange(survival_prob_range[0], 1.01, 0.05))
        max_step = 0
        for line in ax.get_lines():
            max_step = max(max_step, max(line.get_xdata(orig=False)))
        ax.set_xticks(np.arange(0, max_step + 1, 1))
    plt.tight_layout()
    plt.savefig(analysis_dir / "kaplan_meier_curves.png", dpi=300, bbox_inches="tight")
    plt.close()
    logger.info("✓ Created Kaplan-Meier plots")
