import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import logging
from pathlib import Path

logger = logging.getLogger(__name__)


def _create_overall_hazard_histogram(plot_data, analysis_dir):
    logger.info("Creating overall hazard rate histogram...")
    hazard_by_step = (
        plot_data.groupby("step")
        .agg(d=("d_events", "sum"), n=("n_at_risk", "sum"))
        .reset_index()
    )
    hazard_by_step["hazard"] = hazard_by_step.apply(
        lambda r: (r.d / r.n) if r.n > 0 else 0.0, axis=1
    )

    def wilson_ci_hazard(d, n, alpha=0.05):
        if n == 0:
            return 0, 0
        z = 1.96
        p = d / n
        denominator = 1 + z**2 / n
        centre = (p + z**2 / (2 * n)) / denominator
        half_width = (z * np.sqrt((p * (1 - p) + z**2 / (4 * n)) / n)) / denominator
        return max(0, centre - half_width), min(1, centre + half_width)

    hazard_ci_lower = []
    hazard_ci_upper = []
    for _, row in hazard_by_step.iterrows():
        lower, upper = wilson_ci_hazard(row["d"], row["n"])
        hazard_ci_lower.append(lower)
        hazard_ci_upper.append(upper)
    hazard_by_step["ci_lower"] = hazard_ci_lower
    hazard_by_step["ci_upper"] = hazard_ci_upper
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 5))
    ax1.bar(
        hazard_by_step["step"],
        hazard_by_step["hazard"],
        alpha=0.7,
        color="steelblue",
        edgecolor="black",
        yerr=[
            (hazard_by_step["hazard"] - hazard_by_step["ci_lower"]).clip(lower=0),
            (hazard_by_step["ci_upper"] - hazard_by_step["hazard"]).clip(lower=0),
        ],
        capsize=3,
    )
    ax1.set_xlabel("Step")
    ax1.set_ylabel("Hazard Rate")
    ax1.set_title("Overall Hazard Rate by Step\n(with 95% Confidence Intervals)")
    ax1.grid(True, alpha=0.3)
    ax1.set_xticks(
        range(int(plot_data["step"].min()), int(plot_data["step"].max()) + 1)
    )
    if 0 in hazard_by_step["step"].values:
        t0_hazard = hazard_by_step[hazard_by_step["step"] == 0]["hazard"].iloc[0]
        ax1.bar(
            0,
            t0_hazard,
            color="red",
            alpha=0.8,
            edgecolor="black",
            label=f"t=0 impulse: {t0_hazard:.3f}",
        )
        ax1.legend()
    if "communication_type" in plot_data.columns:
        comm_types = plot_data["communication_type"].unique()
        comm_types = [ct for ct in comm_types if pd.notna(ct)]
        for comm_type in comm_types:
            subset = plot_data[plot_data["communication_type"] == comm_type]
            hazard_curve = (
                subset.groupby("step")
                .agg(d=("d_events", "sum"), n=("n_at_risk", "sum"))
                .reset_index()
            )
            hazard_curve["hazard"] = hazard_curve.apply(
                lambda r: (r.d / r.n) if r.n > 0 else 0.0, axis=1
            )
            ax2.plot(
                hazard_curve["step"],
                hazard_curve["hazard"],
                marker="o",
                linewidth=2,
                label=str(comm_type),
            )
        ax2.set_xlabel("Step")
        ax2.set_ylabel("Hazard Rate")
        ax2.set_title("Hazard Rate by Communication Type")
        ax2.legend()
        ax2.grid(True, alpha=0.3)
        ax2.set_xticks(
            range(int(plot_data["step"].min()), int(plot_data["step"].max()) + 1)
        )
    else:
        ax2.text(
            0.5,
            0.5,
            "Communication type\nnot available",
            ha="center",
            va="center",
            transform=ax2.transAxes,
        )
    plt.tight_layout()
    plt.savefig(analysis_dir / "overall_hazard_rates.png", dpi=300, bbox_inches="tight")
    plt.close()
    logger.info("✓ Created overall hazard rate histogram")


def create_hazard_plots(
    analyzer, plot_data: pd.DataFrame, analysis_dir: Path, social_influence: bool
):
    logger.info("Creating hazard rate plots...")
    _create_overall_hazard_histogram(plot_data, analysis_dir)
    if social_influence:
        if "communication_type" not in plot_data.columns:
            logger.warning(
                "communication_type not found - cannot create social influence plots"
            )
            return
        comm_types = plot_data["communication_type"].unique()
        comm_types = [ct for ct in comm_types if pd.notna(ct)]
        if len(comm_types) < 2:
            logger.warning(
                "Need both broadcast and isolated for social influence analysis"
            )
            return
        all_factors = [
            "communication_type",
            "tool_use_policy",
            "visible_question_budget",
            "hedonic",
            "persona_age",
            "distraction",
            "reward_visibility",
            "agent_architecture",
            "model_id",
        ]
        factors = [f for f in all_factors if f != "communication_type"]
        available_factors = [f for f in factors if f in plot_data.columns]
        factors_with_variation = []
        for factor in available_factors:
            if plot_data[factor].nunique() > 1:
                factors_with_variation.append(factor)
        if not factors_with_variation:
            logger.warning(
                "No factors with variation found for social influence hazard plots"
            )
            return
        n_factors = len(factors_with_variation)
        n_cols = min(2, n_factors)
        n_rows = (n_factors + n_cols - 1) // n_cols
        fig, axes = plt.subplots(n_rows, n_cols, figsize=(8 * n_cols, 6 * n_rows))
        if n_factors == 1:
            axes = [axes]
        elif n_rows == 1:
            axes = axes.flatten() if n_factors > 1 else [axes]
        else:
            axes = axes.flatten()
        for i in range(n_factors, len(axes)):
            axes[i].set_visible(False)
        for i, factor in enumerate(factors_with_variation):
            ax = axes[i]
            factor_levels = plot_data[factor].unique()
            factor_levels = [level for level in factor_levels if pd.notna(level)]
            for factor_level in factor_levels:
                line_color = None
                for comm_type in comm_types:
                    subset = plot_data[
                        (plot_data[factor] == factor_level)
                        & (plot_data["communication_type"] == comm_type)
                    ]
                    if len(subset) == 0:
                        continue
                    hazard_curve = (
                        subset.groupby("step")
                        .agg(d=("d_events", "sum"), n=("n_at_risk", "sum"))
                        .reset_index()
                    )
                    hazard_curve["hazard"] = hazard_curve.apply(
                        lambda r: (r.d / r.n) if r.n > 0 else 0.0, axis=1
                    )
                    label = f"{factor_level} ({comm_type})"
                    linestyle = "-" if comm_type == "broadcast" else "--"
                    plot_args = {
                        "marker": "o",
                        "linewidth": 2,
                        "linestyle": linestyle,
                        "label": label,
                    }
                    if line_color:
                        plot_args["color"] = line_color
                    line = ax.plot(
                        hazard_curve["step"],
                        hazard_curve["hazard"],
                        **plot_args,
                    )[0]
                    if line_color is None:
                        line_color = line.get_color()
            ax.set_xlabel("Step")
            ax.set_ylabel("Hazard Rate")
            title_suffix = " (with Social Context)" if social_influence else ""
            ax.set_title(
                f'Hazard Rate by {factor.replace("_", " ").title()}{title_suffix}'
            )
            legend = ax.legend()
            ax.text(
                0.02,
                0.98,
                "Solid = Broadcast\nDashed = Isolated",
                transform=ax.transAxes,
                fontsize=9,
                verticalalignment="top",
                bbox=dict(boxstyle="round", facecolor="wheat", alpha=0.8),
            )
            ax.grid(True, alpha=0.3)
            ax.set_xticks(
                range(int(plot_data["step"].min()), int(plot_data["step"].max()) + 1)
            )
    else:
        all_factors = [
            "communication_type",
            "tool_use_policy",
            "visible_question_budget",
            "hedonic",
            "persona_age",
            "distraction",
            "reward_visibility",
            "agent_architecture",
            "model_id",
        ]
        available_factors = [f for f in all_factors if f in plot_data.columns]
        factors_with_variation = []
        for factor in available_factors:
            if plot_data[factor].nunique() > 1:
                factors_with_variation.append(factor)
        if not factors_with_variation:
            logger.warning("No factors with variation found for hazard plots")
            return
        n_factors = len(factors_with_variation)
        n_cols = min(2, n_factors)
        n_rows = (n_factors + n_cols - 1) // n_cols
        fig, axes = plt.subplots(n_rows, n_cols, figsize=(8 * n_cols, 6 * n_rows))
        if n_factors == 1:
            axes = [axes]
        elif n_rows == 1:
            axes = axes.flatten() if n_factors > 1 else [axes]
        else:
            axes = axes.flatten()
        for i in range(n_factors, len(axes)):
            axes[i].set_visible(False)
        for i, factor in enumerate(factors_with_variation):
            ax = axes[i]
            factor_levels = plot_data[factor].unique()
            factor_levels = [level for level in factor_levels if pd.notna(level)]
            for level in factor_levels:
                subset = plot_data[plot_data[factor] == level]
                hazard_curve = (
                    subset.groupby("step")
                    .agg(d=("d_events", "sum"), n=("n_at_risk", "sum"))
                    .reset_index()
                )
                hazard_curve["hazard"] = hazard_curve.apply(
                    lambda r: (r.d / r.n) if r.n > 0 else 0.0, axis=1
                )

                def wilson_ci(d, n, alpha=0.05):
                    if n == 0:
                        return 0, 0
                    z = 1.96
                    p = d / n
                    denominator = 1 + z**2 / n
                    centre = (p + z**2 / (2 * n)) / denominator
                    half_width = (
                        z * np.sqrt((p * (1 - p) + z**2 / (4 * n)) / n) / denominator
                    )
                    return max(0, centre - half_width), min(1, centre + half_width)

                ci_lower = []
                ci_upper = []
                for _, row in hazard_curve.iterrows():
                    lower, upper = wilson_ci(row["d"], row["n"])
                    ci_lower.append(lower)
                    ci_upper.append(upper)
                hazard_curve["ci_lower"] = ci_lower
                hazard_curve["ci_upper"] = ci_upper
                line = ax.plot(
                    hazard_curve["step"],
                    hazard_curve["hazard"],
                    marker="o",
                    linewidth=2,
                    label=str(level),
                )[0]
                ax.fill_between(
                    hazard_curve["step"],
                    hazard_curve["ci_lower"],
                    hazard_curve["ci_upper"],
                    alpha=0.3,
                    color=line.get_color(),
                )
            ax.set_xlabel("Step")
            ax.set_ylabel("Hazard Rate")
            ax.set_title(
                f'Hazard Rate by {factor.replace("_", " ").title()}\n(with 95% Confidence Intervals)'
            )
            ax.legend()
            ax.grid(True, alpha=0.3)
            ax.set_xticks(
                range(int(plot_data["step"].min()), int(plot_data["step"].max()) + 1)
            )
    plt.tight_layout()
    plt.savefig(analysis_dir / "hazard_rates.png", dpi=300, bbox_inches="tight")
    plt.close()
    logger.info("✓ Created hazard rate plots")
