import pandas as pd
import matplotlib.pyplot as plt
import logging
from pathlib import Path

logger = logging.getLogger(__name__)


def create_tool_usage_plots(
    step_level_data: pd.DataFrame, analysis_dir: Path, social_influence: bool
):
    logger.info("Creating tool usage plots...")
    risk_step_data = step_level_data[step_level_data["is_epilogue_step"] == 0].copy()
    if social_influence:
        if "communication_type" not in risk_step_data.columns:
            logger.warning(
                "communication_type not found - cannot create social influence plots"
            )
            return
        comm_types = risk_step_data["communication_type"].unique()
        comm_types = [ct for ct in comm_types if pd.notna(ct)]
        if len(comm_types) < 2:
            logger.warning(
                "Need both broadcast and isolated for social influence analysis"
            )
            return
        fig, ax = plt.subplots(1, 1, figsize=(10, 6))
        for comm_type in comm_types:
            subset = risk_step_data[
                (risk_step_data["communication_type"] == comm_type)
                & (risk_step_data["at_risk_t"] == 1)
            ]
            if len(subset) > 0:
                tool_usage = (
                    subset.groupby("step")["used_this_step"]
                    .agg(["mean", "sem", "count"])
                    .reset_index()
                )
                tool_usage["ci_lower"] = tool_usage["mean"] - 1.96 * tool_usage["sem"]
                tool_usage["ci_upper"] = tool_usage["mean"] + 1.96 * tool_usage["sem"]
                tool_usage["ci_lower"] = tool_usage["ci_lower"].clip(lower=0)
                line = ax.plot(
                    tool_usage["step"],
                    tool_usage["mean"],
                    marker="o",
                    linewidth=3,
                    markersize=8,
                    label=f"{comm_type.title()} Mode",
                )[0]
                ax.fill_between(
                    tool_usage["step"],
                    tool_usage["ci_lower"],
                    tool_usage["ci_upper"],
                    alpha=0.3,
                    color=line.get_color(),
                )
        ax.set_xlabel("Step")
        ax.set_ylabel("Mean Questions Used")
        ax.set_title(
            "Social Influence: Tool Usage (Broadcast vs Isolated)\n(with 95% Confidence Intervals)"
        )
        ax.legend(fontsize=12)
        ax.grid(True, alpha=0.3)
        ax.set_xticks(
            range(
                int(risk_step_data["step"].min()),
                int(risk_step_data["step"].max()) + 1,
            )
        )
        ax.axvline(x=10, color="r", linestyle="--", linewidth=1, label="Step 10")
    else:
        fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 5))
        tool_usage = (
            risk_step_data[risk_step_data["at_risk_t"] == 1]
            .groupby("step")["used_this_step"]
            .agg(["mean", "sem", "count"])
            .reset_index()
        )
        tool_usage["ci_lower"] = tool_usage["mean"] - 1.96 * tool_usage["sem"]
        tool_usage["ci_upper"] = tool_usage["mean"] + 1.96 * tool_usage["sem"]
        tool_usage["ci_lower"] = tool_usage["ci_lower"].clip(lower=0)
        line1 = ax1.plot(
            tool_usage["step"],
            tool_usage["mean"],
            marker="o",
            linewidth=2,
            color="green",
        )[0]
        ax1.fill_between(
            tool_usage["step"],
            tool_usage["ci_lower"],
            tool_usage["ci_upper"],
            alpha=0.3,
            color=line1.get_color(),
        )
        ax1.set_xlabel("Step")
        ax1.set_ylabel("Mean Questions Used")
        ax1.set_title("Tool Usage Pacing Over Time\n(with 95% Confidence Intervals)")
        ax1.grid(True, alpha=0.3)
        ax1.set_xticks(
            range(
                int(risk_step_data["step"].min()),
                int(risk_step_data["step"].max()) + 1,
            )
        )
        ax1.axvline(x=10, color="r", linestyle="--", linewidth=1, label="Step 10")
        tool_usage_prop = (
            risk_step_data[risk_step_data["at_risk_t"] == 1]
            .groupby("step")
            .agg(
                {
                    "tool_used_flag_t": ["mean", "sem", "count"],
                    "cap_exhausted_t": ["mean", "sem", "count"],
                }
            )
            .reset_index()
        )
        tool_usage_prop.columns = [
            "_".join(col).strip() if col[1] else col[0]
            for col in tool_usage_prop.columns.values
        ]

        def prop_ci(mean, sem):
            ci_lower = (mean - 1.96 * sem).clip(0, 1)
            ci_upper = (mean + 1.96 * sem).clip(0, 1)
            return ci_lower, ci_upper

        tool_ci_lower, tool_ci_upper = prop_ci(
            tool_usage_prop["tool_used_flag_t_mean"],
            tool_usage_prop["tool_used_flag_t_sem"],
        )
        cap_ci_lower, cap_ci_upper = prop_ci(
            tool_usage_prop["cap_exhausted_t_mean"],
            tool_usage_prop["cap_exhausted_t_sem"],
        )
        line2 = ax2.plot(
            tool_usage_prop["step"],
            tool_usage_prop["tool_used_flag_t_mean"],
            marker="o",
            linewidth=2,
            label="Used Tools",
            color="blue",
        )[0]
        ax2.fill_between(
            tool_usage_prop["step"],
            tool_ci_lower,
            tool_ci_upper,
            alpha=0.3,
            color=line2.get_color(),
        )
        line3 = ax2.plot(
            tool_usage_prop["step"],
            tool_usage_prop["cap_exhausted_t_mean"],
            marker="s",
            linewidth=2,
            label="Hit Budget Cap",
            color="red",
        )[0]
        ax2.fill_between(
            tool_usage_prop["step"],
            cap_ci_lower,
            cap_ci_upper,
            alpha=0.3,
            color=line3.get_color(),
        )
        ax2.set_xlabel("Step")
        ax2.set_ylabel("Proportion")
        ax2.set_title("Tool Usage Patterns\n(with 95% Confidence Intervals)")
        ax2.legend()
        ax2.grid(True, alpha=0.3)
        ax2.set_ylim(0, 1.05)
        ax2.set_xticks(
            range(
                int(risk_step_data["step"].min()),
                int(risk_step_data["step"].max()) + 1,
            )
        )
        ax2.axvline(x=10, color="r", linestyle="--", linewidth=1, label="Step 10")
    plt.tight_layout()
    plt.savefig(analysis_dir / "tool_usage_patterns.png", dpi=300, bbox_inches="tight")
    plt.close()
    logger.info("✓ Created tool usage plots")
    _create_tool_usage_by_policy_plot(risk_step_data, analysis_dir)


def _create_tool_usage_by_policy_plot(risk_step_data: pd.DataFrame, analysis_dir: Path):
    logger.info("Creating tool usage by policy plot...")
    if "tool_use_policy" not in risk_step_data.columns:
        logger.warning("tool_use_policy not found - cannot create by-policy plot")
        return
    policy_types = risk_step_data["tool_use_policy"].unique()
    if len(policy_types) < 2:
        logger.warning("Need both must and may policies for by-policy plot")
        return
    fig, ax = plt.subplots(1, 1, figsize=(10, 6))
    for policy_type in policy_types:
        line_color = None
        for comm_type in ["broadcast", "isolated"]:
            if comm_type not in risk_step_data["communication_type"].unique():
                continue
            subset = risk_step_data[
                (risk_step_data["tool_use_policy"] == policy_type)
                & (risk_step_data["communication_type"] == comm_type)
                & (risk_step_data["at_risk_t"] == 1)
            ]
            if len(subset) > 0:
                tool_usage = (
                    subset.groupby("step")["used_this_step"]
                    .agg(["mean", "sem", "count"])
                    .reset_index()
                )
                tool_usage["ci_lower"] = (
                    tool_usage["mean"] - 1.96 * tool_usage["sem"]
                ).clip(lower=0)
                tool_usage["ci_upper"] = tool_usage["mean"] + 1.96 * tool_usage["sem"]
                label = f"{policy_type.title()} ({comm_type})"
                linestyle = "-" if comm_type == "broadcast" else "--"
                plot_args = {
                    "marker": "o",
                    "linewidth": 2,
                    "linestyle": linestyle,
                    "label": label,
                }
                if line_color:
                    plot_args["color"] = line_color
                line = ax.plot(
                    tool_usage["step"],
                    tool_usage["mean"],
                    **plot_args,
                )[0]
                if line_color is None:
                    line_color = line.get_color()
                ax.fill_between(
                    tool_usage["step"],
                    tool_usage["ci_lower"],
                    tool_usage["ci_upper"],
                    alpha=0.2,
                    color=line.get_color(),
                )
    ax.set_xlabel("Step")
    ax.set_ylabel("Mean Questions Used")
    ax.set_title("Tool Usage by Policy: Must vs May\n(with 95% Confidence Intervals)")
    ax.legend()
    ax.grid(True, alpha=0.3)
    ax.axvline(x=10, color="r", linestyle="--", linewidth=1, label="Step 10")
    plt.tight_layout()
    plt.savefig(analysis_dir / "tool_usage_by_policy.png", dpi=300, bbox_inches="tight")
    plt.close()
    logger.info("✓ Created tool usage by policy plot")
