import matplotlib.pyplot as plt
import torch
import numpy as np
from typing import List, Optional
import matplotlib.gridspec as gridspec
from matplotlib.lines import Line2D


def _save_or_show_plot(parent_plotter, fig, filepath, show):
    """Helper to save and/or show a plot."""
    if filepath is None:
        if show:
            plt.show()
    else:
        if show:
            plt.show()
        try:
            full_save_path = parent_plotter.output_dir / filepath
            print(f"Saving plot to {full_save_path}")
            fig.savefig(full_save_path, dpi=300, bbox_inches="tight")
        except Exception as e:
            print(f"Error saving plot: {e}")
    plt.close(fig)


def _plot_single_predicate(
    ax, pred, X_orig, feat_idx, weight_threshold, predicate_bar_height
):
    """
    Helper method to plot a single predicate bar.

    Args:
        ...
        predicate_bar_height (float): The thickness of the main predicate bar.
    """
    ax.set_frame_on(False)
    ax.grid(False)
    ax.set_yticks([])
    ax.set_ylim(-0.5, 0.5)

    if not pred:
        ax.set_visible(False)
        return

    is_active = pred["weight"] > weight_threshold
    color = "cornflowerblue" if is_active else "#bdc3c7"
    data_min, data_max = pred["data_min"], pred["data_max"]

    if not pred["is_discrete"]:
        padding = (data_max - data_min) * 0.01 if data_max > data_min else 0.1
        ax.set_xlim(data_min - padding, data_max + padding)
        rule_min, rule_max = pred["min"], pred["max"]
        ax.barh(
            0,
            data_max - data_min,
            left=data_min,
            height=0.1,
            color="#dddddd",
            zorder=1,
        )
        rule_width = rule_max - rule_min
        bar_width = max(
            rule_width,
            (data_max - data_min) * 0.03 if data_max > data_min else 0.03,
        )
        if data_max - rule_min < bar_width:
            rule_min = data_max - bar_width
        ax.barh(
            0,
            bar_width,
            left=rule_min,
            height=predicate_bar_height,
            color=color,
            edgecolor="black",
            linewidth=0.5,
            zorder=2,
        )
    else:
        unique_vals = np.unique(X_orig[:, feat_idx].cpu().numpy())
        n_cats = len(unique_vals)
        if n_cats <= 1:
            ax.set_xticks([])
            ax.text(
                0.5,
                0,
                f"Value: {unique_vals[0]}" if n_cats == 1 else "No variation",
                ha="center",
                va="center",
            )
        else:
            rule_values = pred["values"]
            proportions = {
                val: np.mean(rule_values == val) if len(rule_values) > 0 else 0
                for val in unique_vals
            }
            max_prop = max(proportions.values()) if proportions else 0
            scale = 1.0 / max_prop if max_prop > 0 else 0
            for cat_idx, val in enumerate(unique_vals):
                left = cat_idx / n_cats
                width = 1 / n_cats
                ax.barh(
                    0,
                    width,
                    left=left,
                    height=predicate_bar_height,
                    color="#ffffff",
                    edgecolor="none",
                    zorder=1,
                )
                if proportions[val] > 0:
                    ax.barh(
                        0,
                        width,
                        left=left,
                        height=predicate_bar_height,
                        color=color,
                        alpha=proportions[val] * scale,
                        zorder=2,
                        edgecolor="none",
                    )
                ax.barh(
                    0,
                    width,
                    left=left,
                    height=predicate_bar_height,
                    facecolor="none",
                    edgecolor="black",
                    linewidth=0.5,
                    zorder=5,
                )
            ax.set_xlim(-0.01, 1.01)
            ax.set_xticks((np.arange(n_cats) + 0.5) / n_cats)
            ax.set_xticklabels([f"{int(v)}" for v in unique_vals])


def _plot_consolidated_target_dist(
    parent_plotter,
    ax,
    model,
    Y_orig,
    active_rules,
    all_rules,
    show_prop,
    show_dens,
    scatter_s,
    subsample,
    rules_to_plot: Optional[List[int]] = None,
    show_other_rules_as_gray: bool = False,
    show_population_histogram: bool = False,
    scale_densities_by_weight: bool = False,
):
    """Helper method to plot the main target distribution panel for the combined plot."""
    rule_colors = {
        r["component"]: parent_plotter._get_colors(len(all_rules))[i]
        for i, r in enumerate(all_rules)
    }
    y_dim = Y_orig.shape[1] if Y_orig.ndim > 1 else 1
    y_name_config = parent_plotter.name_mapping.get("y_variable", "Y")

    if y_dim == 1:
        y_min, y_max = Y_orig.cpu().min(), Y_orig.cpu().max()
        y_label = "Density"  # Default y-axis label

        y_unique = np.unique(Y_orig.cpu().numpy())
        is_discrete_integer = len(y_unique) <= 100 and np.allclose(
            y_unique, np.round(y_unique)
        )

        # Adjust binning for histograms to align with discrete integer values.
        if is_discrete_integer:
            # Create bins centered on each integer for a proper histogram of discrete values.
            y_bins = np.arange(y_min - 0.5, y_max + 1.5)
        else:
            y_bins = np.linspace(y_min, y_max, 50)

        # Plot background histogram if requested (GMM style)
        if show_population_histogram:
            ax.hist(
                Y_orig.cpu().numpy().flatten(),
                bins=y_bins,
                density=True,
                color="#ecf0f1",
                zorder=1,
            )
        # Original proportional plot, only if population hist isn't shown
        elif show_prop:
            bg_density, bins = np.histogram(
                Y_orig.cpu().numpy(), bins=y_bins, density=True
            )
            ax.bar(
                bins[:-1],
                bg_density,
                width=np.diff(bins),
                color="#ecf0f1",
                align="edge",
            )
            bg_counts, _ = np.histogram(
                Y_orig.cpu().numpy(), bins=y_bins, density=False
            )

            for rule in active_rules:
                if len(rule["target_values"]) > 0:
                    rule_counts, _ = np.histogram(
                        rule["target_values"], bins=bins, density=False
                    )
                    props = np.divide(
                        rule_counts,
                        bg_counts,
                        out=np.zeros_like(rule_counts, dtype=float),
                        where=bg_counts != 0,
                    )
                    ax.bar(
                        bins[:-1],
                        props * bg_density,
                        width=np.diff(bins),
                        color=rule_colors[rule["component"]],
                        alpha=0.8,
                        align="edge",
                    )

        if show_dens:
            if is_discrete_integer and y_max > y_min:
                y_range = np.sort(y_unique).reshape(-1, 1)
            else:
                y_range = np.linspace(y_min, y_max, 200).reshape(-1, 1)

            weights = None
            if scale_densities_by_weight:
                if model.X_original is not None:
                    responsibilities = model.get_responsibilities(model.X_original)
                    weights = np.mean(responsibilities, axis=0)
                else:
                    print(
                        "Warning: Cannot scale densities by weight, model.X_original is not available."
                    )
            try:
                all_densities = model.get_expert_densities(y_range)

                if is_discrete_integer:
                    density_sums = np.sum(all_densities, axis=0)
                    # Avoid division by zero for components with no mass
                    density_sums[density_sums == 0] = 1.0
                    all_densities = all_densities / density_sums
                    y_label = "Probability Mass"  # Update label

                # Plot densities for active rules
                for rule in active_rules:
                    comp_idx = rule["component"]
                    if comp_idx < all_densities.shape[1]:
                        density_curve = all_densities[:, comp_idx]
                        if weights is not None and comp_idx < len(weights):
                            density_curve = density_curve * weights[comp_idx]

                        ax.plot(
                            y_range.flatten(),
                            density_curve,
                            color=rule_colors[comp_idx],
                            linewidth=2.5,
                            zorder=2,
                        )

                if show_other_rules_as_gray and rules_to_plot is not None:
                    other_rules = [
                        r for r in all_rules if r["component"] not in rules_to_plot
                    ]
                    gray_color = "#cccccc"
                    for rule in other_rules:
                        comp_idx = rule["component"]
                        if comp_idx < all_densities.shape[1]:
                            density_curve = all_densities[:, comp_idx]
                            if weights is not None and comp_idx < len(weights):
                                density_curve = density_curve * weights[comp_idx]

                            ax.plot(
                                y_range.flatten(),
                                density_curve,
                                color=gray_color,
                                linewidth=1.5,
                                linestyle="-",
                                zorder=-1,
                            )

            except Exception as e:
                print(f"Could not plot densities: {e}")

        ax.set_ylabel(y_label)
        xlabel = y_name_config if isinstance(y_name_config, str) else "Y"
        ax.set_xlabel(xlabel)
        ax.set_ylim(bottom=0)

    elif y_dim == 2:
        Y_np = Y_orig.cpu().numpy()
        if show_prop:
            # Subsample background points for clarity
            Y_bg_plot = Y_np
            if len(Y_np) > subsample:
                Y_bg_plot = Y_np[np.random.choice(len(Y_np), subsample, replace=False)]
            ax.scatter(
                Y_bg_plot[:, 0],
                Y_bg_plot[:, 1],
                color="#ecf0f1",
                s=scatter_s,
                rasterized=True,
                alpha=0.6,
            )
            for rule in active_rules:
                target_values = rule["target_values"]
                if len(target_values) > subsample:
                    idx = np.random.choice(len(target_values), subsample, replace=False)
                    target_values = target_values[idx]

                if len(target_values) > 0:
                    ax.scatter(
                        target_values[:, 0],
                        target_values[:, 1],
                        color=rule_colors[rule["component"]],
                        edgecolors=None,
                        s=scatter_s + 0.5,
                        alpha=0.65,
                        rasterized=True,
                    )

        if show_dens:
            y1r = np.linspace(Y_np[:, 0].min(), Y_np[:, 0].max(), 75)
            y2r = np.linspace(Y_np[:, 1].min(), Y_np[:, 1].max(), 75)
            Y1, Y2 = np.meshgrid(y1r, y2r)
            y_grid = np.c_[Y1.ravel(), Y2.ravel()]
            try:
                all_densities = model.get_expert_densities(y_grid)
                for rule in active_rules:
                    comp_idx = rule["component"]
                    if comp_idx < all_densities.shape[1]:
                        Z = all_densities[:, comp_idx].reshape(Y1.shape)
                        ax.contour(
                            Y1,
                            Y2,
                            Z,
                            levels=4,
                            colors=[rule_colors[comp_idx]],
                            linewidths=1.0,
                            alpha=0.8,
                        )

                if show_other_rules_as_gray and rules_to_plot is not None:
                    other_rules = [
                        r for r in all_rules if r["component"] not in rules_to_plot
                    ]
                    gray_color = "#cccccc"
                    for rule in other_rules:
                        comp_idx = rule["component"]
                        if comp_idx < all_densities.shape[1]:
                            Z = all_densities[:, comp_idx].reshape(Y1.shape)
                            ax.contour(
                                Y1,
                                Y2,
                                Z,
                                levels=3,
                                colors=[gray_color],
                                linewidths=1.5,
                                linestyles="--",
                            )

            except Exception as e:
                print(f"Could not plot 2D densities: {e}")

        if isinstance(y_name_config, (list, tuple)) and len(y_name_config) == 2:
            xlabel, ylabel = y_name_config
        else:
            base_name = y_name_config if isinstance(y_name_config, str) else "Y"
            xlabel = f"{base_name} Dim 1"
            ylabel = f"{base_name} Dim 2"

        ax.set_xlabel(xlabel)
        ax.set_ylabel(ylabel)
        ax.set_aspect("equal", "box")

    ax.grid(True, linestyle="--", alpha=0.5)


def plot_standard(
    parent_plotter,
    structured_rules,
    model,
    X_orig,
    Y_orig,
    filepath,
    weight_threshold=0.1,
    preview_subsample=3000,
    show=True,
    show_densities=False,
    row_height=1.0,
    density_width=1.2,
    scatter_s=1.0,
    **kwargs,
):
    """Generates a plot with a separate target distribution for each rule."""
    active_rules = [r for r in structured_rules if r["predicates"]]
    if not active_rules:
        print("Cannot generate plot: No active rules with defined predicates.")
        return

    all_features = list(active_rules[0]["predicates"].keys())
    n_components = len(active_rules)
    n_features = len(all_features)
    n_plot_cols = n_features + 1

    fig, axes = plt.subplots(
        n_components,
        n_plot_cols,
        figsize=(max(12, n_plot_cols * 2.2), max(4, n_components * row_height)),
        gridspec_kw={
            "hspace": 0.1,
            "wspace": 0.1,
            "width_ratios": [density_width] + [1] * n_features,
        },
    )

    if n_components == 1 and n_plot_cols > 1:
        axes = axes.reshape(1, -1)
    elif n_plot_cols == 1 and n_components > 1:
        axes = axes.reshape(-1, 1)
    elif n_components == 1 and n_plot_cols == 1:
        axes = np.array([[axes]])

    palette = parent_plotter._get_colors(n_components)
    rule_colors = {
        rule["component"]: palette[i % len(palette)]
        for i, rule in enumerate(active_rules)
    }

    y_dim = Y_orig.shape[1] if Y_orig.ndim > 1 else 1
    y_unique = np.unique(Y_orig.cpu().numpy())
    if y_dim == 1:
        y_bins = np.linspace(
            Y_orig.cpu().min(), Y_orig.cpu().max(), min(50, len(y_unique))
        )

    for i, rule in enumerate(active_rules):
        component_idx = rule["component"]
        rule_color = rule_colors[component_idx]
        ax_target = axes[i, 0]
        target_values = rule["target_values"]

        if y_dim == 1:
            ax_target.set_frame_on(False)
            ax_target.grid(False)
            ax_target.tick_params(
                left=False, labelleft=False, bottom=False, labelbottom=False
            )
            background_density, bins = np.histogram(
                Y_orig.cpu().numpy(), bins=y_bins, density=True
            )
            background_counts, _ = np.histogram(
                Y_orig.cpu().numpy(), bins=y_bins, density=False
            )
            ax_target.bar(
                bins[:-1],
                background_density,
                width=np.diff(bins),
                color="#ecf0f1",
                align="edge",
            )
            if len(target_values) > 0:
                rule_counts, _ = np.histogram(target_values, bins=bins, density=False)
                proportions = np.divide(
                    rule_counts,
                    background_counts,
                    out=np.zeros_like(rule_counts, dtype=float),
                    where=background_counts != 0,
                )
                rule_bar_heights = proportions * background_density
                ax_target.bar(
                    bins[:-1],
                    rule_bar_heights,
                    width=np.diff(bins),
                    color=rule_color,
                    alpha=0.8,
                    align="edge",
                )
            if show_densities:
                print("Warning: 1D density curves not implemented in standard plot.")
                pass

        elif y_dim == 2:
            ax_target.set_frame_on(True)
            ax_target.grid(True, linestyle="--", alpha=0.6)

            Y_np = Y_orig.cpu().numpy()

            # Subsample background points
            Y_bg_plot = Y_np
            if len(Y_np) > preview_subsample:
                Y_bg_plot = Y_np[
                    np.random.choice(len(Y_np), preview_subsample, replace=False)
                ]
            ax_target.scatter(
                Y_bg_plot[:, 0],
                Y_bg_plot[:, 1],
                color="#ecf0f1",
                s=scatter_s,
                rasterized=True,
                zorder=1,
            )

            # Subsample and plot rule points
            if len(target_values) > 0:
                target_values_plot = target_values
                if len(target_values) > preview_subsample:
                    target_values_plot = target_values[
                        np.random.choice(
                            len(target_values), preview_subsample, replace=False
                        )
                    ]
                ax_target.scatter(
                    target_values_plot[:, 0],
                    target_values_plot[:, 1],
                    color=rule_color,
                    s=scatter_s + 0.5,
                    alpha=0.8,
                    edgecolors="none",
                    rasterized=True,
                    zorder=2,
                )

            ax_target.set_aspect("equal", "box")
            ax_target.tick_params(
                left=True,
                labelleft=True,
                bottom=True,
                labelbottom=True,
                labelsize="small",
            )

            # Only add axis labels to the bottom-most plot to reduce clutter
            if i == n_components - 1:
                y_name_config = parent_plotter.name_mapping.get("y_variable", "Y")
                if isinstance(y_name_config, (list, tuple)) and len(y_name_config) == 2:
                    xlabel, _ = y_name_config
                else:
                    base_name = y_name_config if isinstance(y_name_config, str) else "Y"
                    xlabel = f"{base_name} Dim 1"
                ax_target.set_xlabel(xlabel, fontsize="small")
            else:
                ax_target.tick_params(labelbottom=False)

        ax_target.set_ylabel(
            f"{component_idx + 1}",
            rotation=0,
            size="large",
            ha="right",
            va="center",
            labelpad=25,
        )
        if y_dim == 1:
            ax_target.set_yticks([])

        if i == 0:
            y_name_config = parent_plotter.name_mapping.get("y_variable", "Y")
            title = y_name_config if isinstance(y_name_config, str) else "Y"
            ax_target.set_title(title, pad=15)
        if i != n_components - 1 and y_dim == 1:
            ax_target.tick_params(axis="x", labelbottom=False)

        for j, feat_name in enumerate(all_features):
            ax_feat = axes[i, j + 1]
            pred = rule["predicates"].get(feat_name)
            _plot_single_predicate(ax_feat, pred, X_orig, j, weight_threshold, 0.3)

            if i == 0:
                ax_feat.set_title(feat_name, pad=15)
            if i != n_components - 1:
                ax_feat.tick_params(
                    axis="x",
                    which="both",
                    bottom=False,
                    top=False,
                    labelbottom=False,
                )
            else:
                if pred and not (
                    pred.get("is_discrete")
                    and len(np.unique(X_orig[:, j].cpu().numpy())) <= 1
                ):
                    plt.setp(
                        ax_feat.get_xticklabels(),
                        rotation=45,
                        ha="right",
                        rotation_mode="anchor",
                    )

    plt.tight_layout(rect=[0.05, 0.05, 0.98, 0.95])
    _save_or_show_plot(parent_plotter, fig, filepath, show)


def plot_combined(
    parent_plotter,
    structured_rules: List[dict],
    model,
    X_orig: torch.Tensor,
    Y_orig: torch.Tensor,
    filepath: str,
    rules_to_plot: Optional[List[int]] = None,
    show_proportional_dist: bool = True,
    show_density_dist: bool = False,
    weight_threshold: float = 0.1,
    show: bool = True,
    scatter_s: float = 1.0,
    preview_subsample: int = 3000,
    row_height: float = 0.8,
    target_plot_width: float = 3.0,
    label_col_width: float = 0.5,
    predicate_col_width: float = 1.0,
    col_spacing: float = 0.2,
    row_spacing: float = 0.2,
    predicate_bar_height: float = 0.3,
    show_other_rules_as_gray: bool = False,
    show_population_histogram: bool = False,
    scale_densities_by_weight: bool = False,
    **kwargs,
):
    """
    Generates a publication-quality plot with a consolidated target distribution
    and rule predicates identified by colored indicators.
    """
    all_structured_rules = [r for r in structured_rules if r["predicates"]]
    active_rules = [
        r
        for r in all_structured_rules
        if rules_to_plot is None or r["component"] in rules_to_plot
    ]

    if not active_rules:
        print("Cannot generate plot: No active rules to visualize.")
        return

    all_features = list(all_structured_rules[0]["predicates"].keys())
    n_rules = len(active_rules)
    n_features = len(all_features)

    # Calculate total figure size in inches for direct control
    fig_height = n_rules * row_height + (n_rules - 1) * row_spacing
    predicate_area_width = (
        label_col_width
        + (n_features * predicate_col_width)
        + (n_features * col_spacing)
    )
    fig_width = target_plot_width + predicate_area_width + col_spacing

    fig = plt.figure(figsize=(fig_width, fig_height))

    # Create the main GridSpec: 1 row, 2 columns (target plot, predicate area)
    gs_main = gridspec.GridSpec(
        1, 2, width_ratios=[target_plot_width, predicate_area_width]
    )

    # Create the target distribution plot in the left column
    ax_target_main = fig.add_subplot(gs_main[0])
    _plot_consolidated_target_dist(
        parent_plotter,
        ax_target_main,
        model,
        Y_orig,
        active_rules,
        all_structured_rules,
        show_proportional_dist,
        show_density_dist,
        scatter_s,
        preview_subsample,
        rules_to_plot=rules_to_plot,
        show_other_rules_as_gray=show_other_rules_as_gray,
        show_population_histogram=show_population_histogram,
        scale_densities_by_weight=scale_densities_by_weight,
    )

    # Create a nested GridSpec for the predicate plots in the right column
    gs_predicates = gridspec.GridSpecFromSubplotSpec(
        n_rules,
        n_features + 1,
        subplot_spec=gs_main[1],
        width_ratios=[label_col_width] + [predicate_col_width] * n_features,
        wspace=col_spacing / predicate_col_width,  # wspace is relative to subplot width
        hspace=row_spacing / row_height,  # hspace is relative to subplot height
    )

    rule_colors = {
        r["component"]: parent_plotter._get_colors(len(all_structured_rules))[i]
        for i, r in enumerate(all_structured_rules)
    }

    for i, rule in enumerate(active_rules):
        # --- Label column ---
        ax_label = fig.add_subplot(gs_predicates[i, 0])
        ax_label.set_xlim(0, 1)
        ax_label.set_ylim(0, 1)
        ax_label.axis("off")
        ax_label.plot(
            [0.1, 0.4],
            [0.5, 0.5],
            color=rule_colors[rule["component"]],
            solid_capstyle="butt",
            linewidth=2.5,
            transform=ax_label.transAxes,
        )
        ax_label.text(
            0.7,
            0.5,
            f"{rule['component'] + 1}",
            ha="center",
            va="center",
            fontsize="small",
            fontweight="bold",
            transform=ax_label.transAxes,
        )

        # --- Predicate columns ---
        for j, feat_name in enumerate(all_features):
            ax_feat = fig.add_subplot(gs_predicates[i, j + 1])
            pred = rule["predicates"].get(feat_name)
            _plot_single_predicate(
                ax_feat, pred, X_orig, j, weight_threshold, predicate_bar_height
            )

            if i == 0:
                ax_feat.set_title(
                    feat_name,
                    pad=15,
                )  # fontdict={"fontsize": "x-small"})

            if i < n_rules - 1:
                ax_feat.tick_params(
                    axis="x", which="both", bottom=False, top=False, labelbottom=False
                )
            else:
                if pred and not (
                    pred.get("is_discrete")
                    and len(np.unique(X_orig[:, j].cpu().numpy())) <= 1
                ):
                    plt.setp(
                        ax_feat.get_xticklabels(),
                        rotation=45,
                        ha="right",
                        rotation_mode="anchor",
                    )

    gs_main.tight_layout(fig)
    _save_or_show_plot(parent_plotter, fig, filepath, show)


def plot_condensed(
    parent_plotter,
    structured_rules,
    model,
    Y_orig,
    filepath,
    rules_to_plot: List[int] = None,
    use_density: bool = False,
    show: bool = True,
    scatter_s: float = 1.0,
    preview_subsample: int = 3000,
    **kwargs,
):
    """
    Generates a single, condensed plot visualizing the target distributions for multiple rules.
    """
    all_structured_rules = [r for r in structured_rules if r["predicates"]]
    active_rules = [
        r
        for r in all_structured_rules
        if rules_to_plot is None or r["component"] in rules_to_plot
    ]

    if not active_rules:
        print("Cannot generate plot: No active rules selected to plot.")
        return

    fig, ax = plt.subplots(figsize=(8, 8))

    rule_colors = {
        r["component"]: parent_plotter._get_colors(len(all_structured_rules))[i]
        for i, r in enumerate(all_structured_rules)
    }
    y_dim = Y_orig.shape[1] if Y_orig.ndim > 1 else 1
    y_name_config = parent_plotter.name_mapping.get("y_variable", "Y")

    show_prop = not use_density
    show_dens = use_density

    if y_dim == 1:
        y_min, y_max = Y_orig.cpu().min(), Y_orig.cpu().max()
        y_bins = np.linspace(y_min, y_max, 50)

        if show_prop:
            bg_density, bins = np.histogram(
                Y_orig.cpu().numpy(), bins=y_bins, density=True
            )
            ax.bar(
                bins[:-1],
                bg_density,
                width=np.diff(bins),
                color="#ecf0f1",
                align="edge",
                # label="Population",
            )
            bg_counts, _ = np.histogram(
                Y_orig.cpu().numpy(), bins=y_bins, density=False
            )
            for rule in active_rules:
                if len(rule["target_values"]) > 0:
                    rule_counts, _ = np.histogram(
                        rule["target_values"], bins=bins, density=False
                    )
                    props = np.divide(
                        rule_counts,
                        bg_counts,
                        out=np.zeros_like(rule_counts, dtype=float),
                        where=bg_counts != 0,
                    )
                    ax.bar(
                        bins[:-1],
                        props * bg_density,
                        width=np.diff(bins),
                        color=rule_colors[rule["component"]],
                        alpha=0.8,
                        align="edge",
                        label=f"Rule {rule['component'] + 1}",
                    )
        if show_dens:
            y_range = np.linspace(y_min, y_max, 200).reshape(-1, 1)
            try:
                all_densities = model.get_expert_densities(y_range)
                for rule in active_rules:
                    comp_idx = rule["component"]
                    if comp_idx < all_densities.shape[1]:
                        ax.plot(
                            y_range.flatten(),
                            all_densities[:, comp_idx],
                            color=rule_colors[comp_idx],
                            linewidth=2.5,
                            label=f"Rule {rule['component'] + 1}",
                        )
            except Exception as e:
                print(f"Could not plot densities: {e}")
        ax.set_ylabel("Density")
        xlabel = y_name_config if isinstance(y_name_config, str) else "Y"
        ax.set_xlabel(xlabel)
        ax.set_ylim(bottom=0)

    elif y_dim == 2:
        legend_elements = []
        Y_np = Y_orig.cpu().numpy()
        if show_prop:
            legend_elements.append(
                Line2D(
                    [0],
                    [0],
                    marker="o",
                    color="w",
                    # label="Population",
                    markerfacecolor="#ecf0f1",
                    markersize=10,
                )
            )
            Y_bg_plot = Y_np
            if len(Y_np) > preview_subsample:
                Y_bg_plot = Y_np[
                    np.random.choice(len(Y_np), preview_subsample, replace=False)
                ]
            ax.scatter(
                Y_bg_plot[:, 0],
                Y_bg_plot[:, 1],
                color="#ecf0f1",
                s=scatter_s,
                rasterized=True,
            )
            for rule in active_rules:
                target_values = rule["target_values"]
                if len(target_values) > preview_subsample:
                    idx = np.random.choice(
                        len(target_values), preview_subsample, replace=False
                    )
                    target_values = target_values[idx]

                if len(target_values) > 0:
                    ax.scatter(
                        target_values[:, 0],
                        target_values[:, 1],
                        color=rule_colors[rule["component"]],
                        s=scatter_s + 0.5,
                        alpha=0.6,
                        edgecolors="none",
                        rasterized=True,
                    )
                    legend_elements.append(
                        Line2D(
                            [0],
                            [0],
                            marker="o",
                            color="w",
                            label=f"Rule {rule['component'] + 1}",
                            markerfacecolor=rule_colors[rule["component"]],
                            markersize=10,
                            alpha=0.6,
                        )
                    )
        if show_dens:
            y1r = np.linspace(Y_np[:, 0].min(), Y_np[:, 0].max(), 75)
            y2r = np.linspace(Y_np[:, 1].min(), Y_np[:, 1].max(), 75)
            Y1, Y2 = np.meshgrid(y1r, y2r)
            y_grid = np.c_[Y1.ravel(), Y2.ravel()]
            try:
                all_densities = model.get_expert_densities(y_grid)
                for rule in active_rules:
                    comp_idx = rule["component"]
                    if comp_idx < all_densities.shape[1]:
                        Z = all_densities[:, comp_idx].reshape(Y1.shape)
                        ax.contour(
                            Y1,
                            Y2,
                            Z,
                            levels=3,
                            colors=[rule_colors[comp_idx]],
                            linewidths=2.0,
                        )
                        legend_elements.append(
                            Line2D(
                                [0],
                                [0],
                                color=rule_colors[comp_idx],
                                lw=2.0,
                                label=f"Rule {comp_idx + 1}",
                            )
                        )
            except Exception as e:
                print(f"Could not plot 2D densities: {e}")

        # IMPROVEMENT: Consistent and configurable naming for Y-axis in 2D plots
        if isinstance(y_name_config, (list, tuple)) and len(y_name_config) == 2:
            xlabel, ylabel = y_name_config
        else:
            base_name = y_name_config if isinstance(y_name_config, str) else "Y"
            xlabel = f"{base_name} Dim 1"
            ylabel = f"{base_name} Dim 2"
        ax.set_xlabel(xlabel)
        ax.set_ylabel(ylabel)
        ax.set_aspect("equal", "box")

    ax.grid(True, linestyle="--", alpha=0.5)

    if y_dim == 2:
        ax.legend(handles=legend_elements)
    else:
        ax.legend()

    plt.tight_layout()
    _save_or_show_plot(parent_plotter, fig, filepath, show)
