"""
Interpretability analysis for MPF on Bike Sharing.

Figures: backbone plots, nominal contributions, first-order PD, scaled PD (f+/f-),
ICE curves, local explanations.
"""

from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple

import matplotlib

matplotlib.use("Agg")  # Use non-interactive backend to prevent display
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

# Enable LaTeX rendering in matplotlib (mathtext)
plt.rcParams["mathtext.default"] = "regular"

# MPF imports
from mpf_py import MPF, TreeGrid

# Set random seeds for reproducibility
np.random.seed(42)

# Configuration (reproducibility folder: data/models/figures under BASE_DIR)
BASE_DIR = Path(__file__).parent
DATA_DIR = BASE_DIR / "data" / "bike_sharing"
MODELS_DIR = BASE_DIR / "models" / "bike_sharing"
FIGURES_DIR = BASE_DIR / "figures" / "bike_sharing"
FIGURES_DIR.mkdir(parents=True, exist_ok=True)

# Feature names for Bike Sharing (updated to match new CSV structure)
# CSV has 13 columns: 0-11 are features, 12 is target
# This will be set dynamically from loaded data
FEATURE_NAMES = [
    "year",
    "month",
    "hour",
    "weekday",
    "temp",
    "feel_temp",
    "humidity",
    "windspeed",
    "holiday",
    "workingday",
    "season",
    "weather",
]

# Key features to focus on (as mentioned in paper)
KEY_FEATURES = ["hour", "workingday", "temp"]  # Hour, workingday, temperature
# KEY_FEATURE_INDICES will be set dynamically based on loaded data
KEY_FEATURE_INDICES = None

# Color scheme for stages
STAGE_COLORS = ["#1f77b4", "#ff7f0e", "#2ca02c", "#d62728", "#9467bd", "#8c564b"]


def load_model_and_data(
    model_path: Optional[str] = None, data_path: Optional[str] = None
) -> Tuple[MPF, np.ndarray, np.ndarray, pd.DataFrame]:
    """
    Load MPF model and Bike Sharing dataset.

    Parameters:
    -----------
    model_path : str, optional
        Path to MPF model .bin file. If None, uses default path.
    data_path : str, optional
        Path to 42712_Bike_Sharing_Demand.csv. If None, uses default path.

    Returns:
    --------
    model : MPF
        Loaded MPF model
    X : np.ndarray
        Feature matrix (n_samples, n_features)
    y : np.ndarray
        Target values (n_samples,)
    X_df : pd.DataFrame
        Feature matrix as DataFrame with column names
    """
    # Default paths
    if model_path is None:
        model_path = str(MODELS_DIR / "mpf_model.bin")
    if data_path is None:
        data_path = str(DATA_DIR / "42712_Bike_Sharing_Demand.csv")

    # Load model
    print(f"Loading MPF model from {model_path}...")
    model = MPF.load(model_path)

    # Load data
    # CSV columns: 0:year, 1:month, 2:hour, 3:weekday, 4:temp, 5:feel_temp,
    #              6:humidity, 7:windspeed, 8:holiday, 9:workingday, 10:season, 11:weather, 12:count
    print(f"Loading data from {data_path}...")
    df = pd.read_csv(data_path, header=None)

    # Map columns according to user specification
    # CSV has 13 columns: 0-11 are features, 12 is target
    column_names = [
        "year",
        "month",
        "hour",
        "weekday",
        "temp",
        "feel_temp",
        "humidity",
        "windspeed",
        "holiday",
        "workingday",
        "season",
        "weather",
        "count",
    ]
    df.columns = column_names[: len(df.columns)]

    # Extract features and target
    # Map to expected feature names (matching what the model was trained on)
    feature_cols = [
        "year",  # 0
        "month",  # 1
        "hour",  # 2
        "weekday",  # 3
        "temp",  # 4
        "feel_temp",  # 5
        "humidity",  # 6
        "windspeed",  # 7
        "holiday",  # 8 (boolean, 0/1)
        "workingday",  # 9 (boolean, 0/1)
        "season",  # 10 (target-encoded categorical)
        "weather",  # 11 (target-encoded categorical)
    ]

    # Filter to only include columns that exist in the dataframe
    available_feature_cols = [col for col in feature_cols if col in df.columns]

    X_df = df[available_feature_cols].copy()
    X = X_df.values
    y = df["count"].values  # Total bike rental count

    # Update global FEATURE_NAMES to match actual loaded features
    global FEATURE_NAMES, KEY_FEATURE_INDICES
    FEATURE_NAMES = list(X_df.columns)

    # Update KEY_FEATURE_INDICES based on actual feature positions
    KEY_FEATURE_INDICES = []
    for key_feat in KEY_FEATURES:
        if key_feat in FEATURE_NAMES:
            KEY_FEATURE_INDICES.append(FEATURE_NAMES.index(key_feat))
        else:
            print(f"Warning: Key feature '{key_feat}' not found in loaded features")

    print(f"Loaded model with {len(model.tree_grid_families)} stages")
    print(f"Loaded data: {X.shape[0]} samples, {X.shape[1]} features")
    print(f"Feature names: {FEATURE_NAMES}")
    print(f"Key feature indices: {KEY_FEATURE_INDICES}")

    return model, X, y, X_df


def extract_backbone_tilt(
    model: MPF, feature_idx: int, stage_idx: int, x_vals: np.ndarray
) -> Tuple[np.ndarray, np.ndarray]:
    """
    Extract backbone and tilt values for a feature at given x values.

    Parameters:
    -----------
    model : MPF
        Trained MPF model
    feature_idx : int
        Index of feature to extract
    stage_idx : int
        Index of stage/epoch (0-indexed)
    x_vals : np.ndarray
        Feature values to evaluate at (1D array)

    Returns:
    --------
    backbone_vals : np.ndarray
        Backbone values b_j^(ℓ)(x_j) for each x in x_vals
    tilt_vals : np.ndarray
        Tilt values d_j^(ℓ)(x_j) for each x in x_vals
    """
    tgf = model.tree_grid_families[stage_idx]
    tg = TreeGrid(tgf.combined_tree_grid)

    # Get backbone, tilt, and splits for this feature
    backbone = tg.backbone_values[feature_idx]
    tilt = tg.tilt_values[feature_idx]
    splits = tg.splits[feature_idx]

    # Evaluate at x_vals (piecewise constant)
    backbone_vals = np.zeros_like(x_vals)
    tilt_vals = np.zeros_like(x_vals)

    for i, x_val in enumerate(x_vals):
        # Find interval index
        if len(splits) == 0:
            interval_idx = 0
        else:
            interval_idx = np.searchsorted(splits, x_val, side="right")
            interval_idx = min(interval_idx, len(backbone) - 1)

        backbone_vals[i] = backbone[interval_idx]
        tilt_vals[i] = tilt[interval_idx]

    return backbone_vals, tilt_vals


def compute_nominal_contribution(
    backbone_vals: np.ndarray, tilt_vals: np.ndarray
) -> np.ndarray:
    """
    Compute nominal contribution: 2 * b_j * sinh(d_j).

    Parameters:
    -----------
    backbone_vals : np.ndarray
        Backbone values b_j
    tilt_vals : np.ndarray
        Tilt values d_j

    Returns:
    --------
    nominal_contrib : np.ndarray
        Nominal contribution 2 * b_j * sinh(d_j)
    """
    return 2 * backbone_vals * np.sinh(tilt_vals)


def plot_figure_1_backbone(
    model: MPF,
    X: np.ndarray,
    X_df: pd.DataFrame,
    feature_indices: Optional[List[int]] = None,
    grid_points: int = 200,
    save_path: Optional[str] = None,
) -> plt.Figure:
    """
    Figure 1: Backbone plots b_j^(ℓ)(x_j) for key features and all epochs.

    Parameters:
    -----------
    model : MPF
        Trained MPF model
    X : np.ndarray
        Feature matrix
    X_df : pd.DataFrame
        Feature DataFrame with column names
    feature_indices : list of int, optional
        Indices of features to plot. If None, plots key features only.
    grid_points : int
        Number of grid points for evaluation
    save_path : str, optional
        Path to save figure

    Returns:
    --------
    fig : matplotlib.figure.Figure
        Figure object
    """
    if feature_indices is None:
        if KEY_FEATURE_INDICES is None:
            raise ValueError(
                "KEY_FEATURE_INDICES not set. Please call load_model_and_data() first."
            )
        feature_indices = KEY_FEATURE_INDICES

    n_features = len(feature_indices)
    n_stages = len(model.tree_grid_families)

    # Layout: n_stages rows, n_features columns
    fig, axes = plt.subplots(
        n_stages, n_features, figsize=(6 * n_features, 4 * n_stages)
    )
    if n_features == 1:
        axes = axes.reshape(n_stages, 1)
    if n_stages == 1:
        axes = axes.reshape(1, n_features)

    for stage_idx in range(n_stages):
        for plot_idx, feat_idx in enumerate(feature_indices):
            ax = axes[stage_idx, plot_idx]
            feature_name = FEATURE_NAMES[feat_idx]

            # Get feature range
            feat_min = X[:, feat_idx].min()
            feat_max = X[:, feat_idx].max()
            x_vals = np.linspace(feat_min, feat_max, grid_points)

            # Plot backbone for this stage
            backbone_vals, _ = extract_backbone_tilt(model, feat_idx, stage_idx, x_vals)

            ax.plot(
                x_vals,
                backbone_vals,
                linewidth=2,
                color=STAGE_COLORS[stage_idx % len(STAGE_COLORS)],
                alpha=0.8,
            )

            ax.set_xlabel(feature_name, fontsize=11)
            if plot_idx == 0:
                stage_num = stage_idx + 1
                ax.set_ylabel(
                    f"Stage {stage_num}\nBackbone $b_j^{{{stage_num}}}$", fontsize=11
                )
            ax.set_title(f"Feature: {feature_name}", fontsize=12, fontweight="bold")
            ax.grid(True, alpha=0.3)

    plt.tight_layout(rect=[0, 0, 1, 0.99])

    if save_path is None:
        save_path = FIGURES_DIR / "figure_1_backbone.svg"
    plt.savefig(save_path, format="svg", bbox_inches="tight")
    print(f"Saved Figure 1 to {save_path}")

    return fig


def plot_figure_2_nominal_contribution(
    model: MPF,
    X: np.ndarray,
    X_df: pd.DataFrame,
    feature_indices: Optional[List[int]] = None,
    grid_points: int = 200,
    save_path: Optional[str] = None,
) -> plt.Figure:
    """
    Figure 2: Nominal Contribution plots for Stage 2.

    Parameters:
    -----------
    model : MPF
        Trained MPF model
    X : np.ndarray
        Feature matrix
    X_df : pd.DataFrame
        Feature DataFrame with column names
    feature_indices : list of int, optional
        Indices of features to plot. If None, plots key features only.
    grid_points : int
        Number of grid points for evaluation
    save_path : str, optional
        Path to save figure

    Returns:
    --------
    fig : matplotlib.figure.Figure
        Figure object
    """
    if feature_indices is None:
        if KEY_FEATURE_INDICES is None:
            raise ValueError(
                "KEY_FEATURE_INDICES not set. Please call load_model_and_data() first."
            )
        feature_indices = KEY_FEATURE_INDICES

    # Only plot Stage 2 (index 1)
    stage_idx = 1
    if len(model.tree_grid_families) <= stage_idx:
        print(
            f"Warning: Model only has {len(model.tree_grid_families)} stages, cannot plot Stage 2"
        )
        return None

    n_features = len(feature_indices)
    n_cols = min(3, n_features)
    n_rows = (n_features + n_cols - 1) // n_cols

    fig, axes = plt.subplots(n_rows, n_cols, figsize=(7 * n_cols, 5 * n_rows))
    if n_features == 1:
        axes = [axes]
    else:
        axes = axes.flatten() if n_features > 1 else [axes]

    for plot_idx, feat_idx in enumerate(feature_indices):
        ax = axes[plot_idx]
        feature_name = FEATURE_NAMES[feat_idx]

        # Get feature range
        feat_min = X[:, feat_idx].min()
        feat_max = X[:, feat_idx].max()
        x_vals = np.linspace(feat_min, feat_max, grid_points)

        # Plot nominal contribution for Stage 2
        backbone_vals, tilt_vals = extract_backbone_tilt(
            model, feat_idx, stage_idx, x_vals
        )
        nominal_contrib = compute_nominal_contribution(backbone_vals, tilt_vals)

        ax.plot(
            x_vals,
            nominal_contrib,
            linewidth=2,
            color=STAGE_COLORS[stage_idx % len(STAGE_COLORS)],
            alpha=0.8,
        )

        # Add horizontal line at y=0
        ax.axhline(y=0, color="black", linestyle="--", linewidth=0.5, alpha=0.5)

        ax.set_xlabel(feature_name, fontsize=11)
        ax.set_ylabel("Nominal Contribution\n2 b_j sinh(d_j)", fontsize=11)
        ax.set_title(
            f"Nominal Contribution: {feature_name}", fontsize=12, fontweight="bold"
        )
        ax.grid(True, alpha=0.3)

    # Hide unused subplots
    for idx in range(len(feature_indices), len(axes)):
        axes[idx].axis("off")

    plt.tight_layout(rect=[0, 0, 1, 0.99])

    if save_path is None:
        save_path = FIGURES_DIR / "figure_2_nominal_contribution.svg"
    plt.savefig(save_path, format="svg", bbox_inches="tight")
    print(f"Saved Figure 2 to {save_path}")

    return fig


def plot_figure_3_first_order_pd(
    model: MPF,
    X: np.ndarray,
    X_df: pd.DataFrame,
    feature_indices: Optional[List[int]] = None,
    grid_points: int = 200,
    save_path: Optional[str] = None,
) -> plt.Figure:
    """
    Figure 3: First-order partial dependence functions.

    Parameters:
    -----------
    model : MPF
        Trained MPF model
    X : np.ndarray
        Feature matrix
    X_df : pd.DataFrame
        Feature DataFrame with column names
    feature_indices : list of int, optional
        Indices of features to plot. If None, plots key features only.
    grid_points : int
        Number of grid points for evaluation
    save_path : str, optional
        Path to save figure

    Returns:
    --------
    fig : matplotlib.figure.Figure
        Figure object
    """
    if feature_indices is None:
        if KEY_FEATURE_INDICES is None:
            raise ValueError(
                "KEY_FEATURE_INDICES not set. Please call load_model_and_data() first."
            )
        feature_indices = KEY_FEATURE_INDICES

    n_features = len(feature_indices)
    n_epochs = len(model.tree_grid_families)

    # Create grid of values for each feature
    feature_grids = []
    X_mean = X.mean(axis=0)
    X_grid_list = []

    for feat_idx in feature_indices:
        feat_min = X[:, feat_idx].min()
        feat_max = X[:, feat_idx].max()
        x_vals = np.linspace(feat_min, feat_max, grid_points)
        feature_grids.append(x_vals)

        # Create grid for this feature: vary this feature, others at mean
        X_grid_feat = np.tile(X_mean, (grid_points, 1))
        X_grid_feat[:, feat_idx] = x_vals
        X_grid_list.append(X_grid_feat)

    X_grid_combined = np.vstack(X_grid_list)

    # Compute first-order PD functions
    first_order_pd = model.compute_first_order_partial_dependence_functions(
        X_grid_combined, X
    )

    # Layout: n_epochs rows, n_features columns
    fig, axes = plt.subplots(
        n_epochs, n_features, figsize=(7 * n_features, 4 * n_epochs), squeeze=False
    )

    for plot_idx, feat_idx in enumerate(feature_indices):
        constants_per_epoch, pd_values = first_order_pd[feat_idx]

        # Extract relevant rows for this feature
        start_idx = plot_idx * grid_points
        end_idx = (plot_idx + 1) * grid_points
        pd_values_feat = pd_values[start_idx:end_idx, :]

        # Extract f+ and f- for each epoch
        f_plus_all = pd_values_feat[:, ::2]  # Columns 0, 2, 4, ...
        f_minus_all = pd_values_feat[:, 1::2]  # Columns 1, 3, 5, ...

        x_vals = feature_grids[plot_idx]

        for epoch_idx in range(n_epochs):
            ax = axes[epoch_idx, plot_idx]

            # Get constants for this epoch
            c_plus, c_minus = constants_per_epoch[epoch_idx]

            # Total PD = f+ + f-
            pd_total = f_plus_all[:, epoch_idx] + f_minus_all[:, epoch_idx]

            ax.plot(
                x_vals,
                pd_total,
                linewidth=2,
                color=STAGE_COLORS[epoch_idx % len(STAGE_COLORS)],
                alpha=0.8,
            )
            ax.axhline(y=0, color="black", linestyle="--", linewidth=0.5, alpha=0.5)

            ax.set_xlabel(FEATURE_NAMES[feat_idx], fontsize=11)
            if plot_idx == 0:
                ax.set_ylabel(f"PD (Stage {epoch_idx + 1})", fontsize=11)

            # Format constants
            c_plus_str = f"{c_plus:.4f}" if abs(c_plus) < 1000 else f"{c_plus:.2e}"
            c_minus_str = f"{c_minus:.4f}" if abs(c_minus) < 1000 else f"{c_minus:.2e}"

            ax.set_title(
                f"{FEATURE_NAMES[feat_idx]}\nStage {epoch_idx + 1}: $C_+={c_plus_str}$, $C_-={c_minus_str}$",
                fontsize=11,
                fontweight="bold",
            )
            ax.grid(True, alpha=0.3)

    plt.tight_layout(rect=[0, 0, 1, 0.99])

    if save_path is None:
        save_path = FIGURES_DIR / "figure_3_first_order_pd.svg"
    plt.savefig(save_path, format="svg", bbox_inches="tight")
    print(f"Saved Figure 3 to {save_path}")

    return fig


def plot_figure_3_1_scaled_first_order_pd(
    model: MPF,
    X: np.ndarray,
    X_df: pd.DataFrame,
    feature_indices: Optional[List[int]] = None,
    grid_points: int = 200,
    save_path: Optional[str] = None,
    use_paper_notation: bool = False,
) -> plt.Figure:
    """
    Figure 3.1: Scaled first-order partial dependence (f+ and f-).

    Parameters:
    -----------
    model : MPF
        Trained MPF model
    X : np.ndarray
        Feature matrix
    X_df : pd.DataFrame
        Feature DataFrame with column names
    feature_indices : list of int, optional
        Indices of features to plot. If None, plots key features only.
    grid_points : int
        Number of grid points for evaluation
    save_path : str, optional
        Path to save figure
    use_paper_notation : bool
        If True, use paper notation PD_{+,j}, PD_{-,j}, PD_{\\pm,j} and x_j
        (same as synthetic_analysis.py).

    Returns:
    --------
    fig : matplotlib.figure.Figure
        Figure object
    """
    if feature_indices is None:
        if KEY_FEATURE_INDICES is None:
            raise ValueError(
                "KEY_FEATURE_INDICES not set. Please call load_model_and_data() first."
            )
        feature_indices = KEY_FEATURE_INDICES

    n_features = len(feature_indices)
    n_epochs = len(model.tree_grid_families)

    # Create grid of values for each feature
    feature_grids = []
    X_mean = X.mean(axis=0)
    X_grid_list = []

    for feat_idx in feature_indices:
        feat_min = X[:, feat_idx].min()
        feat_max = X[:, feat_idx].max()
        x_vals = np.linspace(feat_min, feat_max, grid_points)
        feature_grids.append(x_vals)

        X_grid_feat = np.tile(X_mean, (grid_points, 1))
        X_grid_feat[:, feat_idx] = x_vals
        X_grid_list.append(X_grid_feat)

    X_grid_combined = np.vstack(X_grid_list)

    # Compute first-order PD functions
    first_order_pd = model.compute_first_order_partial_dependence_functions(
        X_grid_combined, X
    )

    # Layout: n_epochs + 1 rows (one extra for global PD), n_features columns
    fig, axes = plt.subplots(
        n_epochs + 1,
        n_features,
        figsize=(7 * n_features, 4 * (n_epochs + 1)),
        squeeze=False,
    )

    for plot_idx, feat_idx in enumerate(feature_indices):
        j = plot_idx + 1  # 1-based index for paper notation PD_{+,j}, x_j
        constants_per_epoch, pd_values = first_order_pd[feat_idx]

        # Extract relevant rows
        start_idx = plot_idx * grid_points
        end_idx = (plot_idx + 1) * grid_points
        pd_values_feat = pd_values[start_idx:end_idx, :]

        # Extract f+ and f-
        f_plus_all = pd_values_feat[:, ::2]
        f_minus_all = pd_values_feat[:, 1::2]

        x_vals = feature_grids[plot_idx]

        # Check if feature is binary (only 0 and 1, or very few unique values)
        unique_vals = np.unique(X[:, feat_idx])
        is_binary = (
            len(unique_vals) <= 2
            and np.allclose(unique_vals, [0, 1])
            or (len(unique_vals) <= 2 and np.allclose(sorted(unique_vals), [0.0, 1.0]))
        )

        for epoch_idx in range(n_epochs):
            ax = axes[epoch_idx, plot_idx]

            # Get constants
            c_plus, c_minus = constants_per_epoch[epoch_idx]

            f_plus = f_plus_all[:, epoch_idx]
            f_minus = f_minus_all[:, epoch_idx]
            f_minus_flipped = -f_minus

            # Extract backbone values for this feature and epoch
            backbone_vals, _ = extract_backbone_tilt(model, feat_idx, epoch_idx, x_vals)

            # Compute sqrt(C_+ * (-C_-)) scaling factor
            # c_minus negated for geometric-mean sign
            sqrt_c_product = np.sqrt(c_plus * (-c_minus))

            # Scale backbone by sqrt(C_+ * (-C_-))
            backbone_scaled = backbone_vals * sqrt_c_product

            if is_binary:
                # For binary features, use thin horizontal line segments connected in middle
                # Get values at 0 and 1
                idx_0 = np.argmin(np.abs(x_vals - 0))
                idx_1 = np.argmin(np.abs(x_vals - 1))

                f_plus_0 = f_plus[idx_0]
                f_plus_1 = f_plus[idx_1]
                f_minus_0 = f_minus_flipped[idx_0]
                f_minus_1 = f_minus_flipped[idx_1]

                # X positions for binary values (0 and 1)
                x_pos_0 = 0
                x_pos_1 = 1

                # Y offsets for f+ and f- lines (small vertical separation)
                y_offset = 0.02

                # Plot for each binary value
                for x_pos, label, f_p, f_m in zip(
                    [x_pos_0, x_pos_1],
                    ["0", "1"],
                    [f_plus_0, f_plus_1],
                    [f_minus_0, f_minus_1],
                ):
                    # Compute difference
                    diff = f_p - f_m

                    # Draw thin horizontal line segments
                    # f+ line (red) - positioned slightly above
                    ax.plot(
                        [x_pos - 0.15, x_pos + 0.15],
                        [f_p, f_p],
                        linewidth=2,
                        color="darkred",
                        alpha=0.9,
                        label=rf"$\mathrm{{PD}}_{{+,{j}}}$"
                        if (x_pos == x_pos_0 and use_paper_notation)
                        else ("$f_+$" if x_pos == x_pos_0 else ""),
                    )

                    # f- line (blue) - positioned slightly below
                    ax.plot(
                        [x_pos - 0.15, x_pos + 0.15],
                        [f_m, f_m],
                        linewidth=2,
                        color="darkblue",
                        alpha=0.9,
                        label=rf"$\mathrm{{PD}}_{{-,{j}}}$"
                        if (x_pos == x_pos_0 and use_paper_notation)
                        else ("$f_-$" if x_pos == x_pos_0 else ""),
                    )

                    # Connect with colored vertical rectangle based on difference
                    if diff >= 0:
                        connect_color = "green"
                        connect_alpha = 0.3
                    else:
                        connect_color = "orange"
                        connect_alpha = 0.3

                    # Draw vertical connecting rectangle in the middle
                    y_bottom = min(f_p, f_m)
                    y_top = max(f_p, f_m)
                    height_connect = y_top - y_bottom
                    if height_connect > 0:
                        ax.fill_between(
                            [x_pos - 0.15, x_pos + 0.15],
                            y_bottom,
                            y_top,
                            color=connect_color,
                            alpha=connect_alpha,
                        )

                    # Add value labels
                    if abs(f_p) > 0.01:
                        ax.text(
                            x_pos + 0.2,
                            f_p,
                            f"{f_p:.3f}",
                            ha="left",
                            va="center",
                            fontsize=8,
                        )
                    if abs(f_m) > 0.01:
                        ax.text(
                            x_pos + 0.2,
                            f_m,
                            f"{f_m:.3f}",
                            ha="left",
                            va="center",
                            fontsize=8,
                        )

                # Set x-axis for binary values
                ax.set_xlim(-0.5, 1.5)
                ax.set_xticks([x_pos_0, x_pos_1])
                ax.set_xticklabels(["0", "1"])

                # Add horizontal line at y=0 for reference
                ax.axhline(y=0, color="black", linestyle="--", linewidth=0.5, alpha=0.5)

                # Plot scaled backbone as dotted black line (skip for first epoch)
                if epoch_idx > 0:
                    backbone_0 = backbone_scaled[idx_0]
                    backbone_1 = backbone_scaled[idx_1]
                    for x_pos, b_val in zip(
                        [x_pos_0, x_pos_1], [backbone_0, backbone_1]
                    ):
                        ax.plot(
                            [x_pos - 0.15, x_pos + 0.15],
                            [b_val, b_val],
                            linewidth=2,
                            color="black",
                            linestyle=":",
                            alpha=0.9,
                            label=r"$\sqrt{C_+ C_-} \cdot b_j$"
                            if x_pos == x_pos_0
                            else "",
                        )

            else:
                # For continuous features, use line plots as before
                # Compute difference
                diff = f_plus - f_minus_flipped

                # Fill areas
                ax.fill_between(
                    x_vals,
                    f_minus_flipped,
                    f_plus,
                    where=(diff >= 0),
                    color="green",
                    alpha=0.3,
                )
                ax.fill_between(
                    x_vals,
                    f_minus_flipped,
                    f_plus,
                    where=(diff < 0),
                    color="orange",
                    alpha=0.3,
                )

                # Plot lines (paper notation PD_{+,j}, PD_{-,j} or f_+, f_-)
                ax.plot(
                    x_vals,
                    f_plus,
                    linewidth=1.5,
                    color="darkred",
                    alpha=0.9,
                    label=rf"$\mathrm{{PD}}_{{+,{j}}}$"
                    if use_paper_notation
                    else "$f_+$",
                )
                ax.plot(
                    x_vals,
                    f_minus_flipped,
                    linewidth=1.5,
                    color="darkblue",
                    alpha=0.9,
                    label=rf"$\mathrm{{PD}}_{{-,{j}}}$"
                    if use_paper_notation
                    else "$f_-$",
                )

                # Plot scaled backbone as dotted black line (skip for first epoch)
                if epoch_idx > 0:
                    ax.plot(
                        x_vals,
                        backbone_scaled,
                        linewidth=2,
                        color="black",
                        linestyle=":",
                        alpha=0.9,
                        label=r"$\sqrt{C_+ C_-} \cdot b_j$",
                    )

            ax.axhline(y=0, color="black", linestyle="--", linewidth=0.5, alpha=0.5)
            ax.set_xlabel(
                f"$x_{j}$" if use_paper_notation else FEATURE_NAMES[feat_idx],
                fontsize=11,
            )
            if plot_idx == 0:
                ax.set_ylabel(
                    rf"$\mathrm{{PD}}_{{\pm,{j}}}$"
                    if use_paper_notation
                    else f"Scaled PD (Stage {epoch_idx + 1})",
                    fontsize=11,
                )

            c_plus_str = f"{c_plus:.4f}" if abs(c_plus) < 1000 else f"{c_plus:.2e}"
            c_minus_str = f"{c_minus:.4f}" if abs(c_minus) < 1000 else f"{c_minus:.2e}"
            ax.set_title(
                f"{FEATURE_NAMES[feat_idx]}\nStage {epoch_idx + 1}: $C_+={c_plus_str}$, $C_-={c_minus_str}$",
                fontsize=11,
                fontweight="bold",
            )
            ax.grid(True, alpha=0.3)
            ax.legend(loc="best", fontsize=9)

        # Plot global PD in the last row (sum of f+ minus sum of f- across all stages)
        ax_global = axes[n_epochs, plot_idx]

        # Sum f+ and f- across all epochs
        f_plus_global = np.sum(f_plus_all, axis=1)
        f_minus_global = np.sum(f_minus_all, axis=1)
        f_minus_global_flipped = -f_minus_global

        if is_binary:
            # For binary features
            idx_0 = np.argmin(np.abs(x_vals - 0))
            idx_1 = np.argmin(np.abs(x_vals - 1))

            f_plus_0 = f_plus_global[idx_0]
            f_plus_1 = f_plus_global[idx_1]
            f_minus_0 = f_minus_global_flipped[idx_0]
            f_minus_1 = f_minus_global_flipped[idx_1]

            x_pos_0 = 0
            x_pos_1 = 1

            for x_pos, f_p, f_m in zip(
                [x_pos_0, x_pos_1], [f_plus_0, f_plus_1], [f_minus_0, f_minus_1]
            ):
                diff = f_p - f_m

                ax_global.plot(
                    [x_pos - 0.15, x_pos + 0.15],
                    [f_p, f_p],
                    linewidth=2,
                    color="darkred",
                    alpha=0.9,
                    label=rf"$\mathrm{{PD}}_{{+,{j}}}$ (sum)"
                    if (x_pos == x_pos_0 and use_paper_notation)
                    else ("$f_+$ (sum)" if x_pos == x_pos_0 else ""),
                )
                ax_global.plot(
                    [x_pos - 0.15, x_pos + 0.15],
                    [f_m, f_m],
                    linewidth=2,
                    color="darkblue",
                    alpha=0.9,
                    label=rf"$\mathrm{{PD}}_{{-,{j}}}$ (sum)"
                    if (x_pos == x_pos_0 and use_paper_notation)
                    else ("$f_-$ (sum)" if x_pos == x_pos_0 else ""),
                )

                connect_color = "green" if diff >= 0 else "orange"
                y_bottom = min(f_p, f_m)
                y_top = max(f_p, f_m)
                if y_top - y_bottom > 0:
                    ax_global.fill_between(
                        [x_pos - 0.15, x_pos + 0.15],
                        y_bottom,
                        y_top,
                        color=connect_color,
                        alpha=0.3,
                    )

                if abs(f_p) > 0.01:
                    ax_global.text(
                        x_pos + 0.2,
                        f_p,
                        f"{f_p:.3f}",
                        ha="left",
                        va="center",
                        fontsize=8,
                    )
                if abs(f_m) > 0.01:
                    ax_global.text(
                        x_pos + 0.2,
                        f_m,
                        f"{f_m:.3f}",
                        ha="left",
                        va="center",
                        fontsize=8,
                    )

            ax_global.set_xlim(-0.5, 1.5)
            ax_global.set_xticks([x_pos_0, x_pos_1])
            ax_global.set_xticklabels(["0", "1"])
        else:
            # For continuous features
            diff_global = f_plus_global - f_minus_global_flipped

            ax_global.fill_between(
                x_vals,
                f_minus_global_flipped,
                f_plus_global,
                where=(diff_global >= 0),
                color="green",
                alpha=0.3,
            )
            ax_global.fill_between(
                x_vals,
                f_minus_global_flipped,
                f_plus_global,
                where=(diff_global < 0),
                color="orange",
                alpha=0.3,
            )

            ax_global.plot(
                x_vals,
                f_plus_global,
                linewidth=1.5,
                color="darkred",
                alpha=0.9,
                label=rf"$\mathrm{{PD}}_{{+,{j}}}$ (sum)"
                if use_paper_notation
                else "$f_+$ (sum)",
            )
            ax_global.plot(
                x_vals,
                f_minus_global_flipped,
                linewidth=1.5,
                color="darkblue",
                alpha=0.9,
                label=rf"$\mathrm{{PD}}_{{-,{j}}}$ (sum)"
                if use_paper_notation
                else "$f_-$ (sum)",
            )

        ax_global.axhline(y=0, color="black", linestyle="--", linewidth=0.5, alpha=0.5)
        ax_global.set_xlabel(
            f"$x_{j}$" if use_paper_notation else FEATURE_NAMES[feat_idx], fontsize=11
        )
        if plot_idx == 0:
            ax_global.set_ylabel(
                rf"$\mathrm{{PD}}_{{\pm,{j}}}$"
                if use_paper_notation
                else "Global PD (Sum)",
                fontsize=11,
            )
        ax_global.set_title(
            f"$x_{j}$\nGlobal: $\\sum \\mathrm{{PD}}_{{+,{j}}} - \\sum \\mathrm{{PD}}_{{-,{j}}}$"
            if use_paper_notation
            else f"{FEATURE_NAMES[feat_idx]}\nGlobal: $\\sum f_+ - \\sum f_-$",
            fontsize=11,
            fontweight="bold",
        )
        ax_global.grid(True, alpha=0.3)
        ax_global.legend(loc="best", fontsize=9)

    plt.tight_layout(rect=[0, 0, 1, 0.99])

    if save_path is None:
        save_path = FIGURES_DIR / "figure_3_1_scaled_first_order_pd.svg"
    file_format = (
        Path(save_path).suffix[1:].lower() if Path(save_path).suffix else "svg"
    )
    plt.savefig(save_path, format=file_format, bbox_inches="tight")
    print(f"Saved Figure 3.1 to {save_path}")

    return fig


def plot_figure_3_2_scaled_first_order_pd(
    model: MPF,
    X: np.ndarray,
    X_df: pd.DataFrame,
    grid_points: int = 200,
    save_path: Optional[str] = None,
) -> plt.Figure:
    """
    Figure 3.2: Scaled first-order partial dependence (f+ and f-) for selected features.

    Shows only: hour, weekday, temp, feel_temp, workingday

    Parameters:
    -----------
    model : MPF
        Trained MPF model
    X : np.ndarray
        Feature matrix
    X_df : pd.DataFrame
        Feature DataFrame with column names
    grid_points : int
        Number of grid points for evaluation
    save_path : str, optional
        Path to save figure

    Returns:
    --------
    fig : matplotlib.figure.Figure
        Figure object
    """
    # Select specific features: hour, weekday, temp, feel_temp, workingday
    selected_features = ["hour", "weekday", "temp", "feel_temp", "workingday"]

    # Find indices for these features
    feature_indices = []
    for feat_name in selected_features:
        if feat_name in FEATURE_NAMES:
            feature_indices.append(FEATURE_NAMES.index(feat_name))
        else:
            print(f"Warning: Feature '{feat_name}' not found in FEATURE_NAMES")

    if len(feature_indices) == 0:
        raise ValueError("None of the selected features were found in the dataset")

    print(
        f"Figure 3.2: Plotting features {selected_features} at indices {feature_indices}"
    )

    # Set default save path if not provided
    if save_path is None:
        save_path = FIGURES_DIR / "figure_3_2_scaled_first_order_pd.svg"

    # Call the existing function with the selected feature indices and paper notation
    # Pass a temporary path to avoid saving with wrong title
    temp_path = FIGURES_DIR / "temp_figure_3_2.svg"
    fig = plot_figure_3_1_scaled_first_order_pd(
        model=model,
        X=X,
        X_df=X_df,
        feature_indices=feature_indices,
        grid_points=grid_points,
        save_path=temp_path,  # Save to temp path first
        use_paper_notation=True,  # PD_{+,j}, PD_{-,j}, PD_{\pm,j}, x_j (same as synthetic_analysis)
    )

    if fig is not None:
        # Save with the correct path and title (format from extension, e.g. .pdf)
        file_format = (
            Path(save_path).suffix[1:].lower() if Path(save_path).suffix else "svg"
        )
        fig.savefig(save_path, format=file_format, bbox_inches="tight")
        print(f"Saved Figure 3.2 to {save_path}")

        # Remove temporary file if it exists
        if temp_path.exists():
            temp_path.unlink()

    return fig


def plot_figure_3_3_ice_curves(
    model: MPF,
    X: np.ndarray,
    X_df: pd.DataFrame,
    feature_indices: Optional[List[int]] = None,
    n_observations: int = 100,
    grid_points: int = 200,
    save_path: Optional[str] = None,
) -> plt.Figure:
    """
    Figure 3.3: Individual Conditional Expectation (ICE) curves.

    Parameters:
    -----------
    model : MPF
        Trained MPF model
    X : np.ndarray
        Feature matrix
    X_df : pd.DataFrame
        Feature DataFrame with column names
    feature_indices : list of int, optional
        Indices of features to plot. If None, plots key features only.
    n_observations : int
        Number of random observations to plot
    grid_points : int
        Number of grid points for evaluation
    save_path : str, optional
        Path to save figure

    Returns:
    --------
    fig : matplotlib.figure.Figure
        Figure object
    """
    if feature_indices is None:
        if KEY_FEATURE_INDICES is None:
            raise ValueError(
                "KEY_FEATURE_INDICES not set. Please call load_model_and_data() first."
            )
        feature_indices = KEY_FEATURE_INDICES

    n_features = len(feature_indices)
    n_epochs = len(model.tree_grid_families)

    # Select random observations
    n_samples = X.shape[0]
    if n_observations > n_samples:
        n_observations = n_samples

    np.random.seed(42)
    selected_indices = np.random.choice(n_samples, size=n_observations, replace=False)
    selected_observations = X[selected_indices]

    # Layout: n_epochs rows, n_features columns
    fig, axes = plt.subplots(
        n_epochs, n_features, figsize=(7 * n_features, 4 * n_epochs), squeeze=False
    )

    for plot_idx, feat_idx in enumerate(feature_indices):
        # Get feature range
        feat_min = X[:, feat_idx].min()
        feat_max = X[:, feat_idx].max()
        x_range = np.linspace(feat_min, feat_max, grid_points)

        # Compute ICE curves
        ice_values = model.compute_ice_curves(
            selected_observations, feature_index=feat_idx, x_range=x_range, data_x=X
        )

        for epoch_idx in range(n_epochs):
            ax = axes[epoch_idx, plot_idx]

            # Extract f+ and f- for this epoch
            f_plus_idx = 2 * epoch_idx
            f_minus_idx = 2 * epoch_idx + 1

            f_plus_curves = ice_values[:, :, f_plus_idx]
            f_minus_curves = ice_values[:, :, f_minus_idx]
            f_minus_curves_flipped = -f_minus_curves

            # Plot ICE curves
            linewidth = 0.3
            alpha = 0.4

            for obs_idx in range(n_observations):
                ax.plot(
                    x_range,
                    f_plus_curves[obs_idx, :],
                    linewidth=linewidth,
                    color="darkred",
                    alpha=alpha,
                )
                ax.plot(
                    x_range,
                    f_minus_curves_flipped[obs_idx, :],
                    linewidth=linewidth,
                    color="darkblue",
                    alpha=alpha,
                )

                # Fill difference
                diff = f_plus_curves[obs_idx, :] - f_minus_curves_flipped[obs_idx, :]
                ax.fill_between(
                    x_range,
                    f_minus_curves_flipped[obs_idx, :],
                    f_plus_curves[obs_idx, :],
                    where=(diff >= 0),
                    color="green",
                    alpha=0.1,
                    linewidth=0,
                )
                ax.fill_between(
                    x_range,
                    f_minus_curves_flipped[obs_idx, :],
                    f_plus_curves[obs_idx, :],
                    where=(diff < 0),
                    color="orange",
                    alpha=0.1,
                    linewidth=0,
                )

            ax.axhline(y=0, color="black", linestyle="--", linewidth=0.5, alpha=0.5)
            ax.set_xlabel(FEATURE_NAMES[feat_idx], fontsize=11)
            if plot_idx == 0:
                ax.set_ylabel(f"ICE (Stage {epoch_idx + 1})", fontsize=11)

            ax.set_title(
                f"{FEATURE_NAMES[feat_idx]}\nStage {epoch_idx + 1}: ICE Curves (n={n_observations})",
                fontsize=11,
                fontweight="bold",
            )
            ax.grid(True, alpha=0.3)

    plt.tight_layout(rect=[0, 0, 1, 0.99])

    if save_path is None:
        save_path = FIGURES_DIR / "figure_3_3_ice_curves.svg"
    plt.savefig(save_path, format="svg", bbox_inches="tight")
    print(f"Saved Figure 3.3 to {save_path}")

    return fig


def plot_figure_3_3_2d_pd_hour_workingday(
    model: MPF,
    X: np.ndarray,
    X_df: pd.DataFrame,
    grid_points: int = 50,
    save_path: Optional[str] = None,
) -> plt.Figure:
    """
    Figure 3.3: 2D Partial Dependence Plot for hour × workingday interaction.

    Shows partial dependence of hour conditioned on workingday (0 and 1) for all epochs.
    Plots two lines per epoch: one for workingday=0 and one for workingday=1.

    Parameters:
    -----------
    model : MPF
        Trained MPF model
    X : np.ndarray
        Feature matrix
    X_df : pd.DataFrame
        Feature DataFrame with column names
    grid_points : int
        Number of grid points for hour (x-axis)
    save_path : str, optional
        Path to save figure

    Returns:
    --------
    fig : matplotlib.figure.Figure
        Figure object
    """
    # Find feature indices for hour and workingday
    hour_idx = None
    workingday_idx = None

    for idx, name in enumerate(FEATURE_NAMES):
        if name == "hour":
            hour_idx = idx
        elif name == "workingday":
            workingday_idx = idx

    if hour_idx is None or workingday_idx is None:
        raise ValueError(
            "Could not find 'hour' or 'workingday' features in FEATURE_NAMES"
        )

    n_epochs = len(model.tree_grid_families)

    # Get hour range
    hour_min = X[:, hour_idx].min()
    hour_max = X[:, hour_idx].max()
    hour_vals = np.linspace(hour_min, hour_max, grid_points)

    # Create mean values for all features
    X_mean = X.mean(axis=0)

    # Layout: 1 row, 3 columns (3 panels side by side)
    fig, axes = plt.subplots(1, 3, figsize=(15, 5), squeeze=False)
    axes = axes.flatten()

    # Store PD values for all epochs to compute sum later
    pd_all_epochs_workingday0 = []
    pd_all_epochs_workingday1 = []

    # Compute partial dependence for each epoch using model's compute_partial_dependence_function
    # Fix hour and workingday, marginalize over other features
    for epoch_idx in range(n_epochs):
        # For each value of workingday (0 and 1)
        for workingday_val in [0, 1]:
            # Create grid: fix hour and workingday, vary hour
            # fixed_indices: [hour_idx, workingday_idx]
            # fixed_values: (grid_points, 2) - each row is [hour_val, workingday_val]
            fixed_values = np.column_stack(
                [hour_vals, np.full(grid_points, workingday_val)]
            )

            # Compute partial dependence function
            # This fixes hour and workingday, marginalizes over all other features
            constants_per_epoch, pd_values_all = (
                model.compute_partial_dependence_function(
                    fixed_indices=[hour_idx, workingday_idx],
                    fixed_values=fixed_values,
                    data_x=X,
                )
            )

            # Extract f+ and f- for this epoch
            # pd_values_all shape: (grid_points, 2 * n_epochs)
            # Columns: [f+_epoch0, f-_epoch0, f+_epoch1, f-_epoch1, ...]
            f_plus_idx = 2 * epoch_idx
            f_minus_idx = 2 * epoch_idx + 1

            f_plus = pd_values_all[:, f_plus_idx]
            f_minus = pd_values_all[:, f_minus_idx]

            # Total PD for this epoch = f+ + f-
            pd_values = f_plus + f_minus

            # Store for summing later
            if workingday_val == 0:
                pd_all_epochs_workingday0.append(pd_values)
            else:
                pd_all_epochs_workingday1.append(pd_values)

    # Plot 3 panels: first 2 epochs + sum across all epochs
    handles = None
    labels = None

    # Panel 1: First epoch
    if n_epochs >= 1:
        ax = axes[0]
        pd_workingday0 = pd_all_epochs_workingday0[0]
        pd_workingday1 = pd_all_epochs_workingday1[0]

        line0 = ax.plot(
            hour_vals,
            pd_workingday0,
            "o-",
            linewidth=2,
            markersize=3,
            alpha=0.7,
            color="blue",
            label="workingday = 0",
        )
        line1 = ax.plot(
            hour_vals,
            pd_workingday1,
            "o-",
            linewidth=2,
            markersize=3,
            alpha=0.7,
            color="red",
            label="workingday = 1",
        )

        handles = [line0[0], line1[0]]
        labels = ["workingday = 0", "workingday = 1"]

        ax.axhline(y=0, color="k", linestyle="--", alpha=0.3)
        ax.set_xlabel("hour", fontsize=12)
        ax.set_ylabel("Partial Dependence f_hour,workingday", fontsize=11)
        ax.set_title("Stage 1", fontsize=14)
        ax.grid(True, alpha=0.3)

    # Panel 2: Second epoch (if exists)
    if n_epochs >= 2:
        ax = axes[1]
        pd_workingday0 = pd_all_epochs_workingday0[1]
        pd_workingday1 = pd_all_epochs_workingday1[1]

        ax.plot(
            hour_vals,
            pd_workingday0,
            "o-",
            linewidth=2,
            markersize=3,
            alpha=0.7,
            color="blue",
            label="workingday = 0",
        )
        ax.plot(
            hour_vals,
            pd_workingday1,
            "o-",
            linewidth=2,
            markersize=3,
            alpha=0.7,
            color="red",
            label="workingday = 1",
        )

        ax.axhline(y=0, color="k", linestyle="--", alpha=0.3)
        ax.set_xlabel("hour", fontsize=12)
        ax.set_title("Stage 2", fontsize=14)
        ax.grid(True, alpha=0.3)

    # Panel 3: Sum across all epochs
    ax = axes[2]

    # Sum PD values across all epochs
    pd_sum_workingday0 = np.sum(pd_all_epochs_workingday0, axis=0)
    pd_sum_workingday1 = np.sum(pd_all_epochs_workingday1, axis=0)

    ax.plot(
        hour_vals,
        pd_sum_workingday0,
        "o-",
        linewidth=2,
        markersize=3,
        alpha=0.7,
        color="blue",
        label="workingday = 0",
    )
    ax.plot(
        hour_vals,
        pd_sum_workingday1,
        "o-",
        linewidth=2,
        markersize=3,
        alpha=0.7,
        color="red",
        label="workingday = 1",
    )

    ax.axhline(y=0, color="k", linestyle="--", alpha=0.3)
    ax.set_xlabel("hour", fontsize=12)
    ax.set_title("Stage 1 + 2", fontsize=14)
    ax.grid(True, alpha=0.3)

    # Create global legend at the bottom, closer to plots
    fig.legend(
        handles,
        labels,
        loc="lower center",
        bbox_to_anchor=(0.5, -0.01),
        ncol=2,
        fontsize=12,
    )

    # Adjust layout to make room for legend at bottom (less space needed since legend is closer)
    plt.tight_layout(rect=[0, 0.05, 1, 1])

    if save_path is None:
        save_path = FIGURES_DIR / "figure_3_3_2d_pd_hour_workingday.svg"
    # Detect format from file extension
    file_format = (
        Path(save_path).suffix[1:].lower() if Path(save_path).suffix else "svg"
    )
    plt.savefig(save_path, format=file_format, bbox_inches="tight")
    print(f"Saved Figure 3.3 (2D PD) to {save_path}")

    return fig


def compute_local_explanation(model: MPF, x: np.ndarray) -> Dict[str, Any]:
    """
    Compute local explanation for a single point x.

    Parameters:
    -----------
    model : MPF
        Trained MPF model
    x : np.ndarray
        Single data point (1D array of length n_features)

    Returns:
    --------
    explanation : dict
        Dictionary containing:
        - stage_contributions: array of s_ℓ(x) for each stage
        - f_plus_contributions: array of scaling_plus * f_plus for each stage
        - f_minus_contributions: array of -scaling_minus * f_minus for each stage
        - backbone_magnitudes: array of b^(ℓ)(x) for each stage
        - tilt_sums: array of d^(ℓ)(x) for each stage
        - feature_breakdown: dict mapping stage_idx to feature-level breakdown
        - total_prediction: sum of stage contributions
    """
    n_stages = len(model.tree_grid_families)
    n_features = len(x)

    stage_contributions = np.zeros(n_stages)
    f_plus_contributions = np.zeros(n_stages)
    f_minus_contributions = np.zeros(n_stages)
    backbone_magnitudes = np.zeros(n_stages)
    tilt_sums = np.zeros(n_stages)
    feature_breakdown = {}

    for stage_idx in range(n_stages):
        tgf = model.tree_grid_families[stage_idx]
        tg = TreeGrid(tgf.combined_tree_grid)

        lambda_plus = tg.lambda_plus
        lambda_minus = tg.lambda_minus

        # Get scaling factors from TreeGridFamily (defaults: scaling_plus=1.0, scaling_minus=0.0)
        scaling_plus = getattr(tgf, "scaling_plus", None) or 1.0
        scaling_minus = getattr(tgf, "scaling_minus", None) or 0.0

        backbone_per_feature = np.zeros(n_features)
        tilt_per_feature = np.zeros(n_features)

        for feat_idx in range(n_features):
            backbone_vals = tg.backbone_values[feat_idx]
            tilt_vals = tg.tilt_values[feat_idx]
            splits = tg.splits[feat_idx]

            if len(splits) == 0:
                bin_idx = 0
            else:
                bin_idx = np.searchsorted(splits, x[feat_idx], side="right")
                bin_idx = min(bin_idx, len(backbone_vals) - 1)

            backbone_per_feature[feat_idx] = backbone_vals[bin_idx]
            tilt_per_feature[feat_idx] = tilt_vals[bin_idx]

        backbone_magnitude = np.prod(backbone_per_feature)
        backbone_magnitudes[stage_idx] = backbone_magnitude

        tilt_sum = np.sum(tilt_per_feature)
        tilt_sums[stage_idx] = tilt_sum

        # Compute f_plus and f_minus (unscaled)
        f_plus = lambda_plus * backbone_magnitude * np.exp(tilt_sum)
        f_minus = lambda_minus * backbone_magnitude * np.exp(-tilt_sum)

        # Apply scaling and store components separately
        f_plus_contrib = scaling_plus * f_plus
        f_minus_contrib = -scaling_minus * f_minus  # negative sign for display

        f_plus_contributions[stage_idx] = f_plus_contrib
        f_minus_contributions[stage_idx] = f_minus_contrib

        # Total stage contribution
        stage_contrib = f_plus_contrib + f_minus_contrib
        stage_contributions[stage_idx] = stage_contrib

        feature_breakdown[stage_idx] = {
            "backbone_per_feature": backbone_per_feature.copy(),
            "tilt_per_feature": tilt_per_feature.copy(),
            "lambda_plus": lambda_plus,
            "lambda_minus": lambda_minus,
            "scaling_plus": scaling_plus,
            "scaling_minus": scaling_minus,
        }

    total_prediction = np.sum(stage_contributions)

    return {
        "stage_contributions": stage_contributions,
        "f_plus_contributions": f_plus_contributions,
        "f_minus_contributions": f_minus_contributions,
        "backbone_magnitudes": backbone_magnitudes,
        "tilt_sums": tilt_sums,
        "feature_breakdown": feature_breakdown,
        "total_prediction": total_prediction,
    }


def compute_and_plot_feature_importance(
    model: MPF,
    X: np.ndarray,
    feature_names: List[str],
    gamma: float = 1.0,
    figsize: Tuple[int, int] = (14, 10),
    save_path: Optional[str] = None,
) -> Dict[str, Any]:
    """
    Figure 6: Compute and visualize feature importance metrics for an MPF model.

    Parameters:
    -----------
    model : MPF
        Trained MPF model
    X : np.ndarray
        Feature matrix (training data)
    feature_names : list of str
        Names of features for labeling
    gamma : float, default=1.0
        Weight for tilt importance in combined score
    figsize : tuple, default=(14, 10)
        Figure size (width, height)
    save_path : str, optional
        Path to save figure. If None, displays interactively.

    Returns:
    --------
    dict
        Dictionary containing all computed importance metrics
    """
    print("\n" + "=" * 80)
    print("FEATURE IMPORTANCE ANALYSIS")
    print("=" * 80)

    # Compute per-stage feature importance
    backbone_per_stage, tilt_per_stage = model.compute_per_stage_feature_importance(X)
    n_stages, n_features = backbone_per_stage.shape

    # Compute aggregated feature importance
    global_backbone, global_tilt, stage_weights = (
        model.compute_aggregated_feature_importance(X)
    )

    # Compute combined feature importance
    combined, combined_backbone, combined_tilt = (
        model.compute_combined_feature_importance(X, gamma=gamma)
    )

    # Print summary statistics
    print(f"\nNumber of stages: {n_stages}")
    print(f"Number of features: {n_features}")
    print(f"\nStage weights: {stage_weights}")
    print(f"Stage weights sum: {stage_weights.sum():.6f}")

    # Print per-stage importance summary
    print("\n" + "-" * 80)
    print("PER-STAGE FEATURE IMPORTANCE (Top 3 features per stage)")
    print("-" * 80)
    for stage_idx in range(n_stages):
        print(f"\nStage {stage_idx + 1} (weight: {stage_weights[stage_idx]:.4f}):")
        # Backbone importance
        backbone_imp = backbone_per_stage[stage_idx, :]
        top_backbone_idx = np.argsort(backbone_imp)[::-1][:3]
        print("  Backbone importance (top 3):")
        for rank, feat_idx in enumerate(top_backbone_idx, 1):
            print(
                f"    {rank}. {feature_names[feat_idx]:15s}: {backbone_imp[feat_idx]:.6f}"
            )
        # Tilt importance
        tilt_imp = tilt_per_stage[stage_idx, :]
        top_tilt_idx = np.argsort(tilt_imp)[::-1][:3]
        print("  Tilt importance (top 3):")
        for rank, feat_idx in enumerate(top_tilt_idx, 1):
            print(
                f"    {rank}. {feature_names[feat_idx]:15s}: {tilt_imp[feat_idx]:.6f}"
            )

    # Print aggregated importance
    print("\n" + "-" * 80)
    print("AGGREGATED FEATURE IMPORTANCE (Global)")
    print("-" * 80)
    print("\nBackbone importance (global):")
    top_backbone_global = np.argsort(global_backbone)[::-1]
    for rank, feat_idx in enumerate(top_backbone_global, 1):
        print(
            f"  {rank}. {feature_names[feat_idx]:15s}: {global_backbone[feat_idx]:.6f}"
        )

    print("\nTilt importance (global):")
    top_tilt_global = np.argsort(global_tilt)[::-1]
    for rank, feat_idx in enumerate(top_tilt_global, 1):
        print(f"  {rank}. {feature_names[feat_idx]:15s}: {global_tilt[feat_idx]:.6f}")

    # Print combined importance
    print("\n" + "-" * 80)
    print(f"COMBINED FEATURE IMPORTANCE (I_j = I_j^b + {gamma} * I_j^d)")
    print("-" * 80)
    top_combined = np.argsort(combined)[::-1]
    for rank, feat_idx in enumerate(top_combined, 1):
        print(
            f"  {rank}. {feature_names[feat_idx]:15s}: {combined[feat_idx]:.6f} "
            f"(backbone: {combined_backbone[feat_idx]:.6f}, "
            f"tilt: {combined_tilt[feat_idx]:.6f})"
        )

    # Create visualization
    fig = plt.figure(figsize=figsize)

    # Plot 1: Per-stage backbone importance (heatmap)
    ax1 = fig.add_subplot(2, 3, 1)
    im1 = ax1.imshow(
        backbone_per_stage.T,
        aspect="auto",
        cmap="YlOrRd",
        interpolation="nearest",
    )
    ax1.set_xlabel("Stage")
    ax1.set_ylabel("Feature")
    ax1.set_title("Per-Stage Backbone Importance")
    ax1.set_xticks(range(n_stages))
    ax1.set_xticklabels([f"Stage {i + 1}" for i in range(n_stages)])
    ax1.set_yticks(range(n_features))
    ax1.set_yticklabels(feature_names, fontsize=8)
    plt.colorbar(im1, ax=ax1, label="Backbone Variance")

    # Plot 2: Per-stage tilt importance (heatmap)
    ax2 = fig.add_subplot(2, 3, 2)
    im2 = ax2.imshow(
        tilt_per_stage.T, aspect="auto", cmap="YlGnBu", interpolation="nearest"
    )
    ax2.set_xlabel("Stage")
    ax2.set_ylabel("Feature")
    ax2.set_title("Per-Stage Tilt Importance")
    ax2.set_xticks(range(n_stages))
    ax2.set_xticklabels([f"Stage {i + 1}" for i in range(n_stages)])
    ax2.set_yticks(range(n_features))
    ax2.set_yticklabels(feature_names, fontsize=8)
    plt.colorbar(im2, ax=ax2, label="Tilt Variance")

    # Plot 3: Global backbone importance (bar plot)
    ax3 = fig.add_subplot(2, 3, 3)
    sorted_idx = np.argsort(global_backbone)[::-1]
    ax3.barh(
        range(n_features),
        global_backbone[sorted_idx],
        color="orange",
        alpha=0.7,
    )
    ax3.set_yticks(range(n_features))
    ax3.set_yticklabels([feature_names[i] for i in sorted_idx], fontsize=8)
    ax3.set_xlabel("Global Backbone Importance")
    ax3.set_title("Aggregated Backbone Importance")
    ax3.grid(True, alpha=0.3, axis="x")

    # Plot 4: Global tilt importance (bar plot)
    ax4 = fig.add_subplot(2, 3, 4)
    sorted_idx = np.argsort(global_tilt)[::-1]
    ax4.barh(
        range(n_features),
        global_tilt[sorted_idx],
        color="cyan",
        alpha=0.7,
    )
    ax4.set_yticks(range(n_features))
    ax4.set_yticklabels([feature_names[i] for i in sorted_idx], fontsize=8)
    ax4.set_xlabel("Global Tilt Importance")
    ax4.set_title("Aggregated Tilt Importance")
    ax4.grid(True, alpha=0.3, axis="x")

    # Plot 5: Combined importance (bar plot)
    ax5 = fig.add_subplot(2, 3, 5)
    sorted_idx = np.argsort(combined)[::-1]
    ax5.barh(
        range(n_features),
        combined[sorted_idx],
        color="purple",
        alpha=0.7,
    )
    ax5.set_yticks(range(n_features))
    ax5.set_yticklabels([feature_names[i] for i in sorted_idx], fontsize=8)
    ax5.set_xlabel(f"Combined Importance (γ={gamma})")
    ax5.set_title("Combined Feature Importance")
    ax5.grid(True, alpha=0.3, axis="x")

    # Plot 6: Stage weights
    ax6 = fig.add_subplot(2, 3, 6)
    ax6.bar(
        range(n_stages),
        stage_weights,
        color="green",
        alpha=0.7,
    )
    ax6.set_xlabel("Stage")
    ax6.set_ylabel("Weight")
    ax6.set_title("Stage Weights (Energy-based)")
    ax6.set_xticks(range(n_stages))
    ax6.set_xticklabels([f"Stage {i + 1}" for i in range(n_stages)])
    ax6.grid(True, alpha=0.3, axis="y")

    plt.tight_layout(rect=[0, 0, 1, 0.99])

    if save_path is None:
        save_path = FIGURES_DIR / "figure_6_feature_importance.svg"
    plt.savefig(save_path, format="svg", bbox_inches="tight")
    print(f"Saved Figure 6 to {save_path}")

    # Return all computed metrics
    return {
        "backbone_per_stage": backbone_per_stage,
        "tilt_per_stage": tilt_per_stage,
        "global_backbone": global_backbone,
        "global_tilt": global_tilt,
        "combined": combined,
        "stage_weights": stage_weights,
    }


def plot_figure_5_local_explanations(
    model: MPF,
    X: np.ndarray,
    X_df: pd.DataFrame,
    point_a: Optional[np.ndarray] = None,
    point_b: Optional[np.ndarray] = None,
    save_path: Optional[str] = None,
    show_feature_decomposition: bool = True,
    top_k_features: int = 3,
) -> plt.Figure:
    """
    Figure 5: Local explanations for two contrasting points with split f+/f- bars.

    Each stage shows:
    - Blue bar (f+): positive contribution from scaling_plus * f_plus
    - Red bar (f-): negative contribution from -scaling_minus * f_minus
    - Vertical line at net contribution
    - Optional: feature-level decomposition within each bar using log-scale percentages

    Default points:
    - Point A: Weekday morning commute (workingday=1, hour=8, high demand)
    - Point B: Weekend afternoon (workingday=0, hour=15, moderate demand)

    Parameters:
    -----------
    model : MPF
        Trained MPF model
    X : np.ndarray
        Feature matrix
    X_df : pd.DataFrame
        Feature DataFrame with column names
    point_a : np.ndarray, optional
        First point (weekday morning). If None, selects automatically.
    point_b : np.ndarray, optional
        Second point (weekend afternoon). If None, selects automatically.
    save_path : str, optional
        Path to save figure
    show_feature_decomposition : bool, optional
        If True, show stacked feature contributions within each f+/f- bar
    top_k_features : int, optional
        Number of top features to annotate

    Returns:
    --------
    fig : matplotlib.figure.Figure
        Figure object
    """
    # Select points if not provided
    if point_a is None:
        # Weekday morning: workingday=1, hour=8
        hr_idx = FEATURE_NAMES.index("hour")
        workingday_idx = FEATURE_NAMES.index("workingday")

        # Find samples matching criteria
        mask_a = (X[:, workingday_idx] == 1) & (X[:, hr_idx] == 8)
        candidates_a = X[mask_a]
        if len(candidates_a) > 0:
            point_a = candidates_a[0]
        else:
            # Fallback: find closest
            point_a = X[0]

    if point_b is None:
        # Weekend afternoon: workingday=0, hour=15
        hr_idx = FEATURE_NAMES.index("hour")
        workingday_idx = FEATURE_NAMES.index("workingday")

        mask_b = (X[:, workingday_idx] == 0) & (X[:, hr_idx] == 15)
        candidates_b = X[mask_b]
        if len(candidates_b) > 0:
            point_b = candidates_b[0]
        else:
            point_b = X[1]

    # Compute explanations
    expl_a = compute_local_explanation(model, point_a)
    expl_b = compute_local_explanation(model, point_b)

    # Create figure
    fig, axes = plt.subplots(1, 2, figsize=(20, 8))

    for panel_idx, (point, expl, title) in enumerate(
        [
            (point_a, expl_a, "Weekday Morning Commute"),
            (point_b, expl_b, "Weekend Afternoon"),
        ]
    ):
        ax = axes[panel_idx]

        stage_contribs = expl["stage_contributions"]
        f_plus_contribs = expl["f_plus_contributions"]
        f_minus_contribs = expl["f_minus_contributions"]
        n_stages = len(stage_contribs)

        # Sort stages by absolute contribution magnitude
        sorted_indices = np.argsort(np.abs(stage_contribs))[::-1]
        sorted_net_contribs = stage_contribs[sorted_indices]
        sorted_f_plus = f_plus_contribs[sorted_indices]
        sorted_f_minus = f_minus_contribs[sorted_indices]
        stage_labels = [f"Stage {idx + 1}" for idx in sorted_indices]

        # Base value (0 for MPF, as contributions sum to prediction)
        base_value = 0.0
        total_prediction = expl["total_prediction"]

        # Build waterfall: cumulative values for net contributions
        cumulative = np.zeros(n_stages + 1)
        cumulative[0] = base_value
        for i in range(n_stages):
            cumulative[i + 1] = cumulative[i] + sorted_net_contribs[i]

        # Plot settings
        bar_positions = np.arange(n_stages)
        bar_height = 0.7

        # Colors for f+ and f-
        color_f_plus = "#1f77b4"  # Blue
        color_f_minus = "#d62728"  # Red

        # For each stage, draw split bars showing f+ and f- components
        for i in range(n_stages):
            stage_idx = sorted_indices[i]
            y_pos = bar_positions[i]
            net_contribution = cumulative[
                i + 1
            ]  # Position of net contribution (black bar)
            f_plus_val = sorted_f_plus[i]  # Positive value
            f_minus_val = sorted_f_minus[i]  # Negative value

            # Get feature breakdown for this stage
            breakdown = expl["feature_breakdown"][stage_idx]
            backbone_per_feature = breakdown["backbone_per_feature"]

            # Compute log-scale percentage contributions
            # Features with backbone far from 1 (either very small like 0.01 or very large like 100)
            # have large influence on the multiplicative structure
            log_contribs = []
            feature_indices = []
            for feat_idx, bb_val in enumerate(backbone_per_feature):
                # Only exclude if exactly 0 or too close to 0 for numerical stability
                if bb_val > 1e-15:  # Avoid log(0), but include very small values
                    log_val = np.abs(np.log(bb_val))
                    # Only exclude features very close to 1 (neutral multipliers)
                    if (
                        log_val > 1e-4
                    ):  # |log(bb)| > 1e-4 means bb < 0.9999 or bb > 1.0001
                        log_contribs.append(log_val)
                        feature_indices.append(feat_idx)

            # Normalize to percentages
            total_log = sum(log_contribs) if log_contribs else 1.0
            percentages = (
                [lc / total_log for lc in log_contribs] if log_contribs else []
            )

            # Sort features by percentage (largest first - closest to center)
            if percentages:
                sorted_pairs = sorted(
                    zip(percentages, feature_indices, log_contribs),
                    key=lambda x: x[0],
                    reverse=True,
                )
                percentages_sorted = [p for p, _, _ in sorted_pairs]
                feature_indices_sorted = [f for _, f, _ in sorted_pairs]
            else:
                percentages_sorted = []
                feature_indices_sorted = []

            # Draw f+ bar (blue, extending RIGHT from net_contribution)
            if abs(f_plus_val) > 1e-10:
                if show_feature_decomposition and percentages_sorted:
                    # Draw stacked segments, largest (darkest) closest to center
                    cumulative_width = 0
                    n_segments = len(percentages_sorted)
                    for seg_idx, (feat_idx, pct) in enumerate(
                        zip(feature_indices_sorted, percentages_sorted)
                    ):
                        segment_width = f_plus_val * pct

                        # Gradient: darkest at center (seg_idx=0), lightest at edge
                        alpha_base = (
                            0.85 - (seg_idx / max(n_segments - 1, 1)) * 0.4
                        )  # 0.85 -> 0.45

                        ax.barh(
                            y_pos,
                            segment_width,
                            bar_height,
                            left=net_contribution + cumulative_width,
                            color=color_f_plus,
                            alpha=alpha_base,
                            edgecolor="white",
                            linewidth=0.5,
                            zorder=2,
                        )

                        # Annotate segment with feature name if wide enough
                        if segment_width > 0.03 * abs(
                            f_plus_val
                        ):  # Only annotate if > 3% of bar width
                            ax.text(
                                net_contribution + cumulative_width + segment_width / 2,
                                y_pos,
                                FEATURE_NAMES[feat_idx].split()[0][
                                    :8
                                ],  # First word, max 8 chars
                                ha="center",
                                va="center",
                                fontsize=6,
                                color="white",
                                weight="bold",
                                zorder=3,
                            )

                        cumulative_width += segment_width
                else:
                    # Simple f+ bar without decomposition
                    ax.barh(
                        y_pos,
                        f_plus_val,
                        bar_height,
                        left=net_contribution,
                        color=color_f_plus,
                        alpha=0.7,
                        edgecolor="black",
                        linewidth=1.0,
                        zorder=2,
                    )

            # Draw f- bar (red, extending LEFT from net_contribution)
            if abs(f_minus_val) > 1e-10:
                if show_feature_decomposition and percentages_sorted:
                    # Draw stacked segments, largest (darkest) closest to center
                    # f_minus_val is negative, so segments go leftward
                    cumulative_width = 0
                    n_segments = len(percentages_sorted)
                    for seg_idx, (feat_idx, pct) in enumerate(
                        zip(feature_indices_sorted, percentages_sorted)
                    ):
                        segment_width = f_minus_val * pct  # Negative width

                        # Gradient: darkest at center (seg_idx=0), lightest at edge
                        alpha_base = (
                            0.85 - (seg_idx / max(n_segments - 1, 1)) * 0.4
                        )  # 0.85 -> 0.45

                        ax.barh(
                            y_pos,
                            segment_width,
                            bar_height,
                            left=net_contribution + cumulative_width,
                            color=color_f_minus,
                            alpha=alpha_base,
                            edgecolor="white",
                            linewidth=0.5,
                            zorder=2,
                        )

                        # Annotate segment with feature name if wide enough
                        if abs(segment_width) > 0.03 * abs(
                            f_minus_val
                        ):  # Only annotate if > 3% of bar width
                            ax.text(
                                net_contribution + cumulative_width + segment_width / 2,
                                y_pos,
                                FEATURE_NAMES[feat_idx].split()[0][
                                    :8
                                ],  # First word, max 8 chars
                                ha="center",
                                va="center",
                                fontsize=6,
                                color="white",
                                weight="bold",
                                zorder=3,
                            )

                        cumulative_width += segment_width
                else:
                    # Simple f- bar without decomposition
                    ax.barh(
                        y_pos,
                        f_minus_val,
                        bar_height,
                        left=net_contribution,
                        color=color_f_minus,
                        alpha=0.7,
                        edgecolor="black",
                        linewidth=1.0,
                        zorder=2,
                    )

            # Draw vertical line at net contribution (center position)
            ax.plot(
                [net_contribution, net_contribution],
                [y_pos - bar_height / 2, y_pos + bar_height / 2],
                color="black",
                linewidth=2.5,
                alpha=0.9,
                zorder=4,
            )

            # Annotate net contribution value above the black bar
            value_str = f"{sorted_net_contribs[i]:+.1f}"
            ax.text(
                net_contribution,
                y_pos + bar_height / 2 + 0.15,
                value_str,
                ha="center",
                va="bottom",
                fontsize=8,
                fontweight="bold",
                color="black",
                bbox=dict(
                    boxstyle="round,pad=0.2",
                    facecolor="white",
                    alpha=0.9,
                    edgecolor="black",
                    linewidth=0.8,
                ),
                zorder=5,
            )

            # Annotate top-k features for this stage (on the right side)
            if show_feature_decomposition and percentages_sorted and top_k_features > 0:
                top_features_str = ", ".join(
                    [
                        f"{FEATURE_NAMES[fidx]}: {pct * 100:.0f}%"
                        for fidx, pct in zip(
                            feature_indices_sorted[:top_k_features],
                            percentages_sorted[:top_k_features],
                        )
                    ]
                )

                ax.text(
                    0.98,
                    y_pos,
                    top_features_str,
                    ha="right",
                    va="center",
                    fontsize=7,
                    color="gray",
                    transform=ax.get_yaxis_transform(),
                    zorder=1,
                )

        # Draw connecting lines (waterfall effect)
        for i in range(n_stages - 1):
            y_start = bar_positions[i] - bar_height / 2
            y_end = bar_positions[i + 1] + bar_height / 2
            x_val = cumulative[i + 1]
            ax.plot(
                [x_val, x_val],
                [y_start, y_end],
                "k--",
                linewidth=1.0,
                alpha=0.4,
                zorder=1,
            )

        # Final prediction marker
        final_x = cumulative[-1]
        last_bar_y = bar_positions[-1]
        last_stage_contrib = sorted_net_contribs[-1]

        marker_symbol = ">" if last_stage_contrib >= 0 else "<"
        ax.plot(
            final_x,
            last_bar_y,
            marker=marker_symbol,
            markersize=15,
            color="#2ca02c",
            markeredgecolor="black",
            markeredgewidth=2.0,
            zorder=10,
        )

        # Final prediction annotation
        text_x = final_x + (0.05 if last_stage_contrib >= 0 else -0.05) * abs(
            total_prediction
        )
        text_ha = "left" if last_stage_contrib >= 0 else "right"
        ax.text(
            text_x,
            last_bar_y,
            f"{total_prediction:.1f}",
            ha=text_ha,
            va="center",
            fontsize=11,
            fontweight="bold",
            color="#2ca02c",
            bbox=dict(
                boxstyle="round,pad=0.4",
                facecolor="white",
                edgecolor="#2ca02c",
                linewidth=2.0,
            ),
            zorder=10,
        )

        # Set labels and styling
        ax.set_yticks(bar_positions)
        ax.set_yticklabels(stage_labels, fontsize=10)
        ax.set_xlabel("Contribution (bikes)", fontsize=12, fontweight="bold")
        ax.set_title(
            f"{title}\nTotal Prediction: {total_prediction:.1f} bikes",
            fontsize=14,
            fontweight="bold",
        )
        ax.grid(True, alpha=0.2, axis="x", zorder=0)
        ax.axvline(
            x=0, color="black", linestyle="-", linewidth=1.0, alpha=0.5, zorder=1
        )
        ax.invert_yaxis()

        # Add legend for f+ and f-
        from matplotlib.patches import Patch

        legend_elements = [
            Patch(
                facecolor=color_f_plus,
                alpha=0.7,
                edgecolor="black",
                label="f+ (positive component)",
            ),
            Patch(
                facecolor=color_f_minus,
                alpha=0.7,
                edgecolor="black",
                label="f- (negative component)",
            ),
            Patch(
                facecolor="white",
                edgecolor="black",
                linewidth=2,
                label="Net contribution",
            ),
        ]
        ax.legend(
            handles=legend_elements, loc="lower right", fontsize=9, framealpha=0.9
        )

        # Add feature values annotation (lower left corner)
        hr_val = int(point[FEATURE_NAMES.index("hour")])
        workingday_val = int(point[FEATURE_NAMES.index("workingday")])
        temp_val = point[FEATURE_NAMES.index("temp")]

        info_text = (
            f"Hour: {hr_val}\nWorkingday: {workingday_val}\nTemp: {temp_val:.2f}"
        )
        ax.text(
            0.05,
            0.05,
            info_text,
            transform=ax.transAxes,
            fontsize=10,
            verticalalignment="bottom",
            horizontalalignment="left",
            bbox=dict(boxstyle="round", facecolor="wheat", alpha=0.7),
        )

    plt.tight_layout(rect=[0, 0, 1, 1])

    if save_path is None:
        save_path = FIGURES_DIR / "figure_5_local_explanations.svg"
    plt.savefig(save_path, format="svg", bbox_inches="tight")
    print(f"Saved Figure 5 to {save_path}")

    return fig


def plot_figure_5_1_map_annotations(
    model: MPF,
    X: np.ndarray,
    X_df: pd.DataFrame,
    point_a: Optional[np.ndarray] = None,
    point_b: Optional[np.ndarray] = None,
    grid_points: int = 50,
    cmap: str = "viridis",
    save_path: Optional[str] = None,
) -> plt.Figure:
    """
    Figure 5.1: 2D partial dependence plot (hour × workingday) with annotated points.

    Shows the 2D partial dependence function f_hour,workingday(hour, workingday) as a heatmap,
    with the two local explanation points (weekday morning and weekend afternoon) annotated.

    Parameters:
    -----------
    model : MPF
        Trained MPF model
    X : np.ndarray
        Feature matrix
    X_df : pd.DataFrame
        Feature DataFrame with column names
    point_a : np.ndarray, optional
        First point (weekday morning). If None, selects automatically.
    point_b : np.ndarray, optional
        Second point (weekend afternoon). If None, selects automatically.
    grid_points : int
        Number of grid points per axis for partial dependence
    cmap : str
        Colormap for partial dependence
    save_path : str, optional
        Path to save figure

    Returns:
    --------
    fig : matplotlib.figure.Figure
        Figure object
    """

    # Select points if not provided
    if point_a is None:
        # Weekday morning: workingday=1, hour=8
        hr_idx = FEATURE_NAMES.index("hour")
        workingday_idx = FEATURE_NAMES.index("workingday")
        mask_a = (X[:, workingday_idx] == 1) & (X[:, hr_idx] == 8)
        candidates_a = X[mask_a]
        if len(candidates_a) > 0:
            point_a = candidates_a[0]
        else:
            point_a = X[0]

    if point_b is None:
        # Weekend afternoon: workingday=0, hour=15
        hr_idx = FEATURE_NAMES.index("hour")
        workingday_idx = FEATURE_NAMES.index("workingday")
        mask_b = (X[:, workingday_idx] == 0) & (X[:, hr_idx] == 15)
        candidates_b = X[mask_b]
        if len(candidates_b) > 0:
            point_b = candidates_b[0]
        else:
            point_b = X[1]

    # Get feature indices
    hour_idx = FEATURE_NAMES.index("hour")
    workingday_idx = FEATURE_NAMES.index("workingday")

    # Get feature bounds
    hour_min, hour_max = X[:, hour_idx].min(), X[:, hour_idx].max()
    # For workingday, use discrete values 0 and 1
    workingday_vals = [0, 1]

    # Create grid for partial dependence
    hour_vals = np.linspace(hour_min, hour_max, grid_points)

    # Convert to flat array for model input
    # compute_partial_dependence_function expects fixed_values as (n_points, n_fixed_features)
    # where each row is [hour_val, workingday_val]
    grid_list = []
    for wd_val in workingday_vals:
        for hr_val in hour_vals:
            grid_list.append([hr_val, wd_val])

    grid_flat = np.array(grid_list)

    # Compute partial dependence
    constants_per_epoch, partial_dep = model.compute_partial_dependence_function(
        [hour_idx, workingday_idx], grid_flat, X
    )
    n_epochs = partial_dep.shape[1] // 2

    # Extract f+ and f- for each epoch
    f_plus_all = partial_dep[:, ::2]  # Columns 0, 2, 4, ... (f+ for each epoch)
    f_minus_all = partial_dep[:, 1::2]  # Columns 1, 3, 5, ... (f- for each epoch)

    # Compute combined PD for each epoch: f+ + f- multiplied by scaling factor
    predictions = np.zeros(partial_dep.shape[0])
    for epoch_idx in range(n_epochs):
        # Get constants for this epoch (needed for scaling)
        c_plus, c_minus = constants_per_epoch[epoch_idx]

        # Compute sqrt(C_+ * (-C_-)) scaling factor
        sqrt_c_product = np.sqrt(c_plus * (-c_minus))

        # Compute PD for this stage: f+ + f- multiplied by scaling factor
        pd_epoch = (
            f_plus_all[:, epoch_idx] + f_minus_all[:, epoch_idx]
        ) * sqrt_c_product
        predictions += pd_epoch

    # Reshape predictions: first workingday=0, then workingday=1
    predictions_reshaped = predictions.reshape((len(workingday_vals), len(hour_vals)))

    # Create figure
    fig, ax = plt.subplots(figsize=(12, 6))

    # Plot partial dependence as contour/heatmap
    # Two lines: workingday=0 and workingday=1
    for wd_idx, wd_val in enumerate(workingday_vals):
        pd_values = predictions_reshaped[wd_idx, :]
        label = "Weekend (workingday=0)" if wd_val == 0 else "Weekday (workingday=1)"
        ax.plot(
            hour_vals,
            pd_values,
            linewidth=2.5,
            marker="o",
            markersize=4,
            label=label,
            alpha=0.8,
        )

    # Fill between the two lines to show the difference
    if len(workingday_vals) == 2:
        ax.fill_between(
            hour_vals,
            predictions_reshaped[0, :],
            predictions_reshaped[1, :],
            alpha=0.3,
            color="gray",
            label="Difference",
        )

    # Compute predictions for both points
    pred_a = model.predict(point_a.reshape(1, -1))[0]
    pred_b = model.predict(point_b.reshape(1, -1))[0]

    # Annotate point A (weekday morning)
    hour_a = point_a[hour_idx]
    workingday_a = point_a[workingday_idx]
    # Find corresponding PD value
    hour_idx_a = np.argmin(np.abs(hour_vals - hour_a))
    wd_idx_a = workingday_vals.index(int(workingday_a))
    pd_value_a = predictions_reshaped[wd_idx_a, hour_idx_a]

    ax.scatter(
        hour_a,
        pd_value_a,
        c="red",
        s=200,
        marker="*",
        edgecolors="black",
        linewidths=2,
        zorder=10,
        label="Weekday Morning",
    )
    ax.annotate(
        f"Morning\nPred: {pred_a:.0f}",
        xy=(hour_a, pd_value_a),
        xytext=(10, 10),
        textcoords="offset points",
        fontsize=11,
        fontweight="bold",
        color="red",
        bbox=dict(boxstyle="round,pad=0.3", facecolor="white", alpha=0.8),
        zorder=11,
        ha="left",
    )

    # Annotate point B (weekend afternoon)
    hour_b = point_b[hour_idx]
    workingday_b = point_b[workingday_idx]
    # Find corresponding PD value
    hour_idx_b = np.argmin(np.abs(hour_vals - hour_b))
    wd_idx_b = workingday_vals.index(int(workingday_b))
    pd_value_b = predictions_reshaped[wd_idx_b, hour_idx_b]

    ax.scatter(
        hour_b,
        pd_value_b,
        c="blue",
        s=200,
        marker="*",
        edgecolors="black",
        linewidths=2,
        zorder=10,
        label="Weekend Afternoon",
    )
    ax.annotate(
        f"Afternoon\nPred: {pred_b:.0f}",
        xy=(hour_b, pd_value_b),
        xytext=(10, 10),
        textcoords="offset points",
        fontsize=11,
        fontweight="bold",
        color="blue",
        bbox=dict(boxstyle="round,pad=0.3", facecolor="white", alpha=0.8),
        zorder=11,
        ha="left",
    )

    ax.set_xlabel("Hour of Day", fontsize=12, fontweight="bold")
    ax.set_ylabel("Partial Dependence (bikes)", fontsize=12, fontweight="bold")
    ax.set_title(
        "Partial Dependence f_hour,workingday(hour, workingday) with Local Explanation Points",
        fontsize=14,
        fontweight="bold",
    )
    ax.legend(loc="best", fontsize=10)
    ax.grid(True, alpha=0.3)

    plt.tight_layout()

    if save_path is None:
        save_path = FIGURES_DIR / "figure_5_1_map_annotations.svg"
    plt.savefig(save_path, format="svg", bbox_inches="tight")
    print(f"Saved Figure 5.1 to {save_path}")

    return fig


def plot_figure_5_2_stagewise_partial_dependence(
    model: MPF,
    X: np.ndarray,
    X_df: pd.DataFrame,
    point_a: Optional[np.ndarray] = None,
    point_b: Optional[np.ndarray] = None,
    grid_points: int = 50,
    cmap: str = "viridis",
    save_path: Optional[str] = None,
) -> plt.Figure:
    """
    Figure 5.2: Stage-wise partial dependence plots (hour × workingday) with annotated points.

    Shows partial dependence f_hour,workingday(hour, workingday) for each stage separately,
    with the two local explanation points annotated with their stage contributions.

    Parameters:
    -----------
    model : MPF
        Trained MPF model
    X : np.ndarray
        Feature matrix
    X_df : pd.DataFrame
        Feature DataFrame with column names
    point_a : np.ndarray, optional
        First point (weekday morning). If None, selects automatically.
    point_b : np.ndarray, optional
        Second point (weekend afternoon). If None, selects automatically.
    grid_points : int
        Number of grid points per axis for partial dependence
    cmap : str
        Colormap for partial dependence
    save_path : str, optional
        Path to save figure

    Returns:
    --------
    fig : matplotlib.figure.Figure
        Figure object
    """

    # Select points if not provided
    if point_a is None:
        # Weekday morning: workingday=1, hour=8
        hr_idx = FEATURE_NAMES.index("hour")
        workingday_idx = FEATURE_NAMES.index("workingday")
        mask_a = (X[:, workingday_idx] == 1) & (X[:, hr_idx] == 8)
        candidates_a = X[mask_a]
        if len(candidates_a) > 0:
            point_a = candidates_a[0]
        else:
            point_a = X[0]

    if point_b is None:
        # Weekend afternoon: workingday=0, hour=15
        hr_idx = FEATURE_NAMES.index("hour")
        workingday_idx = FEATURE_NAMES.index("workingday")
        mask_b = (X[:, workingday_idx] == 0) & (X[:, hr_idx] == 15)
        candidates_b = X[mask_b]
        if len(candidates_b) > 0:
            point_b = candidates_b[0]
        else:
            point_b = X[1]

    # Get feature indices
    hour_idx = FEATURE_NAMES.index("hour")
    workingday_idx = FEATURE_NAMES.index("workingday")

    # Get feature bounds
    hour_min, hour_max = X[:, hour_idx].min(), X[:, hour_idx].max()
    workingday_vals = [0, 1]

    # Create grid for partial dependence
    hour_vals = np.linspace(hour_min, hour_max, grid_points)

    # Create grid points
    # compute_partial_dependence_function expects fixed_values as (n_points, n_fixed_features)
    # where each row is [hour_val, workingday_val]
    grid_list = []
    for wd_val in workingday_vals:
        for hr_val in hour_vals:
            grid_list.append([hr_val, wd_val])

    grid_flat = np.array(grid_list)

    # Compute partial dependence for all stages
    constants_per_epoch, partial_dep = model.compute_partial_dependence_function(
        [hour_idx, workingday_idx], grid_flat, X
    )
    n_epochs = partial_dep.shape[1] // 2

    # Extract f+ and f- for each epoch
    f_plus_all = partial_dep[:, ::2]  # Columns 0, 2, 4, ... (f+ for each epoch)
    f_minus_all = partial_dep[:, 1::2]  # Columns 1, 3, 5, ... (f- for each epoch)

    # Create figure with subplots for each stage
    n_cols = min(4, n_epochs)  # Max 4 columns
    n_rows = (n_epochs + n_cols - 1) // n_cols  # Ceiling division
    fig = plt.figure(figsize=(6 * n_cols, 5 * n_rows))

    # Plot each stage separately
    for epoch_idx in range(n_epochs):
        ax = fig.add_subplot(n_rows, n_cols, epoch_idx + 1)

        # Get constants for this epoch (needed for scaling)
        c_plus, c_minus = constants_per_epoch[epoch_idx]

        # Compute sqrt(C_+ * (-C_-)) scaling factor
        sqrt_c_product = np.sqrt(c_plus * (-c_minus))

        # Compute PD for this stage: f+ + f- multiplied by scaling factor
        pd_stage = (
            f_plus_all[:, epoch_idx] + f_minus_all[:, epoch_idx]
        ) * sqrt_c_product

        # Reshape: first workingday=0, then workingday=1
        pd_stage_reshaped = pd_stage.reshape((len(workingday_vals), len(hour_vals)))

        # Plot lines for each workingday value
        for wd_idx, wd_val in enumerate(workingday_vals):
            pd_values = pd_stage_reshaped[wd_idx, :]
            label = (
                "Weekend (workingday=0)" if wd_val == 0 else "Weekday (workingday=1)"
            )
            ax.plot(
                hour_vals,
                pd_values,
                linewidth=2,
                marker="o",
                markersize=3,
                label=label,
                alpha=0.8,
            )

        # Fill between the two lines
        if len(workingday_vals) == 2:
            ax.fill_between(
                hour_vals,
                pd_stage_reshaped[0, :],
                pd_stage_reshaped[1, :],
                alpha=0.2,
                color="gray",
            )

        # Evaluate PD at the exact point coordinates
        hour_a, workingday_a = point_a[hour_idx], point_a[workingday_idx]
        hour_b, workingday_b = point_b[hour_idx], point_b[workingday_idx]

        # Find closest grid indices
        hour_idx_a = np.argmin(np.abs(hour_vals - hour_a))
        hour_idx_b = np.argmin(np.abs(hour_vals - hour_b))
        wd_idx_a = workingday_vals.index(int(workingday_a))
        wd_idx_b = workingday_vals.index(int(workingday_b))

        # Get PD values at these grid points
        pd_value_a = pd_stage_reshaped[wd_idx_a, hour_idx_a]
        pd_value_b = pd_stage_reshaped[wd_idx_b, hour_idx_b]

        # Annotate point A (weekday morning)
        ax.scatter(
            hour_a,
            pd_value_a,
            c="red",
            s=150,
            marker="*",
            edgecolors="black",
            linewidths=1.5,
            zorder=10,
        )
        ax.annotate(
            f"Morning\nPD: {pd_value_a:.0f}",
            xy=(hour_a, pd_value_a),
            xytext=(10, 10),
            textcoords="offset points",
            fontsize=9,
            fontweight="bold",
            color="red",
            bbox=dict(boxstyle="round,pad=0.3", facecolor="white", alpha=0.8),
            zorder=11,
            ha="left",
        )

        # Annotate point B (weekend afternoon)
        ax.scatter(
            hour_b,
            pd_value_b,
            c="blue",
            s=150,
            marker="*",
            edgecolors="black",
            linewidths=1.5,
            zorder=10,
        )
        ax.annotate(
            f"Afternoon\nPD: {pd_value_b:.0f}",
            xy=(hour_b, pd_value_b),
            xytext=(10, 10),
            textcoords="offset points",
            fontsize=9,
            fontweight="bold",
            color="blue",
            bbox=dict(boxstyle="round,pad=0.3", facecolor="white", alpha=0.8),
            zorder=11,
            ha="left",
        )

        ax.set_xlabel("Hour of Day", fontsize=10)
        ax.set_ylabel("Partial Dependence (bikes)", fontsize=10)
        ax.set_title(
            f"Stage {epoch_idx + 1}\nC+={c_plus:.2e}, C-={c_minus:.2e}, Scale={sqrt_c_product:.2e}",
            fontsize=11,
            fontweight="bold",
        )
        ax.grid(True, alpha=0.3)
        ax.legend(loc="best", fontsize=8)

    plt.tight_layout(rect=[0, 0, 1, 0.98])

    if save_path is None:
        save_path = FIGURES_DIR / "figure_5_2_stagewise_partial_dependence.svg"
    plt.savefig(save_path, format="svg", bbox_inches="tight")
    print(f"Saved Figure 5.2 to {save_path}")

    return fig


if __name__ == "__main__":
    print("=" * 80)
    print("Bike Sharing - MPF Interpretability Analysis")
    print("=" * 80)
    print()

    # Load model and data
    model, X, y, X_df = load_model_and_data()

    # Plot all features (not just key features)
    all_feature_indices = list(range(len(FEATURE_NAMES)))

    # Generate figures
    print("\nGenerating Figure 1: Backbone plots for all features...")
    plot_figure_1_backbone(model, X, X_df, feature_indices=all_feature_indices)

    print("\nGenerating Figure 2: Nominal contribution plots for all features...")
    plot_figure_2_nominal_contribution(
        model, X, X_df, feature_indices=all_feature_indices
    )

    print("\nGenerating Figure 3: First-order partial dependence for all features...")
    plot_figure_3_first_order_pd(model, X, X_df, feature_indices=all_feature_indices)

    print(
        "\nGenerating Figure 3.1: Scaled first-order PD (f+ and f-) for all features..."
    )
    plot_figure_3_1_scaled_first_order_pd(
        model, X, X_df, feature_indices=all_feature_indices
    )

    print(
        "\nGenerating Figure 3.2: Scaled first-order PD (f+ and f-) for selected features..."
    )
    plot_figure_3_2_scaled_first_order_pd(model, X, X_df)

    print("\nGenerating Figure 3.3: ICE curves for all features...")
    plot_figure_3_3_ice_curves(
        model, X, X_df, feature_indices=all_feature_indices, n_observations=100
    )

    print("\nGenerating Figure 3.3 (2D PD): hour × workingday for all epochs...")
    plot_figure_3_3_2d_pd_hour_workingday(model, X, X_df, grid_points=50)

    print("\nGenerating Figure 5: Local explanations...")
    plot_figure_5_local_explanations(model, X, X_df)

    print(
        "\nGenerating Figure 5.1: 2D partial dependence (hour × workingday) with annotations..."
    )
    plot_figure_5_1_map_annotations(model, X, X_df)

    print(
        "\nGenerating Figure 5.2: Stage-wise partial dependence (hour × workingday)..."
    )
    plot_figure_5_2_stagewise_partial_dependence(model, X, X_df)

    print("\nGenerating Figure 6: Feature importance...")
    compute_and_plot_feature_importance(model, X, FEATURE_NAMES, gamma=1.0)

    print("\n" + "=" * 80)
    print("Interpretability analysis complete!")
    print(f"Figures saved to: {FIGURES_DIR}")
    print("=" * 80)
