"""
Plot MPF, EBM, and XGBoost: 1D/2D PD, ICE for x1, scaled first-order PD (MPF).
"""

from __future__ import annotations

import sys
from pathlib import Path

import joblib
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from interpret.glassbox import ExplainableBoostingRegressor
from mpf_py import MPF
import xgboost as xgb

# Enable LaTeX rendering in matplotlib (mathtext)
plt.rcParams["mathtext.default"] = "regular"

# local imports
sys.path.insert(0, str(Path(__file__).parent))

from data_generation import make_dataset


class Config:
    seed: int = 42
    n: int = 10000
    noise_std: float = 0.25

    # Plotting
    grid_points_1d: int = 200
    grid_points_2d: int = 120
    n_ice: int = 150
    n_background: int = 2000
    grid_points_scaled_pd: int = 200


# ============================================================================
# Model loading functions
# ============================================================================


def load_mpf_model(model_path: Path) -> MPF:
    """Load MPF model from file"""
    print(f"Loading MPF model from {model_path}...")
    model = MPF.load(str(model_path))
    print("MPF loaded.")
    return model


def load_ebm_model(model_path: Path) -> ExplainableBoostingRegressor:
    """Load EBM model from file"""
    print(f"Loading EBM model from {model_path}...")
    model = joblib.load(model_path)
    print("EBM loaded.")
    return model


def load_xgboost_model(model_path: Path) -> xgb.XGBRegressor:
    """Load XGBoost model from file"""
    print(f"Loading XGBoost model from {model_path}...")
    model = xgb.XGBRegressor()
    model.load_model(str(model_path))
    print("XGBoost loaded.")
    return model


# ============================================================================
# Partial Dependence computation functions
# ============================================================================


def _mpf_pd_1d(
    model: MPF, X_background: np.ndarray, feat_idx: int, x_grid: np.ndarray
) -> np.ndarray:
    """
    MPF first-order PD via model-native routine (returns stagewise f+/f-).
    We sum stage contributions to obtain the model PD curve.
    """
    X_mean = X_background.mean(axis=0)
    X_grid = np.tile(X_mean, (x_grid.shape[0], 1))
    X_grid[:, feat_idx] = x_grid

    # Returns dict: feat_idx -> (constants_per_epoch, pd_values)
    first_order_pd = model.compute_first_order_partial_dependence_functions(
        X_grid, X_background
    )
    constants_per_epoch, pd_values = first_order_pd[feat_idx]

    # pd_values has columns [f+_epoch0, f-_epoch0, f+_epoch1, f-_epoch1, ...]
    f_plus = pd_values[:, ::2]
    f_minus = pd_values[:, 1::2]
    # Per existing interpretability scripts: total = f+ + f- (f- is already signed)
    pd_total = (f_plus + f_minus).sum(axis=1)
    return pd_total


def _mpf_pd_2d(
    model: MPF, X_background: np.ndarray, x1_vals: np.ndarray, x2_vals: np.ndarray
) -> np.ndarray:
    """
    MPF 2D PD via compute_partial_dependence_function.
    """
    X_mean = X_background.mean(axis=0)
    X_grid = np.tile(X_mean, (x1_vals.size * x2_vals.size, 1))
    X1, X2 = np.meshgrid(x1_vals, x2_vals)
    X_grid[:, 0] = X1.ravel()
    X_grid[:, 1] = X2.ravel()

    res = model.compute_partial_dependence_function(
        [0, 1], X_grid[:, [0, 1]], X_background
    )
    if isinstance(res, tuple) and len(res) == 2:
        _, pd_values = res
    else:
        pd_values = res

    # pd_values has columns [f+_epoch0, f-_epoch0, ...]
    f_plus = pd_values[:, ::2]
    f_minus = pd_values[:, 1::2]
    pd_total = (f_plus + f_minus).sum(axis=1)
    return pd_total.reshape(X2.shape)


def _standard_pd_1d(
    model_predict, X_background: np.ndarray, feat_idx: int, x_grid: np.ndarray
) -> np.ndarray:
    """
    Standard PD estimate: replace column feat_idx with fixed value, average predictions.
    Uses the empirical marginal of other features (as in classic PD).
    """
    preds = np.zeros_like(x_grid, dtype=np.float64)
    Xb = X_background.copy()
    for t, a in enumerate(x_grid):
        Xb[:, feat_idx] = a
        preds[t] = model_predict(Xb).mean()
    return preds


def _standard_pd_2d(
    model_predict, X_background: np.ndarray, x1_vals: np.ndarray, x2_vals: np.ndarray
) -> np.ndarray:
    """Standard 2D PD computation"""
    X1, X2 = np.meshgrid(x1_vals, x2_vals)
    Z = np.zeros_like(X1, dtype=np.float64)
    Xb = X_background.copy()
    for i in range(X2.shape[0]):
        for j in range(X1.shape[1]):
            Xb[:, 0] = X1[i, j]
            Xb[:, 1] = X2[i, j]
            Z[i, j] = model_predict(Xb).mean()
    return Z


def _ice_1d(
    model_predict,
    X_ref: np.ndarray,
    feat_idx: int,
    x_grid: np.ndarray,
    n_ice: int,
    seed: int,
) -> np.ndarray:
    """Compute ICE curves for 1D feature"""
    rng = np.random.default_rng(seed)
    idx = rng.choice(X_ref.shape[0], size=min(n_ice, X_ref.shape[0]), replace=False)
    X_sel = X_ref[idx]

    # Build a big batch: (n_ice * grid_points, p)
    p = X_ref.shape[1]
    batch = np.repeat(X_sel, repeats=x_grid.shape[0], axis=0).reshape(-1, p)
    batch[:, feat_idx] = np.tile(x_grid, reps=X_sel.shape[0])
    preds = model_predict(batch).reshape(X_sel.shape[0], x_grid.shape[0])
    return preds


# ============================================================================
# Plotting functions
# ============================================================================


def _plot_pd1_x1_flatness(
    out_dir: Path,
    x_grid: np.ndarray,
    pd1: np.ndarray,
    title: str,
    model_name: str,
) -> None:
    """Plot 1D PD for x1 (legacy function, kept for backward compatibility)"""
    fig, ax = plt.subplots(figsize=(4, 4))
    ax.plot(x_grid, pd1, linewidth=2)
    ax.axhline(0.0, color="black", linestyle="--", linewidth=0.8, alpha=0.6)
    ax.set_title(title)
    ax.set_xlabel("x1")
    ax.set_ylabel("PD_1(x1)")
    ax.grid(True, alpha=0.3)
    fig.tight_layout()
    save_path = out_dir / f"pd_1_x1_{model_name}.pdf"
    file_format = (
        Path(save_path).suffix[1:].lower() if Path(save_path).suffix else "pdf"
    )
    fig.savefig(save_path, format=file_format, bbox_inches="tight")
    plt.close(fig)


def _plot_combined_pd1_x1(
    out_dir: Path,
    x_grid: np.ndarray,
    pd1_mpf: np.ndarray,
    pd1_ebm: np.ndarray,
    pd1_xgb: np.ndarray,
) -> None:
    """Plot combined 1D PD for x1 with all three models in one square plot"""
    fig, ax = plt.subplots(figsize=(4, 4))
    ax.plot(x_grid, pd1_mpf, linewidth=2, color="tab:blue", label="TSL")
    ax.plot(x_grid, pd1_ebm, linewidth=2, color="tab:orange", label="EBM")
    ax.plot(x_grid, pd1_xgb, linewidth=2, color="tab:green", label="XGBoost")
    ax.axhline(0.0, color="black", linestyle="--", linewidth=0.8, alpha=0.6)
    ax.set_xlabel(r"$x_1$", fontsize=12)
    ax.set_ylabel(r"$\mathrm{PD}_1(x_1)$", fontsize=12)
    ax.grid(True, alpha=0.3)
    ax.legend(loc="best", fontsize=10)
    fig.tight_layout()
    save_path = out_dir / "pd_1_x1_combined.pdf"
    file_format = (
        Path(save_path).suffix[1:].lower() if Path(save_path).suffix else "pdf"
    )
    fig.savefig(save_path, format=file_format, bbox_inches="tight")
    plt.close(fig)


def _plot_pd12_surface(
    out_dir: Path,
    x1_vals: np.ndarray,
    x2_vals: np.ndarray,
    z: np.ndarray,
    title: str,
    model_name: str,
) -> None:
    """Plot 2D PD surface for (x1, x2)"""
    fig, ax = plt.subplots(figsize=(7, 5))
    im = ax.contourf(x1_vals, x2_vals, z, levels=30, cmap="viridis")
    fig.colorbar(im, ax=ax, shrink=0.85, label="PD_{12}(x1,x2)")
    ax.set_title(title)
    ax.set_xlabel("x1")
    ax.set_ylabel("x2")
    fig.tight_layout()
    save_path = out_dir / f"pd_12_x1_x2_{model_name}.pdf"
    file_format = (
        Path(save_path).suffix[1:].lower() if Path(save_path).suffix else "pdf"
    )
    fig.savefig(save_path, format=file_format, bbox_inches="tight")
    plt.close(fig)


def _plot_ice_x1(
    out_dir: Path,
    x_grid: np.ndarray,
    ice: np.ndarray,
    pd: np.ndarray,
    title: str,
    model_name: str,
) -> None:
    """Plot ICE curves for x1 with PDP overlay"""
    fig, ax = plt.subplots(figsize=(7, 4))
    # ICE curves
    for k in range(ice.shape[0]):
        ax.plot(x_grid, ice[k], color="tab:blue", alpha=0.08, linewidth=1)
    # overlay PD
    ax.plot(x_grid, pd, color="black", linewidth=2.5, label="PDP")
    ax.axhline(0.0, color="black", linestyle="--", linewidth=0.8, alpha=0.5)
    ax.set_title(title)
    ax.set_xlabel(r"$x_1$", fontsize=12)
    ax.set_ylabel("prediction", fontsize=12)
    ax.grid(True, alpha=0.3)
    ax.legend(loc="best")
    fig.tight_layout()
    save_path = out_dir / f"ice_x1_{model_name}.pdf"
    file_format = (
        Path(save_path).suffix[1:].lower() if Path(save_path).suffix else "pdf"
    )
    fig.savefig(save_path, format=file_format, bbox_inches="tight")
    plt.close(fig)


def _plot_scaled_first_order_pd(
    out_dir: Path,
    model: MPF,
    X_background: np.ndarray,
    grid_points: int,
) -> None:
    """
    Plot scaled first-order PD for stage 1 only, creating 3 separate square plots
    (one for each feature x1, x2, x3).
    MPF only.
    """
    feature_names = ["x1", "x2", "x3"]
    feature_indices = [0, 1, 2]
    n_features = len(feature_indices)
    epoch_idx = 0  # Only stage 1

    # Build combined grid: for each feature, vary it across its range and keep others at mean.
    X_mean = X_background.mean(axis=0)
    feature_grids: list[np.ndarray] = []
    X_grid_list: list[np.ndarray] = []
    for feat_idx in feature_indices:
        feat_min = float(X_background[:, feat_idx].min())
        feat_max = float(X_background[:, feat_idx].max())
        x_vals = np.linspace(feat_min, feat_max, grid_points)
        feature_grids.append(x_vals)

        X_grid_feat = np.tile(X_mean, (grid_points, 1))
        X_grid_feat[:, feat_idx] = x_vals
        X_grid_list.append(X_grid_feat)

    X_grid_combined = np.vstack(X_grid_list)

    # Compute first-order PD functions (stagewise f+/f- values)
    first_order_pd = model.compute_first_order_partial_dependence_functions(
        X_grid_combined, X_background
    )

    # Create separate square plots for each feature
    for plot_idx, feat_idx in enumerate(feature_indices):
        constants_per_epoch, pd_values = first_order_pd[feat_idx]

        start_idx = plot_idx * grid_points
        end_idx = (plot_idx + 1) * grid_points
        pd_values_feat = pd_values[start_idx:end_idx, :]

        f_plus_all = pd_values_feat[:, ::2]
        f_minus_all = pd_values_feat[:, 1::2]
        x_vals = feature_grids[plot_idx]

        c_plus, c_minus = constants_per_epoch[epoch_idx]

        f_plus = f_plus_all[:, epoch_idx]
        f_minus = f_minus_all[:, epoch_idx]
        f_minus_flipped = -f_minus

        diff = f_plus - f_minus_flipped

        # Create square figure
        fig, ax = plt.subplots(figsize=(4, 4))

        ax.fill_between(
            x_vals,
            f_minus_flipped,
            f_plus,
            where=(diff >= 0),
            color="green",
            alpha=0.3,
        )
        ax.fill_between(
            x_vals,
            f_minus_flipped,
            f_plus,
            where=(diff < 0),
            color="orange",
            alpha=0.3,
        )

        j = feat_idx + 1  # feature index for notation PD_{+,j}, PD_{-,j}, PD_{\pm,j}
        ax.plot(
            x_vals,
            f_plus,
            linewidth=1.5,
            color="darkred",
            alpha=0.9,
            label=rf"$\mathrm{{PD}}_{{+,{j}}}$",
        )
        ax.plot(
            x_vals,
            f_minus_flipped,
            linewidth=1.5,
            color="darkblue",
            alpha=0.9,
            label=rf"$\mathrm{{PD}}_{{-,{j}}}$",
        )

        ax.axhline(y=0, color="black", linestyle="--", linewidth=0.5, alpha=0.5)
        # Use LaTeX for feature names with proper subscripts
        feature_label = f"$x_{feat_idx + 1}$"
        ax.set_xlabel(feature_label, fontsize=12)
        ax.set_ylabel(rf"$\mathrm{{PD}}_{{\pm,{j}}}$", fontsize=12)

        c_plus_str = f"{c_plus:.4f}" if abs(c_plus) < 1000 else f"{c_plus:.2e}"
        c_minus_str = f"{c_minus:.4f}" if abs(c_minus) < 1000 else f"{c_minus:.2e}"
        # Title with only constants, not bold
        ax.set_title(
            f"$C_+={c_plus_str}$, $C_-={c_minus_str}$",
            fontsize=12,
        )
        ax.grid(True, alpha=0.3)
        ax.legend(loc="best", fontsize=10)

        plt.tight_layout()
        save_path = (
            out_dir / f"figure_3_1_scaled_first_order_pd_{feature_names[feat_idx]}.pdf"
        )
        file_format = (
            Path(save_path).suffix[1:].lower() if Path(save_path).suffix else "pdf"
        )
        fig.savefig(save_path, format=file_format, bbox_inches="tight")
        plt.close(fig)


def plot_scaled_first_order_pd_mpf_both_stages(
    out_dir: Path,
    model: MPF,
    X_background: np.ndarray,
    grid_points: int,
    save_name: str = "figure_3_1_scaled_first_order_pd_mpf.pdf",
) -> None:
    """
    Plot scaled first-order PD for MPF with both stages in one figure (2 rows × 3 cols).
    Rows = stage 1, stage 2; cols = x1, x2, x3. Saves a single PDF.
    """
    feature_names = ["x1", "x2", "x3"]
    feature_indices = [0, 1, 2]
    n_features = len(feature_indices)

    X_mean = X_background.mean(axis=0)
    feature_grids: list[np.ndarray] = []
    X_grid_list: list[np.ndarray] = []
    for feat_idx in feature_indices:
        feat_min = float(X_background[:, feat_idx].min())
        feat_max = float(X_background[:, feat_idx].max())
        x_vals = np.linspace(feat_min, feat_max, grid_points)
        feature_grids.append(x_vals)
        X_grid_feat = np.tile(X_mean, (grid_points, 1))
        X_grid_feat[:, feat_idx] = x_vals
        X_grid_list.append(X_grid_feat)

    X_grid_combined = np.vstack(X_grid_list)
    first_order_pd = model.compute_first_order_partial_dependence_functions(
        X_grid_combined, X_background
    )

    # Infer number of stages from first feature
    constants_per_epoch_0, pd_values_0 = first_order_pd[0]
    n_stages = len(constants_per_epoch_0)

    fig, axes = plt.subplots(n_stages, n_features, figsize=(4 * n_features, 4 * n_stages))

    if n_stages == 1:
        axes = axes[np.newaxis, :]
    if n_features == 1:
        axes = axes[:, np.newaxis]

    for stage_idx in range(n_stages):
        for plot_idx, feat_idx in enumerate(feature_indices):
            ax = axes[stage_idx, plot_idx]
            constants_per_epoch, pd_values = first_order_pd[feat_idx]
            start_idx = plot_idx * grid_points
            end_idx = (plot_idx + 1) * grid_points
            pd_values_feat = pd_values[start_idx:end_idx, :]
            f_plus_all = pd_values_feat[:, ::2]
            f_minus_all = pd_values_feat[:, 1::2]
            x_vals = feature_grids[plot_idx]

            c_plus, c_minus = constants_per_epoch[stage_idx]
            f_plus = f_plus_all[:, stage_idx]
            f_minus = f_minus_all[:, stage_idx]
            f_minus_flipped = -f_minus
            diff = f_plus - f_minus_flipped

            ax.fill_between(
                x_vals,
                f_minus_flipped,
                f_plus,
                where=(diff >= 0),
                color="green",
                alpha=0.3,
            )
            ax.fill_between(
                x_vals,
                f_minus_flipped,
                f_plus,
                where=(diff < 0),
                color="orange",
                alpha=0.3,
            )
            j = feat_idx + 1
            ax.plot(
                x_vals,
                f_plus,
                linewidth=1.5,
                color="darkred",
                alpha=0.9,
                label=rf"$\mathrm{{PD}}_{{+,{j}}}$",
            )
            ax.plot(
                x_vals,
                f_minus_flipped,
                linewidth=1.5,
                color="darkblue",
                alpha=0.9,
                label=rf"$\mathrm{{PD}}_{{-,{j}}}$",
            )
            ax.axhline(y=0, color="black", linestyle="--", linewidth=0.5, alpha=0.5)
            ax.set_xlabel(f"$x_{feat_idx + 1}$", fontsize=12)
            ax.set_ylabel(rf"$\mathrm{{PD}}_{{\pm,{j}}}$", fontsize=12)
            c_plus_str = f"{c_plus:.4f}" if abs(c_plus) < 1000 else f"{c_plus:.2e}"
            c_minus_str = f"{c_minus:.4f}" if abs(c_minus) < 1000 else f"{c_minus:.2e}"
            ax.set_title(
                f"Stage {stage_idx + 1}: $C_+={c_plus_str}$, $C_-={c_minus_str}$",
                fontsize=11,
            )
            ax.grid(True, alpha=0.3)
            ax.legend(loc="best", fontsize=9)

    plt.tight_layout()
    save_path = out_dir / save_name
    file_format = (
        Path(save_path).suffix[1:].lower() if Path(save_path).suffix else "pdf"
    )
    fig.savefig(save_path, format=file_format, bbox_inches="tight")
    plt.close(fig)


# ============================================================================
# Main plotting functions for each model type
# ============================================================================


def plot_mpf_model(
    model: MPF,
    X_train: np.ndarray,
    X_val: np.ndarray,
    out_dir: Path,
    cfg: Config,
) -> tuple[np.ndarray, np.ndarray]:
    """Plot all figures for MPF model"""
    print("\n" + "=" * 80)
    print("Plotting MPF Model")
    print("=" * 80)

    # 1D PD for x1 (for combined plot later)
    x1_grid = np.linspace(X_train[:, 0].min(), X_train[:, 0].max(), cfg.grid_points_1d)
    pd1 = _mpf_pd_1d(model, X_train, feat_idx=0, x_grid=x1_grid)

    # ICE for x1 (overlay PDP)
    ice = _ice_1d(
        model.predict,
        X_val,
        feat_idx=0,
        x_grid=x1_grid,
        n_ice=cfg.n_ice,
        seed=cfg.seed,
    )
    _plot_ice_x1(
        out_dir,
        x1_grid,
        ice,
        pd1,
        title="MPF: ICE for x1 (heterogeneous) + PDP (flat)",
        model_name="mpf",
    )

    # Scaled first-order PD (MPF only) - stage 1 only, 3 separate plots
    _plot_scaled_first_order_pd(
        out_dir=out_dir,
        model=model,
        X_background=X_train,
        grid_points=cfg.grid_points_scaled_pd,
    )

    print(f"MPF plots saved to: {out_dir}")

    # Return PD values for combined plot
    return x1_grid, pd1


def plot_ebm_model(
    model: ExplainableBoostingRegressor,
    X_train: np.ndarray,
    X_val: np.ndarray,
    out_dir: Path,
    cfg: Config,
) -> tuple[np.ndarray, np.ndarray]:
    """Plot all figures for EBM model"""
    print("\n" + "=" * 80)
    print("Plotting EBM Model")
    print("=" * 80)

    # EBM needs DataFrames
    feature_names = ["x1", "x2", "x3"]
    X_df_train = pd.DataFrame(X_train, columns=feature_names)
    X_df_val = pd.DataFrame(X_val, columns=feature_names)

    # Background set for PD
    rng = np.random.default_rng(cfg.seed)
    bg_idx = rng.choice(
        X_train.shape[0], size=min(cfg.n_background, X_train.shape[0]), replace=False
    )
    X_bg = X_df_train.iloc[bg_idx].values

    def ebm_predict(X: np.ndarray) -> np.ndarray:
        """Wrapper to predict from numpy array"""
        X_df = pd.DataFrame(X, columns=feature_names)
        return model.predict(X_df)

    # 1D PD for x1 (for combined plot later)
    x1_grid = np.linspace(X_train[:, 0].min(), X_train[:, 0].max(), cfg.grid_points_1d)
    pd1 = _standard_pd_1d(ebm_predict, X_bg, feat_idx=0, x_grid=x1_grid)

    # ICE for x1 (overlay PDP)
    ice = _ice_1d(
        ebm_predict,
        X_val,
        feat_idx=0,
        x_grid=x1_grid,
        n_ice=cfg.n_ice,
        seed=cfg.seed,
    )
    _plot_ice_x1(
        out_dir,
        x1_grid,
        ice,
        pd1,
        title="EBM: ICE for x1 (heterogeneous) + PDP (flat)",
        model_name="ebm",
    )

    print(f"EBM plots saved to: {out_dir}")

    # Return PD values for combined plot
    return x1_grid, pd1


def plot_xgboost_model(
    model: xgb.XGBRegressor,
    X_train: np.ndarray,
    X_val: np.ndarray,
    out_dir: Path,
    cfg: Config,
) -> tuple[np.ndarray, np.ndarray]:
    """Plot all figures for XGBoost model"""
    print("\n" + "=" * 80)
    print("Plotting XGBoost Model")
    print("=" * 80)

    # Background set for PD
    rng = np.random.default_rng(cfg.seed)
    bg_idx = rng.choice(
        X_train.shape[0], size=min(cfg.n_background, X_train.shape[0]), replace=False
    )
    X_bg = X_train[bg_idx].copy()

    # 1D PD for x1 (for combined plot later)
    x1_grid = np.linspace(X_train[:, 0].min(), X_train[:, 0].max(), cfg.grid_points_1d)
    pd1 = _standard_pd_1d(model.predict, X_bg, feat_idx=0, x_grid=x1_grid)

    # ICE for x1 (overlay PDP)
    ice = _ice_1d(
        model.predict,
        X_val,
        feat_idx=0,
        x_grid=x1_grid,
        n_ice=cfg.n_ice,
        seed=cfg.seed,
    )
    _plot_ice_x1(
        out_dir,
        x1_grid,
        ice,
        pd1,
        title="XGBoost: ICE for x1 (heterogeneous) + PDP (flat)",
        model_name="xgboost",
    )

    print(f"XGBoost plots saved to: {out_dir}")

    # Return PD values for combined plot
    return x1_grid, pd1


def main() -> None:
    cfg = Config()
    base_dir = Path(__file__).parent
    models_dir = base_dir / "models" / "synthetic_pd"
    outputs_dir = base_dir / "figures" / "synthetic_pd"
    outputs_dir.mkdir(parents=True, exist_ok=True)

    # Define model paths
    mpf_model_path = models_dir / "mpf" / "model.bin"
    ebm_model_path = models_dir / "ebm" / "model.pkl"
    xgboost_model_path = models_dir / "xgboost" / "model.json"

    # Generate data (same as training)
    print("Generating data...")
    X_train, y_train = make_dataset(
        n=cfg.n,
        seed=cfg.seed,
        noise_std=cfg.noise_std,
        std_x1=1.0,
        std_x2=1.5,
        std_x3=0.8,
    )

    X_val, y_val = make_dataset(
        n=5 * cfg.n,
        seed=cfg.seed + 1,
        noise_std=cfg.noise_std,
        std_x1=1.0,
        std_x2=1.5,
        std_x3=0.8,
    )

    # Load models
    mpf_model = load_mpf_model(mpf_model_path)
    ebm_model = load_ebm_model(ebm_model_path)
    xgboost_model = load_xgboost_model(xgboost_model_path)

    # Plot each model (all outputs go to single outputs/ directory)
    plot_mpf_model(mpf_model, X_train, X_val, outputs_dir, cfg)
    plot_ebm_model(ebm_model, X_train, X_val, outputs_dir, cfg)
    plot_xgboost_model(xgboost_model, X_train, X_val, outputs_dir, cfg)

    print("\n" + "=" * 80)
    print("All plots generated.")
    print("=" * 80)


if __name__ == "__main__":
    main()
