"""
Load ExplainableBoostingRegressor model and make predictions on Bike Sharing data.
"""

from contextlib import redirect_stderr, redirect_stdout
from io import StringIO
from pathlib import Path

import joblib
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from interpret import preserve
from sklearn.metrics import r2_score

# Feature names matching bike_sharing.py
FEATURE_NAMES = [
    "year",
    "month",
    "hour",
    "weekday",
    "temp",
    "feel_temp",
    "humidity",
    "windspeed",
    "holiday",
    "workingday",
    "season",
    "weather",
]


def load_model(model_path: Path):
    """Load the ExplainableBoostingRegressor model from a pickle file.

    Parameters
    ----------
    model_path : Path
        Path to the model pickle file

    Returns
    -------
    model
        Loaded EBM model
    """
    print(f"Loading model from: {model_path}")
    model = joblib.load(model_path)
    print("Model loaded.")
    return model


def create_feature_name_mapping(model_feature_names: list, real_feature_names: list):
    """Create a mapping from model's feature names to real feature names.

    Parameters
    ----------
    model_feature_names : list
        Feature names from the model (may be generic like "feature_0000")
    real_feature_names : list
        Real feature names (e.g., ["year", "month", "hour", ...])

    Returns
    -------
    dict
        Mapping from model names to real names
    """
    mapping = {}
    for i, model_name in enumerate(model_feature_names):
        if i < len(real_feature_names):
            mapping[model_name] = real_feature_names[i]
        else:
            mapping[model_name] = model_name  # Keep original if no mapping available
    return mapping


def load_data(csv_path: Path, model=None):
    """Load Bike Sharing data from CSV.

    Parameters
    ----------
    csv_path : Path
        Path to the CSV file
    model : optional
        Trained model to get feature names from

    Returns
    -------
    X_df : pd.DataFrame
        Features as DataFrame with named columns (using model's feature names)
    X : np.ndarray
        Features as numpy array
    y : np.ndarray
        Target values
    feature_names : list
        Feature names used (from model if available, else default)
    feature_name_mapping : dict
        Mapping from model feature names to real feature names
    """
    print(f"\nLoading bike sharing data from: {csv_path}")
    df = pd.read_csv(csv_path, header=None)

    # Bike Sharing CSV format: 13 columns (0-11 are features, 12 is target)
    # Columns: year, month, hour, weekday, temp, feel_temp, humidity, windspeed,
    #          holiday, workingday, season, weather, count
    column_names = [
        "year",
        "month",
        "hour",
        "weekday",
        "temp",
        "feel_temp",
        "humidity",
        "windspeed",
        "holiday",
        "workingday",
        "season",
        "weather",
        "count",
    ]
    df.columns = column_names[: len(df.columns)]

    # Extract features and target
    X = df.iloc[:, :-1].values  # All columns except last
    y = df.iloc[:, -1].values  # Last column is target (count)

    # Use model's feature names if available, otherwise use defaults
    if model is not None and hasattr(model, "feature_names_in_"):
        model_feature_names = list(model.feature_names_in_)
        print(f"Model's feature names: {model_feature_names}")
    elif model is not None and hasattr(model, "term_names_"):
        # Extract feature names from term_names_ (single features only)
        model_feature_names = [
            name
            for name in model.term_names_
            if isinstance(name, str) and " & " not in name
        ]
        # validate count
        if len(model_feature_names) != X.shape[1]:
            model_feature_names = [f"feature_{i:04d}" for i in range(X.shape[1])]
    else:
        model_feature_names = [f"feature_{i:04d}" for i in range(X.shape[1])]

    # Create mapping from model names to real names
    feature_name_mapping = create_feature_name_mapping(
        model_feature_names, FEATURE_NAMES[: X.shape[1]]
    )

    print(f"Feature name mapping: {feature_name_mapping}")

    # Create DataFrame with model's feature names (for compatibility)
    X_df = pd.DataFrame(X, columns=model_feature_names)

    print(f"Data shape: {X.shape}")
    print(f"Target shape: {y.shape}")
    print(f"Using model feature names for DataFrame: {model_feature_names}")

    return X_df, X, y, model_feature_names, feature_name_mapping


def make_predictions(model, X_df: pd.DataFrame, X: np.ndarray):
    """Make predictions using the EBM model.

    Parameters
    ----------
    model
        Trained EBM model
    X_df : pd.DataFrame
        Features as DataFrame
    X : np.ndarray
        Features as numpy array (fallback)

    Returns
    -------
    predictions : np.ndarray
        Model predictions
    """
    print("\nMaking predictions...")
    predictions = model.predict(X_df)
    return predictions


def calculate_metrics(y: np.ndarray, predictions: np.ndarray):
    """Calculate and print prediction metrics.

    Parameters
    ----------
    y : np.ndarray
        True target values
    predictions : np.ndarray
        Model predictions
    """
    mse = np.mean((y - predictions) ** 2)
    rmse = np.sqrt(mse)
    mae = np.mean(np.abs(y - predictions))
    r2 = 1 - (np.sum((y - predictions) ** 2) / np.sum((y - np.mean(y)) ** 2))

    print("\nPrediction Metrics:")
    print(f"  MSE:  {mse:.4f}")
    print(f"  RMSE: {rmse:.4f}")
    print(f"  MAE:  {mae:.4f}")
    print(f"  R²:   {r2:.4f}")

    print("\nPrediction statistics:")
    print(f"  Min:  {predictions.min():.4f}")
    print(f"  Max:  {predictions.max():.4f}")
    print(f"  Mean: {predictions.mean():.4f}")
    print(f"  Std:  {predictions.std():.4f}")

    print("\nTarget statistics:")
    print(f"  Min:  {y.min():.4f}")
    print(f"  Max:  {y.max():.4f}")
    print(f"  Mean: {y.mean():.4f}")
    print(f"  Std:  {y.std():.4f}")


def plot_predictions_vs_true(
    y_true: np.ndarray,
    predictions: np.ndarray,
    predictions_clamped=None,
    figsize=(14, 6),
    alpha=0.5,
    s=10,
    title_prefix="EBM",
    output_dir: Path = None,
):
    """Plot predictions vs true values in scatter plots and save as SVG.

    Parameters
    ----------
    y_true : np.ndarray
        True target values
    predictions : np.ndarray
        Model predictions (may include negative values)
    predictions_clamped : np.ndarray, optional
        Clamped predictions (non-negative). If provided, creates a two-panel plot.
        If None, creates a single-panel plot.
    figsize : tuple, default=(14, 6)
        Figure size (width, height)
    alpha : float, default=0.5
        Transparency of scatter points
    s : float, default=10
        Size of scatter points
    title_prefix : str, default="EBM"
        Prefix for plot titles
    output_dir : Path, optional
        Directory to save SVG file. If None, uses interpretability_figures subdirectory.

    Returns
    -------
    fig : matplotlib.figure.Figure
        The figure object
    axes : matplotlib.axes.Axes or array of axes
        The axes object(s)
    """
    if predictions_clamped is not None:
        # Two-panel plot: original and clamped
        fig, axes = plt.subplots(1, 2, figsize=figsize)

        # Plot 1: Original predictions (may include negatives)
        ax1 = axes[0]
        ax1.scatter(y_true, predictions, alpha=alpha, s=s, color="blue")
        # Add diagonal line (perfect predictions)
        min_val = min(y_true.min(), predictions.min())
        max_val = max(y_true.max(), predictions.max())
        ax1.plot(
            [min_val, max_val],
            [min_val, max_val],
            "r--",
            linewidth=2,
            label="Perfect prediction",
        )
        ax1.set_xlabel("True Values", fontsize=12)
        ax1.set_ylabel("Predictions", fontsize=12)
        ax1.set_title(
            f"{title_prefix} Predictions vs True Values (Original)", fontsize=13
        )
        ax1.grid(True, alpha=0.3)
        ax1.legend()
        # Add R² score
        r2_orig = r2_score(y_true, predictions)
        ax1.text(
            0.05,
            0.95,
            f"R² = {r2_orig:.4f}",
            transform=ax1.transAxes,
            fontsize=11,
            verticalalignment="top",
            bbox=dict(boxstyle="round", facecolor="wheat", alpha=0.5),
        )

        # Plot 2: Clamped predictions (non-negative)
        ax2 = axes[1]
        ax2.scatter(y_true, predictions_clamped, alpha=alpha, s=s, color="green")
        # Add diagonal line (perfect predictions)
        min_val = min(y_true.min(), predictions_clamped.min())
        max_val = max(y_true.max(), predictions_clamped.max())
        ax2.plot(
            [min_val, max_val],
            [min_val, max_val],
            "r--",
            linewidth=2,
            label="Perfect prediction",
        )
        ax2.set_xlabel("True Values", fontsize=12)
        ax2.set_ylabel("Clamped Predictions", fontsize=12)
        ax2.set_title(
            f"{title_prefix} Predictions vs True Values (Clamped ≥ 0)", fontsize=13
        )
        ax2.grid(True, alpha=0.3)
        ax2.legend()
        # Add R² score
        r2_clamped = r2_score(y_true, predictions_clamped)
        ax2.text(
            0.05,
            0.95,
            f"R² = {r2_clamped:.4f}",
            transform=ax2.transAxes,
            fontsize=11,
            verticalalignment="top",
            bbox=dict(boxstyle="round", facecolor="wheat", alpha=0.5),
        )

        plt.tight_layout()

        # Save as SVG
        if output_dir is None:
            output_dir = Path(__file__).parent / "outputs"
        output_dir = Path(output_dir)
        output_dir.mkdir(parents=True, exist_ok=True)

        svg_path = output_dir / "figure_0_predictions_vs_true.svg"
        fig.savefig(svg_path, format="svg", bbox_inches="tight", dpi=300)
        print(f"   Predictions vs true values plot saved to: {svg_path}")

        plt.close(fig)

        return fig, axes
    else:
        # Single-panel plot: original predictions only
        fig, ax = plt.subplots(1, 1, figsize=(figsize[0] // 2, figsize[1]))

        ax.scatter(y_true, predictions, alpha=alpha, s=s, color="blue")
        # Add diagonal line (perfect predictions)
        min_val = min(y_true.min(), predictions.min())
        max_val = max(y_true.max(), predictions.max())
        ax.plot(
            [min_val, max_val],
            [min_val, max_val],
            "r--",
            linewidth=2,
            label="Perfect prediction",
        )
        ax.set_xlabel("True Values", fontsize=12)
        ax.set_ylabel("Predictions", fontsize=12)
        ax.set_title(f"{title_prefix} Predictions vs True Values", fontsize=13)
        ax.grid(True, alpha=0.3)
        ax.legend()
        # Add R² score
        r2 = r2_score(y_true, predictions)
        ax.text(
            0.05,
            0.95,
            f"R² = {r2:.4f}",
            transform=ax.transAxes,
            fontsize=11,
            verticalalignment="top",
            bbox=dict(boxstyle="round", facecolor="wheat", alpha=0.5),
        )

        plt.tight_layout()

        # Save as SVG
        if output_dir is None:
            output_dir = Path(__file__).parent / "outputs"
        output_dir = Path(output_dir)
        output_dir.mkdir(parents=True, exist_ok=True)

        svg_path = output_dir / "figure_0_predictions_vs_true.svg"
        fig.savefig(svg_path, format="svg", bbox_inches="tight", dpi=300)
        print(f"   Predictions vs true values plot saved to: {svg_path}")

        plt.close(fig)

        return fig, ax


def visualize_global_explanation(
    model, feature_name_mapping: dict = None, output_dir: Path = None
):
    """Generate and save global explanation to HTML file.

    Parameters
    ----------
    model
        Trained EBM model
    feature_name_mapping : dict, optional
        Mapping from model feature names to real feature names
    output_dir : Path, optional
        Directory to save HTML file. If None, uses script directory.
    """
    print("\n1. Global Explanation (Feature Importance & Response Functions)")
    print("   This shows how each feature contributes to predictions overall.")
    global_explanation = model.explain_global()
    if output_dir is None:
        output_dir = Path(__file__).parent
    html_path = output_dir / "ebm_global_explanation.html"
    import builtins

    _original_print = builtins.print

    def _silent_print(*args, **kwargs):
        text = " ".join(str(arg) for arg in args)
        if "text/html" not in text and "PlotlyConfig" not in text:
            _original_print(*args, **kwargs)

    builtins.print = _silent_print
    with redirect_stdout(StringIO()), redirect_stderr(StringIO()):
        preserve(global_explanation, file_name=str(html_path))
    builtins.print = _original_print

    if html_path.exists():
        print(f"   Global explanation saved to: {html_path}")
    else:
        print(f"   Failed to save global explanation to: {html_path}")
    if feature_name_mapping:
        print(
            "   Feature names may show model internal names; see summary for mapping."
        )


def visualize_local_explanation(
    model,
    X_df: pd.DataFrame,
    predictions: np.ndarray,
    n_examples: int = 5,
    output_dir: Path = None,
):
    """Generate and save local explanation to HTML file for specific examples.

    Parameters
    ----------
    model
        Trained EBM model
    X_df : pd.DataFrame
        Features as DataFrame
    predictions : np.ndarray
        Model predictions
    n_examples : int
        Number of examples to explain
    output_dir : Path, optional
        Directory to save HTML file. If None, uses script directory.
    """
    print("\n2. Local Explanation (Feature Contributions for Individual Predictions)")
    print("   This shows why the model made specific predictions.")
    n_examples = min(n_examples, len(X_df))
    X_sample = X_df.iloc[:n_examples]
    predictions_sample = predictions[:n_examples]
    local_explanation = model.explain_local(X_sample, predictions_sample, name="EBM")
    if output_dir is None:
        output_dir = Path(__file__).parent
    html_path = output_dir / "ebm_local_explanation.html"
    import builtins

    _original_print = builtins.print

    def _silent_print(*args, **kwargs):
        text = " ".join(str(arg) for arg in args)
        if "text/html" not in text and "PlotlyConfig" not in text:
            _original_print(*args, **kwargs)

    builtins.print = _silent_print
    with redirect_stdout(StringIO()), redirect_stderr(StringIO()):
        preserve(local_explanation, file_name=str(html_path))
    builtins.print = _original_print

    if html_path.exists():
        print(f"   Local explanation saved to: {html_path} ({n_examples} examples)")
    else:
        print(f"   Failed to save local explanation to: {html_path}")


def print_feature_importance(model, feature_names: list, feature_name_mapping: dict):
    """Print feature importance summary.

    Parameters
    ----------
    model
        Trained EBM model
    feature_names : list
        List of feature names (model's names)
    feature_name_mapping : dict
        Mapping from model feature names to real feature names
    """
    print("\n3. Feature Importance Summary")
    if hasattr(model, "global_importance"):
        importance = model.global_importance()
        display_names = [feature_name_mapping.get(name, name) for name in feature_names]
        importance_df = pd.DataFrame(
            {"Feature": display_names, "Importance": importance}
        ).sort_values("Importance", ascending=False)
        print("\n   Feature Importance (sorted):")
        print(importance_df.to_string(index=False))
    else:
        print(
            "   (Feature importance not directly available, see global explanation above)"
        )


def extract_gam_component(model, term_idx: int):
    """Extract GAM component data for a single term.

    Parameters
    ----------
    model
        Trained EBM model
    term_idx : int
        Index of the term to extract

    Returns
    -------
    dict
        Dictionary containing:
        - 'scores': array of scores (excluding missing/unseen bins)
        - 'x_values': array of x-axis values (bin centers for continuous, categories for categorical)
        - 'feature_type': 'continuous', 'nominal', or 'ordinal'
        - 'feature_idx': index of the feature in the model
    """
    feature_idxs = model.term_features_[term_idx]
    if len(feature_idxs) != 1:
        raise ValueError("extract_gam_component only works for single-feature terms")

    feature_idx = feature_idxs[0]
    feature_type = model.feature_types_in_[feature_idx]
    feature_bins = model.bins_[feature_idx][0]  # Main effect binning
    scores_full = model.term_scores_[term_idx]

    # Skip missing bin (index 0) and unseen bin (last index)
    scores = scores_full[1:-1]

    if isinstance(feature_bins, dict):
        # Categorical feature
        categories = sorted(feature_bins.keys(), key=lambda x: feature_bins[x])
        return {
            "scores": scores,
            "x_values": categories,
            "feature_type": feature_type,
            "feature_idx": feature_idx,
        }
    else:
        # Continuous feature
        cut_points = feature_bins

        # Use histogram_edges_ if available (better for plotting)
        if (
            hasattr(model, "histogram_edges_")
            and model.histogram_edges_[feature_idx] is not None
        ):
            edges = model.histogram_edges_[feature_idx]
            # histogram_edges typically has n_bins+1 edges for n_bins scores
            # But it might include missing/unseen boundaries
            if len(edges) == len(scores) + 2:
                # Includes missing and unseen boundaries, use middle edges
                bin_centers = (edges[1:-1] + edges[2:]) / 2
            elif len(edges) == len(scores) + 1:
                # Standard case: n+1 edges for n bins
                bin_centers = (edges[:-1] + edges[1:]) / 2
            else:
                # Fallback: construct from cut points
                if hasattr(model, "feature_bounds_") and not np.isnan(
                    model.feature_bounds_[feature_idx, 0]
                ):
                    min_val = model.feature_bounds_[feature_idx, 0]
                    max_val = model.feature_bounds_[feature_idx, 1]
                else:
                    min_val = cut_points[0] if len(cut_points) > 0 else 0
                    max_val = cut_points[-1] if len(cut_points) > 0 else 1
                edges = np.concatenate([[min_val], cut_points, [max_val]])
                bin_centers = (edges[:-1] + edges[1:]) / 2
        else:
            # Construct bin centers from cut points
            if hasattr(model, "feature_bounds_") and not np.isnan(
                model.feature_bounds_[feature_idx, 0]
            ):
                min_val = model.feature_bounds_[feature_idx, 0]
                max_val = model.feature_bounds_[feature_idx, 1]
            else:
                # Use cut points to estimate range
                min_val = (
                    cut_points[0] - (cut_points[1] - cut_points[0])
                    if len(cut_points) > 1
                    else cut_points[0] - 1
                )
                max_val = (
                    cut_points[-1] + (cut_points[-1] - cut_points[-2])
                    if len(cut_points) > 1
                    else cut_points[-1] + 1
                )

            # Create edges: [min_val, cut1, cut2, ..., cutN, max_val]
            edges = np.concatenate([[min_val], cut_points, [max_val]])
            bin_centers = (edges[:-1] + edges[1:]) / 2

        return {
            "scores": scores,
            "x_values": bin_centers,
            "feature_type": feature_type,
            "feature_idx": feature_idx,
        }


def plot_feature_responses(
    model,
    X_df: pd.DataFrame,
    feature_name_mapping: dict,
    output_dir: Path = None,
):
    """Display matplotlib plots of feature response functions and save as SVG.

    Parameters
    ----------
    model
        Trained EBM model
    X_df : pd.DataFrame
        Features as DataFrame (with model's feature names)
    feature_name_mapping : dict
        Mapping from model feature names to real feature names
    output_dir : Path, optional
        Directory to save SVG file. If None, uses interpretability_figures subdirectory.
    """
    print("\n4. Matplotlib-based Feature Response Functions")
    print("   Plotting how each feature contributes to predictions...")
    if not hasattr(model, "term_features_") or not hasattr(model, "term_scores_"):
        print("   (Model does not have required attributes for plotting)")
        return

    single_feature_terms = []
    for term_idx, feature_idxs in enumerate(model.term_features_):
        if len(feature_idxs) == 1:
            feature_idx = feature_idxs[0]
            term_name = model.term_names_[term_idx]
            single_feature_terms.append((term_idx, feature_idx, term_name))

    if not single_feature_terms:
        print("   (No single-feature terms found in model)")
        return

    n_cols = 3
    n_features = len(single_feature_terms)
    n_rows = (n_features + n_cols - 1) // n_cols
    fig, axes = plt.subplots(n_rows, n_cols, figsize=(15, 5 * n_rows))
    axes = axes.flatten() if n_features > 1 else [axes]

    for plot_idx, (term_idx, feature_idx, term_name) in enumerate(single_feature_terms):
        if plot_idx >= len(axes):
            break
        ax = axes[plot_idx]
        display_name = feature_name_mapping.get(term_name, term_name)
        component = extract_gam_component(model, term_idx)
        scores = component["scores"]
        x_values = component["x_values"]
        feature_type = component["feature_type"]

        if feature_type in ["nominal", "ordinal"]:
            x_pos = np.arange(len(x_values))
            ax.bar(x_pos, scores, alpha=0.7)
            ax.set_xticks(x_pos)
            ax.set_xticklabels(x_values, rotation=45, ha="right")
            ax.axhline(y=0, color="k", linestyle="--", alpha=0.3)
            ax.set_xlabel(display_name)
            ax.set_ylabel("Contribution to Prediction")
            ax.set_title(f"Feature Response: {display_name}")
            ax.grid(True, alpha=0.3, axis="y")
        else:
            ax.plot(x_values, scores, "o-", linewidth=2, markersize=4)
            ax.axhline(y=0, color="k", linestyle="--", alpha=0.3)
            ax.set_xlabel(display_name)
            ax.set_ylabel("Contribution to Prediction")
            ax.set_title(f"Feature Response: {display_name}")
            ax.grid(True, alpha=0.3)

    for idx in range(len(single_feature_terms), len(axes)):
        axes[idx].set_visible(False)

    plt.tight_layout()
    if output_dir is None:
        output_dir = Path(__file__).parent / "outputs"
    output_dir = Path(output_dir)
    output_dir.mkdir(parents=True, exist_ok=True)
    svg_path = output_dir / "figure_9_feature_responses.svg"
    fig.savefig(svg_path, format="svg", bbox_inches="tight", dpi=300)
    print(f"   Feature response plots saved to: {svg_path}")
    plt.close(fig)


def compute_partial_dependence(
    model, X_df: pd.DataFrame, feature_name: str, num_points: int = 50
):
    """Compute partial dependence for a single feature.

    Parameters
    ----------
    model
        Trained model with predict method
    X_df : pd.DataFrame
        Training data
    feature_name : str
        Name of the feature to compute partial dependence for
    num_points : int
        Number of grid points

    Returns
    -------
    x_values : np.ndarray
        Grid values for the feature
    pd_values : np.ndarray
        Partial dependence values (averaged predictions)
    """
    # Get feature values range
    feature_values = X_df[feature_name].values
    min_val = np.min(feature_values)
    max_val = np.max(feature_values)

    # Create grid of values
    x_values = np.linspace(min_val, max_val, num_points)

    # Compute partial dependence: for each grid value, replace feature in dataset
    # and average predictions
    pd_values = np.zeros(num_points)

    # Create a copy of the dataframe for manipulation
    X_temp = X_df.copy()

    for i, grid_val in enumerate(x_values):
        # Replace the feature column with the grid value
        X_temp[feature_name] = grid_val
        # Get predictions
        predictions = model.predict(X_temp)
        # Average the predictions
        pd_values[i] = np.mean(predictions)

    return x_values, pd_values


def plot_partial_dependence(
    model,
    X_df: pd.DataFrame,
    feature_name_mapping: dict = None,
    num_points: int = 50,
    output_dir: Path = None,
):
    """Plot partial dependence and save as SVG.

    Parameters
    ----------
    model
        Trained EBM model
    X_df : pd.DataFrame
        Features as DataFrame (used for computing partial dependence)
    feature_name_mapping : dict, optional
        Mapping from model feature names to real feature names
    num_points : int, default=50
        Number of grid points for the x-axis in partial dependence plots
    output_dir : Path, optional
        Directory to save SVG files. If None, uses interpretability_figures subdirectory.
    """
    print("\n5. Partial Dependence Plots")
    print("   Computing and plotting marginal effect of each feature on predictions...")

    # Set output directory
    if output_dir is None:
        output_dir = Path(__file__).parent / "outputs"
    output_dir = Path(output_dir)
    output_dir.mkdir(parents=True, exist_ok=True)

    n_features = len(X_df.columns)
    n_cols = 3
    n_rows = (n_features + n_cols - 1) // n_cols
    fig, axes = plt.subplots(n_rows, n_cols, figsize=(15, 5 * n_rows))
    axes = axes.flatten() if n_features > 1 else [axes]

    for idx, feature_name in enumerate(X_df.columns):
        if idx >= len(axes):
            break
        ax = axes[idx]
        display_name = (
            feature_name_mapping.get(feature_name, feature_name)
            if feature_name_mapping
            else feature_name
        )
        x_values, pd_values = compute_partial_dependence(
            model, X_df, feature_name, num_points
        )
        ax.plot(x_values, pd_values, "o-", linewidth=2, markersize=3, alpha=0.7)
        ax.axhline(y=0, color="k", linestyle="--", alpha=0.3)
        ax.set_xlabel(display_name, fontsize=11)
        ax.set_ylabel("Partial Dependence", fontsize=11)
        ax.set_title(f"Partial Dependence: {display_name}", fontsize=12)
        ax.grid(True, alpha=0.3)

    for idx in range(n_features, len(axes)):
        axes[idx].set_visible(False)

    plt.tight_layout()
    svg_path = output_dir / "figure_8_partial_dependence.svg"
    fig.savefig(svg_path, format="svg", bbox_inches="tight", dpi=300)
    print(f"   Partial dependence plots saved to: {svg_path}")
    plt.close(fig)


def extract_interaction_from_ebm(
    model, feature1_name: str, feature2_name: str, feature_name_mapping: dict = None
):
    """Try to extract interaction term data from EBM model's explain_global.

    If the EBM model has learned an interaction between the two features,
    this function extracts the data from the model's internal representation.

    Parameters
    ----------
    model
        Trained EBM model
    feature1_name : str
        Name of the first feature
    feature2_name : str
        Name of the second feature
    feature_name_mapping : dict, optional
        Mapping from model feature names to real feature names

    Returns
    -------
    interaction_data : dict or None
        Dictionary with interaction data if found, None otherwise
    """
    if not hasattr(model, "term_features_") or not hasattr(model, "term_names_"):
        return None

    reverse_mapping = (
        {v: k for k, v in feature_name_mapping.items()} if feature_name_mapping else {}
    )

    feature1_idx = None
    feature2_idx = None
    if hasattr(model, "feature_names_in_"):
        model_feature_names = list(model.feature_names_in_)
        try:
            feature1_idx = model_feature_names.index(feature1_name)
            feature2_idx = model_feature_names.index(feature2_name)
        except ValueError:
            real_name1 = (
                feature_name_mapping.get(feature1_name, feature1_name)
                if feature_name_mapping
                else feature1_name
            )
            real_name2 = (
                feature_name_mapping.get(feature2_name, feature2_name)
                if feature_name_mapping
                else feature2_name
            )
            for idx, name in enumerate(model_feature_names):
                mapped_name = (
                    feature_name_mapping.get(name, name)
                    if feature_name_mapping
                    else name
                )
                if mapped_name == real_name1 or name == feature1_name:
                    feature1_idx = idx
                if mapped_name == real_name2 or name == feature2_name:
                    feature2_idx = idx

    if feature1_idx is None or feature2_idx is None:
        return None

    for term_idx, feature_idxs in enumerate(model.term_features_):
        if len(feature_idxs) == 2 and set(feature_idxs) == {
            feature1_idx,
            feature2_idx,
        }:
            term_scores = model.term_scores_[term_idx]
            term_name = model.term_names_[term_idx]
            bins1 = model.bins_[feature1_idx][0]
            bins2 = model.bins_[feature2_idx][0]
            return {
                "term_scores": term_scores,
                "bins1": bins1,
                "bins2": bins2,
                "feature1_idx": feature1_idx,
                "feature2_idx": feature2_idx,
                "term_name": term_name,
            }

    return None


def compute_2d_partial_dependence(
    model,
    X_df: pd.DataFrame,
    feature1_name: str,
    feature2_name: str,
    feature2_values: list = None,
    num_points: int = 50,
):
    """Compute 2D partial dependence for two features.

    For each value of feature2, compute partial dependence of feature1.
    This is useful for visualizing interactions, especially when feature2 is binary.

    Parameters
    ----------
    model
        Trained model with predict method
    X_df : pd.DataFrame
        Training data
    feature1_name : str
        Name of the first feature (x-axis)
    feature2_name : str
        Name of the second feature (used to condition on)
    feature2_values : list, optional
        Values of feature2 to condition on. If None, uses unique values from data.
    num_points : int
        Number of grid points for feature1

    Returns
    -------
    x_values : np.ndarray
        Grid values for feature1
    pd_dict : dict
        Dictionary mapping feature2 values to partial dependence arrays
    """
    # Get feature1 values range
    feature1_values = X_df[feature1_name].values
    min_val = np.min(feature1_values)
    max_val = np.max(feature1_values)

    # Create grid of values for feature1
    x_values = np.linspace(min_val, max_val, num_points)

    # Get feature2 values to condition on
    if feature2_values is None:
        feature2_values = sorted(X_df[feature2_name].unique())

    # Compute partial dependence for each value of feature2
    pd_dict = {}

    for feature2_val in feature2_values:
        pd_values = np.zeros(num_points)
        X_temp = X_df.copy()

        # Set feature2 to the conditioning value
        X_temp[feature2_name] = feature2_val

        for i, grid_val in enumerate(x_values):
            # Replace feature1 column with the grid value
            X_temp[feature1_name] = grid_val
            # Get predictions
            predictions = model.predict(X_temp)
            # Average the predictions
            pd_values[i] = np.mean(predictions)

        pd_dict[feature2_val] = pd_values

    return x_values, pd_dict


def plot_2d_partial_dependence(
    model,
    X_df: pd.DataFrame,
    feature1_name: str,
    feature2_name: str,
    feature_name_mapping: dict = None,
    feature2_values: list = None,
    num_points: int = 50,
    output_dir: Path = None,
    figure_name: str = None,
):
    """Plot 2D partial dependence: feature1 vs prediction, one line per feature2 value.

    Parameters
    ----------
    model
        Trained EBM model
    X_df : pd.DataFrame
        Features as DataFrame (used for computing partial dependence)
    feature1_name : str
        Name of the first feature (x-axis)
    feature2_name : str
        Name of the second feature (used to condition on, creates separate lines)
    feature_name_mapping : dict, optional
        Mapping from model feature names to real feature names
    feature2_values : list, optional
        Values of feature2 to condition on. If None, uses unique values from data.
    num_points : int, default=50
        Number of grid points for feature1
    output_dir : Path, optional
        Directory to save SVG file. If None, uses interpretability_figures subdirectory.
    figure_name : str, optional
        Name for the output SVG file. If None, auto-generates from feature names.
    """
    print(f"\n6. 2D Partial Dependence Plot: {feature1_name} × {feature2_name}")
    print(
        f"   Plotting partial dependence of {feature1_name} conditioned on {feature2_name}..."
    )

    # Set output directory
    if output_dir is None:
        output_dir = Path(__file__).parent / "outputs"
    output_dir = Path(output_dir)
    output_dir.mkdir(parents=True, exist_ok=True)

    display_name1 = (
        feature_name_mapping.get(feature1_name, feature1_name)
        if feature_name_mapping
        else feature1_name
    )
    display_name2 = (
        feature_name_mapping.get(feature2_name, feature2_name)
        if feature_name_mapping
        else feature2_name
    )
    if feature2_values is None:
        feature2_values = sorted(X_df[feature2_name].unique())

    x_values, pd_dict = compute_2d_partial_dependence(
        model, X_df, feature1_name, feature2_name, feature2_values, num_points
    )

    fig, ax = plt.subplots(1, 1, figsize=(10, 6))
    colors = plt.cm.tab10(np.linspace(0, 1, len(pd_dict)))
    for idx, (feature2_val, pd_values) in enumerate(sorted(pd_dict.items())):
        label = f"{display_name2} = {feature2_val}"
        ax.plot(
            x_values,
            pd_values,
            "o-",
            linewidth=2,
            markersize=4,
            alpha=0.7,
            color=colors[idx],
            label=label,
        )
    ax.axhline(y=0, color="k", linestyle="--", alpha=0.3)
    ax.set_xlabel(display_name1, fontsize=12)
    ax.set_ylabel("Partial Dependence", fontsize=12)
    ax.set_title(f"Partial Dependence: {display_name1} × {display_name2}", fontsize=13)
    ax.grid(True, alpha=0.3)
    ax.legend()
    plt.tight_layout()

    if figure_name is None:
        figure_name = f"figure_10_2d_pd_{feature1_name}_{feature2_name}.svg"
    svg_path = output_dir / figure_name
    file_format = Path(svg_path).suffix[1:].lower() if Path(svg_path).suffix else "svg"
    fig.savefig(svg_path, format=file_format, bbox_inches="tight", dpi=300)
    print(f"   2D partial dependence plot saved to: {svg_path}")
    plt.close(fig)


# Reproducibility paths (same layout as bike_analysis)
_SCRIPT_DIR = Path(__file__).parent
DATA_DIR = _SCRIPT_DIR / "data" / "bike_sharing"
MODELS_DIR = _SCRIPT_DIR / "models" / "bike_sharing"
OUTPUTS_DIR = _SCRIPT_DIR / "figures" / "bike_sharing"


def main():
    """Main function to orchestrate model loading, prediction, and visualization."""
    outputs_dir = OUTPUTS_DIR
    outputs_dir.mkdir(parents=True, exist_ok=True)

    model_path = MODELS_DIR / "ebm_model.pkl"
    csv_path = DATA_DIR / "42712_Bike_Sharing_Demand.csv"

    # Load model
    model = load_model(model_path)

    # Check if model expects DataFrame
    if hasattr(model, "feature_names_in_"):
        print(f"\nModel expects {len(model.feature_names_in_)} features")
        print(f"Model feature names: {model.feature_names_in_}")

    # Load data (pass model to get correct feature names)
    X_df, X, y, feature_names, feature_name_mapping = load_data(csv_path, model=model)

    # Make predictions
    predictions = make_predictions(model, X_df, X)

    # Calculate and print metrics
    calculate_metrics(y, predictions)

    # Plot predictions vs true values
    print("\n" + "=" * 70)
    print("Plotting predictions vs true values...")
    print("=" * 70)
    predictions_clamped = np.maximum(
        0, predictions
    )  # Clamp predictions to be non-negative
    output_dir = outputs_dir
    plot_predictions_vs_true(
        y,
        predictions,
        predictions_clamped=predictions_clamped,
        title_prefix="EBM",
        output_dir=output_dir,
    )

    # Visualize the ExplainableBoostingRegressor model
    print("\n" + "=" * 70)
    print("Generating EBM visualizations...")
    print("=" * 70)

    # visualize_global_explanation(model, feature_name_mapping, output_dir=output_dir)  # Disabled - generates HTML instead of SVG
    # visualize_local_explanation(
    #     model, X_df, predictions, n_examples=5, output_dir=output_dir
    # )  # Disabled - generates HTML instead of SVG
    print_feature_importance(model, feature_names, feature_name_mapping)
    plot_feature_responses(model, X_df, feature_name_mapping, output_dir=output_dir)
    plot_partial_dependence(
        model, X_df, feature_name_mapping, num_points=50, output_dir=output_dir
    )

    # Plot 2D partial dependence for hour × workingday interaction
    print("\n" + "=" * 70)
    print("Plotting 2D partial dependence: hour × workingday...")
    print("=" * 70)

    # Find the correct feature names (could be model names or real names)
    hour_feature = None
    workingday_feature = None

    # Try to find by real names first
    for model_name, real_name in feature_name_mapping.items():
        if real_name.lower() == "hour":
            hour_feature = model_name
        elif real_name.lower() == "workingday":
            workingday_feature = model_name

    # If not found, try direct lookup
    if hour_feature is None:
        for col in X_df.columns:
            if "hour" in str(col).lower():
                hour_feature = col
                break

    if workingday_feature is None:
        for col in X_df.columns:
            if "workingday" in str(col).lower() or "working" in str(col).lower():
                workingday_feature = col
                break

    if hour_feature is not None and workingday_feature is not None:
        plot_2d_partial_dependence(
            model,
            X_df,
            hour_feature,
            workingday_feature,
            feature_name_mapping=feature_name_mapping,
            feature2_values=[0, 1],  # Binary: workingday = 0 or 1
            num_points=50,
            output_dir=output_dir,
            figure_name="figure_10_2d_pd_hour_workingday.svg",
        )
    else:
        print("   ⚠ Could not find hour and/or workingday features")
        print(f"   Available features: {list(X_df.columns)}")
        print(f"   Feature mapping: {feature_name_mapping}")

    print("\n" + "=" * 70)
    print("Visualization complete!")
    print("=" * 70)


if __name__ == "__main__":
    main()
