import pandas as pd
import matplotlib.pyplot as plt
import os
from typing import List, Tuple
import numpy as np
import seaborn as sns

from utils.pretty_print import colored_print
from utils.constants import (
    LATEST_MESSAGE,
    ALL_MESSAGES,
    PARALLEL,
    RECURSIVE,
    # Threat types
    SCAM,
    AVAILABILITY,
    MALWARE,
    MANIPULATED_CONTENT,
    DATA_THEFT,
    # Tool types
    EMAIL,
    WEB,
    PDF,
    GPT3_5,
    GPT4O_MINI,
    GPT4O,
    ATTACK_IGNORED,
    ATTACK_REJECTED,
    MIXED_ACTION,
    DEFORMED_INFECTION,
    NO_ACTION,
    AGENT_ERROR,
    # Defense
    NO_DEFENSE,
    SANDWICH,
    INSTRUCTION_DEFENSE,
    RANDOM_SEQUENCE_ENCLOSURE,
    DELIMITING_DATA,
    MARKING,
    MODEL_DELIMITER,
    # Infection routes
    MODEL_INFECTION,
    EXTERNAL_INFECTION,
)


class ThreatPlotter:
    def __init__(
        self,
        parent_dir: str = os.path.dirname(os.path.dirname(__file__)),
        do_save_figure: bool = False,
        do_show_figure: bool = False,
    ):
        self.parent_dir = parent_dir
        self.tool_types = [EMAIL, PDF, WEB]
        self.infection_modes = [RECURSIVE, PARALLEL]
        self.model_types = ["gpt4o", GPT3_5]
        self.plots_dir = os.path.join(self.parent_dir, "experiments", "plots")
        self.do_save_figure = do_save_figure
        self.do_show_figure = do_show_figure

    def save_figure(self, fig, filename):
        if not self.do_save_figure:
            return

        """Save the figure to the plots directory."""
        filepath = os.path.join(self.plots_dir, filename)
        fig.savefig(filepath, bbox_inches="tight", dpi=300)
        print(f"Figure saved to {filepath}")

    def preprocess_data(
        self, data: pd.DataFrame, communication_mode: str
    ) -> pd.DataFrame:
        """Preprocess the data to convert string values to numerical values."""
        data = data[data["communication_mode"] == communication_mode]
        num_agents = ["2", "3", "4", "5"]
        data = data[num_agents]
        data = data.dropna()

        def to_bool(x):
            if x != "Invalid":
                return bool(eval(str(x)))
            return x

        data = data.applymap(to_bool)

        return data

    def get_success_rates(
        self,
        threat_type: str,
        tool_type: str,
        infection_mode: str,
        model_type: str,
        communication_mode: str = ALL_MESSAGES,
    ) -> Tuple[List[int], List[float]]:
        """Read and process data for a given threat type, tool type, infection mode, and model type."""
        log_filename = os.path.join(
            self.parent_dir,
            f"logs/{tool_type}/{infection_mode}/{threat_type}.csv",
        )
        logs = pd.read_csv(log_filename, index_col=False)

        logs = logs[
            (logs["communication_mode"] == communication_mode)
            & (logs["model"] == model_type)
        ]

        if threat_type == DATA_THEFT:
            num_agents = ["3", "4", "5", "6"]
        elif threat_type in [MODEL_INFECTION, EXTERNAL_INFECTION]:
            num_agents = ["1"]
        else:
            num_agents = ["2", "3", "4", "5"]

        logs = logs[num_agents]
        logs = logs.dropna()

        def to_bool(x):
            return bool(eval(str(x)))

        logs = logs.applymap(to_bool)

        success_rates = logs.mean()

        x = num_agents
        y = success_rates.values

        return x, y

    def plot_combined_threats(self, threat_types: List[str]):
        """Plot combined success rates for all specified threat types across different tool types using a histogram."""
        fig, ax = plt.subplots(figsize=(12, 8))

        data = []
        labels = []

        for tool_type in self.tool_types:
            combined_rates = []
            for threat_type in threat_types:
                _, y = self.get_success_rates(threat_type, tool_type)
                combined_rates.append(
                    sum(y) / len(y)
                )  # Average success rate across all agent numbers
            data.append(combined_rates)
            labels.extend(
                [f"{threat} - {tool_type}" for threat in threat_types]
            )

        x = np.arange(len(labels))
        width = 0.25

        for i, d in enumerate(data):
            ax.bar(
                x[i * len(threat_types) : (i + 1) * len(threat_types)],
                d,
                width,
                label=self.tool_types[i],
            )

        ax.set_ylabel("Average Success Rate")
        ax.set_title("Combined Success Rates by Threat Type and Tool Type")
        ax.set_xticks(x)
        ax.set_xticklabels(labels, rotation=45, ha="right")
        ax.legend()
        ax.set_ylim(0, 1)

        # Add value labels on top of each bar
        for i, v in enumerate(np.array(data).flatten()):
            ax.text(x[i], v, f"{v:.2f}", ha="center", va="bottom", fontsize=8)

        plt.tight_layout()
        filename = f"combined_threats_{'-'.join(threat_types)}.png"
        self.save_figure(fig, filename)
        if self.do_show_figure:
            plt.show()

    def plot_recursive_vs_parallel(
        self, threat_types: List[str], communication_mode: str = LATEST_MESSAGE
    ):
        """Plot a comparison of RECURSIVE vs PARALLEL Prompt Infection"""
        num_threats = len(threat_types)
        num_tools = len(self.tool_types)

        fig, axes = plt.subplots(
            num_tools,
            num_threats,
            figsize=(3 * num_threats, 3 * num_tools),
            squeeze=False,
        )
        fig.suptitle(
            f"{RECURSIVE} vs {PARALLEL} Prompt Infection on {communication_mode} (GPT-3.5 vs GPT-4)",
            fontsize=16,
        )

        # Generate colors for each infection mode
        color_palette = sns.color_palette(
            "husl", n_colors=len(self.model_types)
        )

        for i, tool_type in enumerate(self.tool_types):
            for j, threat_type in enumerate(threat_types):
                ax = axes[i, j]

                for infection_mode in self.infection_modes:
                    for k, model_type in enumerate(self.model_types):
                        color = color_palette[k]

                        x, y = self.get_success_rates(
                            threat_type,
                            tool_type,
                            infection_mode,
                            model_type,
                            communication_mode,
                        )

                        label = f"{infection_mode} {model_type}"

                        if infection_mode == RECURSIVE:
                            linestyle = "-"
                            marker = "o"
                        else:
                            label = f"non-recursive {model_type}"
                            linestyle = "-."
                            marker = "s"

                        ax.plot(
                            x,
                            y,
                            marker=marker,
                            color=color,
                            label=label,
                            linestyle=linestyle,
                        )

                        for n, v in enumerate(y):
                            ax.text(
                                x[n],
                                v,
                                f"{v:.2f}",
                                ha="center",
                                va="bottom",
                                fontsize=8,
                                color=color,
                            )

                ax.set_xticks(x)
                ax.set_xlabel("Number of agents")
                ax.set_ylabel("Success rate")
                ax.set_title(f"{threat_type} ({tool_type})", fontsize=10)
                ax.grid(True, linestyle="--", alpha=0.7)
                ax.set_ylim(0.0, 1.0)
                ax.legend(fontsize="x-small")

        plt.tight_layout()
        filename = f"recursive_vs_parallel_{communication_mode}.png"
        if self.do_show_figure:
            plt.show()
        self.save_figure(fig, filename)

    def plot_recursive_vs_parallel_averaged(
        self, threat_types: List[str], communication_mode: str = LATEST_MESSAGE
    ):
        """Plot a comparison of RECURSIVE vs PARALLEL Prompt Infection, averaged across tool types"""
        num_threats = len(threat_types)

        fig, axes = plt.subplots(
            1, num_threats, figsize=(5 * num_threats, 5), squeeze=False
        )
        # fig.suptitle(
        #     f"{RECURSIVE} vs {PARALLEL} Prompt Infection on {communication_mode} (GPT-3.5 vs GPT-4)",
        #     fontsize=16,
        # )

        # color_palette = sns.color_palette(
        #     "husl", n_colors=len(self.model_types)
        # )
        colors = {"gpt4o": "#FF9999", GPT3_5: "#66B2FF"}
        infection_model_labels = {
            RECURSIVE: "Self-Replicating",
            PARALLEL: "Non-Replicating",
        }

        for j, threat_type in enumerate(threat_types):
            ax = axes[0, j]

            for model_type in self.model_types:
                for infection_mode in self.infection_modes:
                    color = colors[model_type]

                    # Get data for the first tool type to initialize x_values
                    x_init, _ = self.get_success_rates(
                        threat_type,
                        self.tool_types[0],
                        infection_mode,
                        model_type,
                        communication_mode,
                    )

                    y_values = [[] for _ in x_init]

                    for tool_type in self.tool_types:
                        x, y = self.get_success_rates(
                            threat_type,
                            tool_type,
                            infection_mode,
                            model_type,
                            communication_mode,
                        )
                        for i, val in enumerate(y):
                            y_values[i].append(val)

                    x_values = x
                    y_avg = [np.mean(vals) for vals in y_values]

                    linestyle = "-" if infection_mode == RECURSIVE else "-."
                    marker = "o" if infection_mode == RECURSIVE else "s"
                    label = f"{infection_model_labels[infection_mode]} {model_type}"

                    ax.plot(
                        x_values,
                        y_avg,
                        marker=marker,
                        color=color,
                        label=label,
                        linestyle=linestyle,
                    )

                    for n, v in enumerate(y_avg):
                        ax.text(
                            x_values[n],
                            v,
                            f"{v:.2f}",
                            ha="center",
                            va="bottom",
                            fontsize=8,
                            color=color,
                        )

            ax.set_xticks(x_values)
            ax.set_xlabel("Number of agents")
            ax.set_ylabel("Average success rate")
            # ax.set_title(
            #     f"{threat_type} (Averaged across tool types)", fontsize=10
            # )
            ax.grid(True, linestyle="--", alpha=0.7)
            ax.set_ylim(0.0, 1.0)
            ax.legend(fontsize=8)

        plt.tight_layout()
        filename = (
            f"recursive_vs_parallel_{communication_mode}_averaged_new.png"
        )
        if self.do_show_figure:
            plt.show()
        self.save_figure(fig, filename)

    def plot_communication_mode_comparison(self):
        """
        Plot a comparison of LATEST_MESSAGE vs ALL_MESSAGES communication modes using Seaborn for enhanced aesthetics.
        This function takes an average of success rates for all threats (scam, availability, manipulated_content, malware).
        It shows pairs of (model, infection_mode) for comparison.
        """
        threat_types = [SCAM, AVAILABILITY, MANIPULATED_CONTENT, MALWARE]
        communication_modes = [ALL_MESSAGES, LATEST_MESSAGE]
        model_infection_pairs = [
            [model, infection_mode]
            for model in self.model_types
            for infection_mode in self.infection_modes
        ]

        # Set up the Seaborn style
        sns.set_theme(style="whitegrid")
        color_palette = sns.color_palette(
            "husl", n_colors=len(communication_modes)
        )

        fig, ax = plt.subplots(figsize=(14, 8))
        bar_width = 0.35
        opacity = 0.8
        index = np.arange(len(model_infection_pairs))

        for i, comm_mode in enumerate(communication_modes):
            avg_success_rates = []

            for model, infection_mode in model_infection_pairs:
                total_success_rate = 0
                for threat_type in threat_types:
                    for tool_type in self.tool_types:
                        _, y = self.get_success_rates(
                            threat_type,
                            tool_type,
                            infection_mode,
                            model,
                            comm_mode,
                        )
                        total_success_rate += sum(y) / len(y)

                avg_success_rate = total_success_rate / (
                    len(threat_types) * len(self.tool_types)
                )
                avg_success_rates.append(avg_success_rate)

            rects = ax.bar(
                index + i * bar_width,
                avg_success_rates,
                bar_width,
                alpha=opacity,
                label=(
                    "Public Messaging"
                    if comm_mode == ALL_MESSAGES
                    else "Private Messaging"
                ),
                color=color_palette[i],
            )

            # Add value labels on top of each bar
            for rect in rects:
                height = rect.get_height()
                ax.text(
                    rect.get_x() + rect.get_width() / 2.0,
                    height,
                    f"{height:.2f}",
                    ha="center",
                    va="bottom",
                    fontsize=8,
                )

        ax.set_xlabel(
            "Model and Infection Mode Pairs", fontsize=12, labelpad=10
        )
        ax.set_ylabel("Average Success Rate", fontsize=12, labelpad=10)
        ax.set_title(
            "Public Messaging vs Private Messaging", fontsize=16, pad=20
        )

        for pair in model_infection_pairs:
            if pair[1] == PARALLEL:
                pair[1] = "non-recursive"

        ax.set_xticks(index + bar_width / 2)
        ax.set_xticklabels(
            [f"{m}\n{i}" for m, i in model_infection_pairs],
            rotation=45,
            ha="right",
            fontsize=10,
        )
        ax.legend(title="Communication Mode", title_fontsize=12, fontsize=10)
        ax.set_ylim(0, 1)

        # Enhance the plot with Seaborn
        sns.despine(left=True, bottom=True)

        plt.tight_layout()
        filename = "global_vs_local_messaging.png"
        self.save_figure(fig, filename)
        if self.do_show_figure:
            plt.show()

    def get_disinformation_rates(
        self,
        filename: str,
        model_type: str,
        communication_mode: str = LATEST_MESSAGE,
    ) -> Tuple[List[int], List[float], List[float], List[float]]:
        """Read and process data for disinformation experiments with 'Invalid' values, separated by model."""
        log_filename = os.path.join(self.parent_dir, f"logs/{filename}")
        logs = pd.read_csv(log_filename)
        # Filter by communication mode and model
        logs = logs[
            (logs["communication_mode"] == communication_mode)
            & (logs["model"] == model_type)
        ]
        num_agents = ["2", "3", "4", "5"]
        logs = logs[num_agents]
        logs = logs.dropna()

        x = [int(agent) for agent in num_agents]  # 2 to 5 agents
        correct_rates = []
        incorrect_rates = []
        invalid_rates = []

        is_healthy = filename == "web/healthy_qa.csv"

        for agent in num_agents:
            column_data = logs[agent].astype(str)
            total_count = len(column_data)
            num_correct = 0
            num_incorrect = 0
            num_invalid = 0

            for value in column_data:
                if value == "Invalid":
                    num_invalid += 1
                elif eval(value):
                    if is_healthy:
                        num_correct += 1
                    else:
                        num_incorrect += (
                            1  # Attack successful; model is incorrect
                        )
                elif eval(value) == False:
                    if is_healthy:
                        num_incorrect += 1
                    else:
                        num_correct += 1  # Attack failed; model is correct

            correct_rate = num_correct / total_count if total_count > 0 else 0
            incorrect_rate = (
                num_incorrect / total_count if total_count > 0 else 0
            )
            invalid_rate = num_invalid / total_count if total_count > 0 else 0

            correct_rates.append(correct_rate)
            incorrect_rates.append(incorrect_rate)
            invalid_rates.append(invalid_rate)

        return x, correct_rates, incorrect_rates, invalid_rates

    def plot_disinformation_comparison(self):
        """
        Plot comparison of disinformation experiments, treating 'Invalid' as attack success.
        Shows success rates for different models and communication modes.
        """
        communication_modes = [ALL_MESSAGES, LATEST_MESSAGE]
        models = self.model_types

        fig, axes = plt.subplots(
            2, 2, figsize=(12, 10), sharex=True, sharey=True
        )
        fig.suptitle("Disinformation Attack Success Rates", fontsize=16)

        # Generate colors for each condition
        color_palette = sns.color_palette("Set2", n_colors=3)

        for idx_cm, communication_mode in enumerate(communication_modes):
            for idx_model, model_type in enumerate(models):
                ax = axes[idx_cm][idx_model]

                # Plot for each condition
                conditions = [
                    "Without Prompt Infection",
                    "Prompt Infection Recursive",
                    "Prompt Infection Parallel",
                ]
                for i, condition in enumerate(conditions):
                    if condition == "Without Prompt Infection":
                        filename = "web/healthy_qa.csv"
                    else:
                        infection_mode = condition.replace(
                            "Prompt Infection ", ""
                        ).lower()
                        filename = (
                            f"web/{infection_mode}/disinformation_qa.csv"
                        )

                    x, correct_rates, incorrect_rates, invalid_rates = (
                        self.get_disinformation_rates(
                            filename, model_type, communication_mode
                        )
                    )

                    success_rates = correct_rates

                    if "Without" in condition:
                        label = "No Infection"
                    elif "Recursive" in condition:
                        label = "Self-Replicating Infection"
                    else:
                        label = "Non-Replicating Infection"

                    ax.plot(
                        x,
                        success_rates,
                        marker="o",
                        color=color_palette[i],
                        label=label,
                    )

                    # Add value labels
                    for j, v in enumerate(success_rates):
                        ax.text(
                            x[j],
                            v,
                            f"{v:.2f}",
                            ha="center",
                            va="bottom",
                            fontsize=8,
                        )

                ax.set_xticks(x)
                ax.set_ylim(0, 1)
                ax.set_xlim(min(x) - 0.5, max(x) + 0.5)

                # Set titles and labels
                model_name = "GPT-4" if model_type == "gpt4o" else "GPT-3.5"
                comm_mode_name = (
                    "Global Messaging"
                    if communication_mode == ALL_MESSAGES
                    else "Local Messaging"
                )
                ax.set_title(f"{model_name} - {comm_mode_name}", fontsize=14)
                if idx_model == 0:
                    ax.set_ylabel("Success Rate", fontsize=12)
                if idx_cm == 1:
                    ax.set_xlabel("Number of Agents", fontsize=12)

                ax.legend(fontsize=10)
                ax.grid(True, linestyle="--", alpha=0.7)

        plt.tight_layout()
        self.save_figure(fig, "disinformation_comparison.png")
        if self.do_show_figure:
            plt.show()

    def plot_attack_failed_messages(self):
        filepath = os.path.join(
            self.parent_dir, "logs/attack_failed_messages.csv"
        )
        df = pd.read_csv(filepath)

        prompt_infection_modes = [RECURSIVE, PARALLEL]
        classification_order = [
            ATTACK_IGNORED,
            MIXED_ACTION,
            DEFORMED_INFECTION,
            NO_ACTION,
            AGENT_ERROR,
        ]
        classification_names = [
            "Attack\nIgnored",
            "Mixed\nAction",
            "Deformed\nInfection",
            "No Action",
            "Agent\nError",
        ]

        fig, axes = plt.subplots(1, 2, figsize=(16, 5))

        colors = {"gpt4o": "#FF9999", GPT3_5: "#66B2FF"}
        labels = {"gpt4o": "GPT-4", GPT3_5: "GPT-3.5"}

        for j, inf_mode in enumerate(prompt_infection_modes):
            ax = axes[j]
            data = df[df["prompt_infection_mode"] == inf_mode]

            x = np.arange(len(classification_order))
            width = 0.35

            for i, (model, offset) in enumerate(
                zip(["gpt4o", GPT3_5], [-width / 2, width / 2])
            ):
                model_data = data[data["model"] == model]
                ratios = (
                    model_data["classification"]
                    .value_counts(normalize=True)
                    .reindex(classification_order)
                    .fillna(0)
                )

                for k, (classification, v) in enumerate(ratios.items()):
                    bar = ax.bar(
                        k + offset,
                        v,
                        width,
                        label=labels[model] if k == 0 else "",
                        color=colors[model],
                        edgecolor="none",  # Remove the border
                        linewidth=0,
                        hatch=(
                            "///" if classification == ATTACK_IGNORED else None
                        ),
                    )

                    ax.text(
                        k + offset,
                        v,
                        f"{v:.2f}",
                        ha="center",
                        va="bottom",
                        fontsize=10,
                        fontweight=(
                            "bold"
                            if classification == ATTACK_IGNORED
                            else "normal"
                        ),
                    )

            ax.set_ylabel("Ratio", fontsize=12)
            ax.set_title(
                (
                    "Self-Replicating Infection"
                    if inf_mode == RECURSIVE
                    else "Non-Replicating Infection"
                ),
                fontsize=14,
            )
            ax.set_xticks(x)
            ax.set_xticklabels(classification_names, fontsize=10)

            # Create a custom legend without hatch patterns
            legend_elements = [
                plt.Rectangle(
                    (0, 0),
                    1,
                    1,
                    facecolor=colors["gpt4o"],
                    edgecolor="none",
                    label="GPT-4",
                ),
                plt.Rectangle(
                    (0, 0),
                    1,
                    1,
                    facecolor=colors[GPT3_5],
                    edgecolor="none",
                    label="GPT-3.5",
                ),
            ]
            ax.legend(handles=legend_elements, fontsize=12, loc="upper right")

            ax.set_ylim(0, 1)

        plt.tight_layout()
        filename = "model_attack_failure_reasons.png"
        self.save_figure(fig, filename)
        if self.do_show_figure:
            plt.show()

    def plot_defense_rate(self):
        """
        Plot the average defense rate of different defense types for MODEL_INFECTION,
        comparing with and without MODEL_DELIMITER across all tool types.
        Only consider logs where counterattack is True. Uses Seaborn for styling with improved aesthetics.
        """
        # Original defense types for data matching
        original_defense_types = [
            NO_DEFENSE,
            DELIMITING_DATA,
            RANDOM_SEQUENCE_ENCLOSURE,
            SANDWICH,
            INSTRUCTION_DEFENSE,
            MARKING,
        ]

        # Mapping of original names to display names
        display_name_map = {
            NO_DEFENSE: "No\nDefense",
            SANDWICH: "Sandwich",
            INSTRUCTION_DEFENSE: "Instruction\nDefense",
            RANDOM_SEQUENCE_ENCLOSURE: "Random Sequence\nEnclosure",
            DELIMITING_DATA: "Delimiting\nData",
            MARKING: "Marking",
        }

        tool_types = [PDF, WEB]

        # Set Seaborn style
        sns.set_theme(style="whitegrid", font_scale=1.1)
        plt.figure(figsize=(18, 6))

        without_rates_total = pd.Series(0, index=original_defense_types)
        with_rates_total = pd.Series(
            0,
            index=[d + "_" + MODEL_DELIMITER for d in original_defense_types],
        )
        tool_count = 0

        for tool_type in tool_types:
            log_filename = os.path.join(
                self.parent_dir,
                f"logs/{tool_type}/recursive/defense/{MODEL_INFECTION}.csv",
            )

            try:
                logs = pd.read_csv(log_filename)
                logs = logs[logs["counterattack"] == True]

                without_delimiter = logs[
                    ~logs["defense_type"].str.contains(MODEL_DELIMITER)
                ]
                with_delimiter = logs[
                    logs["defense_type"].str.contains(MODEL_DELIMITER)
                ]

                without_rates = without_delimiter.groupby("defense_type")[
                    "1"
                ].mean()
                with_rates = with_delimiter.groupby("defense_type").apply(
                    lambda x: x["1"].mean()
                )

                without_rates = without_rates.reindex(
                    original_defense_types, fill_value=0
                )
                with_rates = with_rates.reindex(
                    [
                        d + "_" + MODEL_DELIMITER
                        for d in original_defense_types
                    ],
                    fill_value=0,
                )

                without_rates_total += without_rates
                with_rates_total += with_rates
                tool_count += 1

            except FileNotFoundError:
                print(f"File not found: {log_filename}")

        if tool_count > 0:
            without_rates_avg = without_rates_total / tool_count
            with_rates_avg = with_rates_total / tool_count

            # Prepare data for Seaborn
            data = pd.DataFrame(
                {
                    "Defense Type": [
                        display_name_map[d] for d in original_defense_types
                    ]
                    * 2,
                    "Success Rate": pd.concat(
                        [without_rates_avg, with_rates_avg]
                    ),
                    "Delimiter": ["Without LLM Tagging"]
                    * len(original_defense_types)
                    + ["With LLM Tagging"] * len(original_defense_types),
                }
            )

            # Create the Seaborn plot
            ax = sns.barplot(
                x="Defense Type",
                y="Success Rate",
                hue="Delimiter",
                data=data,
                palette="Accent",
            )

            plt.xlabel(
                "Defense Type", fontsize=16, fontweight="bold", labelpad=10
            )
            plt.ylabel(
                "Average Attack Success Rate",
                fontsize=16,
                fontweight="bold",
                labelpad=10,
            )

            # Set x-axis labels without rotation
            plt.xticks(ha="center")
            ax.set_xticklabels(ax.get_xticklabels(), fontsize=14)

            plt.ylim(
                0, 1.1
            )  # Increase y-axis limit to make room for value labels

            # Add value labels on top of each bar with larger font
            for container in ax.containers:
                ax.bar_label(container, fmt="%.2f", padding=3, fontsize=14)

            # Adjust layout and legend
            plt.tight_layout()
            plt.legend(
                title_fontsize="12",
                fontsize="14",
                loc="upper right",
            )

            # Remove top and right spines
            sns.despine()

        else:
            plt.text(
                0.5,
                0.5,
                "No data available",
                ha="center",
                va="center",
                fontsize=14,
            )

        filename = "defense_comparison.png"
        self.save_figure(plt.gcf(), filename)
        if self.do_show_figure:
            plt.show()

    def plot_importance_manipulation(self):
        # Create a dataframe for seaborn compatibility
        data = {
            "Model": ["GPT-4o", "GPT-4o", "GPT-3.5", "GPT-3.5"],
            "Condition": [
                "With Manipulation",
                "Without Manipulation",
                "With Manipulation",
                "Without Manipulation",
            ],
            "Score": [10.0, 1.94, 9.84, 1.0],
        }

        df = pd.DataFrame(data)

        # Plot using seaborn
        plt.figure(figsize=(8, 6))
        ax = sns.barplot(
            x="Model", y="Score", hue="Condition", data=df, palette="Set2"
        )

        # Add title and labels
        # plt.title("Importance Scores with and without Manipulation", fontsize=16)
        plt.ylabel("Importance Score", fontsize=14)
        plt.xlabel("Models", fontsize=14)

        # Display the score values on top of each bar
        for container in ax.containers:
            ax.bar_label(container, fmt="%.2f", padding=3)

        # Adjust legend to remove "Condition" title
        ax.legend(title=None)

        # Show plot
        plt.tight_layout()
        plt.show()

    def plot_society_infection(self, with_manipulation: bool = True):
        # Construct the filenames dynamically
        if with_manipulation:
            file_pattern = "logs/simulation/importance_manipulation_seed{}.csv"
        else:
            file_pattern = "logs/simulation/without_manipulation_seed{}.csv"

        filenames = [
            os.path.join(self.parent_dir, file_pattern.format(i))
            for i in range(1, 4)
        ]

        # Read and combine CSV files
        df_list = [pd.read_csv(file) for file in filenames]
        df_combined = pd.concat(df_list)

        # Group by Num_Agents, Turn, and compute the mean of Infection_Count
        df_mean = (
            df_combined.groupby(["Num_Agents", "Turn"]).mean().reset_index()
        )

        # Custom pastel rainbow-like color palette
        custom_palette = {
            10: "#FF9999",  # Pastel Red
            20: "#FFB380",  # Pastel Orange
            30: "#FFCC66",  # Pastel Yellow
            40: "#99CC99",  # Pastel Green
            50: "#B3B3FF",  # Pastel Blue/Purple
        }

        # Set a seaborn style
        sns.set(style="whitegrid")

        # Create the plot
        plt.figure(figsize=(10, 6))

        # Plot data for each Num_Agents group
        for num_agents in df_mean["Num_Agents"].unique():
            subset = df_mean[df_mean["Num_Agents"] == num_agents]

            sns.lineplot(
                x="Turn",
                y="Infection_Count",
                data=subset,
                color=custom_palette[num_agents],
                linewidth=2.5,
                marker="o",
                markersize=6,
                alpha=0.8,
                label=f"{num_agents} Agents",  # Add label for the legend
            )

            # Calculate the full turn logic here
            full_turns = []
            for filename in filenames:
                seed_df = pd.read_csv(filename)
                seed_subset = seed_df[seed_df["Num_Agents"] == num_agents]
                full_turn_subset = seed_subset[
                    seed_subset["Infection_Count"] >= num_agents
                ]

                if not full_turn_subset.empty:
                    full_turn = full_turn_subset["Turn"].min()
                    full_turns.append(full_turn)

            # Check if full_turns has values
            if full_turns:
                # Calculate the average full turn across seeds
                avg_full_turn = sum(full_turns) / len(full_turns)

                # Plot darker line after full turn
                sns.lineplot(
                    x=subset[subset["Turn"] > avg_full_turn]["Turn"],
                    y=subset[subset["Turn"] > avg_full_turn][
                        "Infection_Count"
                    ],
                    color=custom_palette[num_agents],
                    linewidth=3,
                    alpha=1.0,
                    marker="o",
                    markersize=8,
                )

                # Add horizontal dashed lines at full capacity
                plt.hlines(
                    num_agents,
                    xmin=subset["Turn"].min(),
                    xmax=subset["Turn"].max(),
                    linestyles="dashed",
                    colors=custom_palette[num_agents],
                    alpha=0.7,
                )

                # Annotate the 'Full at turn' label with the averaged full turn
                plt.text(
                    avg_full_turn + 0.5,
                    num_agents + 1.5,
                    f"Full at turn {avg_full_turn:.1f}",
                    color=custom_palette[num_agents],
                    fontsize=13,
                )

        # Add labels
        plt.xlabel("Turn", fontsize=14)
        plt.ylabel("Number of Infected Agents", fontsize=14)

        # Let Seaborn handle the legend automatically now
        plt.legend(title="Total Agents", loc="lower right")

        plt.ylim(0, 55)

        # Show plot with tight layout
        plt.tight_layout()
        # plt.show()

        figure_name = (
            "society_of_agents_manipulation.png"
            if with_manipulation
            else "society_of_agents_no_manipulation.png"
        )
        plt.savefig(
            os.path.join(self.plots_dir, figure_name),
            dpi=300,
            bbox_inches="tight",
        )


# Usage example
if __name__ == "__main__":
    plotter = ThreatPlotter(do_show_figure=True, do_save_figure=False)
    plotter.plot_society_infection(with_manipulation=True)
    plotter.plot_society_infection(with_manipulation=False)
    plotter.plot_attack_failed_messages()
    plotter.plot_defense_rate()
    plotter.plot_recursive_vs_parallel_averaged(
        [SCAM, MALWARE, MANIPULATED_CONTENT, DATA_THEFT],
        ALL_MESSAGES,
    )
    plotter.plot_recursive_vs_parallel_averaged(
        [SCAM, MALWARE, MANIPULATED_CONTENT, DATA_THEFT],
        LATEST_MESSAGE,
    )
