import pandas as pd
import matplotlib.pyplot as plt
import logging
from pathlib import Path

logger = logging.getLogger(__name__)


def create_peer_exposure_plots(step_level_data: pd.DataFrame, analysis_dir: Path):
    logger.info("Creating peer exposure plots...")
    risk_step_data = step_level_data[step_level_data["is_epilogue_step"] == 0].copy()
    if "communication_type" not in risk_step_data.columns:
        logger.warning("No communication_type data found for peer exposure plots")
        return
    comm_types = risk_step_data["communication_type"].unique()
    has_broadcast = "broadcast" in comm_types
    if not has_broadcast:
        logger.warning("No broadcast mode data found for peer exposure plots")
        return
    broadcast_data = risk_step_data[
        risk_step_data["communication_type"] == "broadcast"
    ].copy()
    if len(broadcast_data) == 0:
        logger.warning("No broadcast mode data found for peer exposure plots")
        return
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 5))
    peer_exposure = (
        broadcast_data.groupby("step")
        .agg(
            {
                "peers_eaten_so_far": ["mean", "sem", "count"],
                "peers_waiting_so_far": ["mean", "sem", "count"],
            }
        )
        .reset_index()
    )
    peer_exposure.columns = [
        "_".join(col).strip() if col[1] else col[0]
        for col in peer_exposure.columns.values
    ]
    eaten_ci_lower = (
        peer_exposure["peers_eaten_so_far_mean"]
        - 1.96 * peer_exposure["peers_eaten_so_far_sem"]
    ).clip(lower=0)
    eaten_ci_upper = (
        peer_exposure["peers_eaten_so_far_mean"]
        + 1.96 * peer_exposure["peers_eaten_so_far_sem"]
    )
    waiting_ci_lower = (
        peer_exposure["peers_waiting_so_far_mean"]
        - 1.96 * peer_exposure["peers_waiting_so_far_sem"]
    ).clip(lower=0)
    waiting_ci_upper = (
        peer_exposure["peers_waiting_so_far_mean"]
        + 1.96 * peer_exposure["peers_waiting_so_far_sem"]
    )
    line1 = ax1.plot(
        peer_exposure["step"],
        peer_exposure["peers_eaten_so_far_mean"],
        marker="o",
        linewidth=2,
        label="Peers Eaten So Far",
        color="red",
    )[0]
    ax1.fill_between(
        peer_exposure["step"],
        eaten_ci_lower,
        eaten_ci_upper,
        alpha=0.3,
        color=line1.get_color(),
    )
    line2 = ax1.plot(
        peer_exposure["step"],
        peer_exposure["peers_waiting_so_far_mean"],
        marker="s",
        linewidth=2,
        label="Peers Waiting So Far",
        color="blue",
    )[0]
    ax1.fill_between(
        peer_exposure["step"],
        waiting_ci_lower,
        waiting_ci_upper,
        alpha=0.3,
        color=line2.get_color(),
    )
    ax1.set_xlabel("Step")
    ax1.set_ylabel("Mean Number of Peers")
    ax1.set_title("Peer Status Over Time\n(with 95% Confidence Intervals)")
    ax1.legend()
    ax1.grid(True, alpha=0.3)
    ax1.set_xticks(
        range(int(broadcast_data["step"].min()), int(broadcast_data["step"].max()) + 1)
    )
    peer_events = (
        broadcast_data.groupby("step")
        .agg(
            {
                "eats_this_step": ["mean", "sem", "count"],
                "waits_this_step": ["mean", "sem", "count"],
            }
        )
        .reset_index()
    )
    peer_events.columns = [
        "_".join(col).strip() if col[1] else col[0]
        for col in peer_events.columns.values
    ]
    eats_ci_lower = (
        peer_events["eats_this_step_mean"] - 1.96 * peer_events["eats_this_step_sem"]
    ).clip(lower=0)
    eats_ci_upper = (
        peer_events["eats_this_step_mean"] + 1.96 * peer_events["eats_this_step_sem"]
    )
    waits_ci_lower = (
        peer_events["waits_this_step_mean"] - 1.96 * peer_events["waits_this_step_sem"]
    ).clip(lower=0)
    waits_ci_upper = (
        peer_events["waits_this_step_mean"] + 1.96 * peer_events["waits_this_step_sem"]
    )
    line3 = ax2.plot(
        peer_events["step"],
        peer_events["eats_this_step_mean"],
        marker="o",
        linewidth=2,
        label="Peers Eating This Step",
        color="orange",
    )[0]
    ax2.fill_between(
        peer_events["step"],
        eats_ci_lower,
        eats_ci_upper,
        alpha=0.3,
        color=line3.get_color(),
    )
    line4 = ax2.plot(
        peer_events["step"],
        peer_events["waits_this_step_mean"],
        marker="s",
        linewidth=2,
        label="Peers Waiting This Step",
        color="green",
    )[0]
    ax2.fill_between(
        peer_events["step"],
        waits_ci_lower,
        waits_ci_upper,
        alpha=0.3,
        color=line4.get_color(),
    )
    ax2.set_xlabel("Step")
    ax2.set_ylabel("Mean Number of Peers")
    ax2.set_title("Peer Actions This Step\n(with 95% Confidence Intervals)")
    ax2.legend()
    ax2.grid(True, alpha=0.3)
    ax2.set_xticks(
        range(int(broadcast_data["step"].min()), int(broadcast_data["step"].max()) + 1)
    )
    plt.tight_layout()
    plt.savefig(
        analysis_dir / "peer_exposure_patterns.png",
        dpi=300,
        bbox_inches="tight",
    )
    plt.close()
    logger.info("✓ Created original peer exposure plots")
    _create_peer_exposure_by_factors(broadcast_data, analysis_dir)


def _create_peer_exposure_by_factors(broadcast_data, analysis_dir):
    logger.info("Creating factor-specific peer exposure plots...")
    peer_factors = [
        "hedonic",
        "persona_age",
        "tool_use_policy",
        "visible_question_budget",
    ]
    available_factors = [f for f in peer_factors if f in broadcast_data.columns]
    factors_with_variation = []
    for factor in available_factors:
        if broadcast_data[factor].nunique() > 1:
            factors_with_variation.append(factor)
    if not factors_with_variation:
        logger.warning(
            "No factors with variation found for factor-specific peer exposure plots"
        )
        return
    n_factors = len(factors_with_variation)
    n_cols = min(2, n_factors)
    n_rows = (n_factors + n_cols - 1) // n_cols
    fig, axes = plt.subplots(n_rows, n_cols, figsize=(8 * n_cols, 6 * n_rows))
    if n_factors == 1:
        axes = [axes]
    elif n_rows == 1:
        axes = axes.flatten() if n_factors > 1 else [axes]
    else:
        axes = axes.flatten()
    for i in range(n_factors, len(axes)):
        axes[i].set_visible(False)
    for i, factor in enumerate(factors_with_variation):
        ax = axes[i]
        factor_levels = broadcast_data[factor].unique()
        factor_levels = [level for level in factor_levels if pd.notna(level)]
        for level in factor_levels:
            condition_data = broadcast_data[broadcast_data[factor] == level]
            if len(condition_data) == 0:
                continue
            peer_exposure = (
                condition_data.groupby("step")
                .agg(
                    {
                        "peers_eaten_so_far": ["mean", "sem", "count"],
                    }
                )
                .reset_index()
            )
            peer_exposure.columns = [
                "_".join(col).strip() if col[1] else col[0]
                for col in peer_exposure.columns.values
            ]
            eaten_ci_lower = (
                peer_exposure["peers_eaten_so_far_mean"]
                - 1.96 * peer_exposure["peers_eaten_so_far_sem"]
            ).clip(lower=0)
            eaten_ci_upper = (
                peer_exposure["peers_eaten_so_far_mean"]
                + 1.96 * peer_exposure["peers_eaten_so_far_sem"]
            )
            line = ax.plot(
                peer_exposure["step"],
                peer_exposure["peers_eaten_so_far_mean"],
                marker="o",
                linewidth=2,
                label=f"{str(level)}",
            )[0]
            ax.fill_between(
                peer_exposure["step"],
                eaten_ci_lower,
                eaten_ci_upper,
                alpha=0.3,
                color=line.get_color(),
            )
        ax.set_xlabel("Step")
        ax.set_ylabel("Mean Peers Eaten So Far")
        ax.set_title(
            f'Peer Exposure by {factor.replace("_", " ").title()}\n(Broadcast Mode, with 95% Confidence Intervals)'
        )
        ax.legend()
        ax.grid(True, alpha=0.3)
        ax.set_xticks(
            range(
                int(broadcast_data["step"].min()),
                int(broadcast_data["step"].max()) + 1,
            )
        )
    plt.tight_layout()
    plt.savefig(
        analysis_dir / "peer_exposure_by_factors.png",
        dpi=300,
        bbox_inches="tight",
    )
    plt.close()
    logger.info("✓ Created factor-specific peer exposure plots")
