{"cells":[{"cell_type":"markdown","metadata":{"id":"9Fo8itXkscLZ"},"source":["# Semi-Synthetic Job Corps Experiment (\"Ours\")\n","\n","This notebook runs the semi-synthetic experiment on the Job Corps dataset using the **\"Ours\" (Two-Stage Estimator)** method.\n","\n","**Workflow:**\n","1. **Data Loading:** Loads the empirical covariates ($X$) and treatment ($T$) from `emp_app.csv`.\n","2. **Setup:** Loads precomputed nuisance components ($\\hat{\\mu}, g$) and the ground-truth curve $h^*(t)$.\n","3. **Simulation:** Runs $K$ Monte Carlo repetitions. In each run:\n","   - Generates a semi-synthetic outcome $Y_{syn}$.\n","   - Fits the two-stage estimator on the original $t$-grid.\n","   - Computes the Mean Integrated Squared Error (MISE).\n","4. **Evaluation:** Aggregates results, plots the estimated curve against ground truth, and saves the results to CSV."]},{"cell_type":"code","execution_count":1,"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"executionInfo":{"elapsed":33646,"status":"ok","timestamp":1767753884823,"user":{"displayName":"D K","userId":"02556183042422178006"},"user_tz":-540},"id":"CFBq21wJsa8m","outputId":"8eec9781-dab0-449d-d768-4f9f695f7813"},"outputs":[{"name":"stdout","output_type":"stream","text":["Mounted at /content/drive\n","Working Directory: /content/drive/MyDrive/Colab Notebooks/CTE_Baseline\n"]}],"source":["import sys\n","import pathlib\n","import numpy as np\n","import pandas as pd\n","import matplotlib.pyplot as plt\n","import logging\n","\n","# --- Environment Setup ---\n","try:\n","    from google.colab import drive\n","    drive.mount('/content/drive')\n","    BASE_DIR = pathlib.Path(\"/content/drive/MyDrive/Colab Notebooks/CTE_Baseline\")\n","except ImportError:\n","    # Assume local execution if not in Colab\n","    BASE_DIR = pathlib.Path(\".\").resolve()\n","\n","sys.path.append(str(BASE_DIR))\n","\n","# Project imports\n","from KRR_methods.data_jobcorps import make_Xss, gen_semi_y, load_jobcorps_data\n","from KRR_methods.algorithms.estimators_ours import run_ours_on_original_grid\n","\n","print(f\"Working Directory: {BASE_DIR}\")"]},{"cell_type":"code","execution_count":2,"metadata":{"executionInfo":{"elapsed":4,"status":"ok","timestamp":1767753884831,"user":{"displayName":"D K","userId":"02556183042422178006"},"user_tz":-540},"id":"tJO-fWl7sfS1"},"outputs":[],"source":["# ======================================================================\n","# Configuration \u0026 Hyperparameters\n","# ======================================================================\n","\n","# --- Paths ---\n","EMP_DIR = BASE_DIR / \"DML_methods\" / \"Data_and_Results\"\n","DATA_FILE = EMP_DIR / \"emp_app.csv\"\n","SEMI_SYN_FILE = EMP_DIR / \"semi-syn data grf.csv\"\n","H_STAR_FILE = EMP_DIR / \"h_star_grf_empapp.csv\"\n","RESULTS_DIR = BASE_DIR / \"KRR_methods\" / \"Results\"\n","\n","# --- Algorithm Hyperparameters ---\n","# Stage 1: f(x,t) tensor-product kernel\n","KERNEL_TYPE_F = \"matern\"      # {\"matern\", \"gaussian\"}\n","KRR_X_LENGTH_SCALE = 13       # ell_x\n","KRR_X_NU = 0.5                # nu_x\n","KRR_T_LENGTH_SCALE = 6000     # ell_t\n","KRR_T_NU = 0.5                # nu_t\n","\n","# Stage 2: h(t) 1D kernel\n","KERNEL_TYPE_H = \"matern\"\n","KRR_H_LENGTH_SCALE = 6000     # l_H\n","KRR_H_NU = 1.5                # nu_H\n","\n","# Regularization (Ridge) Parameters\n","C_VAL = 0.1\n","BETA_GRID = np.array([C_VAL * (2**i) for i in range(0, 9)])\n","BETA0_F = C_VAL\n","BETA0_PRIME_F = C_VAL\n","\n","# Sampling Scheme\n","SECOND_STAGE_N = 2012         # Number of sampled t values for stage 2\n","L_VAL = 3000.0\n","SECOND_STAGE_RANGE = (0.0, L_VAL)\n","\n","# Simulation Settings\n","FIRST_SEED = 1                # Starting seed for reproducibility\n","K_RUNS = 100                  # Number of Monte Carlo repetitions"]},{"cell_type":"code","execution_count":3,"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"executionInfo":{"elapsed":3064,"status":"ok","timestamp":1767753887898,"user":{"displayName":"D K","userId":"02556183042422178006"},"user_tz":-540},"id":"ETxE1hdosiQW","outputId":"1e4bec10-2108-48ec-8282-37b6b7620d4a"},"outputs":[{"name":"stdout","output_type":"stream","text":["Loading empirical data from: emp_app.csv...\n","Loading semi-synthetic components from: semi-syn data grf.csv...\n","Loading ground truth h*(t) from: h_star_grf_empapp.csv...\n","Data loading complete.\n"]}],"source":["# ======================================================================\n","# Data Loading \u0026 Preprocessing\n","# ======================================================================\n","\n","# Execute Loading via imported function\n","X, T, MU_HAT, G_FUNC, T_GRID, H_STAR_VALS = load_jobcorps_data(\n","    data_file=DATA_FILE,\n","    semi_syn_file=SEMI_SYN_FILE,\n","    h_star_file=H_STAR_FILE\n",")\n","\n","print(\"Data loading complete.\")"]},{"cell_type":"code","execution_count":4,"metadata":{"executionInfo":{"elapsed":7,"status":"ok","timestamp":1767753887916,"user":{"displayName":"D K","userId":"02556183042422178006"},"user_tz":-540},"id":"fjPnDjNIskTL"},"outputs":[],"source":["def run_simulation(K=20):\n","    \"\"\"\n","    Runs the 'Ours' estimator simulation K times.\n","    \"\"\"\n","    mise_list = []\n","    curves_list = []\n","    seeds = []\n","\n","    # Ensure output directory exists\n","    RESULTS_DIR.mkdir(parents=True, exist_ok=True)\n","\n","    for k in range(K):\n","        seed = FIRST_SEED + k\n","        seeds.append(seed)\n","\n","        # [Critical] Set Global Numpy Seed for legacy reproducibility\n","        # This ensures that legacy functions inside 'run_ours' using np.random produce identical results\n","        np.random.seed(seed)\n","\n","        # --- 1. Data Generation \u0026 Shuffling ---\n","        # Seed logic determines both Y_syn generation and the shuffle order\n","        rng_sim = np.random.default_rng(seed)\n","\n","        # Generate outcome\n","        Y_syn = gen_semi_y(MU_HAT, G_FUNC, rng_sim)\n","\n","        # Shuffle (X, T, Y) consistently\n","        n = len(T)\n","        perm = rng_sim.permutation(n)\n","\n","        Xs = X.iloc[perm].reset_index(drop=True)\n","        Ts = T.iloc[perm].reset_index(drop=True)\n","        Ys = Y_syn[perm]\n","\n","        # Standardize covariates (DGP specific preprocessing)\n","        Xss = make_Xss(Xs)\n","\n","        print(f\"--- Run {k + 1}/{K} (seed={seed}) ---\")\n","\n","        # --- 2. Run Estimator ---\n","        out = run_ours_on_original_grid(\n","            Xss,\n","            Ts,\n","            Ys,\n","            t_grid_original=T_GRID,\n","            h_star_vals_original=H_STAR_VALS,\n","            beta_grid=BETA_GRID,\n","            kernel_type_f=KERNEL_TYPE_F,\n","            ell_x=KRR_X_LENGTH_SCALE,\n","            nu_x=KRR_X_NU,\n","            ell_t=KRR_T_LENGTH_SCALE,\n","            nu_t=KRR_T_NU,\n","            beta0_f=BETA0_F,\n","            beta0_prime_f=BETA0_PRIME_F,\n","            kernel_type_H=KERNEL_TYPE_H,\n","            l_H=KRR_H_LENGTH_SCALE,\n","            nu_H=KRR_H_NU,\n","            second_stage_range=SECOND_STAGE_RANGE,\n","            second_stage_n=SECOND_STAGE_N,\n","        )\n","\n","        mise = out[\"mise_ours\"]\n","        mise_list.append(mise)\n","        curves_list.append(out[\"h_ours\"])\n","        print(f\"  MISE: {mise:.6f}\")\n","\n","    # --- 3. Aggregation \u0026 Statistics ---\n","    mise_arr = np.array(mise_list)\n","    curves_mat = np.vstack(curves_list)  # Shape: (K, n_grid)\n","\n","    mean_curve = curves_mat.mean(axis=0)\n","    std_curve = curves_mat.std(axis=0, ddof=1)\n","    se_curve = std_curve / np.sqrt(K)\n","\n","    print(\"\\n\" + \"=\" * 30)\n","    print(f\"Simulation Summary (K={K})\")\n","    print(\"=\" * 30)\n","    print(f\"Mean MISE : {mise_arr.mean():.6f}\")\n","    print(f\"Std MISE  : {mise_arr.std(ddof=1):.6f}\")\n","    print(f\"SE MISE   : {mise_arr.std(ddof=1) / np.sqrt(K):.6f}\")\n","\n","    return {\n","        \"seeds\": seeds,\n","        \"mise_all\": mise_arr,\n","        \"curves_mat\": curves_mat,\n","        \"mean_curve\": mean_curve,\n","        \"se_curve\": se_curve,\n","    }\n"]},{"cell_type":"code","execution_count":null,"metadata":{"colab":{"background_save":true,"base_uri":"https://localhost:8080/"},"id":"BPhISuIbsnho"},"outputs":[{"name":"stdout","output_type":"stream","text":["--- Run 1/100 (seed=1) ---\n","  MISE: 1.994353\n","--- Run 2/100 (seed=2) ---\n","  MISE: 0.249440\n","--- Run 3/100 (seed=3) ---\n","  MISE: 1.835164\n","--- Run 4/100 (seed=4) ---\n","  MISE: 0.607466\n","--- Run 5/100 (seed=5) ---\n","  MISE: 0.160391\n","--- Run 6/100 (seed=6) ---\n","  MISE: 0.791055\n","--- Run 7/100 (seed=7) ---\n","  MISE: 1.143464\n","--- Run 8/100 (seed=8) ---\n","  MISE: 1.010764\n","--- Run 9/100 (seed=9) ---\n","  MISE: 2.392028\n","--- Run 10/100 (seed=10) ---\n","  MISE: 3.122674\n","--- Run 11/100 (seed=11) ---\n","  MISE: 4.026220\n","--- Run 12/100 (seed=12) ---\n","  MISE: 0.866902\n","--- Run 13/100 (seed=13) ---\n","  MISE: 1.589211\n","--- Run 14/100 (seed=14) ---\n","  MISE: 2.297269\n","--- Run 15/100 (seed=15) ---\n","  MISE: 1.420445\n","--- Run 16/100 (seed=16) ---\n","  MISE: 2.070913\n","--- Run 17/100 (seed=17) ---\n","  MISE: 0.172606\n","--- Run 18/100 (seed=18) ---\n","  MISE: 0.400014\n","--- Run 19/100 (seed=19) ---\n","  MISE: 1.793316\n","--- Run 20/100 (seed=20) ---\n","  MISE: 1.088182\n","--- Run 21/100 (seed=21) ---\n","  MISE: 2.680156\n","--- Run 22/100 (seed=22) ---\n","  MISE: 1.089628\n","--- Run 23/100 (seed=23) ---\n","  MISE: 0.896769\n","--- Run 24/100 (seed=24) ---\n","  MISE: 0.333346\n","--- Run 25/100 (seed=25) ---\n","  MISE: 0.109647\n","--- Run 26/100 (seed=26) ---\n","  MISE: 0.109001\n","--- Run 27/100 (seed=27) ---\n","  MISE: 0.162534\n","--- Run 28/100 (seed=28) ---\n","  MISE: 0.310829\n","--- Run 29/100 (seed=29) ---\n","  MISE: 1.186108\n","--- Run 30/100 (seed=30) ---\n","  MISE: 0.392852\n","--- Run 31/100 (seed=31) ---\n","  MISE: 1.049539\n","--- Run 32/100 (seed=32) ---\n","  MISE: 0.379707\n","--- Run 33/100 (seed=33) ---\n","  MISE: 0.307836\n","--- Run 34/100 (seed=34) ---\n","  MISE: 2.233421\n","--- Run 35/100 (seed=35) ---\n","  MISE: 1.506754\n","--- Run 36/100 (seed=36) ---\n","  MISE: 1.433902\n","--- Run 37/100 (seed=37) ---\n","  MISE: 0.210891\n","--- Run 38/100 (seed=38) ---\n","  MISE: 0.205580\n","--- Run 39/100 (seed=39) ---\n","  MISE: 1.279422\n","--- Run 40/100 (seed=40) ---\n","  MISE: 0.554060\n","--- Run 41/100 (seed=41) ---\n","  MISE: 1.528295\n","--- Run 42/100 (seed=42) ---\n","  MISE: 0.161516\n","--- Run 43/100 (seed=43) ---\n","  MISE: 2.031272\n","--- Run 44/100 (seed=44) ---\n","  MISE: 1.142621\n","--- Run 45/100 (seed=45) ---\n","  MISE: 2.990152\n","--- Run 46/100 (seed=46) ---\n","  MISE: 1.851898\n","--- Run 47/100 (seed=47) ---\n","  MISE: 2.118863\n","--- Run 48/100 (seed=48) ---\n","  MISE: 0.743027\n","--- Run 49/100 (seed=49) ---\n","  MISE: 0.288837\n","--- Run 50/100 (seed=50) ---\n","  MISE: 0.858640\n","--- Run 51/100 (seed=51) ---\n","  MISE: 0.269585\n","--- Run 52/100 (seed=52) ---\n","  MISE: 0.611306\n","--- Run 53/100 (seed=53) ---\n","  MISE: 1.410412\n","--- Run 54/100 (seed=54) ---\n","  MISE: 3.203502\n","--- Run 55/100 (seed=55) ---\n","  MISE: 0.956337\n","--- Run 56/100 (seed=56) ---\n","  MISE: 0.048252\n","--- Run 57/100 (seed=57) ---\n","  MISE: 1.785263\n","--- Run 58/100 (seed=58) ---\n","  MISE: 0.434795\n","--- Run 59/100 (seed=59) ---\n","  MISE: 6.276363\n","--- Run 60/100 (seed=60) ---\n","  MISE: 0.330100\n","--- Run 61/100 (seed=61) ---\n","  MISE: 5.109269\n","--- Run 62/100 (seed=62) ---\n","  MISE: 4.340664\n","--- Run 63/100 (seed=63) ---\n","  MISE: 0.462013\n","--- Run 64/100 (seed=64) ---\n","  MISE: 2.411255\n","--- Run 65/100 (seed=65) ---\n","  MISE: 0.790160\n","--- Run 66/100 (seed=66) ---\n","  MISE: 0.206367\n","--- Run 67/100 (seed=67) ---\n","  MISE: 1.276382\n","--- Run 68/100 (seed=68) ---\n","  MISE: 0.618186\n","--- Run 69/100 (seed=69) ---\n","  MISE: 0.467229\n","--- Run 70/100 (seed=70) ---\n","  MISE: 1.013422\n","--- Run 71/100 (seed=71) ---\n","  MISE: 0.251272\n","--- Run 72/100 (seed=72) ---\n","  MISE: 1.124290\n","--- Run 73/100 (seed=73) ---\n","  MISE: 0.982291\n","--- Run 74/100 (seed=74) ---\n","  MISE: 0.057294\n","--- Run 75/100 (seed=75) ---\n","  MISE: 0.365916\n","--- Run 76/100 (seed=76) ---\n","  MISE: 0.918183\n","--- Run 77/100 (seed=77) ---\n","  MISE: 0.361169\n","--- Run 78/100 (seed=78) ---\n","  MISE: 0.998052\n","--- Run 79/100 (seed=79) ---\n","  MISE: 0.123750\n","--- Run 80/100 (seed=80) ---\n","  MISE: 0.119140\n","--- Run 81/100 (seed=81) ---\n","  MISE: 2.868242\n","--- Run 82/100 (seed=82) ---\n","  MISE: 0.082785\n","--- Run 83/100 (seed=83) ---\n","  MISE: 0.221313\n","--- Run 84/100 (seed=84) ---\n","  MISE: 3.329613\n","--- Run 85/100 (seed=85) ---\n","  MISE: 0.725267\n","--- Run 86/100 (seed=86) ---\n","  MISE: 0.609779\n","--- Run 87/100 (seed=87) ---\n","  MISE: 1.621857\n","--- Run 88/100 (seed=88) ---\n","  MISE: 1.140562\n","--- Run 89/100 (seed=89) ---\n","  MISE: 1.089661\n","--- Run 90/100 (seed=90) ---\n","  MISE: 1.650244\n","--- Run 91/100 (seed=91) ---\n","  MISE: 0.828761\n","--- Run 92/100 (seed=92) ---\n","  MISE: 3.750823\n","--- Run 93/100 (seed=93) ---\n","  MISE: 0.119387\n","--- Run 94/100 (seed=94) ---\n","  MISE: 0.289831\n","--- Run 95/100 (seed=95) ---\n","  MISE: 4.119674\n","--- Run 96/100 (seed=96) ---\n","  MISE: 0.055599\n","--- Run 97/100 (seed=97) ---\n","  MISE: 3.177938\n","--- Run 98/100 (seed=98) ---\n","  MISE: 0.684537\n","--- Run 99/100 (seed=99) ---\n","  MISE: 1.426394\n","--- Run 100/100 (seed=100) ---\n","  MISE: 0.363652\n","\n","==============================\n","Simulation Summary (K=100)\n","==============================\n","Mean MISE : 1.246571\n","Std MISE  : 1.208947\n","SE MISE   : 0.120895\n"]},{"name":"stderr","output_type":"stream","text":["/tmp/ipython-input-742300383.py:15: PerformanceWarning: DataFrame is highly fragmented.  This is usually the result of calling `frame.insert` many times, which has poor performance.  Consider joining all columns at once using pd.concat(axis=1) instead. To get a de-fragmented frame, use `newframe = frame.copy()`\n","  df_out[f\"h_hat_seed_{seed}\"] = results[\"curves_mat\"][i, :]\n","/tmp/ipython-input-742300383.py:17: PerformanceWarning: DataFrame is highly fragmented.  This is usually the result of calling `frame.insert` many times, which has poor performance.  Consider joining all columns at once using pd.concat(axis=1) instead. To get a de-fragmented frame, use `newframe = frame.copy()`\n","  df_out[\"mean_h_hat\"] = results[\"mean_curve\"]\n","/tmp/ipython-input-742300383.py:18: PerformanceWarning: DataFrame is highly fragmented.  This is usually the result of calling `frame.insert` many times, which has poor performance.  Consider joining all columns at once using pd.concat(axis=1) instead. To get a de-fragmented frame, use `newframe = frame.copy()`\n","  df_out[\"se_h_hat\"] = results[\"se_curve\"]\n"]},{"name":"stdout","output_type":"stream","text":["\n","Results saved to: /content/drive/MyDrive/Colab Notebooks/CTE_Baseline/KRR_methods/Results/ours_semi-real_seeds_1-100.csv\n"]}],"source":["# ======================================================================\n","# Main Execution\n","# ======================================================================\n","\n","# 1. Run the simulation\n","results = run_simulation(K=K_RUNS)\n","\n","# 3. Save Results to CSV\n","output_csv_name = f\"ours_semi-real_seeds_{results['seeds'][0]}-{results['seeds'][-1]}.csv\"\n","output_path = RESULTS_DIR / output_csv_name\n","\n","# Construct DataFrame: t | h_hat_seed_1 | ... | mean_h_hat | se_h_hat\n","df_out = pd.DataFrame({\"t\": T_GRID})\n","for i, seed in enumerate(results[\"seeds\"]):\n","    df_out[f\"h_hat_seed_{seed}\"] = results[\"curves_mat\"][i, :]\n","\n","df_out[\"mean_h_hat\"] = results[\"mean_curve\"]\n","df_out[\"se_h_hat\"] = results[\"se_curve\"]\n","\n","df_out.to_csv(output_path, index=False)\n","print(f\"\\nResults saved to: {output_path}\")"]}],"metadata":{"colab":{"authorship_tag":"ABX9TyNBprTrcp4QrSuS4G1teehi","name":"","toc_visible":true,"version":""},"kernelspec":{"display_name":"Python 3","name":"python3"},"language_info":{"name":"python"}},"nbformat":4,"nbformat_minor":0}