"""
Interpretability analysis for MPF on California Housing.

Outputs: backbone/tilt/nominal plots, feature gating, local explanations,
GA2M comparison, SHAP comparison, stability experiments.
"""

import matplotlib
import numpy as np
import pandas as pd

matplotlib.use("Agg")  # Use non-interactive backend to prevent display
import json
import warnings
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple, Union

import matplotlib.colors as mcolors
import matplotlib.pyplot as plt

# Enable LaTeX rendering in matplotlib (mathtext)
plt.rcParams["mathtext.default"] = "regular"  # Use regular math text

# For cartopy (used in Figure 4)
import cartopy.crs as ccrs
import cartopy.feature as cfeature

# MPF imports
from mpf_py import MPF, TreeGrid

# Set random seeds for reproducibility
np.random.seed(42)

# Configuration (reproducibility folder: data/models/figures under BASE_DIR)
BASE_DIR = Path(__file__).parent
DATA_DIR = BASE_DIR / "data" / "california"
MODELS_DIR = BASE_DIR / "models" / "california"

# Model selection: Set to True for blackbox model, False for interpretable model
USE_BLACKBOX = False  # Change this to switch between models

# Set up directories based on model type
if USE_BLACKBOX:
    OUTPUT_SUFFIX = "blackbox"
    MODEL_FILENAME = "mpf_blackbox.bin"
else:
    OUTPUT_SUFFIX = "interpretable"
    MODEL_FILENAME = "mpf_interpretable.bin"

# Output directory (run_all_figures overrides FIGURES_DIR when generating)
FIGURES_DIR = BASE_DIR / "figures" / "california"
LOCAL_EXPLANATIONS_DIR = FIGURES_DIR / "local_explanations"

# Create output directories
for dir_path in [FIGURES_DIR, LOCAL_EXPLANATIONS_DIR]:
    dir_path.mkdir(parents=True, exist_ok=True)

# Feature names for California Housing
FEATURE_NAMES = [
    "Longitude",
    "Latitude",
    "HouseAge",
    "TotalRooms",
    "TotalBedrooms",
    "Population",
    "Households",
    "MedInc",
]

# Color scheme for stages
STAGE_COLORS = ["#1f77b4", "#ff7f0e", "#2ca02c", "#d62728", "#9467bd", "#8c564b"]


def _save_fig(fig: plt.Figure, save_path: Optional[Union[str, Path]], **kwargs) -> None:
    """
    Save a figure using the file extension to choose format.
    Prevents mismatches like writing SVG data into a .pdf file.
    """
    if save_path is None:
        return
    p = Path(save_path)
    p.parent.mkdir(parents=True, exist_ok=True)
    if p.suffix:
        fig.savefig(p, format=p.suffix[1:].lower(), **kwargs)
    else:
        fig.savefig(p, **kwargs)


def load_model_and_data(
    model_path: Optional[str] = None, data_path: Optional[str] = None
) -> Tuple[MPF, np.ndarray, np.ndarray, pd.DataFrame]:
    """
    Load MPF model and California Housing dataset.

    Parameters:
    -----------
    model_path : str, optional
        Path to MPF model .bin file. If None, uses default path.
    data_path : str, optional
        Path to California Housing CSV. If None, uses default path.

    Returns:
    --------
    model : MPF
        Loaded MPF model
    X : np.ndarray
        Feature matrix (n_samples, n_features)
    y : np.ndarray
        Target values (n_samples,)
    X_df : pd.DataFrame
        Feature matrix as DataFrame with column names
    """
    # Default paths
    if model_path is None:
        model_path = str(MODELS_DIR / MODEL_FILENAME)
    if data_path is None:
        data_path = str(DATA_DIR / "44977_california_housing.csv")

    # Load model
    model_type = "blackbox" if USE_BLACKBOX else "interpretable"
    print(f"Loading {model_type} MPF model from {model_path}...")
    model = MPF.load(model_path)

    # Load data
    print(f"Loading data from {data_path}...")
    df = pd.read_csv(data_path, header=None)
    y = df.iloc[:, -1].values  # Last column: target
    X = df.iloc[:, :-1].values  # All but last column: features

    # Create DataFrame with column names
    X_df = pd.DataFrame(X, columns=FEATURE_NAMES)

    print(f"Loaded model with {len(model.tree_grid_families)} stages")
    print(f"Loaded data: {X.shape[0]} samples, {X.shape[1]} features")

    return model, X, y, X_df


def extract_backbone_tilt(
    model: MPF, feature_idx: int, stage_idx: int, x_vals: np.ndarray
) -> Tuple[np.ndarray, np.ndarray]:
    """
    Extract backbone and tilt values for a feature at given x values.

    Parameters:
    -----------
    model : MPF
        Trained MPF model
    feature_idx : int
        Index of feature to extract
    stage_idx : int
        Index of stage/Stage (0-indexed)
    x_vals : np.ndarray
        Feature values to evaluate at (1D array)

    Returns:
    --------
    backbone_vals : np.ndarray
        Backbone values b_j^(ℓ)(x_j) for each x in x_vals
    tilt_vals : np.ndarray
        Tilt values d_j^(ℓ)(x_j) for each x in x_vals
    """
    tgf = model.tree_grid_families[stage_idx]
    tg = TreeGrid(tgf.combined_tree_grid)

    # Get backbone, tilt, and splits for this feature
    backbone = tg.backbone_values[feature_idx]
    tilt = tg.tilt_values[feature_idx]
    splits = tg.splits[feature_idx]

    # Evaluate at x_vals (piecewise constant)
    backbone_vals = np.zeros_like(x_vals)
    tilt_vals = np.zeros_like(x_vals)

    for i, x_val in enumerate(x_vals):
        # Find interval index
        if len(splits) == 0:
            interval_idx = 0
        else:
            interval_idx = np.searchsorted(splits, x_val, side="right")
            interval_idx = min(interval_idx, len(backbone) - 1)

        backbone_vals[i] = backbone[interval_idx]
        tilt_vals[i] = tilt[interval_idx]

    return backbone_vals, tilt_vals


def compute_nominal_contribution(
    backbone_vals: np.ndarray, tilt_vals: np.ndarray
) -> np.ndarray:
    """
    Compute nominal contribution: 2 * b_j * sinh(d_j).

    This represents the feature's intrinsic contribution assuming system balance (Δ = 0).

    Parameters:
    -----------
    backbone_vals : np.ndarray
        Backbone values b_j
    tilt_vals : np.ndarray
        Tilt values d_j

    Returns:
    --------
    nominal_contrib : np.ndarray
        Nominal contribution 2 * b_j * sinh(d_j)
    """
    return 2 * backbone_vals * np.sinh(tilt_vals)


def plot_figure_1_backbone(
    model: MPF,
    X: np.ndarray,
    X_df: pd.DataFrame,
    grid_points: int = 200,
    save_path: Optional[str] = None,
) -> plt.Figure:
    """
    Figure 1: Backbone plots b_j^(ℓ)(x_j) for all features j and all epochs.

    Creates a grid with n_epochs rows and n_features columns (one per feature per epoch).

    Parameters:
    -----------
    model : MPF
        Trained MPF model
    X : np.ndarray
        Feature matrix
    X_df : pd.DataFrame
        Feature DataFrame with column names
    grid_points : int
        Number of grid points for evaluation
    save_path : str, optional
        Path to save figure

    Returns:
    --------
    fig : matplotlib.figure.Figure
        Figure object
    """
    n_features = X.shape[1]
    n_stages = len(model.tree_grid_families)  # Plot all epochs

    # Layout: n_stages rows, n_features columns
    fig, axes = plt.subplots(
        n_stages, n_features, figsize=(4 * n_features, 4 * n_stages)
    )
    if n_features == 1:
        axes = axes.reshape(n_stages, 1)
    if n_stages == 1:
        axes = axes.reshape(1, n_features)

    for stage_idx in range(n_stages):
        for feat_idx in range(n_features):
            ax = axes[stage_idx, feat_idx]
            feature_name = FEATURE_NAMES[feat_idx]

            # Get feature range
            feat_min = X[:, feat_idx].min()
            feat_max = X[:, feat_idx].max()
            x_vals = np.linspace(feat_min, feat_max, grid_points)

            # Plot backbone for this stage
            backbone_vals, _ = extract_backbone_tilt(model, feat_idx, stage_idx, x_vals)

            ax.plot(
                x_vals,
                backbone_vals,
                linewidth=2,
                color=STAGE_COLORS[stage_idx % len(STAGE_COLORS)],
                alpha=0.8,
            )

            ax.set_xlabel(feature_name, fontsize=10)
            if feat_idx == 0:
                # Use LaTeX for backbone notation
                stage_num = stage_idx + 1
                ax.set_ylabel(
                    f"Stage {stage_num}\nBackbone $b_j^{{{stage_num}}}$", fontsize=10
                )
            ax.set_title(f"Feature: {feature_name}", fontsize=12, fontweight="bold")
            ax.grid(True, alpha=0.3)

    plt.tight_layout()

    if save_path is None:
        save_path = FIGURES_DIR / "figure_1_backbone.svg"
    _save_fig(fig, save_path, bbox_inches="tight")
    print(f"Saved Figure 1 to {save_path}")

    return fig


def plot_figure_2_nominal_contribution(
    model: MPF,
    X: np.ndarray,
    X_df: pd.DataFrame,
    grid_points: int = 200,
    save_path: Optional[str] = None,
) -> plt.Figure:
    """
    Figure 2: Nominal Contribution plots 2 b_j^(ℓ)(x_j) sinh(d_j^(ℓ)(x_j)) for Stage 2 only.

    This shows the feature's intrinsic contribution assuming system balance (Δ = 0).
    Only plots Stage 2 since tilt ≈ 0 in Stage 1, making nominal contribution ≈ 0.
    Analogous to GA2M's 1D partial dependence, but for MPF's multiplicative structure.

    Parameters:
    -----------
    model : MPF
        Trained MPF model
    X : np.ndarray
        Feature matrix
    X_df : pd.DataFrame
        Feature DataFrame with column names
    grid_points : int
        Number of grid points for evaluation
    save_path : str, optional
        Path to save figure

    Returns:
    --------
    fig : matplotlib.figure.Figure
        Figure object
    """
    n_features = X.shape[1]

    # Only plot Stage 2 (index 1) since tilt ≈ 0 in Stage 1
    stage_idx = 1
    if len(model.tree_grid_families) <= stage_idx:
        print(
            f"Warning: Model only has {len(model.tree_grid_families)} stages, cannot plot Stage 2"
        )
        return None

    # Calculate grid layout
    n_cols = 4
    n_rows = (n_features + n_cols - 1) // n_cols

    fig, axes = plt.subplots(n_rows, n_cols, figsize=(20, 5 * n_rows))
    axes = axes.flatten() if n_features > 1 else [axes]

    for feat_idx in range(n_features):
        ax = axes[feat_idx]
        feature_name = FEATURE_NAMES[feat_idx]

        # Get feature range
        feat_min = X[:, feat_idx].min()
        feat_max = X[:, feat_idx].max()
        x_vals = np.linspace(feat_min, feat_max, grid_points)

        # Plot nominal contribution for Stage 2 only
        backbone_vals, tilt_vals = extract_backbone_tilt(
            model, feat_idx, stage_idx, x_vals
        )
        nominal_contrib = compute_nominal_contribution(backbone_vals, tilt_vals)

        ax.plot(
            x_vals,
            nominal_contrib,
            linewidth=2,
            color=STAGE_COLORS[stage_idx % len(STAGE_COLORS)],
            alpha=0.8,
        )

        # Add horizontal line at y=0 for reference
        ax.axhline(y=0, color="black", linestyle="--", linewidth=0.5, alpha=0.5)

        ax.set_xlabel(feature_name, fontsize=10)
        ax.set_ylabel("Nominal Contribution\n2 b_j sinh(d_j)", fontsize=10)
        ax.set_title(
            f"Nominal Contribution: {feature_name}", fontsize=12, fontweight="bold"
        )
        ax.grid(True, alpha=0.3)

    # Hide unused subplots
    for idx in range(n_features, len(axes)):
        axes[idx].axis("off")

    plt.tight_layout()

    if save_path is None:
        save_path = FIGURES_DIR / "figure_2_nominal_contribution.svg"
    _save_fig(fig, save_path, bbox_inches="tight")
    print(f"Saved Figure 2 to {save_path}")

    return fig


def plot_figure_4_spatial_backbone_evolution(
    model: MPF,
    X: np.ndarray,
    X_df: pd.DataFrame,
    epochs: Optional[List[int]] = None,
    grid_points: int = 100,
    cmap: str = "seismic",
    cmap_backbone: str = "viridis",
    save_path: Optional[str] = None,
    save_each_plot_dir: Optional[Union[str, Path]] = None,
) -> plt.Figure:
    """
    Figure 4: Spatial backbone evolution showing b_lat × b_lon for each epoch.

    Based on plot_spatial_backbone_evolution from california_housing.py.
    Shows spatial backbone product Z = b_lat(lat) × b_lon(lon) and 2D PD function of lat x long.

    Parameters:
    -----------
    model : MPF
        Trained MPF model
    X : np.ndarray
        Feature matrix
    X_df : pd.DataFrame
        Feature DataFrame with column names (must have 'Latitude' and 'Longitude')
    epochs : list of int, optional
        List of Stage indices to plot. If None, plots all epochs.
    grid_points : int
        Number of grid points per axis
    cmap : str
        Colormap for the 2D PD row (signed, diverging).
    cmap_backbone : str
        Colormap for the spatial backbone row.
    save_path : str, optional
        Path to save combined figure
    save_each_plot_dir : path, optional
        If set, save each subplot (backbone and PD per epoch) to its own PDF in this directory.

    Returns:
    --------
    fig : matplotlib.figure.Figure
        Figure object (combined figure, or last single-plot figure if save_each_plot_dir was used)
    """
    # cartopy and mcolors already imported at top

    # Validate inputs
    required_cols = ["Latitude", "Longitude"]
    for col in required_cols:
        if col not in X_df.columns:
            raise ValueError(
                f"Column '{col}' not found in X_df. Available columns: {list(X_df.columns)}"
            )

    n_epochs = len(model.tree_grid_families)
    if epochs is None:
        epochs = list(range(n_epochs))

    # Get feature indices
    lon_idx = X_df.columns.get_loc("Longitude")
    lat_idx = X_df.columns.get_loc("Latitude")

    # Get coordinate bounds
    lon_vals_data = X_df["Longitude"].values
    lat_vals_data = X_df["Latitude"].values
    lon_min, lon_max = lon_vals_data.min(), lon_vals_data.max()
    lat_min, lat_max = lat_vals_data.min(), lat_vals_data.max()
    margin = 0.5

    # Create mesh grid for evaluation
    lon_vals = np.linspace(lon_min, lon_max, grid_points)
    lat_vals = np.linspace(lat_min, lat_max, grid_points)
    LON, LAT = np.meshgrid(lon_vals, lat_vals)

    # Flatten mesh grid for evaluation
    grid_flat = np.column_stack([LON.ravel(), LAT.ravel()])

    # Store spatial backbone and 2D PD for each epoch
    spatial_backbone_per_stage = []
    pd_2d_per_stage = []

    # Process each epoch
    for epoch_idx in epochs:
        # Extract backbone values and splits for this epoch
        tgf = model.tree_grid_families[epoch_idx]
        tg = TreeGrid(tgf.combined_tree_grid)

        # Get backbone values and splits for longitude and latitude
        lon_backbone = tg.backbone_values[lon_idx]
        lon_splits = tg.splits[lon_idx]
        lat_backbone = tg.backbone_values[lat_idx]
        lat_splits = tg.splits[lat_idx]

        # Compute spatial backbone product Z = b_lat(lat) × b_lon(lon)
        spatial_backbone = np.zeros(len(grid_flat))

        for i in range(len(grid_flat)):
            lon_val = grid_flat[i, 0]
            lat_val = grid_flat[i, 1]

            # Find bin index for longitude
            if len(lon_splits) == 0:
                lon_bin_idx = 0
            else:
                lon_bin_idx = np.searchsorted(lon_splits, lon_val, side="right")
                lon_bin_idx = min(lon_bin_idx, len(lon_backbone) - 1)

            # Find bin index for latitude
            if len(lat_splits) == 0:
                lat_bin_idx = 0
            else:
                lat_bin_idx = np.searchsorted(lat_splits, lat_val, side="right")
                lat_bin_idx = min(lat_bin_idx, len(lat_backbone) - 1)

            # Compute backbone: Z = b_lat(lat) × b_lon(lon)
            b_lat = lat_backbone[lat_bin_idx]
            b_lon = lon_backbone[lon_bin_idx]
            Z = b_lat * b_lon
            spatial_backbone[i] = Z

        spatial_backbone_per_stage.append(spatial_backbone)

    # Compute 2D PD function for lat x lon using the model's PD computation
    # This computes the actual partial dependence function, not just the net correction
    constants_per_epoch, partial_dep = model.compute_partial_dependence_function(
        [lon_idx, lat_idx], grid_flat, X
    )
    n_epochs_pd = partial_dep.shape[1] // 2

    # Extract f+ and f- for each epoch
    # According to mpf.py docstring: scaling is already absorbed into lambda values
    # Final prediction per Stage = f+ + f- (add columns, no additional scaling needed)
    f_plus_all = partial_dep[:, ::2]  # Columns 0, 2, 4, ... (f+ for each epoch)
    f_minus_all = partial_dep[:, 1::2]  # Columns 1, 3, 5, ... (f- for each epoch)

    # Compute combined PD for each epoch: f+ + f- (scaling already absorbed)
    for epoch_idx in epochs:
        if epoch_idx < n_epochs_pd:
            # Compute PD for this stage: f+ + f- (no additional scaling needed)
            pd_Stage = f_plus_all[:, epoch_idx] + f_minus_all[:, epoch_idx]
            pd_2d_per_stage.append(pd_Stage)
        else:
            # If Stage index is out of range, use zeros
            pd_2d_per_stage.append(np.zeros(len(grid_flat)))

    def _setup_map_axis(ax, lon_min, lon_max, lat_min, lat_max, margin=0.5):
        """Helper to set up map axis with features."""
        ax.add_feature(cfeature.COASTLINE)
        ax.add_feature(cfeature.STATES)
        ax.set_extent(
            [lon_min - margin, lon_max + margin, lat_min - margin, lat_max + margin],
            crs=ccrs.PlateCarree(),
        )
        gl = ax.gridlines(
            draw_labels=True, linewidth=0.5, color="gray", alpha=0.5, linestyle="--"
        )
        gl.top_labels = gl.right_labels = False

    def _diverging_norm(arr):
        """Safe diverging norm (TwoSlopeNorm) that avoids division by zero when arr is constant."""
        vmax = np.max(np.abs(arr))
        vmin = -vmax if vmax > 0 else np.min(arr)
        if vmax <= 0 or vmin >= vmax:
            return mcolors.TwoSlopeNorm(vmin=-1.0, vcenter=0.0, vmax=1.0)
        return mcolors.TwoSlopeNorm(vmin=vmin, vcenter=0.0, vmax=vmax)

    def _backbone_norm(arr):
        """Norm for nonnegative backbone: 0 at low end, max at high end (strong 0 vs >0 contrast)."""
        vmax = np.max(arr) if arr.size else 1.0
        vmax = max(vmax, 1e-10)  # avoid 0 range
        return mcolors.Normalize(vmin=0.0, vmax=vmax)

    # Per-plot size: same for all four so they align when placed in a row (save without
    # bbox_inches='tight' so output dimensions are identical).
    single_plot_figsize = (6, 5)

    # Optionally save each plot in its own PDF
    each_plot_dir = Path(save_each_plot_dir) if save_each_plot_dir else None
    if each_plot_dir is not None:
        each_plot_dir = each_plot_dir.resolve()
        each_plot_dir.mkdir(parents=True, exist_ok=True)
        for plot_idx, epoch_idx in enumerate(epochs):
            # Backbone plot
            fig_one = plt.figure(figsize=single_plot_figsize)
            ax = fig_one.add_subplot(1, 1, 1, projection=ccrs.PlateCarree())
            _setup_map_axis(ax, lon_min, lon_max, lat_min, lat_max, margin=margin)
            Z = spatial_backbone_per_stage[plot_idx].reshape(LAT.shape)
            norm = _backbone_norm(Z)
            cs = ax.contourf(
                LON,
                LAT,
                Z,
                levels=20,
                cmap=cmap_backbone,
                norm=norm,
                transform=ccrs.PlateCarree(),
                alpha=0.7,
            )
            plt.colorbar(cs, ax=ax, shrink=0.6, pad=0.05, label="Backbone Magnitude")
            ax.set_title(
                f"Stage {epoch_idx + 1}: backbone $b_{{lon}} \\times b_{{lat}}$",
                fontsize=12,
            )
            plt.tight_layout()
            backbone_path = (
                each_plot_dir
                / f"figure_4_spatial_backbone_evolution_epoch{epoch_idx + 1}_backbone.pdf"
            )
            _save_fig(fig_one, backbone_path, pad_inches=0.05)
            plt.close(fig_one)

            # PD plot (same figsize so all four PDFs have identical dimensions in a row)
            fig_one = plt.figure(figsize=single_plot_figsize)
            ax = fig_one.add_subplot(1, 1, 1, projection=ccrs.PlateCarree())
            _setup_map_axis(ax, lon_min, lon_max, lat_min, lat_max, margin=margin)
            pd_2d = pd_2d_per_stage[plot_idx].reshape(LAT.shape)
            norm = _diverging_norm(pd_2d)
            cs = ax.contourf(
                LON,
                LAT,
                pd_2d,
                levels=20,
                cmap=cmap,
                norm=norm,
                transform=ccrs.PlateCarree(),
                alpha=0.7,
            )
            plt.colorbar(
                cs,
                ax=ax,
                shrink=0.6,
                pad=0.05,
                label=r"$\mathrm{PD}_{\mathrm{lat,long}}$",
            )
            ax.set_title(
                f"Stage {epoch_idx + 1}: $\\mathrm{{PD}}_{{lat,long}}$", fontsize=12
            )
            plt.tight_layout()
            pd_path = (
                each_plot_dir
                / f"figure_4_spatial_backbone_evolution_epoch{epoch_idx + 1}_pd.pdf"
            )
            _save_fig(fig_one, pd_path, pad_inches=0.05)
            plt.close(fig_one)
        print(f"Saved {2 * len(epochs)} spatial backbone plots to {each_plot_dir}")

    # Layout: 2 rows (backbone and 2D PD), n_epochs columns
    n_cols = len(epochs)
    n_rows = 2  # Row 1: Backbone, Row 2: 2D PD function

    figsize = (6 * n_cols, 10)
    fig = plt.figure(figsize=figsize)

    # Row 1: Plot spatial backbone for each epoch
    for plot_idx, epoch_idx in enumerate(epochs):
        ax = fig.add_subplot(
            n_rows, n_cols, plot_idx + 1, projection=ccrs.PlateCarree()
        )
        _setup_map_axis(ax, lon_min, lon_max, lat_min, lat_max, margin=margin)

        # Reshape spatial backbone to mesh grid shape
        Z = spatial_backbone_per_stage[plot_idx].reshape(LAT.shape)

        # Create contour plot (backbone: 0→max sequential for strong 0 vs >0 contrast)
        norm = _backbone_norm(Z)

        cs = ax.contourf(
            LON,
            LAT,
            Z,
            levels=20,
            cmap=cmap_backbone,
            norm=norm,
            transform=ccrs.PlateCarree(),
            alpha=0.7,
        )
        plt.colorbar(cs, ax=ax, shrink=0.6, pad=0.05, label="Backbone Magnitude")

        # Set title
        ax.set_title(
            f"Stage {epoch_idx + 1}: backbone $b_{{lon}} \\times b_{{lat}}$",
            fontsize=12,
        )

    # Row 2: Plot 2D PD function for each epoch
    for plot_idx, epoch_idx in enumerate(epochs):
        ax = fig.add_subplot(
            n_rows, n_cols, n_cols + plot_idx + 1, projection=ccrs.PlateCarree()
        )
        _setup_map_axis(ax, lon_min, lon_max, lat_min, lat_max, margin=margin)

        # Reshape 2D PD to mesh grid shape
        pd_2d = pd_2d_per_stage[plot_idx].reshape(LAT.shape)

        # Create contour plot with diverging colormap centered at zero
        norm = _diverging_norm(pd_2d)

        cs = ax.contourf(
            LON,
            LAT,
            pd_2d,
            levels=20,
            cmap=cmap,
            norm=norm,
            transform=ccrs.PlateCarree(),
            alpha=0.7,
        )
        plt.colorbar(
            cs, ax=ax, shrink=0.6, pad=0.05, label=r"$\mathrm{PD}_{\mathrm{lat,long}}$"
        )

        # Set title
        ax.set_title(
            f"Stage {epoch_idx + 1}: $\\mathrm{{PD}}_{{lat,long}}$",
            fontsize=12,
        )

    plt.tight_layout()

    if save_path is None:
        save_path = FIGURES_DIR / "figure_4_spatial_backbone_evolution.svg"
    _save_fig(fig, save_path, bbox_inches="tight")
    print(f"Saved Figure 4 to {save_path}")

    return fig


def plot_figure_3_first_order_pd(
    model: MPF,
    X: np.ndarray,
    X_df: pd.DataFrame,
    grid_points: int = 200,
    save_path: Optional[str] = None,
) -> plt.Figure:
    """
    Figure 3: First-order partial dependence functions for all features and epochs.

    Plots PD_j(x_j) = C+_j * a+_j(x_j) + C-_j * a-_j(x_j) for each feature j and each epoch.

    Parameters:
    -----------
    model : MPF
        Trained MPF model
    X : np.ndarray
        Feature matrix
    X_df : pd.DataFrame
        Feature DataFrame with column names
    grid_points : int
        Number of grid points for evaluation
    save_path : str, optional
        Path to save figure

    Returns:
    --------
    fig : matplotlib.figure.Figure
        Figure object
    """
    n_features = X.shape[1]
    n_epochs = len(model.tree_grid_families)

    # Create grid of values for each feature
    feature_grids = []
    X_mean = X.mean(axis=0)

    # First-order PD: separate grids per feature
    # Function expects single X_grid; one grid per feature
    # One large grid with all combinations, or vary each feature separately
    # Call once with a grid that varies all features

    # Create a grid where each feature varies over its range
    # This will be (grid_points, n_features)
    X_grid_list = []
    for feat_idx in range(n_features):
        feat_min = X[:, feat_idx].min()
        feat_max = X[:, feat_idx].max()
        x_vals = np.linspace(feat_min, feat_max, grid_points)
        feature_grids.append(x_vals)

        # Create grid for this feature: vary this feature, others at mean
        X_grid_feat = np.tile(X_mean, (grid_points, 1))
        X_grid_feat[:, feat_idx] = x_vals
        X_grid_list.append(X_grid_feat)

    # Compute first-order PD functions for all features at once
    # Combined grid; function handles each feature separately
    X_grid_combined = np.vstack(
        X_grid_list
    )  # Shape: (grid_points * n_features, n_features)

    # Compute first-order PD functions
    first_order_pd = model.compute_first_order_partial_dependence_functions(
        X_grid_combined, X
    )

    # Layout: n_epochs rows, n_features columns (wider plots like Figure 3.1)
    fig, axes = plt.subplots(
        n_epochs, n_features, figsize=(8 * n_features, 4 * n_epochs), squeeze=False
    )

    for feat_idx in range(n_features):
        constants_per_epoch, pd_values = first_order_pd[feat_idx]

        # Extract the relevant rows for this feature (rows feat_idx * grid_points to (feat_idx+1) * grid_points)
        start_idx = feat_idx * grid_points
        end_idx = (feat_idx + 1) * grid_points
        pd_values_feat = pd_values[start_idx:end_idx, :]

        # Extract f+ and f- for each epoch
        f_plus_all = pd_values_feat[:, ::2]  # Columns 0, 2, 4, ...
        f_minus_all = pd_values_feat[:, 1::2]  # Columns 1, 3, 5, ...

        x_vals = feature_grids[feat_idx]

        for epoch_idx in range(n_epochs):
            ax = axes[epoch_idx, feat_idx]

            # Get constants for this epoch
            c_plus, c_minus = constants_per_epoch[epoch_idx]

            # Total PD = f+ + f- (already includes constants)
            pd_total = f_plus_all[:, epoch_idx] + f_minus_all[:, epoch_idx]

            ax.plot(
                x_vals,
                pd_total,
                linewidth=1.5,
                color=STAGE_COLORS[epoch_idx % len(STAGE_COLORS)],
                alpha=0.8,
            )
            ax.axhline(y=0, color="black", linestyle="--", linewidth=0.5, alpha=0.5)

            ax.set_xlabel(FEATURE_NAMES[feat_idx], fontsize=10)
            if feat_idx == 0:
                ax.set_ylabel(f"PD (Stage {epoch_idx + 1})", fontsize=10)

            # Title with constants (using LaTeX, no scientific notation)
            # Format constants nicely
            if abs(c_plus) >= 1e-3 and abs(c_plus) < 1e3:
                c_plus_str = f"{c_plus:.4f}"
            elif abs(c_plus) >= 1e3:
                c_plus_str = f"{c_plus:.2f}"
            else:
                c_plus_str = f"{c_plus:.6f}"

            if abs(c_minus) >= 1e-3 and abs(c_minus) < 1e3:
                c_minus_str = f"{c_minus:.4f}"
            elif abs(c_minus) >= 1e3:
                c_minus_str = f"{c_minus:.2f}"
            else:
                c_minus_str = f"{c_minus:.6f}"

            ax.set_title(
                f"{FEATURE_NAMES[feat_idx]}\nStage {epoch_idx + 1}: $C_+={c_plus_str}$, $C_-={c_minus_str}$",
                fontsize=10,
                fontweight="bold",
            )
            ax.grid(True, alpha=0.3)

    plt.tight_layout(rect=[0, 0, 1, 0.99])

    if save_path is None:
        save_path = FIGURES_DIR / "figure_3_first_order_pd.svg"
    _save_fig(fig, save_path, bbox_inches="tight")
    print(f"Saved Figure 3 to {save_path}")

    return fig


def plot_figure_3_1_scaled_first_order_pd(
    model: MPF,
    X: np.ndarray,
    X_df: pd.DataFrame,
    grid_points: int = 200,
    save_path: Optional[str] = None,
) -> plt.Figure:
    """
    Figure 3.1: Scaled first-order partial dependence functions.

    Plots C+_j * PD+_j(x_j) and C-_j * PD-_j(x_j) on the same plot for each feature j.
    These are the f+ and f- components separately.

    Parameters:
    -----------
    model : MPF
        Trained MPF model
    X : np.ndarray
        Feature matrix
    X_df : pd.DataFrame
        Feature DataFrame with column names
    grid_points : int
        Number of grid points for evaluation
    save_path : str, optional
        Path to save figure

    Returns:
    --------
    fig : matplotlib.figure.Figure
        Figure object
    """
    n_features = X.shape[1]
    n_epochs = len(model.tree_grid_families)

    # Create grid of values for each feature
    feature_grids = []
    X_mean = X.mean(axis=0)
    X_grid_list = []

    for feat_idx in range(n_features):
        feat_min = X[:, feat_idx].min()
        feat_max = X[:, feat_idx].max()
        x_vals = np.linspace(feat_min, feat_max, grid_points)
        feature_grids.append(x_vals)

        # Create grid for this feature: vary this feature, others at mean
        X_grid_feat = np.tile(X_mean, (grid_points, 1))
        X_grid_feat[:, feat_idx] = x_vals
        X_grid_list.append(X_grid_feat)

    # Combine all grids
    X_grid_combined = np.vstack(
        X_grid_list
    )  # Shape: (grid_points * n_features, n_features)

    # Compute first-order PD functions
    first_order_pd = model.compute_first_order_partial_dependence_functions(
        X_grid_combined, X
    )

    # Layout: n_epochs rows, n_features columns (wider plots for Figure 3.1)
    fig, axes = plt.subplots(
        n_epochs, n_features, figsize=(6 * n_features, 4 * n_epochs), squeeze=False
    )

    for feat_idx in range(n_features):
        constants_per_epoch, pd_values = first_order_pd[feat_idx]

        # Extract the relevant rows for this feature
        start_idx = feat_idx * grid_points
        end_idx = (feat_idx + 1) * grid_points
        pd_values_feat = pd_values[start_idx:end_idx, :]

        # Extract f+ and f- for each epoch
        f_plus_all = pd_values_feat[:, ::2]  # Columns 0, 2, 4, ...
        f_minus_all = pd_values_feat[:, 1::2]  # Columns 1, 3, 5, ...

        x_vals = feature_grids[feat_idx]

        for epoch_idx in range(n_epochs):
            ax = axes[epoch_idx, feat_idx]

            # Get constants for this epoch
            c_plus, c_minus = constants_per_epoch[epoch_idx]

            # f+ = C+ * PD+ (already scaled)
            # f- = C- * PD- (already scaled, C- includes negative sign)
            f_plus = f_plus_all[:, epoch_idx]
            f_minus = f_minus_all[:, epoch_idx]
            f_minus_flipped = (
                -f_minus
            )  # Flip sign: -C-·PD- to show actual contribution direction

            # Compute difference: C+PD+ - (-C-PD-) = C+PD+ + C-PD- = total PD
            # Display: C+PD+ - (-C-PD-) = f_plus - f_minus_flipped
            diff = f_plus - f_minus_flipped  # This is C+PD+ - (-C-PD-)

            # Fill area based on sign of difference
            # Positive difference: green, negative difference: orange/red
            ax.fill_between(
                x_vals,
                f_minus_flipped,
                f_plus,
                where=(diff >= 0),
                color="green",
                alpha=0.3,
                label="Positive difference" if epoch_idx == 0 and feat_idx == 0 else "",
            )
            ax.fill_between(
                x_vals,
                f_minus_flipped,
                f_plus,
                where=(diff < 0),
                color="orange",
                alpha=0.3,
                label="Negative difference" if epoch_idx == 0 and feat_idx == 0 else "",
            )

            # Plot PD+ and PD- lines with distinct colors (thinner lines)
            ax.plot(
                x_vals,
                f_plus,
                linewidth=1.5,
                color="darkred",
                alpha=0.9,
                label="$PD_+$",
            )
            ax.plot(
                x_vals,
                f_minus_flipped,
                linewidth=1.5,
                color="darkblue",
                alpha=0.9,
                label="$PD_-$",
            )

            ax.axhline(y=0, color="black", linestyle="--", linewidth=0.5, alpha=0.5)
            ax.set_xlabel(FEATURE_NAMES[feat_idx], fontsize=10)
            if feat_idx == 0:
                ax.set_ylabel(f"Scaled PD (Stage {epoch_idx + 1})", fontsize=10)

            # Title with constants (no scientific notation, use LaTeX)
            # Format constants without scientific notation
            c_plus_str = f"{c_plus:.4f}" if abs(c_plus) < 1000 else f"{c_plus:.2f}"
            c_minus_str = f"{c_minus:.4f}" if abs(c_minus) < 1000 else f"{c_minus:.2f}"
            ax.set_title(
                f"{FEATURE_NAMES[feat_idx]}\nStage {epoch_idx + 1}: $C_+={c_plus_str}$, $C_-={c_minus_str}$",
                fontsize=10,
                fontweight="bold",
            )
            ax.grid(True, alpha=0.3)
            ax.legend(loc="best", fontsize=8)

    plt.tight_layout(rect=[0, 0, 1, 0.99])

    if save_path is None:
        save_path = FIGURES_DIR / "figure_3_1_scaled_first_order_pd.svg"
    _save_fig(fig, save_path, bbox_inches="tight")
    print(f"Saved Figure 3.1 to {save_path}")

    return fig


def plot_figure_3_2_backbone_and_tilt_expression(
    model: MPF,
    X: np.ndarray,
    X_df: pd.DataFrame,
    grid_points: int = 200,
    save_path: Optional[str] = None,
) -> plt.Figure:
    """
    Figure 3.2: Backbone and tilt expression for all stages and features.

    For each stage × feature combination, plots:
    - Backbone b_j(x_j) on left y-axis
    - Expression (f_plus + f_minus) / b_j on right y-axis
    This expression, when multiplied by b_j, gives the first-order partial dependence:
    PD_j(x_j) = b_j * expression = f_plus + f_minus

    The expression can be negative (matching figure 3.1 where the difference between
    PD+ and PD- can be negative). This shows the signed contribution per unit of backbone.

    Layout: stages as rows, features as columns.

    Parameters:
    -----------
    model : MPF
        Trained MPF model
    X : np.ndarray
        Feature matrix
    X_df : pd.DataFrame
        Feature DataFrame with column names
    grid_points : int
        Number of grid points for evaluation
    save_path : str, optional
        Path to save figure

    Returns:
    --------
    fig : matplotlib.figure.Figure
        Figure object
    """
    n_features = X.shape[1]
    n_epochs = len(model.tree_grid_families)

    # Create grid of values for each feature
    feature_grids = []
    X_mean = X.mean(axis=0)
    X_grid_list = []

    for feat_idx in range(n_features):
        feat_min = X[:, feat_idx].min()
        feat_max = X[:, feat_idx].max()
        x_vals = np.linspace(feat_min, feat_max, grid_points)
        feature_grids.append(x_vals)

        # Create grid for this feature: vary this feature, others at mean
        X_grid_feat = np.tile(X_mean, (grid_points, 1))
        X_grid_feat[:, feat_idx] = x_vals
        X_grid_list.append(X_grid_feat)

    # Combine all grids
    X_grid_combined = np.vstack(
        X_grid_list
    )  # Shape: (grid_points * n_features, n_features)

    # Compute first-order PD functions to get constants C_+ and C_-
    first_order_pd = model.compute_first_order_partial_dependence_functions(
        X_grid_combined, X
    )

    # Layout: n_epochs rows, n_features columns
    fig, axes = plt.subplots(
        n_epochs, n_features, figsize=(6 * n_features, 4 * n_epochs), squeeze=False
    )

    for feat_idx in range(n_features):
        constants_per_epoch, pd_values = first_order_pd[feat_idx]

        # Extract the relevant rows for this feature
        start_idx = feat_idx * grid_points
        end_idx = (feat_idx + 1) * grid_points
        pd_values_feat = pd_values[start_idx:end_idx, :]

        # Extract f+ and f- for each epoch
        f_plus_all = pd_values_feat[:, ::2]  # Columns 0, 2, 4, ...
        f_minus_all = pd_values_feat[:, 1::2]  # Columns 1, 3, 5, ...

        # Get x values for this feature
        x_vals = feature_grids[feat_idx]

        for epoch_idx in range(n_epochs):
            ax = axes[epoch_idx, feat_idx]

            # Get constants for this epoch
            c_plus, c_minus = constants_per_epoch[epoch_idx]

            # Extract backbone and tilt values for this feature and stage
            backbone_vals, tilt_vals = extract_backbone_tilt(
                model, feat_idx, epoch_idx, x_vals
            )

            # Get f+ and f- for this Stage (these already include b_j and all scaling)
            f_plus = f_plus_all[:, epoch_idx]
            f_minus = f_minus_all[:, epoch_idx]

            # Compute the expression: (f_plus + f_minus) / b_j
            # This gives the expression that multiplies b_j to get the total PD
            # Total PD = f_plus + f_minus = b_j * expression
            # So expression = (f_plus + f_minus) / b_j
            # As in figure 3.1: the difference can be negative
            # Avoid division by zero
            exp_expression = np.where(
                backbone_vals > 1e-10,
                (f_plus + f_minus) / backbone_vals,
                np.zeros_like(backbone_vals),
            )

            # Plot backbone on left y-axis (single color for all plots)
            ax1 = ax
            backbone_color = "#1f77b4"  # Blue color for backbone
            line1 = ax1.plot(
                x_vals,
                backbone_vals,
                linewidth=1.5,
                color=backbone_color,
                alpha=0.8,
                label="$b_j$",
            )
            ax1.set_xlabel(FEATURE_NAMES[feat_idx], fontsize=10)
            if feat_idx == 0:
                ax1.set_ylabel(f"Stage {epoch_idx + 1}\nBackbone $b_j$", fontsize=10)
            ax1.tick_params(axis="y", labelcolor=backbone_color)

            # Plot expression on right y-axis (single color for all plots)
            ax2 = ax1.twinx()
            expression_color = "red"
            line2 = ax2.plot(
                x_vals,
                exp_expression,
                linewidth=1.5,
                color=expression_color,
                alpha=0.8,
                linestyle="--",
                label="$(f_+ + f_-) / b_j$",
            )
            ax2.set_ylabel("$(f_+ + f_-) / b_j$", fontsize=10, color=expression_color)
            ax2.tick_params(axis="y", labelcolor=expression_color)
            ax2.axhline(y=0, color="red", linestyle=":", linewidth=0.5, alpha=0.5)

            # Set y-axis limits for expression to avoid forcing 0 as minimum
            exp_min = exp_expression.min()
            exp_max = exp_expression.max()
            exp_range = exp_max - exp_min
            # Add 10% padding on each side, but don't force 0 if not needed
            if exp_range > 0:
                padding = exp_range * 0.1
                ax2.set_ylim(exp_min - padding, exp_max + padding)
            else:
                # If all values are the same, add small symmetric padding
                ax2.set_ylim(exp_min - 0.1, exp_max + 0.1)

            # Format constants for title
            if abs(c_plus) >= 1e-3 and abs(c_plus) < 1e3:
                c_plus_str = f"{c_plus:.4f}"
            elif abs(c_plus) >= 1e3:
                c_plus_str = f"{c_plus:.2f}"
            else:
                c_plus_str = f"{c_plus:.6f}"

            if abs(c_minus) >= 1e-3 and abs(c_minus) < 1e3:
                c_minus_str = f"{c_minus:.4f}"
            elif abs(c_minus) >= 1e3:
                c_minus_str = f"{c_minus:.2f}"
            else:
                c_minus_str = f"{c_minus:.6f}"

            ax1.set_title(
                f"{FEATURE_NAMES[feat_idx]}\nStage {epoch_idx + 1}: $C_+={c_plus_str}$, $C_-={c_minus_str}$",
                fontsize=10,
                fontweight="bold",
            )
            ax1.grid(True, alpha=0.3)

            # Combine legends
            lines = line1 + line2
            labels = [line.get_label() for line in lines]
            ax1.legend(lines, labels, loc="best", fontsize=8)

    plt.tight_layout(rect=[0, 0, 1, 0.99])

    if save_path is None:
        save_path = FIGURES_DIR / "figure_3_2_backbone_and_tilt_expression.svg"
    _save_fig(fig, save_path, bbox_inches="tight")
    print(f"Saved Figure 3.2 to {save_path}")

    return fig


def plot_figure_3_3_ice_curves(
    model: MPF,
    X: np.ndarray,
    X_df: pd.DataFrame,
    n_observations: int = 100,
    grid_points: int = 200,
    save_path: Optional[str] = None,
) -> plt.Figure:
    """
    Figure 3.3: Individual Conditional Expectation (ICE) curves.

    Plots ICE curves for f+ and f- for 100 random observations.
    Shows both f+ and f- curves on the same plot with different colors,
    and colors the difference between the two curves for each observation.

    Parameters:
    -----------
    model : MPF
        Trained MPF model
    X : np.ndarray
        Feature matrix
    X_df : pd.DataFrame
        Feature DataFrame with column names
    n_observations : int
        Number of random observations to plot (default: 100)
    grid_points : int
        Number of grid points for evaluation
    save_path : str, optional
        Path to save figure

    Returns:
    --------
    fig : matplotlib.figure.Figure
        Figure object
    """
    n_features = X.shape[1]
    n_epochs = len(model.tree_grid_families)

    # Select random observations
    n_samples = X.shape[0]
    if n_observations > n_samples:
        n_observations = n_samples
        warnings.warn(
            f"Requested {n_observations} observations but only {n_samples} available. Using all."
        )

    # Randomly select observations
    np.random.seed(42)  # For reproducibility
    selected_indices = np.random.choice(n_samples, size=n_observations, replace=False)
    selected_observations = X[selected_indices]

    # Layout: n_epochs rows, n_features columns (similar to Figure 3.1)
    fig, axes = plt.subplots(
        n_epochs, n_features, figsize=(6 * n_features, 4 * n_epochs), squeeze=False
    )

    for feat_idx in range(n_features):
        # Get feature range
        feat_min = X[:, feat_idx].min()
        feat_max = X[:, feat_idx].max()
        x_range = np.linspace(feat_min, feat_max, grid_points)

        # Compute ICE curves for selected observations
        # Returns shape: (n_obs, n_range_values, 2 * n_epochs)
        # Last dimension alternates: [f+_epoch0, f-_epoch0, f+_epoch1, f-_epoch1, ...]
        ice_values = model.compute_ice_curves(
            selected_observations, feature_index=feat_idx, x_range=x_range, data_x=X
        )

        for epoch_idx in range(n_epochs):
            ax = axes[epoch_idx, feat_idx]

            # Extract f+ and f- for this epoch
            # ice_values shape: (n_obs, n_range_values, 2 * n_epochs)
            # For epoch_idx, f+ is at index 2*epoch_idx, f- is at index 2*epoch_idx+1
            f_plus_idx = 2 * epoch_idx
            f_minus_idx = 2 * epoch_idx + 1

            f_plus_curves = ice_values[
                :, :, f_plus_idx
            ]  # Shape: (n_obs, n_range_values)
            f_minus_curves = ice_values[
                :, :, f_minus_idx
            ]  # Shape: (n_obs, n_range_values)

            # Flip f- sign to show actual contribution direction (similar to figure 3.1)
            f_minus_curves_flipped = -f_minus_curves

            # Plot ICE curves for each observation
            # Thin lines for many curves
            linewidth = 0.3
            alpha = 0.4

            # Plot f+ curves in darkred
            for obs_idx in range(n_observations):
                ax.plot(
                    x_range,
                    f_plus_curves[obs_idx, :],
                    linewidth=linewidth,
                    color="darkred",
                    alpha=alpha,
                )

            # Plot f- curves in darkblue
            for obs_idx in range(n_observations):
                ax.plot(
                    x_range,
                    f_minus_curves_flipped[obs_idx, :],
                    linewidth=linewidth,
                    color="darkblue",
                    alpha=alpha,
                )

            # Color the difference between f+ and f- for each observation
            # Difference = f+ - (-f-) = f+ + f- (but we're using flipped f-)
            # So difference = f_plus_curves - f_minus_curves_flipped
            for obs_idx in range(n_observations):
                diff = f_plus_curves[obs_idx, :] - f_minus_curves_flipped[obs_idx, :]

                # Fill area between curves, colored by sign of difference
                # Positive difference: green, negative difference: orange
                ax.fill_between(
                    x_range,
                    f_minus_curves_flipped[obs_idx, :],
                    f_plus_curves[obs_idx, :],
                    where=(diff >= 0),
                    color="green",
                    alpha=0.1,
                    linewidth=0,
                )
                ax.fill_between(
                    x_range,
                    f_minus_curves_flipped[obs_idx, :],
                    f_plus_curves[obs_idx, :],
                    where=(diff < 0),
                    color="orange",
                    alpha=0.1,
                    linewidth=0,
                )

            # Add horizontal line at y=0 for reference
            ax.axhline(y=0, color="black", linestyle="--", linewidth=0.5, alpha=0.5)

            ax.set_xlabel(FEATURE_NAMES[feat_idx], fontsize=10)
            if feat_idx == 0:
                ax.set_ylabel(f"ICE (Stage {epoch_idx + 1})", fontsize=10)

            ax.set_title(
                f"{FEATURE_NAMES[feat_idx]}\nStage {epoch_idx + 1}: ICE Curves (n={n_observations})",
                fontsize=10,
                fontweight="bold",
            )
            ax.grid(True, alpha=0.3)

            # Add legend only for first subplot
            if epoch_idx == 0 and feat_idx == 0:
                from matplotlib.lines import Line2D

                legend_elements = [
                    Line2D(
                        [0],
                        [0],
                        color="darkred",
                        linewidth=1.5,
                        label="$f_+$ ICE curves",
                    ),
                    Line2D(
                        [0],
                        [0],
                        color="darkblue",
                        linewidth=1.5,
                        label="$f_-$ ICE curves",
                    ),
                    plt.Rectangle(
                        (0, 0),
                        1,
                        1,
                        facecolor="green",
                        alpha=0.3,
                        label="Positive diff",
                    ),
                    plt.Rectangle(
                        (0, 0),
                        1,
                        1,
                        facecolor="orange",
                        alpha=0.3,
                        label="Negative diff",
                    ),
                ]
                ax.legend(handles=legend_elements, loc="best", fontsize=8)

    plt.tight_layout(rect=[0, 0, 1, 0.99])

    if save_path is None:
        save_path = FIGURES_DIR / "figure_3_3_ice_curves.svg"
    _save_fig(fig, save_path, bbox_inches="tight")
    print(f"Saved Figure 3.3 to {save_path}")

    return fig


def plot_figure_3_4_scaled_first_order_pd(
    model: MPF,
    X: np.ndarray,
    X_df: pd.DataFrame,
    grid_points: int = 200,
    save_path: Optional[str] = None,
) -> plt.Figure:
    """
    Figure 3.4: Scaled first-order partial dependence (f+ and f-) with scaled backbone.

    Similar to bikesharing figure 3.2, shows f+ and f- with dotted line for scaled backbone.
    Only plots: Latitude, Longitude, MedInc.

    Parameters:
    -----------
    model : MPF
        Trained MPF model
    X : np.ndarray
        Feature matrix
    X_df : pd.DataFrame
        Feature DataFrame with column names
    grid_points : int
        Number of grid points for evaluation
    save_path : str, optional
        Path to save figure

    Returns:
    --------
    fig : matplotlib.figure.Figure
        Figure object
    """
    # Select specific features: Latitude, Longitude, MedInc
    selected_features = ["Latitude", "Longitude", "MedInc"]

    # Find indices for these features
    feature_indices = []
    for feat_name in selected_features:
        if feat_name in FEATURE_NAMES:
            feature_indices.append(FEATURE_NAMES.index(feat_name))
        else:
            print(f"Warning: Feature '{feat_name}' not found in FEATURE_NAMES")

    if len(feature_indices) == 0:
        raise ValueError("None of the selected features were found in the dataset")

    print(
        f"Figure 3.4: Plotting features {selected_features} at indices {feature_indices}"
    )

    n_features = len(feature_indices)
    n_epochs = len(model.tree_grid_families)

    # Create grid of values for each feature
    feature_grids = []
    X_mean = X.mean(axis=0)
    X_grid_list = []

    for feat_idx in feature_indices:
        feat_min = X[:, feat_idx].min()
        feat_max = X[:, feat_idx].max()
        x_vals = np.linspace(feat_min, feat_max, grid_points)
        feature_grids.append(x_vals)

        X_grid_feat = np.tile(X_mean, (grid_points, 1))
        X_grid_feat[:, feat_idx] = x_vals
        X_grid_list.append(X_grid_feat)

    X_grid_combined = np.vstack(X_grid_list)

    # Compute first-order PD functions
    first_order_pd = model.compute_first_order_partial_dependence_functions(
        X_grid_combined, X
    )

    # Layout: n_epochs rows, n_features columns
    fig, axes = plt.subplots(
        n_epochs, n_features, figsize=(7 * n_features, 4 * n_epochs), squeeze=False
    )

    for plot_idx, feat_idx in enumerate(feature_indices):
        constants_per_epoch, pd_values = first_order_pd[feat_idx]

        # Extract relevant rows
        start_idx = plot_idx * grid_points
        end_idx = (plot_idx + 1) * grid_points
        pd_values_feat = pd_values[start_idx:end_idx, :]

        # Extract f+ and f-
        f_plus_all = pd_values_feat[:, ::2]
        f_minus_all = pd_values_feat[:, 1::2]

        x_vals = feature_grids[plot_idx]

        for epoch_idx in range(n_epochs):
            ax = axes[epoch_idx, plot_idx]

            # Get constants
            c_plus, c_minus = constants_per_epoch[epoch_idx]

            f_plus = f_plus_all[:, epoch_idx]
            f_minus = f_minus_all[:, epoch_idx]
            f_minus_flipped = -f_minus

            # Extract backbone values for this feature and epoch
            backbone_vals, _ = extract_backbone_tilt(model, feat_idx, epoch_idx, x_vals)

            # Compute sqrt(C_+ * (-C_-)) scaling factor
            # c_minus negated for geometric-mean sign
            sqrt_c_product = np.sqrt(c_plus * (-c_minus))

            # Scale backbone by sqrt(C_+ * (-C_-))
            backbone_scaled = backbone_vals * sqrt_c_product

            # For continuous features, use line plots
            # Compute difference
            diff = f_plus - f_minus_flipped

            # Fill areas
            ax.fill_between(
                x_vals,
                f_minus_flipped,
                f_plus,
                where=(diff >= 0),
                color="green",
                alpha=0.3,
            )
            ax.fill_between(
                x_vals,
                f_minus_flipped,
                f_plus,
                where=(diff < 0),
                color="orange",
                alpha=0.3,
            )

            # Plot lines
            ax.plot(
                x_vals,
                f_plus,
                linewidth=1.5,
                color="darkred",
                alpha=0.9,
                label="$f_+$",
            )
            ax.plot(
                x_vals,
                f_minus_flipped,
                linewidth=1.5,
                color="darkblue",
                alpha=0.9,
                label="$f_-$",
            )

            # Plot scaled backbone as dotted black line (skip for first epoch)
            if epoch_idx > 0:
                ax.plot(
                    x_vals,
                    backbone_scaled,
                    linewidth=2,
                    color="black",
                    linestyle=":",
                    alpha=0.9,
                    label=r"$\sqrt{C_+ C_-} \cdot b_j$",
                )

            ax.axhline(y=0, color="black", linestyle="--", linewidth=0.5, alpha=0.5)
            ax.set_xlabel(FEATURE_NAMES[feat_idx], fontsize=11)
            if plot_idx == 0:
                ax.set_ylabel(f"Scaled PD (Stage {epoch_idx + 1})", fontsize=11)

            c_plus_str = f"{c_plus:.4f}" if abs(c_plus) < 1000 else f"{c_plus:.2e}"
            c_minus_str = f"{c_minus:.4f}" if abs(c_minus) < 1000 else f"{c_minus:.2e}"
            ax.set_title(
                f"{FEATURE_NAMES[feat_idx]}\nStage {epoch_idx + 1}: $C_+={c_plus_str}$, $C_-={c_minus_str}$",
                fontsize=11,
                fontweight="bold",
            )
            ax.grid(True, alpha=0.3)
            ax.legend(loc="best", fontsize=9)

    plt.tight_layout(rect=[0, 0, 1, 0.99])

    if save_path is None:
        save_path = FIGURES_DIR / "figure_3_4_scaled_first_order_pd.svg"
    _save_fig(fig, save_path, bbox_inches="tight")
    print(f"Saved Figure 3.4 to {save_path}")

    return fig


def plot_figure_3_5_pd_comparison(
    model: MPF,
    X: np.ndarray,
    X_df: pd.DataFrame,
    grid_points: int = 1000,
    save_path: Optional[str] = None,
) -> List[plt.Figure]:
    """
    Figure 3.5: Partial dependence comparison for Latitude and Longitude.

    Compares PD plots from:
    1. MPF (first stage only)
    2. Explainable Boosting Regressor (EBR)
    3. XGBoost (blackbox), from xgb_model.json
    4. XGBoost (interpretable), from xgb_model_interp.json

    Creates separate plots for Latitude and Longitude, each saved as PDF.

    Parameters:
    -----------
    model : MPF
        Trained MPF model
    X : np.ndarray
        Feature matrix
    X_df : pd.DataFrame
        Feature DataFrame with column names
    grid_points : int
        Number of grid points for evaluation
    save_path : str, optional
        Base path to save figures (will append feature name and .pdf extension)

    Returns:
    --------
    figs : List[matplotlib.figure.Figure]
        List of figure objects (one per feature)
    """
    # Import required libraries
    import glex_rust
    import joblib
    import xgboost as xgb

    # Select features: Latitude and Longitude
    selected_features = ["Latitude", "Longitude"]
    feature_indices = []
    for feat_name in selected_features:
        if feat_name in FEATURE_NAMES:
            feature_indices.append(FEATURE_NAMES.index(feat_name))
        else:
            raise ValueError(f"Feature '{feat_name}' not found in FEATURE_NAMES")

    print(
        f"Figure 3.5: Computing PD for {selected_features} at indices {feature_indices}"
    )

    # Load EBR model
    ebr_model_path = MODELS_DIR / "ebm_model.pkl"
    if not ebr_model_path.exists():
        raise FileNotFoundError(f"EBR model not found at {ebr_model_path}")

    print(f"Loading EBR model from {ebr_model_path}...")
    ebr_model = joblib.load(ebr_model_path)

    xgb_model_path_json = MODELS_DIR / "xgb_model.json"
    xgb_model_path_pkl = MODELS_DIR / "xgb_model.pkl"
    xgb_model = None
    if xgb_model_path_json.exists():
        print(f"Loading XGBoost (blackbox) model from {xgb_model_path_json}...")
        xgb_model = xgb.XGBRegressor()
        xgb_model.load_model(str(xgb_model_path_json))
        print("XGBoost (blackbox) loaded.")
    elif xgb_model_path_pkl.exists():
        print(f"Loading XGBoost model from {xgb_model_path_pkl}...")
        xgb_model = joblib.load(xgb_model_path_pkl)
        print("XGBoost loaded.")

    xgb_model_interp_path = MODELS_DIR / "xgb_model_interp.json"
    xgb_model_interp = None
    if xgb_model_interp_path.exists():
        print(f"Loading XGBoost (interpretable) model from {xgb_model_interp_path}...")
        xgb_model_interp = xgb.XGBRegressor()
        xgb_model_interp.load_model(str(xgb_model_interp_path))
        print("XGBoost (interpretable) loaded.")

    # Create grid for each feature
    X_mean = X.mean(axis=0)
    figures = []

    for plot_idx, (feat_idx, feat_name) in enumerate(
        zip(feature_indices, selected_features)
    ):
        # Create separate figure for each feature
        fig, ax = plt.subplots(1, 1, figsize=(6, 4))

        # Get feature range
        feat_min = X[:, feat_idx].min()
        feat_max = X[:, feat_idx].max()
        x_vals = np.linspace(feat_min, feat_max, grid_points)

        # 1. MPF first-order PD (first stage only, stage_idx=0)
        print(f"Computing MPF PD for {feat_name} (Stage 1 only)...")
        X_grid_feat = np.tile(X_mean, (grid_points, 1))
        X_grid_feat[:, feat_idx] = x_vals

        # Compute first-order PD functions
        first_order_pd = model.compute_first_order_partial_dependence_functions(
            X_grid_feat, X
        )

        constants_per_epoch, pd_values = first_order_pd[feat_idx]

        # Extract f+ and f- for first stage (epoch_idx=0)
        f_plus_all = pd_values[:, ::2]  # Columns 0, 2, 4, ...
        f_minus_all = pd_values[:, 1::2]  # Columns 1, 3, 5, ...

        # Get constants for first stage
        c_plus, c_minus = constants_per_epoch[0]

        # Total PD = f+ + f- (already includes constants)
        pd_mpf = f_plus_all[:, 0] + f_minus_all[:, 0]

        # 2. EBR PD
        print(f"Computing EBR PD for {feat_name}...")
        # Create EBR-compatible DataFrame
        # EBR models expect feature names matching what they were trained with
        if hasattr(ebr_model, "feature_names_in_"):
            ebr_feature_names = list(ebr_model.feature_names_in_)
            # Create DataFrame with EBR feature names
            X_df_ebr = pd.DataFrame(X, columns=ebr_feature_names)
            ebr_feat_name = (
                ebr_feature_names[feat_idx]
                if feat_idx < len(ebr_feature_names)
                else None
            )
        else:
            # Fallback: use generic feature names
            ebr_feature_names = [f"feature_{i:04d}" for i in range(X.shape[1])]
            X_df_ebr = pd.DataFrame(X, columns=ebr_feature_names)
            ebr_feat_name = ebr_feature_names[feat_idx]

        # Compute PD manually (similar to gam.py compute_partial_dependence)
        pd_ebr = np.zeros(grid_points)
        X_temp = X_df_ebr.copy()
        for i, grid_val in enumerate(x_vals):
            X_temp[ebr_feat_name] = grid_val
            predictions = ebr_model.predict(X_temp)
            pd_ebr[i] = np.mean(predictions)

        # 3. XGBoost (blackbox) PD using glex
        pd_xgb = None
        if xgb_model is not None:
            print(f"Computing XGBoost (blackbox) PD for {feat_name} using glex...")
            fastpd = glex_rust.FastPD.from_xgboost(
                xgb_model,
                background_samples=X,
                n_threads=10,
            )
            X_eval = np.tile(X_mean, (grid_points, 1))
            X_eval[:, feat_idx] = x_vals
            pd_xgb = fastpd.pd_function(X_eval, feature_subset=[feat_idx])

        # 4. XGBoost (interpretable) PD using glex
        pd_xgb_interp = None
        if xgb_model_interp is not None:
            print(f"Computing XGBoost (interpretable) PD for {feat_name} using glex...")
            fastpd_interp = glex_rust.FastPD.from_xgboost(
                xgb_model_interp,
                background_samples=X,
                n_threads=10,
            )
            X_eval = np.tile(X_mean, (grid_points, 1))
            X_eval[:, feat_idx] = x_vals
            pd_xgb_interp = fastpd_interp.pd_function(X_eval, feature_subset=[feat_idx])

        # Plot all PD curves
        lw = 1.25
        ax.plot(
            x_vals, pd_mpf, linewidth=lw, color="blue", label="TSL (Stage 1)", alpha=0.8
        )
        ax.plot(x_vals, pd_ebr, linewidth=lw, color="green", label="EBR", alpha=0.8)
        if pd_xgb is not None:
            ax.plot(
                x_vals,
                pd_xgb,
                linewidth=lw,
                color="red",
                label="XGBoost (blackbox)",
                alpha=0.8,
            )
        if pd_xgb_interp is not None:
            ax.plot(
                x_vals,
                pd_xgb_interp,
                linewidth=lw,
                color="orange",
                label="XGBoost (interpretable)",
                alpha=0.8,
            )

        ax.axhline(y=0, color="black", linestyle="--", linewidth=0.5, alpha=0.5)
        ax.set_xlabel(feat_name, fontsize=12)
        # LaTeX y-labels: PD_lat for latitude, PD_lon for longitude
        pd_label = (
            r"$\mathrm{PD}_{\mathrm{lat}}$"
            if feat_name == "Latitude"
            else r"$\mathrm{PD}_{\mathrm{lon}}$"
        )
        ax.set_ylabel(pd_label, fontsize=12)
        ax.grid(True, alpha=0.3)

        # Add legend inside the plot
        ax.legend(loc="best", fontsize=11, frameon=True)

        plt.tight_layout()

        # Determine save path
        if save_path is None:
            base_path = FIGURES_DIR / "figure_3_5_pd_comparison"
        else:
            base_path = Path(save_path).with_suffix("")  # Remove extension

        # Save as PDF with feature name
        pdf_path = base_path.parent / f"{base_path.name}_{feat_name.lower()}.pdf"
        _save_fig(fig, pdf_path, bbox_inches="tight")
        print(f"Saved {feat_name} PD plot to {pdf_path}")

        figures.append(fig)
        plt.close(fig)  # Close figure to free memory

    return figures


def compute_local_explanation(model: MPF, x: np.ndarray) -> Dict[str, Any]:
    """
    Compute local explanation for a single point x.

    Returns stage contributions, backbone magnitudes, tilt sums, and feature breakdowns.

    Parameters:
    -----------
    model : MPF
        Trained MPF model
    x : np.ndarray
        Single data point (1D array of length n_features)

    Returns:
    --------
    explanation : dict
        Dictionary containing:
        - stage_contributions: array of s_ℓ(x) for each stage
        - f_plus_contributions: array of scaling_plus * f_plus for each stage
        - f_minus_contributions: array of -scaling_minus * f_minus for each stage
        - backbone_magnitudes: array of b^(ℓ)(x) for each stage
        - tilt_sums: array of d^(ℓ)(x) for each stage
        - feature_breakdown: dict mapping stage_idx to feature-level breakdown
        - total_prediction: sum of stage contributions
    """
    n_stages = len(model.tree_grid_families)
    n_features = len(x)

    stage_contributions = np.zeros(n_stages)
    f_plus_contributions = np.zeros(n_stages)
    f_minus_contributions = np.zeros(n_stages)
    backbone_magnitudes = np.zeros(n_stages)
    tilt_sums = np.zeros(n_stages)
    feature_breakdown = {}

    for stage_idx in range(n_stages):
        tgf = model.tree_grid_families[stage_idx]
        tg = TreeGrid(tgf.combined_tree_grid)

        # Get lambda values
        lambda_plus = tg.lambda_plus
        lambda_minus = tg.lambda_minus

        # Get scaling factors from TreeGridFamily (defaults: scaling_plus=1.0, scaling_minus=0.0)
        scaling_plus = getattr(tgf, "scaling_plus", None) or 1.0
        scaling_minus = getattr(tgf, "scaling_minus", None) or 0.0

        # Compute backbone and tilt for each feature
        backbone_per_feature = np.zeros(n_features)
        tilt_per_feature = np.zeros(n_features)

        for feat_idx in range(n_features):
            backbone_vals = tg.backbone_values[feat_idx]
            tilt_vals = tg.tilt_values[feat_idx]
            splits = tg.splits[feat_idx]

            # Find bin index
            if len(splits) == 0:
                bin_idx = 0
            else:
                bin_idx = np.searchsorted(splits, x[feat_idx], side="right")
                bin_idx = min(bin_idx, len(backbone_vals) - 1)

            backbone_per_feature[feat_idx] = backbone_vals[bin_idx]
            tilt_per_feature[feat_idx] = tilt_vals[bin_idx]

        # Compute multiplicative backbone: b^(ℓ)(x) = ∏_j b_j^(ℓ)(x_j)
        backbone_magnitude = np.prod(backbone_per_feature)
        backbone_magnitudes[stage_idx] = backbone_magnitude

        # Compute additive tilt: d^(ℓ)(x) = ∑_j d_j^(ℓ)(x_j)
        tilt_sum = np.sum(tilt_per_feature)
        tilt_sums[stage_idx] = tilt_sum

        # Compute f_plus and f_minus (unscaled)
        f_plus = lambda_plus * backbone_magnitude * np.exp(tilt_sum)
        f_minus = lambda_minus * backbone_magnitude * np.exp(-tilt_sum)

        # Apply scaling and store components separately
        f_plus_contrib = scaling_plus * f_plus
        f_minus_contrib = -scaling_minus * f_minus  # negative sign for display

        f_plus_contributions[stage_idx] = f_plus_contrib
        f_minus_contributions[stage_idx] = f_minus_contrib

        # Total stage contribution
        stage_contrib = f_plus_contrib + f_minus_contrib
        stage_contributions[stage_idx] = stage_contrib

        # Store feature breakdown
        feature_breakdown[stage_idx] = {
            "backbone_per_feature": backbone_per_feature.copy(),
            "tilt_per_feature": tilt_per_feature.copy(),
            "lambda_plus": lambda_plus,
            "lambda_minus": lambda_minus,
            "scaling_plus": scaling_plus,
            "scaling_minus": scaling_minus,
        }

    total_prediction = np.sum(stage_contributions)

    return {
        "stage_contributions": stage_contributions,
        "f_plus_contributions": f_plus_contributions,
        "f_minus_contributions": f_minus_contributions,
        "backbone_magnitudes": backbone_magnitudes,
        "tilt_sums": tilt_sums,
        "feature_breakdown": feature_breakdown,
        "total_prediction": total_prediction,
    }


def plot_figure_5_local_explanations(
    model: MPF,
    X: np.ndarray,
    X_df: pd.DataFrame,
    point_a: Optional[np.ndarray] = None,
    point_b: Optional[np.ndarray] = None,
    save_path: Optional[str] = None,
    show_feature_decomposition: bool = True,
    top_k_features: int = 3,
    show_zero_line: bool = True,
) -> plt.Figure:
    """
    Figure 5: Local explanations for two points (desert vs coastal) with split f+/f- bars.

    Each stage shows:
    - Blue bar (f+): positive contribution from scaling_plus * f_plus
    - Red bar (f-): negative contribution from -scaling_minus * f_minus
    - Vertical line at net contribution
    - Optional: feature-level decomposition within each bar using log-scale percentages

    Parameters:
    -----------
    model : MPF
        Trained MPF model
    X : np.ndarray
        Feature matrix
    X_df : pd.DataFrame
        Feature DataFrame with column names
    point_a : np.ndarray, optional
        First point (desert). If None, selects automatically.
    point_b : np.ndarray, optional
        Second point (coastal). If None, selects automatically.
    save_path : str, optional
        Path to save figure
    show_feature_decomposition : bool, optional
        If True, show stacked feature contributions within each f+/f- bar
    top_k_features : int, optional
        Number of top features to annotate
    show_zero_line : bool, optional
        If True, show vertical reference line at x=0. Default is True.

    Returns:
    --------
    fig : matplotlib.figure.Figure
        Figure object
    """
    # Select points if not provided
    if point_a is None:
        # Desert point: row 2799 (lowest house value in desert region)
        df = pd.read_csv(DATA_DIR / "44977_california_housing.csv", header=None)
        point_a = df.iloc[2799, :-1].values  # All columns except last (target)

    if point_b is None:
        # Coastal point: Los Angeles coordinates (34.0-34.5°N, -118.5 to -118.0°W)
        lat_idx = 1
        lon_idx = 0
        # LA coordinates: approximately 34.05°N, 118.25°W
        target_lat = 34.05
        target_lon = -118.25
        # Find closest point in dataset to these coordinates
        distances = np.sqrt(
            (X[:, lat_idx] - target_lat) ** 2 + (X[:, lon_idx] - target_lon) ** 2
        )
        point_b = X[np.argmin(distances)]

    # Compute local explanations
    expl_a = compute_local_explanation(model, point_a)
    expl_b = compute_local_explanation(model, point_b)

    # Create figure with two panels
    fig, axes = plt.subplots(2, 1, figsize=(10, 14))

    # Colors for f+ and f- (shared across both panels)
    color_f_plus = "#1f77b4"  # Blue
    color_f_minus = "#d62728"  # Red

    for panel_idx, (point, expl, title) in enumerate(
        [(point_a, expl_a, "Desert Point"), (point_b, expl_b, "Coastal Point")]
    ):
        ax = axes[panel_idx]

        stage_contribs = expl["stage_contributions"]
        f_plus_contribs = expl["f_plus_contributions"]
        f_minus_contribs = expl["f_minus_contributions"]
        n_stages = len(stage_contribs)

        # Sort stages by absolute contribution magnitude
        sorted_indices = np.argsort(np.abs(stage_contribs))[::-1]
        sorted_net_contribs = stage_contribs[sorted_indices]
        sorted_f_plus = f_plus_contribs[sorted_indices]
        sorted_f_minus = f_minus_contribs[sorted_indices]
        stage_labels = [f"Stage {idx + 1}" for idx in sorted_indices]

        # Base value (0 for MPF, as contributions sum to prediction)
        base_value = 0.0
        total_prediction = expl["total_prediction"]

        # Build waterfall: cumulative values for net contributions
        cumulative = np.zeros(n_stages + 1)
        cumulative[0] = base_value
        for i in range(n_stages):
            cumulative[i + 1] = cumulative[i] + sorted_net_contribs[i]

        # Check if any bars cross zero
        any_bar_crosses_zero = False
        for i in range(n_stages):
            net_contribution = cumulative[i + 1]
            f_plus_val = sorted_f_plus[i]
            f_minus_val = sorted_f_minus[i]
            # Bar extends from (net_contribution + f_minus_val) to (net_contribution + f_plus_val)
            # Check if this interval contains zero
            bar_left = net_contribution + f_minus_val
            bar_right = net_contribution + f_plus_val
            if bar_left <= 0 <= bar_right:
                any_bar_crosses_zero = True
                break

        # Plot settings
        bar_positions = np.arange(n_stages)
        bar_height = 0.7

        # Helper function to trim feature name to fit within segment width
        def trim_feature_name_for_segment(feature_name, segment_width, ax, fontsize=6):
            """
            Trim feature name to fit within segment width.
            Uses a conservative estimate based on segment width relative to typical axis range.
            """
            # Get first word of feature name
            base_name = feature_name.split()[0]

            # Use a conservative approach: estimate that each character takes up
            # a certain fraction of the segment width
            # For fontsize 6, estimate roughly 8-10% of segment width per character
            # 12% margin to avoid overflow
            chars_per_segment_ratio = 0.12  # Each character takes ~12% of segment width

            # Calculate max characters (with 25% padding on each side = 50% total)
            padding_ratio = 0.25  # 25% padding on each side
            available_ratio = 1 - 2 * padding_ratio  # 50% of segment available for text
            max_chars = max(1, int(available_ratio / chars_per_segment_ratio))

            # cap label length
            max_chars = min(max_chars, 10)

            # Trim to fit
            if len(base_name) > max_chars:
                return base_name[:max_chars]
            return base_name

        # For each stage, draw split bars showing f+ and f- components
        for i in range(n_stages):
            stage_idx = sorted_indices[i]
            y_pos = bar_positions[i]
            net_contribution = cumulative[
                i + 1
            ]  # Position of net contribution (black bar)
            f_plus_val = sorted_f_plus[i]  # Positive value
            f_minus_val = sorted_f_minus[i]  # Negative value

            # Get feature breakdown for this stage
            breakdown = expl["feature_breakdown"][stage_idx]
            backbone_per_feature = breakdown["backbone_per_feature"]

            # Compute log-scale percentage contributions
            # Features with backbone far from 1 (either very small like 0.01 or very large like 100)
            # have large influence on the multiplicative structure
            log_contribs = []
            feature_indices = []
            for feat_idx, bb_val in enumerate(backbone_per_feature):
                # Only exclude if exactly 0 or too close to 0 for numerical stability
                if bb_val > 1e-15:  # Avoid log(0), but include very small values
                    log_val = np.abs(np.log(bb_val))
                    # Only exclude features very close to 1 (neutral multipliers)
                    if (
                        log_val > 1e-4
                    ):  # |log(bb)| > 1e-4 means bb < 0.9999 or bb > 1.0001
                        log_contribs.append(log_val)
                        feature_indices.append(feat_idx)

            # Normalize to percentages
            total_log = sum(log_contribs) if log_contribs else 1.0
            percentages = (
                [lc / total_log for lc in log_contribs] if log_contribs else []
            )

            # Sort features by percentage (largest first - closest to center)
            if percentages:
                sorted_pairs = sorted(
                    zip(percentages, feature_indices, log_contribs),
                    key=lambda x: x[0],
                    reverse=True,
                )
                percentages_sorted = [p for p, _, _ in sorted_pairs]
                feature_indices_sorted = [f for _, f, _ in sorted_pairs]
            else:
                percentages_sorted = []
                feature_indices_sorted = []

            # Draw f+ bar (blue, extending RIGHT from net_contribution)
            if abs(f_plus_val) > 1e-10:
                if show_feature_decomposition and percentages_sorted:
                    # Draw stacked segments, largest (darkest) closest to center
                    cumulative_width = 0
                    n_segments = len(percentages_sorted)
                    for seg_idx, (feat_idx, pct) in enumerate(
                        zip(feature_indices_sorted, percentages_sorted)
                    ):
                        segment_width = f_plus_val * pct

                        # Gradient: darkest at center (seg_idx=0), lightest at edge
                        alpha_base = (
                            0.85 - (seg_idx / max(n_segments - 1, 1)) * 0.4
                        )  # 0.85 -> 0.45

                        ax.barh(
                            y_pos,
                            segment_width,
                            bar_height,
                            left=net_contribution + cumulative_width,
                            color=color_f_plus,
                            alpha=alpha_base,
                            edgecolor="white",
                            linewidth=0.5,
                            zorder=2,
                        )

                        # Annotate segment with feature name if wide enough
                        if segment_width > 0.03 * abs(
                            f_plus_val
                        ):  # Only annotate if > 3% of bar width
                            # Trim feature name to fit within segment
                            trimmed_name = trim_feature_name_for_segment(
                                FEATURE_NAMES[feat_idx], segment_width, ax
                            )
                            ax.text(
                                net_contribution + cumulative_width + segment_width / 2,
                                y_pos,
                                trimmed_name,
                                ha="center",
                                va="center",
                                fontsize=6,
                                color="white",
                                zorder=3,
                            )

                        cumulative_width += segment_width
                else:
                    # Simple f+ bar without decomposition
                    ax.barh(
                        y_pos,
                        f_plus_val,
                        bar_height,
                        left=net_contribution,
                        color=color_f_plus,
                        alpha=0.7,
                        edgecolor="black",
                        linewidth=1.0,
                        zorder=2,
                    )

            # Draw f- bar (red, extending LEFT from net_contribution)
            if abs(f_minus_val) > 1e-10:
                if show_feature_decomposition and percentages_sorted:
                    # Draw stacked segments, largest (darkest) closest to center
                    # f_minus_val is negative, so segments go leftward
                    cumulative_width = 0
                    n_segments = len(percentages_sorted)
                    for seg_idx, (feat_idx, pct) in enumerate(
                        zip(feature_indices_sorted, percentages_sorted)
                    ):
                        segment_width = f_minus_val * pct  # Negative width

                        # Gradient: darkest at center (seg_idx=0), lightest at edge
                        alpha_base = (
                            0.85 - (seg_idx / max(n_segments - 1, 1)) * 0.4
                        )  # 0.85 -> 0.45

                        ax.barh(
                            y_pos,
                            segment_width,
                            bar_height,
                            left=net_contribution + cumulative_width,
                            color=color_f_minus,
                            alpha=alpha_base,
                            edgecolor="white",
                            linewidth=0.5,
                            zorder=2,
                        )

                        # Annotate segment with feature name if wide enough
                        if abs(segment_width) > 0.03 * abs(
                            f_minus_val
                        ):  # Only annotate if > 3% of bar width
                            # Trim feature name to fit within segment
                            trimmed_name = trim_feature_name_for_segment(
                                FEATURE_NAMES[feat_idx], segment_width, ax
                            )
                            ax.text(
                                net_contribution + cumulative_width + segment_width / 2,
                                y_pos,
                                trimmed_name,
                                ha="center",
                                va="center",
                                fontsize=6,
                                color="white",
                                zorder=3,
                            )

                        cumulative_width += segment_width
                else:
                    # Simple f- bar without decomposition
                    ax.barh(
                        y_pos,
                        f_minus_val,
                        bar_height,
                        left=net_contribution,
                        color=color_f_minus,
                        alpha=0.7,
                        edgecolor="black",
                        linewidth=1.0,
                        zorder=2,
                    )

            # Draw vertical line at net contribution (center position)
            ax.plot(
                [net_contribution, net_contribution],
                [y_pos - bar_height / 2, y_pos + bar_height / 2],
                color="black",
                linewidth=2.5,
                alpha=0.9,
                zorder=4,
            )

            # Annotate net contribution value above the black bar
            value_str = f"{sorted_net_contribs[i]:+.2f}"
            # Color background: pink for negative, light blue for positive
            bg_color = (
                "#FFB6C1" if sorted_net_contribs[i] < 0 else "#ADD8E6"
            )  # Pink or light blue
            ax.text(
                net_contribution,
                y_pos + bar_height / 2 + 0.15,
                value_str,
                ha="center",
                va="bottom",
                fontsize=8,
                color="black",
                bbox=dict(
                    boxstyle="round,pad=0.2",
                    facecolor=bg_color,
                    alpha=0.9,
                    edgecolor="black",
                    linewidth=0.8,
                ),
                zorder=5,
            )

            # Annotate top-k features for this stage (on the right side)
            if show_feature_decomposition and percentages_sorted and top_k_features > 0:
                top_features_str = ", ".join(
                    [
                        f"{FEATURE_NAMES[fidx]}: {pct * 100:.0f}%"
                        for fidx, pct in zip(
                            feature_indices_sorted[:top_k_features],
                            percentages_sorted[:top_k_features],
                        )
                    ]
                )

                ax.text(
                    0.98,
                    y_pos,
                    top_features_str,
                    ha="right",
                    va="center",
                    fontsize=8,
                    color="black",
                    transform=ax.get_yaxis_transform(),
                    bbox=dict(
                        boxstyle="round,pad=0.3",
                        facecolor="white",
                        edgecolor="black",
                        linewidth=0.8,
                        alpha=0.95,
                    ),
                    zorder=10,  # High z-order to appear in front of bars
                    clip_on=False,
                )

        # Draw connecting lines (waterfall effect)
        for i in range(n_stages - 1):
            y_start = bar_positions[i] - bar_height / 2
            y_end = bar_positions[i + 1] + bar_height / 2
            x_val = cumulative[i + 1]
            ax.plot(
                [x_val, x_val],
                [y_start, y_end],
                "k--",
                linewidth=1.0,
                alpha=0.4,
                zorder=1,
            )

        # Final prediction position for x-axis tick and info text
        final_x = cumulative[-1]

        # Get current x-axis ticks and labels
        current_ticks = list(ax.get_xticks())
        current_labels = [label.get_text() for label in ax.get_xticklabels()]

        # Check if final_x is close to any existing tick
        tolerance = (
            0.01 * (max(current_ticks) - min(current_ticks))
            if len(current_ticks) > 1 and max(current_ticks) != min(current_ticks)
            else 0.1
        )
        tick_found = False
        for i, tick in enumerate(current_ticks):
            if abs(tick - final_x) < tolerance:
                # Update existing tick label
                current_labels[i] = f"{total_prediction:.2f}"
                tick_found = True
                break

        # If not found, add as new tick
        if not tick_found:
            current_ticks.append(final_x)
            current_labels.append(f"{total_prediction:.2f}")

        # Set ticks and labels
        ax.set_xticks(current_ticks)
        ax.set_xticklabels(current_labels)

        # Color the final prediction tick label green and add green box
        for tick, label in zip(ax.get_xticks(), ax.get_xticklabels()):
            if abs(tick - final_x) < tolerance:
                # White text for visibility on green background
                label.set_color("white")
                label.set_fontweight("bold")
                # Add green box around the label
                label.set_bbox(
                    dict(
                        boxstyle="round,pad=0.3",
                        facecolor="#2ca02c",
                        edgecolor="#2ca02c",
                        linewidth=2.0,
                        alpha=1.0,
                    )
                )
                # Bring to front by setting clip_on=False
                label.set_clip_on(False)

        # Set labels and styling
        ax.set_yticks(bar_positions)
        ax.set_yticklabels(stage_labels, fontsize=10)
        ax.set_xlabel("Contribution", fontsize=12)
        ax.set_title(
            f"{title}, Total Prediction: ${total_prediction:,.0f}",
            fontsize=14,
        )
        ax.grid(True, alpha=0.2, axis="x", zorder=0)
        if show_zero_line and any_bar_crosses_zero:
            ax.axvline(
                x=0, color="black", linestyle="-", linewidth=1.0, alpha=0.5, zorder=1
            )
        ax.invert_yaxis()

        # Add feature values annotation (lower left corner)
        # Show all covariates
        info_lines = []
        for feature_name in FEATURE_NAMES:
            feature_idx = FEATURE_NAMES.index(feature_name)
            feature_val = point[feature_idx]
            # Format with appropriate precision and abbreviations
            if feature_name == "Longitude":
                info_lines.append(f"Lon: {feature_val:.2f}")
            elif feature_name == "Latitude":
                info_lines.append(f"Lat: {feature_val:.2f}")
            elif feature_name == "MedInc":
                info_lines.append(f"MedInc: {feature_val:.2f}")
            elif feature_name == "HouseAge":
                info_lines.append(f"Age: {int(round(feature_val))}")
            elif feature_name == "TotalRooms":
                info_lines.append(f"Rooms: {int(round(feature_val))}")
            elif feature_name == "TotalBedrooms":
                info_lines.append(f"Bedrms: {int(round(feature_val))}")
            elif feature_name == "Population":
                info_lines.append(f"Pop: {int(round(feature_val))}")
            elif feature_name == "Households":
                info_lines.append(f"Households: {int(round(feature_val))}")
            else:
                # Default formatting
                if abs(feature_val - round(feature_val)) < 0.01:
                    info_lines.append(f"{feature_name}: {int(round(feature_val))}")
                else:
                    info_lines.append(f"{feature_name}: {feature_val:.2f}")

        info_text = "\n".join(info_lines)
        # Position info text to the left of final prediction with spacing
        # Calculate spacing: use ~5% of the x-axis range as spacing
        xlim = ax.get_xlim()
        x_range = xlim[1] - xlim[0]
        spacing = x_range * 0.05  # 5% of range as spacing
        info_x = final_x - spacing

        ax.text(
            info_x,
            bar_positions[-1] + bar_height / 2 + 0.3,  # Position near the last stage
            info_text,
            fontsize=9,
            verticalalignment="bottom",
            horizontalalignment="right",
            bbox=dict(boxstyle="round", facecolor="wheat", alpha=0.7),
            zorder=10,
        )

    # Add shared legend at the bottom of the figure
    from matplotlib.patches import Patch

    legend_elements = [
        Patch(
            facecolor=color_f_plus,
            alpha=0.7,
            edgecolor="black",
            label="f+ (positive component)",
        ),
        Patch(
            facecolor=color_f_minus,
            alpha=0.7,
            edgecolor="black",
            label="f- (negative component)",
        ),
        Patch(
            facecolor="white",
            edgecolor="black",
            linewidth=2,
            label="Net contribution",
        ),
    ]
    fig.legend(
        handles=legend_elements,
        loc="lower center",
        bbox_to_anchor=(0.5, 0.02),
        ncol=3,
        fontsize=9,
        framealpha=0.9,
    )

    plt.tight_layout(rect=[0, 0.03, 1, 1])  # Add small bottom margin for legend

    if save_path is None:
        save_path = FIGURES_DIR / "figure_5_local_explanations.svg"
    _save_fig(fig, save_path, bbox_inches="tight")
    print(f"Saved Figure 5 to {save_path}")

    # Also save detailed breakdown to file
    breakdown_path = LOCAL_EXPLANATIONS_DIR / "local_explanations_detailed.json"
    breakdown_data = {
        "point_a": {
            "coordinates": point_a.tolist(),
            "stage_contributions": expl_a["stage_contributions"].tolist(),
            "f_plus_contributions": expl_a["f_plus_contributions"].tolist(),
            "f_minus_contributions": expl_a["f_minus_contributions"].tolist(),
            "backbone_magnitudes": expl_a["backbone_magnitudes"].tolist(),
            "tilt_sums": expl_a["tilt_sums"].tolist(),
            "total_prediction": float(expl_a["total_prediction"]),
        },
        "point_b": {
            "coordinates": point_b.tolist(),
            "stage_contributions": expl_b["stage_contributions"].tolist(),
            "f_plus_contributions": expl_b["f_plus_contributions"].tolist(),
            "f_minus_contributions": expl_b["f_minus_contributions"].tolist(),
            "backbone_magnitudes": expl_b["backbone_magnitudes"].tolist(),
            "tilt_sums": expl_b["tilt_sums"].tolist(),
            "total_prediction": float(expl_b["total_prediction"]),
        },
    }
    with open(breakdown_path, "w") as f:
        json.dump(breakdown_data, f, indent=2)
    print(f"Saved detailed breakdown to {breakdown_path}")

    return fig


def plot_figure_5_1_map_annotations(
    model: MPF,
    X: np.ndarray,
    X_df: pd.DataFrame,
    point_a: Optional[np.ndarray] = None,
    point_b: Optional[np.ndarray] = None,
    grid_points: int = 100,
    cmap: str = "viridis",
    save_path: Optional[str] = None,
) -> plt.Figure:
    """
    Figure 5.1: California map with partial dependence overlay and annotated points.

    Shows the 2D partial dependence function f_lat,lon(lat, lon) overlaid on a California map,
    with the two local explanation points (desert and coastal) annotated.

    Parameters:
    -----------
    model : MPF
        Trained MPF model
    X : np.ndarray
        Feature matrix
    X_df : pd.DataFrame
        Feature DataFrame with column names
    point_a : np.ndarray, optional
        First point (desert). If None, selects automatically.
    point_b : np.ndarray, optional
        Second point (coastal). If None, selects automatically.
    grid_points : int
        Number of grid points per axis for partial dependence
    cmap : str
        Colormap for partial dependence
    save_path : str, optional
        Path to save figure

    Returns:
    --------
    fig : matplotlib.figure.Figure
        Figure object
    """

    # Select points if not provided
    if point_a is None:
        # Desert point: row 2799 (lowest house value in desert region)
        df = pd.read_csv(DATA_DIR / "44977_california_housing.csv", header=None)
        point_a = df.iloc[2799, :-1].values  # All columns except last (target)

    if point_b is None:
        # LA point: row 4556 (exact LA coordinates)
        df = pd.read_csv(DATA_DIR / "44977_california_housing.csv", header=None)
        point_b = df.iloc[4556, :-1].values  # All columns except last (target)

    # Get coordinate bounds
    lon_idx = 0
    lat_idx = 1
    lon_min, lon_max = X[:, lon_idx].min(), X[:, lon_idx].max()
    lat_min, lat_max = X[:, lat_idx].min(), X[:, lat_idx].max()

    # Create grid for partial dependence
    lon_vals = np.linspace(lon_min, lon_max, grid_points)
    lat_vals = np.linspace(lat_min, lat_max, grid_points)
    LON, LAT = np.meshgrid(lon_vals, lat_vals)

    # Convert to [lon, lat] order for model (fixed_indices expects [0=lon, 1=lat])
    grid_flat = np.column_stack([LON.ravel(), LAT.ravel()])

    # Compute partial dependence
    # Returns tuple: (constants_per_epoch, pd_values)
    # - constants_per_epoch: list of (C_plus, C_minus) per Stage (expectation over marginalized features)
    # - pd_values: Array2 of shape (n_points, 2 * n_epochs) with columns [f+_epoch0, f-_epoch0, f+_epoch1, f-_epoch1, ...]
    #   Values already include constants: f+ = ∏_{j ∈ S} f_j(x_j) · E[∏_{j ∉ S} f_j(X_j)]
    constants_per_epoch, partial_dep = model.compute_partial_dependence_function(
        [0, 1], grid_flat, X
    )
    n_epochs = partial_dep.shape[1] // 2

    # Extract f+ and f- for each epoch
    f_plus_all = partial_dep[:, ::2]  # Columns 0, 2, 4, ... (f+ for each epoch)
    f_minus_all = partial_dep[:, 1::2]  # Columns 1, 3, 5, ... (f- for each epoch)

    # Compute combined PD for each epoch: f+ + f- multiplied by scaling factor
    # Lambdas scaled; multiply each stage prediction
    # Then sum across all epochs
    predictions = np.zeros(partial_dep.shape[0])
    for epoch_idx in range(n_epochs):
        # Get constants for this Stage (needed for scaling)
        c_plus, c_minus = constants_per_epoch[epoch_idx]

        # Compute sqrt(C_+ * (-C_-)) scaling factor
        sqrt_c_product = np.sqrt(c_plus * (-c_minus))

        # Compute PD for this stage: f+ + f- multiplied by scaling factor
        pd_Stage = (
            f_plus_all[:, epoch_idx] + f_minus_all[:, epoch_idx]
        ) * sqrt_c_product
        predictions += pd_Stage

    predictions = predictions.reshape(LAT.shape)

    # Helper function to set up map axis
    def _setup_map_axis(ax, lon_min, lon_max, lat_min, lat_max, margin=0.5):
        """Helper to set up map axis with features."""
        ax.add_feature(cfeature.COASTLINE)
        ax.add_feature(cfeature.STATES)
        ax.set_extent(
            [lon_min - margin, lon_max + margin, lat_min - margin, lat_max + margin],
            crs=ccrs.PlateCarree(),
        )
        gl = ax.gridlines(
            draw_labels=True, linewidth=0.5, color="gray", alpha=0.5, linestyle="--"
        )
        gl.top_labels = gl.right_labels = False

    # Create figure with map projection
    fig, ax = plt.subplots(
        figsize=(12, 8), subplot_kw=dict(projection=ccrs.PlateCarree())
    )
    _setup_map_axis(ax, lon_min, lon_max, lat_min, lat_max)

    # Plot partial dependence as contour
    cs = ax.contourf(
        LON,
        LAT,
        predictions,
        levels=20,
        cmap=cmap,
        transform=ccrs.PlateCarree(),
        alpha=0.7,
    )
    plt.colorbar(cs, ax=ax, shrink=0.6, pad=0.05, label="Partial Dependence")

    # Compute predictions for both points
    pred_a = model.predict(point_a.reshape(1, -1))[0]
    pred_b = model.predict(point_b.reshape(1, -1))[0]

    # Annotate point A (desert)
    ax.scatter(
        point_a[lon_idx],
        point_a[lat_idx],
        c="red",
        s=200,
        marker="*",
        edgecolors="black",
        linewidths=2,
        transform=ccrs.PlateCarree(),
        zorder=10,
        label="Desert Point",
    )
    ax.annotate(
        f"Desert\nPred: ${pred_a:,.0f}",
        xy=(point_a[lon_idx], point_a[lat_idx]),
        xytext=(10, 10),
        textcoords="offset points",
        fontsize=11,
        fontweight="bold",
        color="red",
        bbox=dict(boxstyle="round,pad=0.3", facecolor="white", alpha=0.8),
        transform=ccrs.PlateCarree(),
        zorder=11,
        ha="left",
    )

    # Annotate point B (coastal)
    ax.scatter(
        point_b[lon_idx],
        point_b[lat_idx],
        c="blue",
        s=200,
        marker="*",
        edgecolors="black",
        linewidths=2,
        transform=ccrs.PlateCarree(),
        zorder=10,
        label="Coastal Point",
    )
    ax.annotate(
        f"Coastal\nPred: ${pred_b:,.0f}",
        xy=(point_b[lon_idx], point_b[lat_idx]),
        xytext=(10, 10),
        textcoords="offset points",
        fontsize=11,
        fontweight="bold",
        color="blue",
        bbox=dict(boxstyle="round,pad=0.3", facecolor="white", alpha=0.8),
        transform=ccrs.PlateCarree(),
        zorder=11,
        ha="left",
    )

    ax.set_title(
        "Partial Dependence f_lat,lon(lat, lon) with Local Explanation Points",
        fontsize=14,
        fontweight="bold",
    )
    ax.legend(loc="upper right", fontsize=10)

    plt.tight_layout()

    if save_path is None:
        save_path = FIGURES_DIR / "figure_5_1_map_annotations.svg"
    _save_fig(fig, save_path, bbox_inches="tight")
    print(f"Saved Figure 5.1 to {save_path}")

    return fig


def plot_figure_5_2_stagewise_partial_dependence(
    model: MPF,
    X: np.ndarray,
    X_df: pd.DataFrame,
    point_a: Optional[np.ndarray] = None,
    point_b: Optional[np.ndarray] = None,
    grid_points: int = 100,
    cmap: str = "viridis",
    save_path: Optional[str] = None,
) -> plt.Figure:
    """
    Figure 5.2: Stage-wise partial dependence maps with annotated points.

    Shows partial dependence f_lat,lon(lat, lon) for each stage separately,
    with the two local explanation points annotated with their stage contributions.

    Parameters:
    -----------
    model : MPF
        Trained MPF model
    X : np.ndarray
        Feature matrix
    X_df : pd.DataFrame
        Feature DataFrame with column names
    point_a : np.ndarray, optional
        First point (desert). If None, selects automatically.
    point_b : np.ndarray, optional
        Second point (coastal). If None, selects automatically.
    grid_points : int
        Number of grid points per axis for partial dependence
    cmap : str
        Colormap for partial dependence
    save_path : str, optional
        Path to save figure

    Returns:
    --------
    fig : matplotlib.figure.Figure
        Figure object
    """

    # Select points if not provided
    if point_a is None:
        lat_idx = 1
        lon_idx = 0
        target_lat = 37.5
        target_lon = -118.5
        distances = np.sqrt(
            (X[:, lat_idx] - target_lat) ** 2 + (X[:, lon_idx] - target_lon) ** 2
        )
        point_a = X[np.argmin(distances)]

    if point_b is None:
        # LA point: row 4556 (exact LA coordinates)
        df = pd.read_csv(DATA_DIR / "44977_california_housing.csv", header=None)
        point_b = df.iloc[4556, :-1].values  # All columns except last (target)

    # Get coordinate bounds
    lon_idx = 0
    lat_idx = 1
    lon_min, lon_max = X[:, lon_idx].min(), X[:, lon_idx].max()
    lat_min, lat_max = X[:, lat_idx].min(), X[:, lat_idx].max()

    # Create grid for partial dependence
    lon_vals = np.linspace(lon_min, lon_max, grid_points)
    lat_vals = np.linspace(lat_min, lat_max, grid_points)
    LON, LAT = np.meshgrid(lon_vals, lat_vals)

    # Convert to [lon, lat] order for model (fixed_indices expects [0=lon, 1=lat])
    grid_flat = np.column_stack([LON.ravel(), LAT.ravel()])

    # Compute partial dependence for all stages
    # Returns tuple: (constants_per_epoch, pd_values)
    # - constants_per_epoch: list of (C_plus, C_minus) per Stage (expectation over marginalized features)
    # - pd_values: Array2 of shape (n_points, 2 * n_epochs) with columns [f+_epoch0, f-_epoch0, f+_epoch1, f-_epoch1, ...]
    #   Values already include constants: f+ = ∏_{j ∈ S} f_j(x_j) · E[∏_{j ∉ S} f_j(X_j)]
    constants_per_epoch, partial_dep = model.compute_partial_dependence_function(
        [0, 1], grid_flat, X
    )
    n_epochs = partial_dep.shape[1] // 2

    # Extract f+ and f- for each epoch
    # According to Rust lib.rs: scaling is already absorbed into f+ and f- via effective_lambda
    # Final prediction per Stage = f+ + f- (add columns, see lib.rs line 183)
    f_plus_all = partial_dep[:, ::2]  # Columns 0, 2, 4, ... (f+ for each epoch)
    f_minus_all = partial_dep[:, 1::2]  # Columns 1, 3, 5, ... (f- for each epoch)

    # Helper function to set up map axis
    def _setup_map_axis(ax, lon_min, lon_max, lat_min, lat_max, margin=0.5):
        """Helper to set up map axis with features."""
        ax.add_feature(cfeature.COASTLINE)
        ax.add_feature(cfeature.STATES)
        ax.set_extent(
            [lon_min - margin, lon_max + margin, lat_min - margin, lat_max + margin],
            crs=ccrs.PlateCarree(),
        )
        gl = ax.gridlines(
            draw_labels=True, linewidth=0.5, color="gray", alpha=0.5, linestyle="--"
        )
        gl.top_labels = gl.right_labels = False

    # Create figure with subplots for each stage
    n_cols = min(4, n_epochs)  # Max 4 columns
    n_rows = (n_epochs + n_cols - 1) // n_cols  # Ceiling division
    fig = plt.figure(figsize=(6 * n_cols, 6 * n_rows))

    # Plot each stage separately
    for epoch_idx in range(n_epochs):
        ax = fig.add_subplot(
            n_rows, n_cols, epoch_idx + 1, projection=ccrs.PlateCarree()
        )
        _setup_map_axis(ax, lon_min, lon_max, lat_min, lat_max)

        # Get constants for this Stage (needed for scaling)
        c_plus, c_minus = constants_per_epoch[epoch_idx]

        # Compute sqrt(C_+ * (-C_-)) scaling factor
        # c_minus negated for geometric-mean sign
        sqrt_c_product = np.sqrt(c_plus * (-c_minus))

        # Compute PD for this stage: f+ + f- multiplied by scaling factor
        # Lambdas scaled; multiply final stage prediction
        pd_stage = (
            f_plus_all[:, epoch_idx] + f_minus_all[:, epoch_idx]
        ) * sqrt_c_product
        pd_stage = pd_stage.reshape(LAT.shape)

        # Plot partial dependence as contour
        cs = ax.contourf(
            LON,
            LAT,
            pd_stage,
            levels=20,
            cmap=cmap,
            transform=ccrs.PlateCarree(),
            alpha=0.7,
        )
        plt.colorbar(cs, ax=ax, shrink=0.6, pad=0.05, label="Partial Dependence")

        # Evaluate PD at the exact point coordinates
        # Find the closest grid point indices for point_a and point_b
        lon_a, lat_a = point_a[lon_idx], point_a[lat_idx]
        lon_b, lat_b = point_b[lon_idx], point_b[lat_idx]

        # Find closest grid indices
        idx_lon_a = np.argmin(np.abs(lon_vals - lon_a))
        idx_lat_a = np.argmin(np.abs(lat_vals - lat_a))
        idx_lon_b = np.argmin(np.abs(lon_vals - lon_b))
        idx_lat_b = np.argmin(np.abs(lat_vals - lat_b))

        # Get PD values at these grid points
        pd_value_a = pd_stage[idx_lat_a, idx_lon_a]
        pd_value_b = pd_stage[idx_lat_b, idx_lon_b]

        # Annotate point A (desert)
        ax.scatter(
            point_a[lon_idx],
            point_a[lat_idx],
            c="red",
            s=150,
            marker="*",
            edgecolors="black",
            linewidths=1.5,
            transform=ccrs.PlateCarree(),
            zorder=10,
        )
        ax.annotate(
            f"Desert\nPD: ${pd_value_a:,.0f}",
            xy=(point_a[lon_idx], point_a[lat_idx]),
            xytext=(10, 10),
            textcoords="offset points",
            fontsize=10,
            fontweight="bold",
            color="red",
            bbox=dict(boxstyle="round,pad=0.3", facecolor="white", alpha=0.8),
            transform=ccrs.PlateCarree(),
            zorder=11,
            ha="left",
        )

        # Annotate point B (coastal)
        ax.scatter(
            point_b[lon_idx],
            point_b[lat_idx],
            c="blue",
            s=150,
            marker="*",
            edgecolors="black",
            linewidths=1.5,
            transform=ccrs.PlateCarree(),
            zorder=10,
        )
        ax.annotate(
            f"Coastal\nPD: ${pd_value_b:,.0f}",
            xy=(point_b[lon_idx], point_b[lat_idx]),
            xytext=(10, 10),
            textcoords="offset points",
            fontsize=10,
            fontweight="bold",
            color="blue",
            bbox=dict(boxstyle="round,pad=0.3", facecolor="white", alpha=0.8),
            transform=ccrs.PlateCarree(),
            zorder=11,
            ha="left",
        )

        ax.set_title(
            f"Stage {epoch_idx + 1} Partial Dependence\nC+={c_plus:.2e}, C-={c_minus:.2e}, Scale={sqrt_c_product:.2e}",
            fontsize=11,
            fontweight="bold",
        )

    plt.tight_layout(rect=[0, 0, 1, 0.98])

    if save_path is None:
        save_path = FIGURES_DIR / "figure_5_2_stagewise_partial_dependence.svg"
    _save_fig(fig, save_path, bbox_inches="tight")
    print(f"Saved Figure 5.2 to {save_path}")

    return fig


def compute_and_plot_feature_importance(
    model: MPF,
    X: np.ndarray,
    feature_names: List[str],
    gamma: float = 1.0,
    figsize: Tuple[int, int] = (14, 10),
    save_path: Optional[str] = None,
) -> Dict[str, Any]:
    """
    Figure 6: Compute and visualize feature importance metrics for an MPF model.

    Parameters:
    -----------
    model : MPF
        Trained MPF model
    X : np.ndarray
        Feature matrix (training data)
    feature_names : list of str
        Names of features for labeling
    gamma : float, default=1.0
        Weight for tilt importance in combined score
    figsize : tuple, default=(14, 10)
        Figure size (width, height)
    save_path : str, optional
        Path to save figure. If None, displays interactively.

    Returns:
    --------
    dict
        Dictionary containing all computed importance metrics
    """
    print("\n" + "=" * 80)
    print("FEATURE IMPORTANCE ANALYSIS")
    print("=" * 80)

    # Compute per-stage feature importance
    backbone_per_stage, tilt_per_stage = model.compute_per_stage_feature_importance(X)
    n_stages, n_features = backbone_per_stage.shape

    # Compute aggregated feature importance
    global_backbone, global_tilt, stage_weights = (
        model.compute_aggregated_feature_importance(X)
    )

    # Compute combined feature importance
    combined, combined_backbone, combined_tilt = (
        model.compute_combined_feature_importance(X, gamma=gamma)
    )

    # Print summary statistics
    print(f"\nNumber of stages: {n_stages}")
    print(f"Number of features: {n_features}")
    print(f"\nStage weights: {stage_weights}")
    print(f"Stage weights sum: {stage_weights.sum():.6f}")

    # Print per-stage importance summary
    print("\n" + "-" * 80)
    print("PER-STAGE FEATURE IMPORTANCE (Top 3 features per stage)")
    print("-" * 80)
    for stage_idx in range(n_stages):
        print(f"\nStage {stage_idx + 1} (weight: {stage_weights[stage_idx]:.4f}):")
        # Backbone importance
        backbone_imp = backbone_per_stage[stage_idx, :]
        top_backbone_idx = np.argsort(backbone_imp)[::-1][:3]
        print("  Backbone importance (top 3):")
        for rank, feat_idx in enumerate(top_backbone_idx, 1):
            print(
                f"    {rank}. {feature_names[feat_idx]:15s}: {backbone_imp[feat_idx]:.6f}"
            )
        # Tilt importance
        tilt_imp = tilt_per_stage[stage_idx, :]
        top_tilt_idx = np.argsort(tilt_imp)[::-1][:3]
        print("  Tilt importance (top 3):")
        for rank, feat_idx in enumerate(top_tilt_idx, 1):
            print(
                f"    {rank}. {feature_names[feat_idx]:15s}: {tilt_imp[feat_idx]:.6f}"
            )

    # Print aggregated importance
    print("\n" + "-" * 80)
    print("AGGREGATED FEATURE IMPORTANCE (Global)")
    print("-" * 80)
    print("\nBackbone importance (global):")
    top_backbone_global = np.argsort(global_backbone)[::-1]
    for rank, feat_idx in enumerate(top_backbone_global, 1):
        print(
            f"  {rank}. {feature_names[feat_idx]:15s}: {global_backbone[feat_idx]:.6f}"
        )

    print("\nTilt importance (global):")
    top_tilt_global = np.argsort(global_tilt)[::-1]
    for rank, feat_idx in enumerate(top_tilt_global, 1):
        print(f"  {rank}. {feature_names[feat_idx]:15s}: {global_tilt[feat_idx]:.6f}")

    # Print combined importance
    print("\n" + "-" * 80)
    print(f"COMBINED FEATURE IMPORTANCE (I_j = I_j^b + {gamma} * I_j^d)")
    print("-" * 80)
    top_combined = np.argsort(combined)[::-1]
    for rank, feat_idx in enumerate(top_combined, 1):
        print(
            f"  {rank}. {feature_names[feat_idx]:15s}: {combined[feat_idx]:.6f} "
            f"(backbone: {combined_backbone[feat_idx]:.6f}, "
            f"tilt: {combined_tilt[feat_idx]:.6f})"
        )

    # Create visualization
    fig = plt.figure(figsize=figsize)

    # Plot 1: Per-stage backbone importance (heatmap)
    ax1 = fig.add_subplot(2, 3, 1)
    im1 = ax1.imshow(
        backbone_per_stage.T,
        aspect="auto",
        cmap="YlOrRd",
        interpolation="nearest",
    )
    ax1.set_xlabel("Stage")
    ax1.set_ylabel("Feature")
    ax1.set_title("Per-Stage Backbone Importance")
    ax1.set_xticks(range(n_stages))
    ax1.set_xticklabels([f"{i + 1}" for i in range(n_stages)])
    ax1.set_yticks(range(n_features))
    ax1.set_yticklabels(feature_names, fontsize=8)
    plt.colorbar(im1, ax=ax1, label="Backbone Variance")

    # Plot 2: Per-stage tilt importance (heatmap)
    ax2 = fig.add_subplot(2, 3, 2)
    im2 = ax2.imshow(
        tilt_per_stage.T, aspect="auto", cmap="YlGnBu", interpolation="nearest"
    )
    ax2.set_xlabel("Stage")
    ax2.set_ylabel("Feature")
    ax2.set_title("Per-Stage Tilt Importance")
    ax2.set_xticks(range(n_stages))
    ax2.set_xticklabels([f"{i + 1}" for i in range(n_stages)])
    ax2.set_yticks(range(n_features))
    ax2.set_yticklabels(feature_names, fontsize=8)
    plt.colorbar(im2, ax=ax2, label="Tilt Variance")

    # Plot 3: Global backbone importance (bar plot)
    ax3 = fig.add_subplot(2, 3, 3)
    sorted_idx = np.argsort(global_backbone)[::-1]
    ax3.barh(
        range(n_features),
        global_backbone[sorted_idx],
        color="orange",
        alpha=0.7,
    )
    ax3.set_yticks(range(n_features))
    ax3.set_yticklabels([feature_names[i] for i in sorted_idx], fontsize=8)
    ax3.set_xlabel("Global Backbone Importance")
    ax3.set_title("Aggregated Backbone Importance")
    ax3.grid(True, alpha=0.3, axis="x")

    # Plot 4: Global tilt importance (bar plot)
    ax4 = fig.add_subplot(2, 3, 4)
    sorted_idx = np.argsort(global_tilt)[::-1]
    ax4.barh(
        range(n_features),
        global_tilt[sorted_idx],
        color="cyan",
        alpha=0.7,
    )
    ax4.set_yticks(range(n_features))
    ax4.set_yticklabels([feature_names[i] for i in sorted_idx], fontsize=8)
    ax4.set_xlabel("Global Tilt Importance")
    ax4.set_title("Aggregated Tilt Importance")
    ax4.grid(True, alpha=0.3, axis="x")

    # Plot 5: Combined importance (bar plot)
    ax5 = fig.add_subplot(2, 3, 5)
    sorted_idx = np.argsort(combined)[::-1]
    ax5.barh(
        range(n_features),
        combined[sorted_idx],
        color="purple",
        alpha=0.7,
    )
    ax5.set_yticks(range(n_features))
    ax5.set_yticklabels([feature_names[i] for i in sorted_idx], fontsize=8)
    ax5.set_xlabel(f"Combined Importance (γ={gamma})")
    ax5.set_title("Combined Feature Importance")
    ax5.grid(True, alpha=0.3, axis="x")

    # Plot 6: Stage weights
    ax6 = fig.add_subplot(2, 3, 6)
    ax6.bar(
        range(n_stages),
        stage_weights,
        color="green",
        alpha=0.7,
    )
    ax6.set_xlabel("Stage")
    ax6.set_ylabel("Weight")
    ax6.set_title("Stage Weights (Energy-based)")
    ax6.set_xticks(range(n_stages))
    ax6.set_xticklabels([f"{i + 1}" for i in range(n_stages)])
    ax6.grid(True, alpha=0.3, axis="y")

    plt.tight_layout(rect=[0, 0, 1, 0.99])

    if save_path is None:
        save_path = FIGURES_DIR / "figure_6_feature_importance.svg"
    _save_fig(fig, save_path, bbox_inches="tight")
    print(f"Saved Figure 6 to {save_path}")

    # Return all computed metrics
    return {
        "backbone_per_stage": backbone_per_stage,
        "tilt_per_stage": tilt_per_stage,
        "global_backbone": global_backbone,
        "global_tilt": global_tilt,
        "combined": combined,
        "stage_weights": stage_weights,
    }


if __name__ == "__main__":
    print("=" * 80)
    print("MPF Interpretability Analysis")
    print("=" * 80)
    model_type = "BLACKBOX" if USE_BLACKBOX else "INTERPRETABLE"
    print(f"Model Type: {model_type}")
    print(f"Output Directory Suffix: {OUTPUT_SUFFIX}")
    print("=" * 80)
    print()

    # Load model and data
    model, X, y, X_df = load_model_and_data()

    # Generate all figures
    print("\nGenerating Figure 1: Backbone plots (Stage 1 and Stage 2)...")
    plot_figure_1_backbone(model, X, X_df)

    print("\nGenerating Figure 2: Nominal Contribution plots (Stage 2 only)...")
    plot_figure_2_nominal_contribution(model, X, X_df)

    print("\nGenerating Figure 3: First-order partial dependence functions...")
    plot_figure_3_first_order_pd(model, X, X_df)

    print("\nGenerating Figure 3.1: Scaled first-order partial dependence functions...")
    plot_figure_3_1_scaled_first_order_pd(model, X, X_df)

    print("\nGenerating Figure 3.2: Backbone and tilt expression...")
    plot_figure_3_2_backbone_and_tilt_expression(model, X, X_df)

    print(
        "\nGenerating Figure 3.4: Scaled first-order PD with scaled backbone (selected features)..."
    )
    plot_figure_3_4_scaled_first_order_pd(model, X, X_df)

    print(
        "\nGenerating Figure 3.5: PD comparison (MPF, EBR, XGBoost) for Latitude and Longitude..."
    )
    plot_figure_3_5_pd_comparison(model, X, X_df)

    print("\nGenerating Figure 3.3: ICE curves...")
    plot_figure_3_3_ice_curves(model, X, X_df, n_observations=10)

    print("\nGenerating Figure 4: Spatial backbone evolution...")
    plot_figure_4_spatial_backbone_evolution(
        model, X, X_df, epochs=range(len(model.tree_grid_families))
    )

    # Select points for local explanations (desert and coastal)
    print("\nSelecting points for local explanations...")

    # Load CSV to get exact row indices
    df = pd.read_csv(DATA_DIR / "44977_california_housing.csv", header=None)

    # Exact row indices from find_closest_points.py
    desert_row_idx = 2784  # Lowest house value in desert region
    la_row_idx = 4556  # LA coordinates (exact match)

    # Extract exact points from CSV
    point_a = df.iloc[desert_row_idx, :-1].values  # All columns except last (target)
    point_b = df.iloc[la_row_idx, :-1].values

    # Get predictions for these points
    pred_a = model.predict(point_a.reshape(1, -1))[0]
    pred_b = model.predict(point_b.reshape(1, -1))[0]

    # Get stage contributions for these points
    expl_a = compute_local_explanation(model, point_a)
    expl_b = compute_local_explanation(model, point_b)
    stage_contribs_a = expl_a["stage_contributions"]
    stage_contribs_b = expl_b["stage_contributions"]

    # Get coordinates for display
    lat_idx = 1
    lon_idx = 0
    print(
        f"  Desert point (row {desert_row_idx}): lat={point_a[lat_idx]:.4f}, lon={point_a[lon_idx]:.4f}, prediction=${pred_a:,.0f}"
    )
    print("    Stage contributions:", end=" ")
    for stage_idx, contrib in enumerate(stage_contribs_a):
        print(f"S{stage_idx + 1}=${contrib:,.0f}", end="  ")
    print()  # New line

    print(
        f"  LA point (row {la_row_idx}): lat={point_b[lat_idx]:.4f}, lon={point_b[lon_idx]:.4f}, prediction=${pred_b:,.0f}"
    )
    print("    Stage contributions:", end=" ")
    for stage_idx, contrib in enumerate(stage_contribs_b):
        print(f"S{stage_idx + 1}=${contrib:,.0f}", end="  ")
    print()  # New line

    print("\nGenerating Figure 5: Local explanations...")
    plot_figure_5_local_explanations(model, X, X_df, point_a=point_a, point_b=point_b)

    print("\nGenerating Figure 5.1: Map annotations with partial dependence...")
    plot_figure_5_1_map_annotations(model, X, X_df, point_a=point_a, point_b=point_b)

    print("\nGenerating Figure 5.2: Stage-wise partial dependence maps...")
    plot_figure_5_2_stagewise_partial_dependence(
        model, X, X_df, point_a=point_a, point_b=point_b
    )

    print("\nGenerating Figure 6: Feature importance...")
    feature_names = [
        "Longitude",
        "Latitude",
        "HouseAge",
        "TotalRooms",
        "TotalBedrooms",
        "Population",
        "Households",
        "MedInc",
    ]
    compute_and_plot_feature_importance(model, X, feature_names, gamma=1.0)

    print("\n" + "=" * 80)
    print("Core figures generated.")
    print("\nOptional extensions: SHAP comparison, stability refits, validation.")
    print("=" * 80)
