#!/usr/bin/env python3
"""
Reproduce All Paper Figures
===========================

Single entry point to regenerate all figures. Scripts and paths are under this
folder for a flat, self-contained reproducibility setup.

Usage:
    cd reproducibility
    # Copy data and models from simulations/ (see README.md), or use existing:
    python run_all_figures.py

Output:
    figures/
      bike_sharing/
      california/
      synthetic_pd/
"""

import sys
from pathlib import Path

# All scripts live in this directory
REPRO_DIR = Path(__file__).resolve().parent
FIGURES_DIR = REPRO_DIR / "figures"
sys.path.insert(0, str(REPRO_DIR))


def generate_bike_sharing_figures():
    """Generate bike sharing figures."""
    print("\n" + "=" * 60)
    print("BIKE SHARING")
    print("=" * 60)

    out_dir = FIGURES_DIR / "bike_sharing"
    out_dir.mkdir(parents=True, exist_ok=True)

    from bike_analysis import (
        load_model_and_data as load_mpf,
        plot_figure_3_2_scaled_first_order_pd,
        plot_figure_3_3_2d_pd_hour_workingday,
    )
    from bike_ebm_analysis import (
        load_data as load_ebm_data,
        load_model as load_ebm,
        plot_2d_partial_dependence,
    )

    print("\nLoading MPF model...")
    model, X, y, X_df = load_mpf()

    print("\nGenerating Figure 3.2 (Scaled first-order PD)...")
    plot_figure_3_2_scaled_first_order_pd(
        model, X, X_df, save_path=out_dir / "figure_3_2_scaled_first_order_pd.pdf"
    )

    print("\nGenerating Figure 3.3 (MPF 2D PD)...")
    plot_figure_3_3_2d_pd_hour_workingday(
        model, X, X_df, save_path=out_dir / "figure_3_3_2d_pd_hour_workingday.pdf"
    )

    print("\nGenerating Figure 10 (EBM 2D PD)...")
    bike_data = REPRO_DIR / "data" / "bike_sharing"
    bike_models = REPRO_DIR / "models" / "bike_sharing"
    ebm_model = load_ebm(bike_models / "ebm_model.pkl")
    X_df_ebm, _, _, _, feature_mapping = load_ebm_data(
        bike_data / "42712_Bike_Sharing_Demand.csv", model=ebm_model
    )
    hour_feat = workingday_feat = None
    for col in X_df_ebm.columns:
        if "hour" in str(col).lower():
            hour_feat = col
        elif "workingday" in str(col).lower():
            workingday_feat = col
    if hour_feat and workingday_feat:
        plot_2d_partial_dependence(
            ebm_model,
            X_df_ebm,
            hour_feat,
            workingday_feat,
            feature_name_mapping=feature_mapping,
            feature2_values=[0, 1],
            num_points=50,
            output_dir=out_dir,
            figure_name="figure_10_2d_pd_hour_workingday.pdf",
        )

    print("Bike sharing figures done.")


def generate_california_figures():
    """Generate California housing figures."""
    print("\n" + "=" * 60)
    print("CALIFORNIA HOUSING")
    print("=" * 60)

    out_dir = FIGURES_DIR / "california"
    out_dir.mkdir(parents=True, exist_ok=True)

    import cali_analysis

    print("\nGenerating Figure 4 (Spatial Backbone, interpretable)...")
    cali_analysis.USE_BLACKBOX = False
    cali_analysis.MODEL_FILENAME = "mpf_interpretable.bin"
    cali_analysis.FIGURES_DIR = out_dir

    model_interp, X, y, X_df = cali_analysis.load_model_and_data()
    cali_analysis.plot_figure_4_spatial_backbone_evolution(
        model_interp,
        X,
        X_df,
        epochs=range(len(model_interp.tree_grid_families)),
        save_path=out_dir / "figure_4_spatial_backbone_evolution.pdf",
        save_each_plot_dir=out_dir,
    )

    print("\nGenerating Figure 3.4 (Scaled First-Order PD, interpretable)...")
    cali_analysis.plot_figure_3_4_scaled_first_order_pd(
        model_interp,
        X,
        X_df,
        save_path=out_dir / "figure_3_4_scaled_first_order_pd_interpretable.pdf",
    )

    print("\nGenerating Figure 3.5 (PD Comparison)...")
    # Commented out because it takes a lot of RAM
    # cali_analysis.plot_figure_3_5_pd_comparison(
    #     model_interp,
    #     X,
    #     X_df,
    #     save_path=out_dir / "figure_3_5_pd_comparison.pdf",
    # )

    print("\nGenerating Figure 6 (Feature Importance, interpretable)...")
    cali_analysis.compute_and_plot_feature_importance(
        model_interp,
        X,
        list(X_df.columns),
        gamma=1.0,
        save_path=out_dir / "figure_6_feature_importance_interpretable.pdf",
    )

    print("\nGenerating Figure 5 (Local Explanations, blackbox)...")
    cali_analysis.USE_BLACKBOX = True
    cali_analysis.MODEL_FILENAME = "mpf_blackbox.bin"
    model_blackbox, X, y, X_df = cali_analysis.load_model_and_data()

    import pandas as pd
    df = pd.read_csv(
        cali_analysis.DATA_DIR / "44977_california_housing.csv", header=None
    )
    desert_row_idx = 2784
    la_row_idx = 4556
    point_a = df.iloc[desert_row_idx, :-1].values
    point_b = df.iloc[la_row_idx, :-1].values

    cali_analysis.plot_figure_5_local_explanations(
        model_blackbox,
        X,
        X_df,
        point_a=point_a,
        point_b=point_b,
        save_path=out_dir / "figure_5_local_explanations_blackbox.pdf",
    )

    print("\nGenerating Figure 3.4 (Scaled First-Order PD, blackbox)...")
    cali_analysis.plot_figure_3_4_scaled_first_order_pd(
        model_blackbox,
        X,
        X_df,
        save_path=out_dir / "figure_3_4_scaled_first_order_pd_blackbox.pdf",
    )

    print("\nGenerating Figure 6 (Feature Importance, blackbox)...")
    cali_analysis.compute_and_plot_feature_importance(
        model_blackbox,
        X,
        list(X_df.columns),
        gamma=1.0,
        save_path=out_dir / "figure_6_feature_importance_blackbox.pdf",
    )

    print("California figures done.")


def generate_synthetic_figures():
    """Generate synthetic PD cancellation figures."""
    print("\n" + "=" * 60)
    print("SYNTHETIC PD CANCELLATION")
    print("=" * 60)

    out_dir = FIGURES_DIR / "synthetic_pd"
    out_dir.mkdir(parents=True, exist_ok=True)

    from data_generation import make_dataset
    from synthetic_analysis import (
        Config,
        _plot_combined_pd1_x1,
        load_ebm_model,
        load_mpf_model,
        load_xgboost_model,
        plot_ebm_model,
        plot_mpf_model,
        plot_scaled_first_order_pd_mpf_both_stages,
        plot_xgboost_model,
    )

    cfg = Config()
    models_dir = REPRO_DIR / "models" / "synthetic_pd"

    print("\nGenerating data...")
    X_train, _ = make_dataset(
        n=cfg.n,
        seed=cfg.seed,
        noise_std=cfg.noise_std,
        std_x1=1.0,
        std_x2=1.5,
        std_x3=0.8,
    )
    X_val, _ = make_dataset(
        n=5 * cfg.n,
        seed=cfg.seed + 1,
        noise_std=cfg.noise_std,
        std_x1=1.0,
        std_x2=1.5,
        std_x3=0.8,
    )

    print("\nGenerating MPF figures...")
    mpf_model = load_mpf_model(models_dir / "mpf" / "model.bin")
    x1_grid_mpf, pd1_mpf = plot_mpf_model(mpf_model, X_train, X_val, out_dir, cfg)

    print("\nGenerating Figure 3.1 (Scaled first-order PD MPF, both stages)...")
    plot_scaled_first_order_pd_mpf_both_stages(
        out_dir=out_dir,
        model=mpf_model,
        X_background=X_train,
        grid_points=cfg.grid_points_scaled_pd,
        save_name="figure_3_1_scaled_first_order_pd_mpf.pdf",
    )

    print("\nGenerating EBM figures...")
    ebm_model = load_ebm_model(models_dir / "ebm" / "model.pkl")
    x1_grid_ebm, pd1_ebm = plot_ebm_model(ebm_model, X_train, X_val, out_dir, cfg)

    print("\nGenerating XGBoost figures...")
    xgb_model = load_xgboost_model(models_dir / "xgboost" / "model.json")
    x1_grid_xgb, pd1_xgb = plot_xgboost_model(xgb_model, X_train, X_val, out_dir, cfg)

    print("\nGenerating combined PD plot...")
    _plot_combined_pd1_x1(out_dir, x1_grid_mpf, pd1_mpf, pd1_ebm, pd1_xgb)

    print("Synthetic figures done.")


def main():
    print("=" * 60)
    print("REPRODUCING PAPER FIGURES")
    print("=" * 60)

    FIGURES_DIR.mkdir(parents=True, exist_ok=True)

    generate_bike_sharing_figures()
    generate_california_figures()
    generate_synthetic_figures()

    print("\n" + "=" * 60)
    print("SUMMARY")
    print("=" * 60)
    for subdir in sorted(FIGURES_DIR.iterdir()):
        if subdir.is_dir():
            count = len(list(subdir.glob("*.pdf")))
            print(f"  {subdir.name}: {count} figures")
    print(f"\nOutput: {FIGURES_DIR}")
    print("=" * 60)


if __name__ == "__main__":
    main()
