import json
from pathlib import Path
from typing import Dict, List, Optional, Union

import matplotlib as mpl
import matplotlib.pyplot as plt
import pandas as pd

from eliciting_contexts.benchmark.external.shared.utils import (
    filter_table,
    flatten_results,
    remove_entries,
)
from eliciting_contexts.utils.constants import BENCHMARK_RESULTS_DIR


def get_default_color_map_and_display_names(result_sets: List[str]):
    # Default color map and display names for known methods
    default_color_map = {
        "EPO": "#FFA500",  # orange
        "EPOAssist": "#FF0000",  # orange-red
        "EPOAssist-Final": "#FF0000",  # red
        "EPOInpaint": "#9467bd",  # purple
        "GPT4o": "#2ca02c",  # green
        "max_activating_examples": "#1f77b4",  # blue
    }
    default_display_name_map = {
        "EPO": "EPO",
        "EPOAssist": "EPO-Assist",
        "EPOAssist-Final": "EPO-Assist-Final",
        "EPOInpaint": "EPO-Inpaint",
        "GPT4o": "GPT-4o",
        "max_activating_examples": "Max Activating Examples",
    }
    # Fill in any missing methods with default matplotlib colors and their own names
    for i, method in enumerate(result_sets):
        if method not in default_color_map:
            default_color_map[method] = f"C{i}"
        if method not in default_display_name_map:
            default_display_name_map[method] = method
    return default_color_map, default_display_name_map


def create_violin_plots(
    df: pd.DataFrame,
    target_metric: str = "logit_diff",
    fluency_metric: str = "cross_entropy",
    color_map: Dict[str, str] = None,
    naming_map: Dict[str, str] = None,
    axes_names: Optional[Dict[str, str]] = None,
) -> Dict[str, plt.Figure]:
    """Create violin plots for target metric and fluency metric with consistent styling.

    Args:
        df: DataFrame with evaluation results
        target_metric: The target metric to visualize
        fluency_metric: The fluency metric to visualize
        color_map: Optional dictionary mapping result set names to colors
        naming_map: Optional dictionary mapping result set names to display names
        axes_names: Optional dictionary mapping metric names to axis labels

    Returns:
        Dictionary containing the two violin plot figures
    """
    # Set up matplotlib styling to match plot_results.py
    plt.style.use("seaborn-v0_8-whitegrid")
    mpl.rcParams["font.family"] = "sans-serif"
    mpl.rcParams["font.sans-serif"] = ["Arial"]
    mpl.rcParams["axes.edgecolor"] = "#333333"
    mpl.rcParams["axes.linewidth"] = 1.2
    mpl.rcParams["xtick.major.width"] = 1.2
    mpl.rcParams["ytick.major.width"] = 1.2
    mpl.rcParams["axes.grid"] = True
    mpl.rcParams["grid.alpha"] = 0.3

    # Get all methods in the data
    result_sets = sorted(df["result_set"].unique().tolist())
    default_color_map, default_display_name_map = (
        get_default_color_map_and_display_names(result_sets)
    )
    color_map = color_map or default_color_map
    naming_map = naming_map or default_display_name_map

    charts = {}

    # Create violin plot for target metric
    fig_target = plt.figure(figsize=(10, 6))  # Reduced height from 8 to 6
    ax_target = fig_target.add_subplot(111)

    # Prepare data for violin plot
    data = []
    labels = []
    colors_list = []

    for i, result_set in enumerate(result_sets):
        result_set_data = df[df["result_set"] == result_set]
        values = result_set_data[target_metric].dropna().tolist()

        if values:
            data.append(values)
            labels.append(naming_map.get(result_set, result_set))
            colors_list.append(color_map.get(result_set, f"C{i}"))

    # Create violin plot
    parts = ax_target.violinplot(
        data,
        showmeans=True,
        showmedians=True,
        showextrema=True,
    )

    # Customize violin plot appearance
    for i, pc in enumerate(parts["bodies"]):
        pc.set_facecolor(colors_list[i])
        pc.set_alpha(0.6)

    # Customize box plot appearance
    for i, part in enumerate(["cmeans", "cmedians", "cmaxes", "cmins"]):
        if part in parts:
            parts[part].set_color(colors_list[i % len(colors_list)])

    # Set x-axis labels
    ax_target.set_xticks(range(1, len(labels) + 1))
    ax_target.set_xticklabels(labels)

    # Set y-axis label
    if axes_names and target_metric in axes_names:
        label = axes_names[target_metric]

    else:
        label = target_metric.replace("_", " ").title()
    ax_target.set_ylabel(label, fontsize=22, fontweight="bold")

    # Customize grid and spines
    ax_target.grid(True, alpha=0.3)
    ax_target.spines["top"].set_visible(False)
    ax_target.spines["right"].set_visible(False)
    ax_target.spines["bottom"].set_linewidth(1.5)
    ax_target.spines["left"].set_linewidth(1.5)

    # Set tick parameters
    ax_target.tick_params(axis="both", which="major", labelsize=18)

    # Add legend
    # ax_target.legend(
    #     labels,
    #     loc="upper right",
    #     fontsize=12,
    #     frameon=True,
    #     framealpha=0.7,
    #     edgecolor="#cccccc",
    # )

    # Adjust layout
    plt.tight_layout()

    # Create violin plot for fluency metric
    fig_fluency = plt.figure(figsize=(10, 6))  # Reduced height from 8 to 6
    ax_fluency = fig_fluency.add_subplot(111)

    # Prepare data for violin plot
    data = []
    labels = []
    colors_list = []

    for i, result_set in enumerate(result_sets):
        result_set_data = df[df["result_set"] == result_set]
        values = result_set_data[fluency_metric].dropna().tolist()

        if values:
            data.append(values)
            labels.append(naming_map.get(result_set, result_set))
            colors_list.append(color_map.get(result_set, f"C{i}"))

    # Create violin plot
    parts = ax_fluency.violinplot(
        data,
        showmeans=True,
        showmedians=True,
        showextrema=True,
    )

    # Customize violin plot appearance
    for i, pc in enumerate(parts["bodies"]):
        pc.set_facecolor(colors_list[i])
        pc.set_alpha(0.6)

    # Customize box plot appearance
    for i, part in enumerate(["cmeans", "cmedians", "cmaxes", "cmins"]):
        if part in parts:
            parts[part].set_color(colors_list[i % len(colors_list)])

    # Set x-axis labels
    ax_fluency.set_xticks(range(1, len(labels) + 1))
    ax_fluency.set_xticklabels(labels)

    # Set y-axis label
    if axes_names and fluency_metric in axes_names:
        label = axes_names[fluency_metric]
    else:
        label = fluency_metric.replace("_", " ").title()
    ax_fluency.set_ylabel(label, fontsize=22, fontweight="bold")

    # Customize grid and spines
    ax_fluency.grid(True, alpha=0.3)
    ax_fluency.spines["top"].set_visible(False)
    ax_fluency.spines["right"].set_visible(False)
    ax_fluency.spines["bottom"].set_linewidth(1.5)
    ax_fluency.spines["left"].set_linewidth(1.5)

    # Set tick parameters
    ax_fluency.tick_params(axis="both", which="major", labelsize=18)

    # Add legend
    # ax_fluency.legend(
    #     labels,
    #     loc="upper right",
    #     fontsize=12,
    #     frameon=True,
    #     framealpha=0.7,
    #     edgecolor="#cccccc",
    # )

    # Adjust layout
    plt.tight_layout()

    charts[f"{target_metric}_violin"] = fig_target
    charts[f"{fluency_metric}_violin"] = fig_fluency

    return charts


def save_violin_plots(
    df: pd.DataFrame,
    target_metric: str = "logit_diff",
    fluency_metric: str = "cross_entropy",
    naming_map: Dict[str, str] = None,
    color_map: Dict[str, str] = None,
    benchmark_task_name: str = "tiny_stories",
    output_dir: Union[str, Path] = BENCHMARK_RESULTS_DIR / "plots",
    only_include_best_in_range: bool = False,
    min_threshold_fluency: Optional[float] = None,
    max_threshold_fluency: Optional[float] = None,
    axes_names: Optional[Dict[str, str]] = None,
    plot_names: Optional[Dict[str, str]] = None,
) -> None:
    """
    Generate and save violin plots.

    Args:
        df: DataFrame with evaluation results
        target_metric: The target metric to visualize

    """
    # Create output directory using standardized path
    output_dir = Path(output_dir) / benchmark_task_name
    output_dir.mkdir(parents=True, exist_ok=True)

    suffix = (
        f"_{min_threshold_fluency}_{max_threshold_fluency}"
        if min_threshold_fluency is not None or max_threshold_fluency is not None
        else ""
    )
    if only_include_best_in_range:
        suffix += "_best_in_range"

    if only_include_best_in_range or (
        min_threshold_fluency is not None and max_threshold_fluency is not None
    ):
        df_filtered = filter_table(
            df,
            target_metric=target_metric,
            fluency_metric=fluency_metric,
            only_include_best_in_range=only_include_best_in_range,
            min_threshold_fluency=min_threshold_fluency,
            max_threshold_fluency=max_threshold_fluency,
        )
    else:
        df_filtered = df

    # Generate violin plots
    print("\nGenerating violin plots...")
    violin_plots = create_violin_plots(
        df_filtered,
        target_metric=target_metric,
        fluency_metric=fluency_metric,
        color_map=color_map,
        naming_map=naming_map,
        axes_names=axes_names,
    )

    for plot_name, plot_content in violin_plots.items():
        print(f"Saving {plot_name} plot...")
        print(plot_names)
        print(plot_name)
        key_plot_name = plot_name.replace("_violin", "")
        # Use custom plot names if provided
        if plot_names and key_plot_name in plot_names:
            filename = output_dir / f"{plot_names[key_plot_name]}.png"
        else:
            filename = output_dir / f"{plot_name}.png"

        # Save using matplotlib's savefig
        plot_content.savefig(filename, dpi=300, bbox_inches="tight")
        print(f"Saved PNG to {filename}")
        plt.close(plot_content)


if __name__ == "__main__":
    ############# ARGS ##############
    json_path = "/workspace/eliciting-contexts/src/eliciting_contexts/benchmark/results/results_final_latest.json"

    result_sets_config = [
        ("EPO", "EPO", "#FF5200"),  # orange
        ("EPOAssist", "EPO-Assist", "#FF0000"),  # orange-red
        ("EPOAssist-Final", "EPO-Assist-Final", "#FF0000"),  # red
        ("EPOInpaint", "EPO-Inpaint", "#9467bd"),  # purple
        ("GPT4o", "GPT-4o", "#2ca02c"),  # green
        ("max_activating_examples", "Max Activating Examples", "#1f77b4"),  # blue
        ("human", "Human", "#1f77b4"),  # blue
    ]

    only_include_best_in_range = (True,)
    min_threshold_fluency = (3.0,)
    max_threshold_fluency = (9.0,)

    target_metric = "logit_diff_improvement"
    fluency_metric = "cross_entropy"

    axes_names = {
        target_metric: "Logit Diff Improvement",
        fluency_metric: "Cross Entropy",
    }
    plot_names = {
        target_metric: "stories_violin_logitdiff",
        fluency_metric: "stories_violin_crossentropy",
    }

    # Extract naming_map for flatten_results
    naming_map = {
        json_name: display_name for json_name, display_name, _ in result_sets_config
    }

    # Extract color_map for charts
    color_map = {display_name: color for _, display_name, color in result_sets_config}

    entries_to_remove = list(range(38, 77))
    entries_to_remove = [f"{i}" for i in entries_to_remove]

    ############# ARGS END ##############

    with open(json_path, "r") as f:
        all_results = json.load(f)

    # Remove specified entries before processing
    if entries_to_remove:
        print(f"Removing entries: {entries_to_remove}")
        all_results_filtered = remove_entries(all_results, entries_to_remove)

    # Pass the naming map to flatten_results to filter and rename in one step
    df = flatten_results(all_results, naming_map)

    save_violin_plots(
        df,
        target_metric="logit_diff_improvement",
        fluency_metric="cross_entropy",
        naming_map=naming_map,
        color_map=color_map,
        benchmark_task_name="tiny_stories",
        output_dir=BENCHMARK_RESULTS_DIR / "plots",
        only_include_best_in_range=only_include_best_in_range,
        min_threshold_fluency=min_threshold_fluency,
        max_threshold_fluency=max_threshold_fluency,
        axes_names=axes_names,
        plot_names=plot_names,
    )
