{"cells":[{"cell_type":"markdown","metadata":{"id":"-_NPQIJAyFqg"},"source":["# Semi-Synthetic Job Corps Experiment (Plug-in Estimator)\n","\n","This notebook runs the semi-synthetic experiment on the Job Corps dataset using the **Plug-in Estimator** (Nyström + LOOCV).\n","\n","**Workflow:**\n","1. **Data Loading:** Loads empirical covariates ($X$) and treatment ($T$), and precomputed nuisance components ($\\hat{\\mu}, g$).\n","2. **Simulation:** Runs $K$ Monte Carlo repetitions. In each run:\n","   - Generates a semi-synthetic outcome $Y_{syn} = \\hat{\\mu} + \\epsilon \\cdot g$.\n","   - Fits the Plug-in estimator on the original $t$-grid.\n","   - Optimizes the ridge parameter $\\beta$ via LOOCV.\n","3. **Evaluation:** Aggregates MISE (Mean Integrated Squared Error), plots the estimated curve against the ground truth, and saves results to CSV."]},{"cell_type":"code","execution_count":null,"metadata":{"colab":{"background_save":true},"id":"bQ3clTY4x0C_"},"outputs":[{"name":"stdout","output_type":"stream","text":["Mounted at /content/drive\n","Working Directory: /content/drive/MyDrive/Colab Notebooks/CTE_Baseline\n"]}],"source":["import sys\n","import pathlib\n","import numpy as np\n","import pandas as pd\n","import matplotlib.pyplot as plt\n","import logging\n","\n","# --- Environment Setup ---\n","try:\n","    from google.colab import drive\n","    drive.mount('/content/drive')\n","    BASE_DIR = pathlib.Path(\"/content/drive/MyDrive/Colab Notebooks/CTE_Baseline\")\n","except ImportError:\n","    BASE_DIR = pathlib.Path(\".\").resolve()\n","\n","sys.path.append(str(BASE_DIR))\n","\n","# Project imports\n","from KRR_methods.data_jobcorps import make_Xss, gen_semi_y, load_jobcorps_data\n","from KRR_methods.algorithms.estimators_plugin import run_plugin_loocv_on_original_grid\n","\n","print(f\"Working Directory: {BASE_DIR}\")"]},{"cell_type":"code","execution_count":null,"metadata":{"colab":{"background_save":true},"id":"ApeeSWUfyJkg"},"outputs":[],"source":["# ======================================================================\n","# Configuration \u0026 Hyperparameters\n","# ======================================================================\n","\n","# --- Paths ---\n","EMP_DIR = BASE_DIR / \"DML_methods\" / \"Data_and_Results\"\n","DATA_FILE = EMP_DIR / \"emp_app.csv\"\n","SEMI_SYN_FILE = EMP_DIR / \"semi-syn data grf.csv\"\n","H_STAR_FILE = EMP_DIR / \"h_star_grf_empapp.csv\"\n","RESULTS_DIR = BASE_DIR / \"KRR_methods\" / \"Results\"\n","\n","# --- Algorithm Hyperparameters ---\n","# Tensor-product kernel parameters for f(x,t)\n","KERNEL_TYPE_F = \"matern\"\n","KRR_X_LENGTH_SCALE = 13       # ell_x\n","KRR_X_NU = 0.5                # nu_x\n","KRR_T_LENGTH_SCALE = 6000     # ell_t\n","KRR_T_NU = 0.5                # nu_t\n","NYSTROM_M_F = 700             # Nystrom landmarks\n","\n","# Ridge Parameter (beta) Search Range for LOOCV\n","BETA_MIN = 0.05\n","BETA_MAX = 80.0\n","\n","# --- Simulation Settings ---\n","FIRST_SEED = 1                # Starting seed\n","K_RUNS = 100                  # Number of Monte Carlo repetitions"]},{"cell_type":"code","execution_count":null,"metadata":{"colab":{"background_save":true},"id":"032oPbH8yK_W"},"outputs":[{"name":"stdout","output_type":"stream","text":["Loading empirical data from: emp_app.csv...\n","Loading semi-synthetic components from: semi-syn data grf.csv...\n","Loading ground truth h*(t) from: h_star_grf_empapp.csv...\n","Data loading complete.\n"]}],"source":["# ======================================================================\n","# Data Loading \u0026 Preprocessing\n","# ======================================================================\n","\n","# Execute Loading via imported function\n","X, T, MU_HAT, G_FUNC, T_GRID, H_STAR_VALS = load_jobcorps_data(\n","    data_file=DATA_FILE,\n","    semi_syn_file=SEMI_SYN_FILE,\n","    h_star_file=H_STAR_FILE\n",")\n","\n","print(\"Data loading complete.\")"]},{"cell_type":"code","execution_count":null,"metadata":{"colab":{"background_save":true},"id":"Bef0cEFEyMZK"},"outputs":[],"source":["# ======================================================================\n","# Simulation Functions\n","# ======================================================================\n","\n","def run_simulation(K=20):\n","    \"\"\"\n","    Runs the Plug-in estimator simulation K times.\n","    \"\"\"\n","    mise_list = []\n","    curves_list = []\n","    seeds = []\n","\n","    # Ensure output directory exists\n","    RESULTS_DIR.mkdir(parents=True, exist_ok=True)\n","\n","    for k in range(K):\n","        seed = FIRST_SEED + k\n","        seeds.append(seed)\n","\n","        # [Critical] Set Global Numpy Seed for legacy reproducibility\n","        # This matches the behavior of the original script inside the loop\n","        np.random.seed(seed)\n","\n","        print(f\"--- Run {k + 1}/{K} (seed={seed}) ---\")\n","\n","        # 1. Generate semi-synthetic outcomes\n","        rng_sim = np.random.default_rng(seed)\n","        Y_syn = gen_semi_y(MU_HAT, G_FUNC, rng_sim)\n","\n","        # 2. Shuffle (X, T, Y) consistently for this run\n","        n = len(T)\n","        perm = rng_sim.permutation(n)\n","\n","        Xs = X.iloc[perm].reset_index(drop=True)\n","        Ts = T.iloc[perm].reset_index(drop=True)\n","        Ys = Y_syn[perm]\n","\n","        # Standardize covariates (DGP specific preprocessing)\n","        Xss = make_Xss(Xs)\n","\n","        # 3. Run Plug-in Estimator (Nystrom + LOOCV)\n","        out = run_plugin_loocv_on_original_grid(\n","            Xss,\n","            Ts,\n","            Ys,\n","            t_grid_original=T_GRID,\n","            h_star_vals_original=H_STAR_VALS,\n","            beta_min=BETA_MIN,\n","            beta_max=BETA_MAX,\n","            kernel_type_f=KERNEL_TYPE_F,\n","            ell_x=KRR_X_LENGTH_SCALE,\n","            nu_x=KRR_X_NU,\n","            ell_t=KRR_T_LENGTH_SCALE,\n","            nu_t=KRR_T_NU,\n","            m_f=NYSTROM_M_F,\n","            verbose=False,\n","        )\n","\n","        mise = out[\"mise_plugin\"]\n","        mise_list.append(mise)\n","        curves_list.append(out[\"h_plugin\"])\n","\n","        print(f\"  MISE: {mise:.6f}\")\n","\n","    # --- Aggregation \u0026 Statistics ---\n","    mise_arr = np.array(mise_list)\n","    curves_mat = np.vstack(curves_list) # Shape: (K, n_grid)\n","\n","    mean_curve = curves_mat.mean(axis=0)\n","    std_curve = curves_mat.std(axis=0, ddof=1)\n","    se_curve = std_curve / np.sqrt(K)\n","\n","    print(\"\\n\" + \"=\"*30)\n","    print(f\"Simulation Summary (K={K})\")\n","    print(\"=\"*30)\n","    print(f\"Mean MISE : {mise_arr.mean():.6f}\")\n","    print(f\"Std MISE  : {mise_arr.std(ddof=1):.6f}\")\n","    print(f\"SE MISE   : {mise_arr.std(ddof=1) / np.sqrt(K):.6f}\")\n","\n","    return {\n","        \"seeds\": seeds,\n","        \"mise_all\": mise_arr,\n","        \"curves_mat\": curves_mat,\n","        \"mean_curve\": mean_curve,\n","        \"se_curve\": se_curve\n","    }"]},{"cell_type":"code","execution_count":null,"metadata":{"colab":{"background_save":true},"id":"FR7dLtYyyN3x"},"outputs":[{"name":"stdout","output_type":"stream","text":["--- Run 1/100 (seed=1) ---\n","  MISE: 3.720129\n","--- Run 2/100 (seed=2) ---\n","  MISE: 0.829279\n","--- Run 3/100 (seed=3) ---\n","  MISE: 0.552317\n","--- Run 4/100 (seed=4) ---\n","  MISE: 0.508747\n","--- Run 5/100 (seed=5) ---\n","  MISE: 2.207687\n","--- Run 6/100 (seed=6) ---\n","  MISE: 1.184421\n","--- Run 7/100 (seed=7) ---\n","  MISE: 1.077564\n","--- Run 8/100 (seed=8) ---\n","  MISE: 0.463204\n","--- Run 9/100 (seed=9) ---\n","  MISE: 1.304490\n","--- Run 10/100 (seed=10) ---\n","  MISE: 2.407019\n","--- Run 11/100 (seed=11) ---\n","  MISE: 2.768264\n","--- Run 12/100 (seed=12) ---\n","  MISE: 0.975751\n","--- Run 13/100 (seed=13) ---\n","  MISE: 1.192407\n","--- Run 14/100 (seed=14) ---\n","  MISE: 1.602493\n","--- Run 15/100 (seed=15) ---\n","  MISE: 1.198100\n","--- Run 16/100 (seed=16) ---\n","  MISE: 1.190029\n","--- Run 17/100 (seed=17) ---\n","  MISE: 2.179930\n","--- Run 18/100 (seed=18) ---\n","  MISE: 1.037991\n","--- Run 19/100 (seed=19) ---\n","  MISE: 1.264956\n","--- Run 20/100 (seed=20) ---\n","  MISE: 1.256931\n","--- Run 21/100 (seed=21) ---\n","  MISE: 2.568045\n","--- Run 22/100 (seed=22) ---\n","  MISE: 4.246382\n","--- Run 23/100 (seed=23) ---\n","  MISE: 3.243456\n","--- Run 24/100 (seed=24) ---\n","  MISE: 0.732566\n","--- Run 25/100 (seed=25) ---\n","  MISE: 1.441703\n","--- Run 26/100 (seed=26) ---\n","  MISE: 1.093944\n","--- Run 27/100 (seed=27) ---\n","  MISE: 1.348219\n","--- Run 28/100 (seed=28) ---\n","  MISE: 1.523376\n","--- Run 29/100 (seed=29) ---\n","  MISE: 1.646100\n","--- Run 30/100 (seed=30) ---\n","  MISE: 2.945389\n","--- Run 31/100 (seed=31) ---\n","  MISE: 0.898181\n","--- Run 32/100 (seed=32) ---\n","  MISE: 0.454879\n","--- Run 33/100 (seed=33) ---\n","  MISE: 2.557203\n","--- Run 34/100 (seed=34) ---\n","  MISE: 1.947252\n","--- Run 35/100 (seed=35) ---\n","  MISE: 2.282780\n","--- Run 36/100 (seed=36) ---\n","  MISE: 1.315273\n","--- Run 37/100 (seed=37) ---\n","  MISE: 0.485906\n","--- Run 38/100 (seed=38) ---\n","  MISE: 0.628882\n","--- Run 39/100 (seed=39) ---\n","  MISE: 1.456384\n","--- Run 40/100 (seed=40) ---\n","  MISE: 1.290445\n","--- Run 41/100 (seed=41) ---\n","  MISE: 0.284712\n","--- Run 42/100 (seed=42) ---\n","  MISE: 1.434297\n","--- Run 43/100 (seed=43) ---\n","  MISE: 0.795433\n","--- Run 44/100 (seed=44) ---\n","  MISE: 2.911150\n","--- Run 45/100 (seed=45) ---\n","  MISE: 0.798618\n","--- Run 46/100 (seed=46) ---\n","  MISE: 0.823582\n","--- Run 47/100 (seed=47) ---\n","  MISE: 1.683189\n","--- Run 48/100 (seed=48) ---\n","  MISE: 2.611675\n","--- Run 49/100 (seed=49) ---\n","  MISE: 1.569739\n","--- Run 50/100 (seed=50) ---\n","  MISE: 1.213604\n","--- Run 51/100 (seed=51) ---\n","  MISE: 2.534857\n","--- Run 52/100 (seed=52) ---\n","  MISE: 1.280427\n","--- Run 53/100 (seed=53) ---\n","  MISE: 1.248879\n","--- Run 54/100 (seed=54) ---\n","  MISE: 3.372764\n","--- Run 55/100 (seed=55) ---\n","  MISE: 4.410381\n","--- Run 56/100 (seed=56) ---\n","  MISE: 2.255442\n","--- Run 57/100 (seed=57) ---\n","  MISE: 0.526653\n","--- Run 58/100 (seed=58) ---\n","  MISE: 0.756605\n","--- Run 59/100 (seed=59) ---\n","  MISE: 0.774125\n","--- Run 60/100 (seed=60) ---\n","  MISE: 2.260796\n","--- Run 61/100 (seed=61) ---\n","  MISE: 7.993044\n","--- Run 62/100 (seed=62) ---\n","  MISE: 4.638987\n","--- Run 63/100 (seed=63) ---\n","  MISE: 1.320389\n","--- Run 64/100 (seed=64) ---\n","  MISE: 0.839457\n","--- Run 65/100 (seed=65) ---\n","  MISE: 1.154436\n","--- Run 66/100 (seed=66) ---\n","  MISE: 1.063611\n","--- Run 67/100 (seed=67) ---\n","  MISE: 0.744065\n","--- Run 68/100 (seed=68) ---\n","  MISE: 1.462543\n","--- Run 69/100 (seed=69) ---\n","  MISE: 0.934991\n","--- Run 70/100 (seed=70) ---\n","  MISE: 3.016759\n","--- Run 71/100 (seed=71) ---\n","  MISE: 0.596116\n","--- Run 72/100 (seed=72) ---\n","  MISE: 1.511628\n","--- Run 73/100 (seed=73) ---\n","  MISE: 1.407100\n","--- Run 74/100 (seed=74) ---\n","  MISE: 0.867645\n","--- Run 75/100 (seed=75) ---\n","  MISE: 1.659218\n","--- Run 76/100 (seed=76) ---\n","  MISE: 0.186489\n","--- Run 77/100 (seed=77) ---\n","  MISE: 0.421845\n","--- Run 78/100 (seed=78) ---\n","  MISE: 1.120051\n","--- Run 79/100 (seed=79) ---\n","  MISE: 1.000789\n","--- Run 80/100 (seed=80) ---\n","  MISE: 2.434547\n","--- Run 81/100 (seed=81) ---\n","  MISE: 0.567024\n","--- Run 82/100 (seed=82) ---\n","  MISE: 0.687638\n","--- Run 83/100 (seed=83) ---\n","  MISE: 1.277190\n","--- Run 84/100 (seed=84) ---\n","  MISE: 2.544714\n","--- Run 85/100 (seed=85) ---\n","  MISE: 1.446000\n","--- Run 86/100 (seed=86) ---\n","  MISE: 1.542140\n","--- Run 87/100 (seed=87) ---\n","  MISE: 1.097173\n","--- Run 88/100 (seed=88) ---\n","  MISE: 2.610444\n","--- Run 89/100 (seed=89) ---\n","  MISE: 0.767608\n","--- Run 90/100 (seed=90) ---\n","  MISE: 1.067674\n","--- Run 91/100 (seed=91) ---\n","  MISE: 1.693729\n","--- Run 92/100 (seed=92) ---\n","  MISE: 1.740890\n","--- Run 93/100 (seed=93) ---\n","  MISE: 2.154092\n","--- Run 94/100 (seed=94) ---\n","  MISE: 0.768177\n","--- Run 95/100 (seed=95) ---\n","  MISE: 4.439206\n","--- Run 96/100 (seed=96) ---\n","  MISE: 1.298758\n","--- Run 97/100 (seed=97) ---\n","  MISE: 1.281691\n","--- Run 98/100 (seed=98) ---\n","  MISE: 1.625069\n","--- Run 99/100 (seed=99) ---\n","  MISE: 1.368942\n","--- Run 100/100 (seed=100) ---\n","  MISE: 1.049560\n","\n","==============================\n","Simulation Summary (K=100)\n","==============================\n","Mean MISE : 1.619764\n","Std MISE  : 1.146753\n","SE MISE   : 0.114675\n"]},{"name":"stderr","output_type":"stream","text":["/tmp/ipython-input-3487969510.py:15: PerformanceWarning: DataFrame is highly fragmented.  This is usually the result of calling `frame.insert` many times, which has poor performance.  Consider joining all columns at once using pd.concat(axis=1) instead. To get a de-fragmented frame, use `newframe = frame.copy()`\n","  df_out[f\"h_hat_seed_{seed}\"] = results[\"curves_mat\"][i, :]\n","/tmp/ipython-input-3487969510.py:17: PerformanceWarning: DataFrame is highly fragmented.  This is usually the result of calling `frame.insert` many times, which has poor performance.  Consider joining all columns at once using pd.concat(axis=1) instead. To get a de-fragmented frame, use `newframe = frame.copy()`\n","  df_out[\"mean_h_hat\"] = results[\"mean_curve\"]\n","/tmp/ipython-input-3487969510.py:18: PerformanceWarning: DataFrame is highly fragmented.  This is usually the result of calling `frame.insert` many times, which has poor performance.  Consider joining all columns at once using pd.concat(axis=1) instead. To get a de-fragmented frame, use `newframe = frame.copy()`\n","  df_out[\"se_h_hat\"] = results[\"se_curve\"]\n"]},{"name":"stdout","output_type":"stream","text":["\n","Results saved to: /content/drive/MyDrive/Colab Notebooks/CTE_Baseline/KRR_methods/Results/plugin_semi-real_seeds_1-100.csv\n"]}],"source":["# ======================================================================\n","# Main Execution\n","# ======================================================================\n","\n","# 1. Run the simulation\n","results = run_simulation(K=K_RUNS)\n","\n","# 3. Save Results to CSV\n","output_csv_name = f\"plugin_semi-real_seeds_{results['seeds'][0]}-{results['seeds'][-1]}.csv\"\n","output_path = RESULTS_DIR / output_csv_name\n","\n","# Construct DataFrame: t | h_hat_seed_1 | ... | mean_h_hat | se_h_hat\n","df_out = pd.DataFrame({\"t\": T_GRID})\n","for i, seed in enumerate(results[\"seeds\"]):\n","    df_out[f\"h_hat_seed_{seed}\"] = results[\"curves_mat\"][i, :]\n","\n","df_out[\"mean_h_hat\"] = results[\"mean_curve\"]\n","df_out[\"se_h_hat\"] = results[\"se_curve\"]\n","\n","df_out.to_csv(output_path, index=False)\n","print(f\"\\nResults saved to: {output_path}\")"]}],"metadata":{"colab":{"authorship_tag":"ABX9TyOq9lu2NBIIO2aCMSMq+Av3","name":"","version":""},"kernelspec":{"display_name":"Python 3","name":"python3"},"language_info":{"name":"python"}},"nbformat":4,"nbformat_minor":0}