import pandas as pd
import seaborn as sns
from matplotlib.figure import Figure

from test_case import Method

# matplotlib.rcParams.update({"font.size": 7})


def analyze_results(
    variance="Variance.LARGE",
    csv_path="assets/test_case_contest_entries.csv",
    base_path="../docs/manuscript/generated/",
    output="kl_boxplot_large_variance.pdf",
):
    """
    Analyze and visualize the contest entries.

    Args:
        csv_path: Path to the CSV file containing the contest entries
        base_path: Base path for saving output files
        output: Output file name
    """
    # Convert to DataFrame
    contest_entries_df = pd.read_csv(csv_path)

    # Filter for large variance only
    large_variance_df = contest_entries_df[contest_entries_df["Variance"] == variance]

    # Create figure and axis using matplotlib's object-oriented interface
    fig = Figure(dpi=300, figsize=(5, 3), constrained_layout=True)
    sns.set_theme(
        style="whitegrid",
        rc={
            "font.size": 7,
            # "axes.titlesize": 7,
            # "axes.labelsize": 7,
            # "xtick.labelsize": 7,
            # "ytick.labelsize": 7,
            # "legend.fontsize": 7,
        },
    )
    ax = fig.add_subplot(111)

    # Define method name mapping
    method_labels = {
        Method.ANALYTIC: "analytic",
        Method.MEAN_FIELD: "mean field",
        Method.LINEAR: "linear",
        Method.UNSCENTED0: "unscented'95",
        Method.UNSCENTED1: "unscented'02",
    }

    # Create a boxplot with custom labels
    sns.boxplot(
        x="Method",
        y="KL",
        data=large_variance_df,
        order=[
            Method.ANALYTIC,
            Method.MEAN_FIELD,
            Method.LINEAR,
            Method.UNSCENTED0,
            Method.UNSCENTED1,
        ],
        ax=ax,
        fliersize=4,  # Size of the outlier markers
        # linewidth=1.5,  # Thicker box borders
        # whis=1.5,  # Whisker length in IQR
        showfliers=True,  # Show outliers
    )

    # Set tick positions and labels
    method_order = [
        Method.ANALYTIC,
        Method.MEAN_FIELD,
        Method.LINEAR,
        Method.UNSCENTED0,
        Method.UNSCENTED1,
    ]
    ax.set_xticks(range(len(method_order)))
    ax.set_xticklabels(
        [method_labels[m] for m in method_order],
        rotation=0,
        ha="center",
        fontsize=7,
    )
    ax.tick_params(axis="both", which="major", labelsize=7)
    ax.yaxis.label.set_size(7)

    # Improve plot aesthetics
    ax.set_yscale("log")  # Use log scale for better visualization of KL values
    ax.set_xlabel("")
    ax.set_ylabel("KL Divergence")

    # Save the figure
    fig.savefig(base_path + output)

    return contest_entries_df


analyze_results(
    variance="Variance.LARGE",
    csv_path="assets/test_case_contest_entries.csv",
    base_path="../docs/manuscript/generated/",
    output="kl_boxplot_large_variance.pdf",
)
analyze_results(
    variance="Variance.MEDIUM",
    csv_path="assets/test_case_contest_entries.csv",
    base_path="../docs/manuscript/generated/",
    output="kl_boxplot_medium_variance.pdf",
)
analyze_results(
    variance="Variance.SMALL",
    csv_path="assets/test_case_contest_entries.csv",
    base_path="../docs/manuscript/generated/",
    output="kl_boxplot_small_variance.pdf",
)

# import IPython

# IPython.embed(colors="neutral")
