"""
Visualization utilities for experiment results.

Core functions:
- get_subset: Filter and prepare data for visualization
- grouped_bar_chart: Create grouped bar charts
- grouped_line_chart: Create line charts
"""

from __future__ import annotations

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import scipy.stats as st


# ============================================================
# Display Name Mappings
# ============================================================

STAGE_NAMES = {
    "baseline": "Baseline",
    "filtering": "Filtering",
    "coreftaux": "Finedtuned Core",
    "routed": "Routed",
    "rmu": "RMU",
    "maxent": "MaxEnt",
    "gradient_ascent": "Grad Ascent",
}

FIELD_NAMES = {
    "name": "Stage",
    "loss": "Loss",
    "compute_ratio": "Compute Ratio",
    "step_equiv": "Step Equivalent",
    "ppl": "Perplexity",
    "ppl_ratio": "Perplexity Ratio",
    "loss_ratio": "Loss Ratio",
    "model_size": "Model Size",
}

LABEL_CLASS_COLORS = {
    "Core": "#1f77b4",
    "Retain": "#2ca02c",
    "Forget": "#d62728",
    "Elicited Forget": "#ff7f0e",
}


# ============================================================
# Data Helpers
# ============================================================

def classify_label(row: pd.Series) -> str:
    """Classify a row as Core, Forget, Elicited Forget, or Retain."""
    label = row["data_label"]
    target = row["target"]
    
    if label == "core":
        return "Core"
    elif label in target:
        return "Elicited Forget" if row.get("elicited", False) else "Forget"
    else:
        return "Retain"

def get_subset(
    df: pd.DataFrame,
    filters: dict | None = None,
    target: tuple | list | None = None,
    stage_col: str = "name",
    add_label_class: bool = False,
    drop_na_loss: bool = True,
) -> pd.DataFrame:
    """Filter and prepare a subset of the DataFrame for visualization."""
    result = df.copy()
    
    if target is not None:
        target_set = set(target)
        mask = result["target"].apply(lambda x: set(x) == target_set)
        mask |= result[stage_col].isin(["baseline"])
        result = result[mask]
    
    if filters:
        for col, value in filters.items():
            if col not in result.columns:
                continue
            is_null = result[col].isna()
            mask = is_null | (result[col] == value)
            result = result[mask]
    
    if drop_na_loss:
        result = result[result["loss"].notna()]
    
    if add_label_class:
        result["label_class"] = result.apply(classify_label, axis=1)
    
    return result.sort_values(by=[stage_col, "data_label"]).reset_index(drop=True)

def aggregate_by_seed(
    df: pd.DataFrame,
    group_cols: list[str],
    y_col: str,
    seed_col: str = "seed",
    ci_level: float = 0.9,
) -> pd.DataFrame:
    """Two-stage aggregation: average within seed first, then across seeds.
    
    Recommended for aggregate metrics where seed is the unit of replication.
    """
    # Stage 1: mean within each (group_cols..., seed)
    seed_group_cols = group_cols + [seed_col]
    seed_means = (
        df.groupby(seed_group_cols)[y_col]
        .mean()
        .reset_index()
        .rename(columns={y_col: "seed_mean"})
    )
    
    # Stage 2: t-interval across seeds
    agg = (
        seed_means.groupby(group_cols)["seed_mean"]
        .agg(["mean", "std", "count"])
        .reset_index()
    )
    agg["std"] = agg["std"].fillna(0.0)
    agg["sem"] = agg["std"] / np.sqrt(agg["count"].clip(lower=1))
    
    alpha = 1.0 - float(ci_level)
    t_crit = np.where(agg["count"] > 1, st.t.ppf(1 - alpha/2, agg["count"] - 1), 0.0)
    agg["ci"] = t_crit * agg["sem"]
    
    return agg

# ============================================================
# Visualization Functions
# ============================================================

def grouped_bar_chart(
    df: pd.DataFrame,
    x_col: str,
    y_col: str,
    group_col: str | None = None,
    seed_col: str = "seed",
    ci_level: float = 0.9,
    x_order: list[str] | None = None,
    group_order: list[str] | None = None,
    x_labels: dict[str, str] | None = None,
    title: str | None = None,
    x_axis_label: str | None = None,
    y_axis_label: str | None = None,
    figsize: tuple[float, float] | None = None,
    fontsize: int = 10,
    y_min: float | None = None,
    y_max: float | None = None,
    error_bars: bool = True,
    show_values: bool = True,
    fade_map: dict[str, list[str]] | None = None,
    colors: dict[str, str] | None = None,
    min_bar_height: float | None = None,
) -> tuple[plt.Axes, pd.DataFrame]:
    """Create a grouped bar chart.
    
    Args:
        df: DataFrame with data to plot.
        x_col: Column for x-axis categories.
        y_col: Column for y-axis values.
        group_col: Column for bars within each x-category. If None, one bar per x.
        seed_col: Column containing seed identifiers.
        ci_level: Confidence interval level (0-1).
        x_order: Order of x-axis categories.
        group_order: Order of groups within each x-category.
        x_labels: Dict mapping x values to display labels.
        title: Plot title.
        x_axis_label, y_axis_label: Axis labels.
        figsize: Figure size.
        fontsize: Base font size.
        y_min, y_max: Y-axis limits.
        error_bars: If True, show confidence interval error bars.
        show_values: If True, show numeric values above bars.
        fade_map: Dict mapping x values to list of groups to fade.
        colors: Dict mapping group values to colors.
        min_bar_height: Minimum display height for bars.

    Returns:
        matplotlib Axes.
    """
    # Filter data
    data = df.copy()
    if x_order is not None:
        data = data[data[x_col].isin(x_order)]
    if group_col is not None and group_order is not None:
        data = data[data[group_col].isin(group_order)]
    
    # Dispatch to aggregation method
    group_cols = [x_col] if group_col is None else [x_col, group_col]
    agg = aggregate_by_seed(data, group_cols, y_col, seed_col, ci_level)
    
    # Determine order
    if x_order is not None:
        x_vals = [v for v in x_order if v in agg[x_col].values]
    else:
        x_vals = list(agg[x_col].unique())
    
    if group_col is not None:
        if group_order is not None:
            groups = [g for g in group_order if g in agg[group_col].values]
        else:
            groups = list(agg[group_col].unique())
    else:
        groups = [None]
    
    # Pivot for plotting
    if group_col is not None:
        pivot_mean = agg.pivot(index=x_col, columns=group_col, values="mean")
        pivot_ci = agg.pivot(index=x_col, columns=group_col, values="ci").fillna(0.0)
        pivot_count = agg.pivot(index=x_col, columns=group_col, values="count")
        pivot_mean = pivot_mean.reindex(index=x_vals, columns=groups)
        pivot_ci = pivot_ci.reindex(index=x_vals, columns=groups).fillna(0.0)
        pivot_count = pivot_count.reindex(index=x_vals, columns=groups)
    else:
        pivot_mean = agg.set_index(x_col)["mean"].reindex(x_vals)
        pivot_ci = agg.set_index(x_col)["ci"].reindex(x_vals).fillna(0.0)
        pivot_count = agg.set_index(x_col)["count"].reindex(x_vals)
    
    # Create figure
    if figsize is None:
        figsize = (max(6, len(x_vals) * 1.2), 5)
    fig, ax = plt.subplots(figsize=figsize)
    
    n_x = len(x_vals)
    n_groups = len(groups) if group_col else 1
    bar_spacing = 0.02
    total_width = 0.85 - (n_groups - 1) * bar_spacing
    bar_width = total_width / n_groups
    x = np.arange(n_x)
    
    color_map = {**(colors or {}), **LABEL_CLASS_COLORS}
    
    for i, g in enumerate(groups):
        offset = -total_width / 2 + bar_width / 2 + i * (bar_width + bar_spacing)
        
        if group_col is not None:
            means = pivot_mean[g].values
            errors = pivot_ci[g].values if error_bars else None
            counts = pivot_count[g].values if error_bars else None
        else:
            means = pivot_mean.values
            errors = pivot_ci.values if error_bars else None
            counts = pivot_count.values if error_bars else None
        
        if error_bars and errors is not None and counts is not None:
            errors = np.where(counts > 1, errors, np.nan)
        
        display_means = means.copy()
        if min_bar_height is not None:
            display_means = np.where(
                ~np.isnan(means), np.maximum(means, min_bar_height), means
            )
        
        bar_kwargs = {"width": bar_width}
        if g is not None:
            bar_kwargs["label"] = g
            if g in color_map:
                bar_kwargs["color"] = color_map[g]
        if error_bars:
            bar_kwargs["yerr"] = errors
            bar_kwargs["capsize"] = 3
        
        bars = ax.bar(x + offset, display_means, **bar_kwargs)
        
        if fade_map is not None:
            for bar, x_val in zip(bars, x_vals):
                if g in fade_map.get(x_val, []):
                    bar.set_alpha(0.3)
        
        if show_values:
            for j, (bar, mean_val) in enumerate(zip(bars, means)):
                if not np.isnan(mean_val):
                    err = errors[j] if error_bars and not np.isnan(errors[j]) else 0
                    bar_top = display_means[j] if not np.isnan(display_means[j]) else mean_val
                    ax.text(
                        bar.get_x() + bar.get_width() / 2,
                        bar_top + err + 0.02 * (ax.get_ylim()[1] - ax.get_ylim()[0] or 1),
                        f"{mean_val:.2f}",
                        ha="center", va="bottom", fontsize=fontsize * 0.8,
                    )
    
    ax.set_xticks(x)
    label_map = {**STAGE_NAMES, **(x_labels or {})}
    ax.set_xticklabels([label_map.get(v, v) for v in x_vals], fontsize=fontsize)
    
    ax.set_xlabel(x_axis_label or FIELD_NAMES.get(x_col, x_col), fontsize=fontsize, fontweight="bold")
    ax.set_ylabel(y_axis_label or FIELD_NAMES.get(y_col, y_col), fontsize=fontsize, fontweight="bold")
    ax.tick_params(axis="y", labelsize=fontsize)
    
    if title:
        ax.set_title(title, fontsize=fontsize)
    
    if y_min is not None or y_max is not None:
        ylim = ax.get_ylim()
        ax.set_ylim(y_min if y_min is not None else ylim[0], 
                    y_max if y_max is not None else ylim[1])
    
    if group_col is not None:
        legend = ax.legend(fontsize=fontsize, bbox_to_anchor=(1.01, 1), loc="upper left")
        for handle in legend.legend_handles:
            handle.set_alpha(1.0)
    
    plt.tight_layout()
    
    return ax, agg


def grouped_bar_chart_by_retain(
    df: pd.DataFrame,
    stage_name: str,
    y_col: str,
    aux_labels: list[str],
    retain_values: list[str | None],
    seed_col: str = "seed",
    ci_level: float = 0.9,
    category_order: list[str] | None = None,
    retain_labels: dict[str | None, str] | None = None,
    title: str | None = None,
    x_axis_label: str | None = None,
    y_axis_label: str | None = None,
    figsize: tuple[float, float] | None = None,
    fontsize: int = 10,
    y_min: float | None = None,
    y_max: float | None = None,
    error_bars: bool = True,
    show_values: bool = True,
    colors: dict[str, str] | None = None,
    min_bar_height: float | None = None,
    stage_col: str = "name",
) -> tuple[plt.Axes, pd.DataFrame]:
    """Create a grouped bar chart for one model across multiple target configurations.
    
    This function plots a single model/stage with different target label configurations.
    Each x-axis position represents a target configuration (retain setting), and bars 
    represent different data categories showing performance per category.
    
    Categories that are in the "forget" set (target labels) for each configuration will
    be automatically faded (alpha=0.3) to visually indicate they are being forgotten.
    
    Args:
        df: DataFrame with data to plot.
        stage_name: The specific stage/model to plot (e.g., "routed_moe", "filtering").
        y_col: Column for y-axis values (e.g., "compute_ratio").
        aux_labels: List of all auxiliary labels in the dataset.
        retain_values: List of retain values to compare (can include None).
        seed_col: Column containing seed identifiers.
        ci_level: Confidence interval level (0-1).
        category_order: Order of categories in legend.
        retain_labels: Dict mapping retain values to display labels for x-axis.
        title: Plot title.
        x_axis_label, y_axis_label: Axis labels.
        figsize: Figure size.
        fontsize: Base font size.
        y_min, y_max: Y-axis limits.
        error_bars: If True, show confidence interval error bars.
        show_values: If True, show numeric values above bars.
        colors: Dict mapping category names to colors.
        min_bar_height: Minimum display height for bars.
        stage_col: Column name for stage/model names (default: "name").

    Returns:
        Tuple of (matplotlib Axes, aggregated DataFrame).
        
    Example:
        ```python
        # Compare categories across different retain configurations for routed_moe
        grouped_bar_chart_by_retain(
            df,
            stage_name="routed_moe",
            y_col="compute_ratio",
            aux_labels=["biology", "chemistry", "physics"],
            retain_values=["biology", "chemistry", None],
            retain_labels={"biology": "Retain Biology", 
                          "chemistry": "Retain Chemistry",
                          None: "No Retain"},
        )
        ```
    """
    # Prepare data for each retain configuration
    all_subsets = []
    fade_map = {}  # Track which categories to fade for each retain config
    
    for retain_val in retain_values:
        # Get target labels (all aux labels except the retain one)
        if retain_val is None:
            print(f"retain_val: {retain_val}, aux_labels: {aux_labels}")
            target_labels = sorted(aux_labels)
        else:
            target_labels = sorted([x for x in aux_labels if x != retain_val])
            print(f"retain_val: {retain_val}, target_labels: {target_labels}")
        
        # Store which categories should be faded (the ones being forgotten)
        fade_map[retain_val] = target_labels
        
        # Filter data for this retain configuration
        subset = get_subset(
            df, 
            target=target_labels, 
            # filters={"elicited": False},
            stage_col=stage_col,
        ).copy()
        # print(f"subset", subset)
        
        # Filter to only the specified stage
        subset = subset[subset[stage_col] == stage_name]
        
        # Add a column to identify this retain configuration
        subset["retain_config"] = retain_val
        
        all_subsets.append(subset)
    
    # Combine all subsets
    combined_df = pd.concat(all_subsets, ignore_index=True)
    
    if combined_df.empty:
        raise ValueError(f"No data found for stage '{stage_name}' with the specified retain configurations.")
    
    # Aggregate by seed - NOW SWAPPED: group by retain_config and data_label
    group_cols = ["retain_config", "data_label"]
    agg = aggregate_by_seed(combined_df, group_cols, y_col, seed_col, ci_level)
    
    # X-axis: retain configurations
    x_vals = retain_values
    
    # Groups (bars): categories
    if category_order is not None:
        groups = [c for c in category_order if c in agg["data_label"].values]
    else:
        # Default: core first, then alphabetical
        all_labels = list(agg["data_label"].unique())
        if "core" in all_labels:
            groups = ["core"] + sorted([l for l in all_labels if l != "core"])
        else:
            groups = sorted(all_labels)
    
    # Pivot for plotting - SWAPPED indices and columns
    pivot_mean = agg.pivot(index="retain_config", columns="data_label", values="mean")
    pivot_ci = agg.pivot(index="retain_config", columns="data_label", values="ci").fillna(0.0)
    pivot_count = agg.pivot(index="retain_config", columns="data_label", values="count")
    pivot_mean = pivot_mean.reindex(index=x_vals, columns=groups)
    pivot_ci = pivot_ci.reindex(index=x_vals, columns=groups).fillna(0.0)
    pivot_count = pivot_count.reindex(index=x_vals, columns=groups)
    
    # Create figure
    if figsize is None:
        figsize = (max(6, len(x_vals) * 2.0), 5)
    fig, ax = plt.subplots(figsize=figsize)
    
    n_x = len(x_vals)
    n_groups = len(groups)
    bar_spacing = 0.02
    total_width = 0.85 - (n_groups - 1) * bar_spacing
    bar_width = total_width / n_groups
    x = np.arange(n_x)
    
    # Default retain labels for x-axis
    default_retain_labels = {
        None: "Forget All",
        **{label: f"Retain {label.title()}" for label in aux_labels}
    }
    x_label_map = {**default_retain_labels, **(retain_labels or {})}
    
    # Use LABEL_CLASS_COLORS for categories if not overridden
    # color_map = {**LABEL_CLASS_COLORS, **(colors or {})}
    
    # Plot each category as a group
    for i, category in enumerate(groups):
        offset = -total_width / 2 + bar_width / 2 + i * (bar_width + bar_spacing)
        
        means = pivot_mean[category].values
        errors = pivot_ci[category].values if error_bars else None
        counts = pivot_count[category].values if error_bars else None
        
        if error_bars and errors is not None and counts is not None:
            errors = np.where(counts > 1, errors, np.nan)
        
        display_means = means.copy()
        if min_bar_height is not None:
            display_means = np.where(
                ~np.isnan(means), np.maximum(means, min_bar_height), means
            )
        
        bar_kwargs = {
            "width": bar_width,
            "label": category.replace("-", " ").title(),
        }
        
        # # Use color map for categories
        # if category == "core":
        #     bar_kwargs["color"] = color_map.get("Core", "#1f77b4")
        # elif category in color_map:
        #     bar_kwargs["color"] = color_map[category]
        
        if error_bars and errors is not None:
            bar_kwargs["yerr"] = errors
            bar_kwargs["capsize"] = 3
        
        bars = ax.bar(x + offset, display_means, **bar_kwargs)
        
        # Apply fading to categories that are being forgotten
        for bar, retain_val in zip(bars, x_vals):
            if category in fade_map.get(retain_val, []):
                bar.set_alpha(0.3)
        
        # Show values above bars
        if show_values:
            for j, (bar, mean_val) in enumerate(zip(bars, means)):
                if not np.isnan(mean_val):
                    err = errors[j] if error_bars and errors is not None and not np.isnan(errors[j]) else 0
                    bar_top = display_means[j] if not np.isnan(display_means[j]) else mean_val
                    ax.text(
                        bar.get_x() + bar.get_width() / 2,
                        bar_top + err + 0.02 * (ax.get_ylim()[1] - ax.get_ylim()[0] or 1),
                        f"{mean_val:.2f}",
                        ha="center", va="bottom", fontsize=fontsize * 0.8,
                    )
    
    # Set x-axis labels - retain configurations
    ax.set_xticks(x)
    ax.set_xticklabels([x_label_map.get(v, str(v)) for v in x_vals], fontsize=fontsize)
    
    # Set axis labels
    ax.set_xlabel(x_axis_label or "Target Configuration", fontsize=fontsize, fontweight="bold")
    ax.set_ylabel(y_axis_label or FIELD_NAMES.get(y_col, y_col), fontsize=fontsize, fontweight="bold")
    ax.tick_params(axis="y", labelsize=fontsize)
    
    # Set title
    if title:
        ax.set_title(title, fontsize=fontsize)
    else:
        stage_display = STAGE_NAMES.get(stage_name, stage_name)
        ax.set_title(f"{stage_display} - {FIELD_NAMES.get(y_col, y_col)} by Target Config", 
                    fontsize=fontsize)
    
    # Set y-axis limits
    if y_min is not None or y_max is not None:
        ylim = ax.get_ylim()
        ax.set_ylim(y_min if y_min is not None else ylim[0], 
                    y_max if y_max is not None else ylim[1])
    
    # Add legend
    legend = ax.legend(fontsize=fontsize, bbox_to_anchor=(1.01, 1), loc="upper left")
    for handle in legend.legend_handles:
        handle.set_alpha(1.0)
    
    plt.tight_layout()
    
    return ax, agg


def grouped_line_chart(
    df: pd.DataFrame,
    x_col: str,
    y_col: str,
    group_col: str | None = None,
    seed_col: str = "seed",
    ci_level: float = 0.9,
    group_order: list[str] | None = None,
    group_labels: dict[str, str] | None = None,
    title: str | None = None,
    x_axis_label: str | None = None,
    y_axis_label: str | None = None,
    figsize: tuple[float, float] = (6.0, 4.0),
    fontsize: int = 10,
    x_log: bool = True,
    y_log: bool = False,
    y_min: float | None = None,
    y_max: float | None = None,
    error_bars: bool = True,
    styles: dict[str, dict] | None = None,
    sci_notation_x: bool = True,
) -> tuple[plt.Axes, pd.DataFrame]:
    """Create a line chart showing how a metric scales with x.

    Args:
        df: DataFrame with data to plot.
        x_col: Column for x-axis values (e.g., "model_size").
        y_col: Column for y-axis values (e.g., "loss", "compute_ratio").
        group_col: Column to group lines by (e.g., "label_class"). If None, single line.
        seed_col: Column containing seed identifiers.
        ci_level: Confidence interval level (0-1).
        group_order: Order of groups in legend.
        group_labels: Dict mapping group values to display labels in legend.
        title: Plot title.
        x_axis_label, y_axis_label: Axis labels.
        figsize: Figure size.
        fontsize: Base font size.
        x_log: If True, use log scale for x-axis.
        y_log: If True, use log scale for y-axis.
        y_min, y_max: Y-axis limits.
        error_bars: If True, show confidence interval error bars.
        styles: Dict mapping group values to style dicts (e.g., {"color": "#fff", "marker": "o"}).
        sci_notation_x: If True, format x-axis ticks as scientific notation.

    Returns:
        matplotlib Axes object.
    """
    data = df.copy()
    
    # Filter to requested groups
    if group_col is not None and group_order is not None:
        data = data[data[group_col].isin(group_order)]
    
    # Dispatch to aggregation method
    group_cols = [x_col] if group_col is None else [x_col, group_col]
    agg = aggregate_by_seed(data, group_cols, y_col, seed_col, ci_level)
    agg = agg.sort_values(x_col)
    
    # Determine group order
    if group_col is not None:
        if group_order is not None:
            groups = [g for g in group_order if g in agg[group_col].values]
        else:
            legend_order = ["Core", "Retain", "Forget", "Elicited Forget"]
            available = set(agg[group_col].unique())
            groups = [g for g in legend_order if g in available]
            groups += [g for g in available if g not in legend_order]
    else:
        groups = [None]
    
    # Build default styles from LABEL_CLASS_COLORS
    default_styles = {k: {"color": v} for k, v in LABEL_CLASS_COLORS.items()}
    style_map = {**default_styles, **(styles or {})}
    
    fig, ax = plt.subplots(figsize=figsize, dpi=150)
    
    for g in groups:
        if group_col is not None:
            sub = agg[agg[group_col] == g]
        else:
            sub = agg
        
        xs = sub[x_col].values
        ys = sub["mean"].values
        counts = sub["count"].values
        yerr = np.where(counts > 1, sub["ci"].values, np.nan)
        
        base_kwargs = {"marker": "o", "markersize": 4, "linestyle": "-"}
        if g is not None:
            label_map = group_labels or {}
            base_kwargs["label"] = label_map.get(g, g)
            if g in style_map:
                base_kwargs.update(style_map[g])
        
        if error_bars:
            error_kwargs = {"capsize": 4, "elinewidth": 1.2}
            error_kwargs.update(base_kwargs)
            ax.errorbar(xs, ys, yerr=yerr, **error_kwargs)
        else:
            ax.plot(xs, ys, **base_kwargs)

        print(f"group_val: {g}, means: {ys}, errors: {yerr}")

    
    if x_log:
        ax.set_xscale("log")
    if y_log:
        ax.set_yscale("log")
    
    ax.grid(True, which="major", linestyle="--", alpha=0.4)
    
    ax.set_xlabel(x_axis_label or x_col, fontsize=fontsize, fontweight="bold")
    ax.set_ylabel(y_axis_label or FIELD_NAMES.get(y_col, y_col), fontsize=fontsize, fontweight="bold")
    
    if title:
        ax.set_title(title, fontsize=fontsize)
    
    ax.tick_params(axis="both", labelsize=fontsize * 0.9)
    
    # Format x-ticks
    x_vals = np.sort(data[x_col].unique())
    ax.set_xticks(x_vals)
    if sci_notation_x:
        ax.set_xticklabels([f"{v:.1e}".replace("e+0", "E").replace("e+", "E") for v in x_vals])
    
    if y_min is not None or y_max is not None:
        ylim = ax.get_ylim()
        ax.set_ylim(y_min if y_min is not None else ylim[0],
                    y_max if y_max is not None else ylim[1])
    
    if group_col is not None:
        ax.legend(fontsize=fontsize * 0.9, bbox_to_anchor=(1.01, 1), loc="upper left")
    
    plt.tight_layout()
    
    return ax, agg