import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import logging
from pathlib import Path

logger = logging.getLogger(__name__)


def _create_overall_event_distribution(plot_data, analysis_dir):
    logger.info("Creating overall event distribution histogram...")
    event_dist = (
        plot_data.groupby("step")
        .agg(d=("d_events", "sum"), n=("n_at_risk", "sum"))
        .reset_index()
    )
    event_dist["hazard"] = event_dist.apply(
        lambda r: (r.d / r.n) if r.n > 0 else 0.0, axis=1
    )
    S_prev = 1.0
    surv_start = []
    for h in event_dist["hazard"]:
        surv_start.append(S_prev)
        S_prev *= 1.0 - h
    event_dist["survival_prev"] = surv_start
    event_dist["pmf"] = event_dist["hazard"] * event_dist["survival_prev"]

    def wilson_ci_pmf(d, n, s_prev, alpha=0.05):
        if n == 0:
            return 0, 0
        z = 1.96
        p = d / n
        denominator = 1 + z**2 / n
        centre = (p + z**2 / (2 * n)) / denominator
        half_width = (z * np.sqrt((p * (1 - p) + z**2 / (4 * n)) / n)) / denominator
        hazard_lower = max(0, centre - half_width)
        hazard_upper = min(1, centre + half_width)
        return hazard_lower * s_prev, hazard_upper * s_prev

    pmf_ci_lower = []
    pmf_ci_upper = []
    for _, row in event_dist.iterrows():
        lower, upper = wilson_ci_pmf(row["d"], row["n"], row["survival_prev"])
        pmf_ci_lower.append(lower)
        pmf_ci_upper.append(upper)
    event_dist["pmf_ci_lower"] = pmf_ci_lower
    event_dist["pmf_ci_upper"] = pmf_ci_upper
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 5))
    ax1.bar(
        event_dist["step"],
        event_dist["pmf"],
        alpha=0.7,
        color="purple",
        edgecolor="black",
        yerr=[
            (event_dist["pmf"] - event_dist["pmf_ci_lower"]).clip(lower=0),
            (event_dist["pmf_ci_upper"] - event_dist["pmf"]).clip(lower=0),
        ],
        capsize=3,
    )
    ax1.set_xlabel("Step")
    ax1.set_ylabel("Probability of First Eat")
    ax1.set_title("Overall Event Distribution (PMF)\n(with 95% Confidence Intervals)")
    ax1.grid(True, alpha=0.3)
    ax1.set_xticks(
        range(int(plot_data["step"].min()), int(plot_data["step"].max()) + 1)
    )
    event_dist["cdf"] = event_dist["pmf"].cumsum()
    ax2.plot(
        event_dist["step"],
        event_dist["cdf"],
        marker="o",
        linewidth=2,
        color="darkred",
    )
    ax2.set_xlabel("Step")
    ax2.set_ylabel("Cumulative Probability")
    ax2.set_title("Overall Cumulative Event Distribution")
    ax2.grid(True, alpha=0.3)
    ax2.set_ylim(0, 1.05)
    ax2.set_xticks(
        range(int(plot_data["step"].min()), int(plot_data["step"].max()) + 1)
    )
    plt.tight_layout()
    plt.savefig(
        analysis_dir / "overall_event_distribution.png",
        dpi=300,
        bbox_inches="tight",
    )
    plt.close()
    logger.info("✓ Created overall event distribution histogram")


def create_event_distribution_plot(
    analyzer, plot_data: pd.DataFrame, analysis_dir: Path, social_influence: bool
):
    logger.info("Creating event distribution plots...")
    all_factors = [
        "communication_type",
        "tool_use_policy",
        "visible_question_budget",
        "hedonic",
        "persona_age",
        "distraction",
        "reward_visibility",
        "agent_architecture",
        "model_id",
    ]
    if social_influence:
        factors = [f for f in all_factors if f != "communication_type"]
        available_factors = [f for f in factors if f in plot_data.columns]
    else:
        available_factors = [f for f in all_factors if f in plot_data.columns]
    factors_with_variation = []
    for factor in available_factors:
        if plot_data[factor].nunique() > 1:
            factors_with_variation.append(factor)
    if not factors_with_variation:
        logger.warning("No factors with variation found for event distribution plots")
        return
    _create_overall_event_distribution(plot_data, analysis_dir)
    n_factors = len(factors_with_variation)
    n_cols = min(2, n_factors)
    n_rows = (n_factors + n_cols - 1) // n_cols
    fig, axes = plt.subplots(n_rows, n_cols, figsize=(8 * n_cols, 6 * n_rows))
    if n_factors == 1:
        axes = [axes]
    elif n_rows == 1:
        axes = axes.flatten() if n_factors > 1 else [axes]
    else:
        axes = axes.flatten()
    for i in range(n_factors, len(axes)):
        axes[i].set_visible(False)

    def wilson_ci_pmf(d, n, s_prev, alpha=0.05):
        if n == 0:
            return 0, 0
        z = 1.96
        p = d / n
        denominator = 1 + z**2 / n
        centre = (p + z**2 / (2 * n)) / denominator
        half_width = (z * np.sqrt((p * (1 - p) + z**2 / (4 * n)) / n)) / denominator
        hazard_lower = max(0, centre - half_width)
        hazard_upper = min(1, centre + half_width)
        return hazard_lower * s_prev, hazard_upper * s_prev

    for i, factor in enumerate(factors_with_variation):
        ax = axes[i]
        factor_levels = plot_data[factor].unique()
        factor_levels = [level for level in factor_levels if pd.notna(level)]
        if social_influence and "communication_type" in plot_data.columns:
            comm_types = plot_data["communication_type"].unique()
            comm_types = [ct for ct in comm_types if pd.notna(ct)]
            for factor_level in factor_levels:
                line_color = None
                for comm_type in comm_types:
                    condition_data = plot_data[
                        (plot_data[factor] == factor_level)
                        & (plot_data["communication_type"] == comm_type)
                    ]
                    if len(condition_data) == 0:
                        continue
                    event_dist = (
                        condition_data.groupby("step")
                        .agg(d=("d_events", "sum"), n=("n_at_risk", "sum"))
                        .reset_index()
                    )
                    event_dist["hazard"] = event_dist.apply(
                        lambda r: (r.d / r.n) if r.n > 0 else 0.0, axis=1
                    )
                    S_prev = 1.0
                    surv_start = []
                    for h in event_dist["hazard"]:
                        surv_start.append(S_prev)
                        S_prev *= 1.0 - h
                    event_dist["survival_prev"] = surv_start
                    event_dist["pmf"] = (
                        event_dist["hazard"] * event_dist["survival_prev"]
                    )
                    pmf_ci_lower = []
                    pmf_ci_upper = []
                    for _, row in event_dist.iterrows():
                        lower, upper = wilson_ci_pmf(
                            row["d"], row["n"], row["survival_prev"]
                        )
                        pmf_ci_lower.append(lower)
                        pmf_ci_upper.append(upper)
                    event_dist["pmf_ci_lower"] = pmf_ci_lower
                    event_dist["pmf_ci_upper"] = pmf_ci_upper
                    label = f"{factor_level} ({comm_type})"
                    linestyle = "-" if comm_type == "broadcast" else "--"
                    plot_args = {
                        "marker": "o",
                        "linewidth": 2,
                        "linestyle": linestyle,
                        "label": label,
                    }
                    if line_color:
                        plot_args["color"] = line_color
                    line = ax.plot(
                        event_dist["step"],
                        event_dist["pmf"],
                        **plot_args,
                    )[0]
                    if line_color is None:
                        line_color = line.get_color()
                    ax.fill_between(
                        event_dist["step"],
                        event_dist["pmf_ci_lower"],
                        event_dist["pmf_ci_upper"],
                        alpha=0.3,
                        color=line.get_color(),
                    )
        else:
            for level in factor_levels:
                condition_data = plot_data[plot_data[factor] == level]
                if len(condition_data) == 0:
                    continue
                event_dist = (
                    condition_data.groupby("step")
                    .agg(d=("d_events", "sum"), n=("n_at_risk", "sum"))
                    .reset_index()
                )
                event_dist["hazard"] = event_dist.apply(
                    lambda r: (r.d / r.n) if r.n > 0 else 0.0, axis=1
                )
                S_prev = 1.0
                surv_start = []
                for h in event_dist["hazard"]:
                    surv_start.append(S_prev)
                    S_prev *= 1.0 - h
                event_dist["survival_prev"] = surv_start
                event_dist["pmf"] = event_dist["hazard"] * event_dist["survival_prev"]
                pmf_ci_lower = []
                pmf_ci_upper = []
                for _, row in event_dist.iterrows():
                    lower, upper = wilson_ci_pmf(
                        row["d"], row["n"], row["survival_prev"]
                    )
                    pmf_ci_lower.append(lower)
                    pmf_ci_upper.append(upper)
                event_dist["pmf_ci_lower"] = pmf_ci_lower
                event_dist["pmf_ci_upper"] = pmf_ci_upper
                line = ax.plot(
                    event_dist["step"],
                    event_dist["pmf"],
                    marker="o",
                    linewidth=2,
                    label=str(level),
                )[0]
                ax.fill_between(
                    event_dist["step"],
                    event_dist["pmf_ci_lower"],
                    event_dist["pmf_ci_upper"],
                    alpha=0.3,
                    color=line.get_color(),
                )
        ax.set_xlabel("Step")
        ax.set_ylabel("Probability of First Eat")
        title_suffix = " (with Social Context)" if social_influence else ""
        ax.set_title(
            f'Event Distribution by {factor.replace("_", " ").title()}{title_suffix}\n(with 95% Confidence Intervals)'
        )
        if social_influence:
            legend = ax.legend()
            ax.text(
                0.02,
                0.98,
                "Solid = Broadcast\nDashed = Isolated",
                transform=ax.transAxes,
                fontsize=9,
                verticalalignment="top",
                bbox=dict(boxstyle="round", facecolor="wheat", alpha=0.8),
            )
        else:
            ax.legend()
        ax.grid(True, alpha=0.3)
        ax.set_ylim(0, None)
        ax.set_xticks(
            range(int(plot_data["step"].min()), int(plot_data["step"].max()) + 1)
        )
    plt.tight_layout()
    plt.savefig(analysis_dir / "event_distribution.png", dpi=300, bbox_inches="tight")
    plt.close()
    logger.info("✓ Created event distribution plots")
