import matplotlib.pyplot as plt
import pandas as pd
from typing import List, Dict, Tuple
import numpy as np
from .utils import (
    get_color_palette,
    add_sample_counts,
    format_accuracy,
    get_model_family_order,
    get_size_pattern_order,
    get_display_name,
    apply_display_names_to_list,
)


def create_model_performance_chart(
    data: pd.DataFrame,
    title: str,
    ylabel: str = "Accuracy",
    show_sample_counts: bool = True,
    figsize: Tuple[int, int] = (12, 8),
    color_palette: str = "models",
    no_titles: bool = False,
) -> plt.Figure:
    """
    Create a standard model performance bar chart with proper model ordering and display names.

    Parameters:
    - data: DataFrame with 'model' and 'correct' columns
    - title: Chart title
    - ylabel: Y-axis label
    - show_sample_counts: Whether to show sample sizes
    - figsize: Figure size
    - color_palette: Color palette type
    - no_titles: Whether to suppress titles

    Returns:
    - matplotlib Figure
    """
    # Calculate performance by model
    performance = data.groupby("model")["correct"].agg(["mean", "count"]).reset_index()
    performance.columns = ["model", "accuracy", "sample_count"]

    # Sort by family order instead of accuracy
    if color_palette == "models":
        model_order = get_model_family_order(performance["model"].tolist())
        performance = performance.set_index("model").reindex(model_order).reset_index()
    else:
        performance = performance.sort_values("accuracy", ascending=False)

    # Apply display names
    performance["display_name"] = performance["model"].apply(get_display_name)

    # Create figure
    fig, ax = plt.subplots(figsize=figsize)

    # Get colors (use internal names for color mapping)
    colors = get_color_palette(performance["model"].tolist(), color_palette)
    bar_colors = [colors[model] for model in performance["model"]]

    # Create bars (use display names for x-axis)
    bars = ax.bar(
        performance["display_name"],
        performance["accuracy"],
        color=bar_colors,
        alpha=0.8,
    )

    # Customize chart
    if not no_titles:
        ax.set_title(title, fontsize=14, pad=20)
    ax.set_ylabel(ylabel)
    ax.set_ylim(0, 1.0)
    ax.grid(True, axis="y", alpha=0.3)

    # Add sample counts if requested
    if show_sample_counts:
        sample_counts = dict(
            zip(performance["display_name"], performance["sample_count"])
        )
        add_sample_counts(ax, sample_counts)

    # REMOVED: Model name rotation - keep horizontal for readability
    # Model names should remain horizontal when models are on x-axis

    # Add accuracy values on bars
    for bar, acc in zip(bars, performance["accuracy"]):
        height = bar.get_height()
        ax.text(
            bar.get_x() + bar.get_width() / 2.0,
            height + 0.01,
            format_accuracy(acc),
            ha="center",
            va="bottom",
            fontsize=9,
        )

    return fig


def create_comparison_chart(
    data_dict: Dict[str, pd.DataFrame],
    title: str,
    comparison_labels: List[str],
    ylabel: str = "Accuracy",
    figsize: Tuple[int, int] = (14, 8),
    no_titles: bool = False,
) -> plt.Figure:
    """
    Create a side-by-side comparison chart (e.g., input vs output) with proper model ordering and display names.

    Parameters:
    - data_dict: Dict with keys being comparison categories, values being DataFrames
    - title: Chart title
    - comparison_labels: Labels for the comparison categories
    - ylabel: Y-axis label
    - figsize: Figure size
    - no_titles: Whether to suppress titles

    Returns:
    - matplotlib Figure
    """
    # Calculate performance for each category
    all_models = set()
    performance_data = {}

    for category, df in data_dict.items():
        if df.empty:
            continue
        perf = df.groupby("model")["correct"].agg(["mean", "count"]).reset_index()
        perf.columns = ["model", "accuracy", "sample_count"]
        performance_data[category] = perf
        all_models.update(perf["model"].tolist())

    # Use family ordering for models
    all_models = get_model_family_order(list(all_models))

    # Apply display names for x-axis labels
    display_names = apply_display_names_to_list(all_models)

    # Create figure
    fig, ax = plt.subplots(figsize=figsize)

    # Set up bar positions
    x = np.arange(len(display_names))
    width = 0.35 if len(comparison_labels) == 2 else 0.8 / len(comparison_labels)

    # Get colors for categories
    colors = get_color_palette(comparison_labels, "targets")

    # IMPORTANT: Ensure input comes first in the display order
    ordered_categories = []
    ordered_labels = []

    # Put "input" first if it exists
    if "input" in data_dict:
        ordered_categories.append("input")
        ordered_labels.append("input")

    # Add other categories
    for category in data_dict.keys():
        if category != "input":
            ordered_categories.append(category)
            ordered_labels.append(category)

    # Create bars for each category in the correct order
    for i, (category, label) in enumerate(zip(ordered_categories, ordered_labels)):
        if category not in performance_data:
            continue

        perf = performance_data[category]

        # Align with all_models
        accuracies = []
        sample_counts = []
        for model in all_models:
            model_data = perf[perf["model"] == model]
            if not model_data.empty:
                accuracies.append(model_data["accuracy"].iloc[0])
                sample_counts.append(model_data["sample_count"].iloc[0])
            else:
                accuracies.append(0)
                sample_counts.append(0)

        offset = (i - (len(ordered_categories) - 1) / 2) * width
        bars = ax.bar(
            x + offset,
            accuracies,
            width,
            label=label.title(),
            color=colors[label],
            alpha=0.8,
        )

        # Add accuracy values on bars
        for bar, acc, count in zip(bars, accuracies, sample_counts):
            if acc > 0:  # Only show for non-zero values
                height = bar.get_height()
                ax.text(
                    bar.get_x() + bar.get_width() / 2.0,
                    height + 0.01,
                    f"{acc:.3f}\n(n={count})",
                    ha="center",
                    va="bottom",
                    fontsize=8,
                )

    # Customize chart
    if not no_titles:
        ax.set_title(title, fontsize=14, pad=20)
    ax.set_ylabel(ylabel)
    ax.set_xlabel("Model")
    ax.set_ylim(0, 1.0)
    ax.grid(True, axis="y", alpha=0.3)
    ax.set_xticks(x)
    ax.set_xticklabels(display_names)  # REMOVED rotation - keep horizontal
    ax.legend()

    return fig


def create_breakdown_chart(
    data: pd.DataFrame,
    groupby_col: str,
    title: str,
    ylabel: str = "Accuracy",
    color_by: str = "models",
    figsize: Tuple[int, int] = (12, 8),
    no_titles: bool = False,
) -> plt.Figure:
    """
    Create a breakdown chart grouped by a specific column with proper ordering and display names.
    UPDATED to use SIZE_PATTERNS ordering for size_pattern groupings.
    """
    # Calculate performance
    performance = (
        data.groupby([groupby_col, "model"])["correct"]
        .agg(["mean", "count"])
        .reset_index()
    )
    performance.columns = [groupby_col, "model", "accuracy", "sample_count"]

    # Create figure
    fig, ax = plt.subplots(figsize=figsize)

    # Get unique values with proper ordering
    if groupby_col == "size_pattern":
        # Use SIZE_PATTERNS definition order for size patterns
        groups = get_size_pattern_order(performance[groupby_col].unique())
    else:
        # Use alphabetical order for other groupings
        groups = sorted(performance[groupby_col].unique())

    models = performance["model"].unique()

    # Use family ordering for models and apply display names
    if color_by == "models":
        models = get_model_family_order(list(models))
    else:
        models = sorted(models)

    display_names = apply_display_names_to_list(models)

    # Set up bar positions
    x = np.arange(len(groups))
    width = 0.8 / len(models)

    # Get colors (use internal names for color mapping)
    colors = get_color_palette(list(models), color_by)

    # Create bars for each model
    for i, (model, display_name) in enumerate(zip(models, display_names)):
        model_data = performance[performance["model"] == model]

        # Align with groups (using proper ordering)
        accuracies = []
        sample_counts = []
        for group in groups:  # Now uses proper ordering
            group_data = model_data[model_data[groupby_col] == group]
            if not group_data.empty:
                accuracies.append(group_data["accuracy"].iloc[0])
                sample_counts.append(group_data["sample_count"].iloc[0])
            else:
                accuracies.append(0)
                sample_counts.append(0)

        offset = (i - (len(models) - 1) / 2) * width
        bars = ax.bar(
            x + offset,
            accuracies,
            width,
            label=display_name,
            color=colors[model],
            alpha=0.8,
        )

        # Add accuracy values on bars for non-zero values
        for bar, acc, _ in zip(bars, accuracies, sample_counts):
            if acc > 0:
                height = bar.get_height()
                ax.text(
                    bar.get_x() + bar.get_width() / 2.0,
                    height + 0.01,
                    f"{acc:.3f}",
                    ha="center",
                    va="bottom",
                    fontsize=8,
                )

    # Customize chart
    if not no_titles:
        ax.set_title(title, fontsize=14, pad=20)
    ax.set_ylabel(ylabel)
    ax.set_xlabel(groupby_col.replace("_", " ").title())
    ax.set_ylim(0, 1.0)
    ax.grid(True, axis="y", alpha=0.3)
    ax.set_xticks(x)

    # Only rotate labels if they are very long or numerous
    if groupby_col in ["benchmark", "task"] and len(groups) > 8:
        ax.set_xticklabels(groups, rotation=0, ha="right")
    else:
        ax.set_xticklabels(groups)  # Keep horizontal for most cases

    ax.legend(bbox_to_anchor=(1.05, 1), loc="upper left")

    return fig
