import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from pathlib import Path
import numpy as np
from typing import List
from logging import getLogger
from ..ablation_plotter import _calculate_kaplan_meier_ablation


def create_suite_survival_plot(
    agent_outcomes_df: pd.DataFrame,
    analysis_dir: Path,
    social_influence: bool,
    survival_prob_range: List[float],
):
    logger.info("Creating experiment suite survival comparison plot by model...")
    grouping_factor = "experiment_id"
    if grouping_factor not in agent_outcomes_df.columns:
        logger.error(f"'{grouping_factor}' not found. Skipping plot.")
        return
    plt.figure(figsize=(20, 10))
    ax = plt.gca()
    unique_models = sorted(agent_outcomes_df[grouping_factor].unique())
    palette = sns.color_palette("husl", n_colors=len(unique_models))
    if social_influence:
        comm_types = agent_outcomes_df["communication_type"].unique()
        for i, model_name in enumerate(unique_models):
            for comm_type in comm_types:
                condition_agents = agent_outcomes_df[
                    (agent_outcomes_df[grouping_factor] == model_name)
                    & (agent_outcomes_df["communication_type"] == comm_type)
                    & (agent_outcomes_df["is_valid_outcome"] == 1)
                ].copy()
                if condition_agents.empty:
                    continue
                survival_curve = _calculate_kaplan_meier_ablation(condition_agents)
                label = f"{model_name} ({comm_type.title()})"
                linestyle = "-" if comm_type == "broadcast" else "--"
                line = ax.step(
                    survival_curve["step"],
                    survival_curve["survival"],
                    where="post",
                    linestyle=linestyle,
                    color=palette[i],
                    label=label,
                    linewidth=2,
                )[0]
                ax.fill_between(
                    survival_curve["step"],
                    survival_curve["ci_lower"],
                    survival_curve["ci_upper"],
                    alpha=0.2,
                    color=line.get_color(),
                    step="post",
                )
    else:
        for i, model_name in enumerate(unique_models):
            condition_agents = agent_outcomes_df[
                (agent_outcomes_df[grouping_factor] == model_name)
                & (agent_outcomes_df["is_valid_outcome"] == 1)
            ].copy()
            if condition_agents.empty:
                continue
            survival_curve = _calculate_kaplan_meier_ablation(condition_agents)
            line = ax.step(
                survival_curve["step"],
                survival_curve["survival"],
                where="post",
                color=palette[i],
                label=model_name,
                linewidth=2,
            )[0]
            ax.fill_between(
                survival_curve["step"],
                survival_curve["ci_lower"],
                survival_curve["ci_upper"],
                alpha=0.2,
                color=line.get_color(),
                step="post",
            )
    ax.set_title(
        "Survival Probability by Model (with 95% Confidence Intervals)", fontsize=20
    )
    ax.set_xlabel("Step", fontsize=14)
    ax.set_ylabel("Survival Probability", fontsize=14)
    ax.tick_params(axis="both", which="major", labelsize=12)
    legend = ax.legend(
        title="Model", bbox_to_anchor=(1.05, 1), loc="upper left", fontsize=12
    )
    legend.get_title().set_fontsize(14)
    ax.set_ylim(survival_prob_range[0], 1.05)
    ax.set_yticks(np.arange(survival_prob_range[0], 1.01, 0.05))
    max_step = 0
    for line in ax.get_lines():
        max_step = max(max_step, max(line.get_xdata(orig=False)))
    ax.set_xticks(np.arange(0, max_step + 1, 1))
    plt.xticks(rotation=45)
    plt.tight_layout(rect=[0, 0, 0.85, 1])
    plot_filename = "suite_survival_comparison.png"
    plt.savefig(analysis_dir / plot_filename, dpi=300)
    plt.close()
    logger.info(f"✓ Created enhanced experiment suite survival plot: {plot_filename}")


from ..kaplan_meier_plotter import create_kaplan_meier_plots

logger = getLogger(__name__)
