import json
import warnings
from pathlib import Path
from typing import Dict, List, Optional, Union

import matplotlib as mpl
import matplotlib.pyplot as plt
import pandas as pd

from eliciting_contexts.benchmark.external.shared.utils import (
    filter_table,
    flatten_results,
    remove_entries,
)
from eliciting_contexts.utils.constants import BENCHMARK_RESULTS_DIR

# Suppress font warnings
warnings.filterwarnings("ignore", message="findfont: Font family.*not found")


def get_default_color_map_and_display_names(result_sets: List[str]):
    # Default color map and display names for known methods
    default_color_map = {
        "EPO": "#FFA500",  # orange
        "EPOAssist": "#FF0000",  # red
        "EPOAssist-Final": "#FF0000",  # red
        "EPOInpaint": "#9467bd",  # purple
        "GPT4o": "#2ca02c",  # green
        "max_activating_examples": "#1f77b4",  # blue
    }
    default_display_name_map = {
        "EPO": "EPO",
        "EPOAssist": "EPO-Assist",
        "EPOAssist-Final": "EPO-Assist-Final",
        "EPOInpaint": "EPO-Inpaint",
        "GPT4o": "GPT-4o",
        "max_activating_examples": "Max Activating Examples",
    }
    # Fill in any missing methods with default matplotlib colors and their own names
    for i, method in enumerate(result_sets):
        if method not in default_color_map:
            default_color_map[method] = f"C{i}"
        if method not in default_display_name_map:
            default_display_name_map[method] = method
    return default_color_map, default_display_name_map


def create_scatter_plot(
    df: pd.DataFrame,
    target_metric: str = "logit_diff",
    fluency_metric: str = "cross_entropy",
    color_map: Optional[Dict[str, str]] = None,
    naming_map: Optional[Dict[str, str]] = None,
) -> plt.Figure:
    """Create a scatter plot for target metric vs fluency metric with consistent styling.

    Args:
        df: DataFrame with evaluation results
        target_metric: The target metric to visualize
        fluency_metric: The fluency metric to visualize
        color_map: Optional dictionary mapping result set names to colors
        naming_map: Optional dictionary mapping result set names to display names

    Returns:
        Matplotlib figure containing the scatter plot
    """
    # Set up matplotlib styling to match plot_results.py
    plt.style.use("seaborn-v0_8-whitegrid")
    mpl.rcParams["font.family"] = "sans-serif"
    # Use a list of fallback fonts instead of just Arial
    mpl.rcParams["font.sans-serif"] = [
        "Arial",
        "DejaVu Sans",
        "Helvetica",
        "Verdana",
        "Liberation Sans",
    ]
    mpl.rcParams["axes.edgecolor"] = "#333333"
    mpl.rcParams["axes.linewidth"] = 1.2
    mpl.rcParams["xtick.major.width"] = 1.2
    mpl.rcParams["ytick.major.width"] = 1.2
    mpl.rcParams["axes.grid"] = True
    mpl.rcParams["grid.alpha"] = 0.3

    # Get all methods in the data
    result_sets = sorted(df["result_set"].unique().tolist())
    default_color_map, default_display_name_map = (
        get_default_color_map_and_display_names(result_sets)
    )
    color_map = color_map or default_color_map
    naming_map = naming_map or default_display_name_map
    default_colors = list(color_map.values())
    markers = ["s", "X", "D", "^", "o", "P", "*", "v"]

    # Create figure
    fig = plt.figure(figsize=(10, 6))
    ax = fig.add_subplot(111)

    # Plot each result set
    for i, result_set in enumerate(result_sets):
        result_set_data = df[df["result_set"] == result_set]
        color = color_map.get(result_set, default_colors[i % len(default_colors)])
        label = naming_map.get(result_set, result_set)

        # Plot points with marker
        ax.scatter(
            result_set_data[fluency_metric],
            result_set_data[target_metric],
            color=color,
            marker=markers[i % len(markers)],
            s=120,
            alpha=0.8,
            edgecolors="white",
            linewidth=0.8,
            label=label,
        )

    # Set labels with larger font sizes
    ax.set_xlabel("Cross Entropy", fontsize=22, fontweight="bold")
    ax.set_ylabel("Logit Difference Improvement", fontsize=22, fontweight="bold")
    ax.tick_params(axis="both", which="major", labelsize=18)

    # Customize grid and spines
    ax.grid(True, alpha=0.3)
    ax.spines["top"].set_visible(False)
    ax.spines["right"].set_visible(False)
    ax.spines["bottom"].set_linewidth(1.5)
    ax.spines["left"].set_linewidth(1.5)

    # Get current y-axis limits
    y_min, y_max = ax.get_ylim()
    # Add 20% padding to the top of the y-axis
    y_padding = (y_max - y_min) * 0.2
    ax.set_ylim(y_min, y_max + y_padding)

    # Add legend in top right corner
    ax.legend(
        loc="upper right",
        fontsize=12,
        frameon=True,
        framealpha=0.9,
        edgecolor="#cccccc",
    )

    # Adjust layout
    plt.tight_layout()

    return fig


def create_scatter_plot_with_thresholds(
    df: pd.DataFrame,
    target_metric: str = "logit_diff",
    fluency_metric: str = "cross_entropy",
    color_map: Optional[Dict[str, str]] = None,
    min_threshold_fluency: Optional[float] = 3,
    max_threshold_fluency: Optional[float] = 9,
    naming_map: Optional[Dict[str, str]] = None,
) -> plt.Figure:
    """Create a scatter plot for target metric vs fluency metric with thresholds at 3 and 9.

    Args:
        df: DataFrame with evaluation results
        target_metric: The target metric to visualize
        fluency_metric: The fluency metric to visualize
        color_map: Optional dictionary mapping result set names to colors
        naming_map: Optional dictionary mapping result set names to display names

    Returns:
        Matplotlib figure containing the scatter plot
    """
    # Set up matplotlib styling to match plot_results.py
    plt.style.use("seaborn-v0_8-whitegrid")
    mpl.rcParams["font.family"] = "sans-serif"
    mpl.rcParams["font.sans-serif"] = ["Arial"]
    mpl.rcParams["axes.edgecolor"] = "#333333"
    mpl.rcParams["axes.linewidth"] = 1.2
    mpl.rcParams["xtick.major.width"] = 1.2
    mpl.rcParams["ytick.major.width"] = 1.2
    mpl.rcParams["axes.grid"] = True
    mpl.rcParams["grid.alpha"] = 0.3

    # Get all methods in the data
    result_sets = sorted(df["result_set"].unique().tolist())
    default_color_map, default_display_name_map = (
        get_default_color_map_and_display_names(result_sets)
    )
    color_map = color_map or default_color_map
    naming_map = naming_map or default_display_name_map
    default_colors = list(color_map.values())
    markers = ["s", "X", "D", "^", "o", "P", "*", "v"]

    # Create figure
    fig = plt.figure(figsize=(10, 6))
    ax = fig.add_subplot(111)

    # Plot each result set
    for i, result_set in enumerate(result_sets):
        result_set_data = df[df["result_set"] == result_set]
        color = color_map.get(result_set, default_colors[i % len(default_colors)])
        label = naming_map.get(result_set, result_set)

        # Plot points with marker
        ax.scatter(
            result_set_data[fluency_metric],
            result_set_data[target_metric],
            color=color,
            marker=markers[i % len(markers)],
            s=120,
            alpha=0.8,
            edgecolors="white",
            linewidth=0.8,
            label=label,
        )

    # Add dotted vertical lines at x=3 and x=9 with labels
    if min_threshold_fluency is not None and max_threshold_fluency is not None:
        for x_pos, label in [
            (min_threshold_fluency, "lower threshold"),
            (max_threshold_fluency, "upper threshold"),
        ]:
            ax.axvline(
                x=x_pos, color="#555555", linestyle="--", linewidth=1.5, alpha=0.7
            )
            ax.text(
                x_pos + 0.1,
                ax.get_ylim()[0] + 0.1 * (ax.get_ylim()[1] - ax.get_ylim()[0]),
                label,
                fontsize=16,
                color="#555555",
                fontweight="bold",
                ha="left",
                va="bottom",
                rotation=90,
            )
    else:
        print("No thresholds provided, skipping threshold lines")

    # Set labels with larger font sizes
    ax.set_xlabel("Cross Entropy", fontsize=22, fontweight="bold")
    ax.set_ylabel("Logit Difference Improvement", fontsize=22, fontweight="bold")
    ax.tick_params(axis="both", which="major", labelsize=18)

    # Customize grid and spines
    ax.grid(True, alpha=0.3)
    ax.spines["top"].set_visible(False)
    ax.spines["right"].set_visible(False)
    ax.spines["bottom"].set_linewidth(1.5)
    ax.spines["left"].set_linewidth(1.5)

    # Get current y-axis limits
    y_min, y_max = ax.get_ylim()
    # Add 20% padding to the top of the y-axis
    y_padding = (y_max - y_min) * 0.2
    ax.set_ylim(y_min, y_max + y_padding)

    # Add legend in top right corner
    ax.legend(
        loc="upper right",
        fontsize=12,
        frameon=True,
        framealpha=0.9,
        edgecolor="#cccccc",
    )

    # Adjust layout
    plt.tight_layout()

    return fig


def save_scatter_plots(
    df: pd.DataFrame,
    target_metric: str = "logit_diff",
    fluency_metric: str = "cross_entropy",
    color_map: Optional[Dict[str, str]] = None,
    naming_map: Optional[Dict[str, str]] = None,
    benchmark_task_name: str = "tiny_stories",
    output_dir: Union[str, Path] = BENCHMARK_RESULTS_DIR / "plots",
    min_threshold_fluency: Optional[float] = 3,
    max_threshold_fluency: Optional[float] = 9,
) -> None:
    """Save scatter plots to the standardized location.

    Args:
        df: Full DataFrame with evaluation results
        df_filtered: Filtered DataFrame with results between cross-entropy 3 and 9
        target_metric: The target metric to visualize
        fluency_metric: The fluency metric to visualize
        color_map: Optional dictionary mapping result set names to colors
        naming_map: Optional dictionary mapping result set names to display names
        benchmark_task_name: Name of the benchmark task for organizing plots
    """
    # Create output directory using standardized path

    output_dir = Path(output_dir) / benchmark_task_name
    output_dir.mkdir(parents=True, exist_ok=True)

    # Generate and save filtered scatter plot
    print("Generating filtered scatter plot...")
    df_filtered = filter_table(
        df,
        target_metric=target_metric,
        fluency_metric=fluency_metric,
        only_include_best_in_range=True,
        min_threshold_fluency=min_threshold_fluency,
        max_threshold_fluency=max_threshold_fluency,
    )

    scatter_plot = create_scatter_plot(
        df_filtered,
        target_metric=target_metric,
        fluency_metric=fluency_metric,
        color_map=color_map,
        naming_map=naming_map,
    )

    if min_threshold_fluency is not None and max_threshold_fluency is not None:
        scatter_filename = (
            output_dir
            / f"scatter_plot_{min_threshold_fluency}_{max_threshold_fluency}.png"
        )
    else:
        scatter_filename = output_dir / "scatter_plot.png"

    scatter_plot.savefig(scatter_filename, dpi=300, bbox_inches="tight")
    print(f"Saved scatter plot to {scatter_filename}")
    plt.close(scatter_plot)

    # Generate and save scatter plot with thresholds
    print("Generating full scatter plot with thresholds drawn in...")
    scatter_plot_with_thresholds = create_scatter_plot_with_thresholds(
        df,
        target_metric=target_metric,
        fluency_metric=fluency_metric,
        color_map=color_map,
        naming_map=naming_map,
        min_threshold_fluency=min_threshold_fluency,
        max_threshold_fluency=max_threshold_fluency,
    )
    if min_threshold_fluency is not None and max_threshold_fluency is not None:
        scatter_thresholds_filename = (
            output_dir
            / f"scatter_plot_with_thresholds_{min_threshold_fluency}_{max_threshold_fluency}.png"
        )
    else:
        scatter_thresholds_filename = output_dir / "scatter_plot_with_thresholds.png"

    scatter_plot_with_thresholds.savefig(
        scatter_thresholds_filename, dpi=300, bbox_inches="tight"
    )
    print(f"Saved scatter plot with thresholds to {scatter_thresholds_filename}")
    plt.close(scatter_plot_with_thresholds)


if __name__ == "__main__":
    ############# ARGS ##############
    json_path = "/workspace/eliciting-contexts/src/eliciting_contexts/benchmark/results/results_final_latest_last.json"

    result_sets_config = [
        ("human", "Human", "#1f77b4"),  # blue
        ("GPT4o", "GPT-4o", "#2ca02c"),  # green
        ("EPO", "EPO", "#FFA500"),  # orange
        ("EPOAssistHobbled-Final", "EPO-Assist", "#FF0000"),  # red
        ("GCG", "GCG", "#9467bd"),  # purple
    ]
    entries_to_remove = list(range(38, 77))
    entries_to_remove = [f"{i}" for i in entries_to_remove]

    min_threshold_fluency = 3
    max_threshold_fluency = 9

    # Extract naming_map for flatten_results
    naming_map = {
        json_name: display_name for json_name, display_name, _ in result_sets_config
    }

    # Extract color_map for charts
    color_map = {display_name: color for _, display_name, color in result_sets_config}

    ############# ARGS END ##############

    with open(json_path, "r") as f:
        all_results = json.load(f)

    if entries_to_remove:
        print(f"Removing entries: {entries_to_remove}")
        all_results_filtered = remove_entries(all_results, entries_to_remove)

    # Pass the naming map to flatten_results to filter and rename in one step
    df_flattened = flatten_results(all_results_filtered, naming_map)

    # Save scatter plots
    print("Generating and saving scatter plots...")
    save_scatter_plots(
        df=df_flattened,
        color_map=color_map,
        naming_map=naming_map,
        benchmark_task_name="tiny_stories",
        min_threshold_fluency=min_threshold_fluency,
        max_threshold_fluency=max_threshold_fluency,
    )
