{"cells":[{"cell_type":"markdown","metadata":{"id":"tGkDROqo57dX"},"source":["# Synthetic Data Experiment (\"Ours\" Joint-Kernel Estimator)\n","\n","This notebook runs the synthetic data experiment using the **\"Ours\" (Joint-Kernel Two-Stage Estimator)** method.\n","\n","**Workflow:**\n","1. **Setup:** Defines hyperparameters and computes the ground-truth curve $h^*(t)$ via Monte Carlo approximation.\n","2. **Simulation:** Runs $K$ Monte Carlo repetitions. In each run:\n","   - Generates synthetic data $(X, T, Y)$ based on the DGP.\n","   - Splits data into $D_1$ (for $\\hat{f}$) and $D_2$ (for $\\tilde{f}, \\hat{h}$).\n","   - Fits the joint-kernel estimator and evaluates $\\hat{h}(t)$ on a grid.\n","   - Computes MISE (Mean Integrated Squared Error).\n","3. **Evaluation:** Aggregates results, plots the estimated curve against the ground truth, and saves results to CSV."]},{"cell_type":"code","execution_count":30,"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"executionInfo":{"elapsed":603,"status":"ok","timestamp":1769544084431,"user":{"displayName":"D K","userId":"02556183042422178006"},"user_tz":300},"id":"mxkv1wrV54Kc","outputId":"c83cbae9-ed27-4aa5-86f8-db975067053b"},"outputs":[{"name":"stdout","output_type":"stream","text":["Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount(\"/content/drive\", force_remount=True).\n","Working Directory: /content/drive/MyDrive/Colab Notebooks/CTE_Baseline\n"]}],"source":["import sys\n","import pathlib\n","import numpy as np\n","import pandas as pd\n","import matplotlib.pyplot as plt\n","\n","# --- Environment Setup ---\n","try:\n","    from google.colab import drive\n","    drive.mount('/content/drive')\n","    BASE_DIR = pathlib.Path(\"/content/drive/MyDrive/Colab Notebooks/CTE_Baseline\")\n","except ImportError:\n","    BASE_DIR = pathlib.Path(\".\").resolve()\n","\n","# Add project root to path\n","sys.path.append(str(BASE_DIR))\n","sys.path.append(str(BASE_DIR / \"KRR_methods\"))\n","\n","# Project imports\n","from KRR_methods.synthetic_dgps import (\n","    generate_unified_data,\n","    split_data,\n","    approximate_h_star,\n",")\n","from KRR_methods.algorithms.estimators_ours import run_ours_joint_kernel\n","\n","print(f\"Working Directory: {BASE_DIR}\")"]},{"cell_type":"code","execution_count":31,"metadata":{"executionInfo":{"elapsed":1,"status":"ok","timestamp":1769544084433,"user":{"displayName":"D K","userId":"02556183042422178006"},"user_tz":300},"id":"fQaXziJ45-LT"},"outputs":[],"source":["# ======================================================================\n","# Configuration \u0026 Hyperparameters\n","# ======================================================================\n","\n","# --- Paths ---\n","RESULTS_DIR = BASE_DIR / \"KRR_methods\" / \"Results\"\n","\n","# --- Data Generation Parameters ---\n","N_SAMPLES =  1000\n","NOISE_STD = 1.0\n","FIRST_SEED = 0\n","K_RUNS = 100\n","\n","# --- Algorithm Hyperparameters ---\n","# 1. Joint Kernel (Matérn on (X, T))\n","KERNEL_TYPE_ND = \"matern\"   # \"matern\" or \"gaussian\"\n","LENGTH_SCALE_ND = 3.0\n","NU_ND = 1.5\n","\n","# 2. Second-stage 1D kernel for h(t)\n","LENGTH_SCALE_1D_H = 3.0\n","NU_1D_H = 2.5\n","\n","# 3. Regularization (Ridge)\n","C_VAL = 0.1\n","BETA_H_GRID = np.array([C_VAL * (2**i) for i in range(0, 9)], dtype=float)\n","BETA0_F = C_VAL             # Ridge for f_hat on D1\n","BETA0_PRIME_F = C_VAL       # Ridge for f_tilde on D2\n","\n","# 4. Sampling \u0026 Grid\n","N_SIZE_KRR = N_SAMPLES // 2  # Size of split data (D1, D2)\n","T_MIN = -np.pi\n","T_MAX = np.pi\n","T_GRID = np.linspace(T_MIN, T_MAX, 1000)\n"]},{"cell_type":"code","execution_count":32,"metadata":{"executionInfo":{"elapsed":1,"status":"ok","timestamp":1769544084436,"user":{"displayName":"D K","userId":"02556183042422178006"},"user_tz":300},"id":"Rmg30gt76C0o"},"outputs":[],"source":["# ======================================================================\n","# Helper Functions\n","# ======================================================================\n","\n","def run_single_joint_ours(X, T, Y, t_grid):\n","    \"\"\"\n","    Run the joint-kernel version of 'Ours' estimator on a single dataset.\n","\n","    Steps:\n","      1. Split data into D1 and D2.\n","      2. Run the two-stage KRR procedure (joint kernel on X, T).\n","      3. Evaluate the resulting function on t_grid.\n","      4. Return both the estimated curve and the selected second-stage beta.\n","    \"\"\"\n","    # 1. Split Data\n","    D1, D2 = split_data(X, T, Y)\n","\n","    # 2. Fit Joint Kernel Estimator\n","    # Now returns (h_func_prop, beta_selected, best_error, ...)\n","    h_func_prop, beta_sel, *_ = run_ours_joint_kernel(\n","        D1=D1,\n","        D2=D2,\n","        beta0_for_f_cand=BETA0_F,\n","        beta_h_grid_for_h_cand=BETA_H_GRID,\n","        beta0_prime_for_f_tilde=BETA0_PRIME_F,\n","        nu_nd=NU_ND,\n","        nu_1d_h=NU_1D_H,\n","        length_scale_nd=LENGTH_SCALE_ND,\n","        length_scale_1d_h=LENGTH_SCALE_1D_H,\n","        second_stage_N_SIZE=N_SIZE_KRR,\n","        kernel_type_nd=KERNEL_TYPE_ND,\n","        t_min=T_MIN,\n","        t_max=T_MAX,\n","    )\n","\n","    # 3. Evaluate on grid\n","    h_hat = h_func_prop(t_grid)\n","\n","    # 4. Return both outputs\n","    return h_hat, float(beta_sel)\n"]},{"cell_type":"code","execution_count":33,"metadata":{"executionInfo":{"elapsed":0,"status":"ok","timestamp":1769544084436,"user":{"displayName":"D K","userId":"02556183042422178006"},"user_tz":300},"id":"4lXA_b0d6Yfc"},"outputs":[],"source":["# ======================================================================\n","# Simulation Logic\n","# ======================================================================\n","\n","def run_simulation(K=100, noise_std=1.0, first_seed=0):\n","    \"\"\"\n","    Runs the 'Ours' Joint-Kernel simulation K times.\n","    Also tracks the selected second-stage beta for each run and reports its mean and SE.\n","    \"\"\"\n","    RESULTS_DIR.mkdir(parents=True, exist_ok=True)\n","\n","    # 1. Compute True Curve h*(t) once (using MC approximation)\n","    print(\"Approximating ground truth h*(t)...\")\n","    h_star_vals = approximate_h_star(T_GRID, n_mc_samples=40000)\n","\n","    mise_list = []\n","    curves_list = []\n","    beta_list = []\n","    seeds = []\n","\n","    print(f\"\\n🚀 Running {K} simulations for 'Ours' Joint-Kernel (noise_std={noise_std})...\")\n","\n","    for k in range(K):\n","        seed = first_seed + k\n","        seeds.append(seed)\n","\n","        # [Critical] Set Global Numpy Seed for reproducibility\n","        # This matches the user's requirement to ensure identical results per run\n","        np.random.seed(seed)\n","\n","        print(f\"--- Run {k + 1}/{K} (seed={seed}) ---\")\n","\n","        # 2. Generate Synthetic Data\n","        X, T, Y = generate_unified_data(\n","            n_samples=N_SAMPLES,\n","            noise_std=noise_std,\n","        )\n","\n","        # 3. Fit Estimator (now returns both h_hat and beta_sel)\n","        h_hat, beta_sel = run_single_joint_ours(X, T, Y, T_GRID)\n","        curves_list.append(h_hat)\n","        beta_list.append(beta_sel)\n","\n","        # 4. Compute MISE\n","        mise = np.mean((h_hat - h_star_vals) ** 2)\n","        mise_list.append(mise)\n","\n","        print(f\"   MISE: {mise:.6f} | Selected beta: {beta_sel:.6g}\")\n","\n","    # --- Aggregation \u0026 Statistics ---\n","    mise_arr = np.array(mise_list, dtype=float)\n","    curves_mat = np.vstack(curves_list)\n","\n","    mean_curve = curves_mat.mean(axis=0)\n","    std_curve = curves_mat.std(axis=0, ddof=1)\n","    se_curve = std_curve / np.sqrt(K)\n","\n","    beta_arr = np.array(beta_list, dtype=float)\n","    beta_mean = float(beta_arr.mean())\n","    beta_std = float(beta_arr.std(ddof=1))\n","    beta_se = float(beta_std / np.sqrt(K))\n","\n","    print(\"\\n\" + \"=\"*40)\n","    print(f\"Simulation Summary (K={K})\")\n","    print(\"=\"*40)\n","    print(f\"Mean MISE : {mise_arr.mean():.6f}\")\n","    print(f\"Std MISE  : {mise_arr.std(ddof=1):.6f}\")\n","    print(f\"SE MISE   : {mise_arr.std(ddof=1) / np.sqrt(K):.6f}\")\n","\n","    print(\"-\" * 40)\n","    print(f\"Mean selected beta : {beta_mean:.6g}\")\n","    print(f\"Std selected beta  : {beta_std:.6g}\")\n","    print(f\"SE selected beta   : {beta_se:.6g}\")\n","\n","    # Tag for filename\n","    noise_str = f\"{noise_std:.3g}\".replace(\".\", \"p\")\n","    run_tag = f\"ours_synth_sample_{N_SAMPLES}_noise_{noise_str}_seeds_{seeds[0]}-{seeds[-1]}\"\n","\n","    return {\n","        \"seeds\": seeds,\n","        \"t_grid\": T_GRID,\n","        \"h_star\": h_star_vals,\n","        \"mise_all\": mise_arr,\n","        \"curves_mat\": curves_mat,\n","        \"mean_curve\": mean_curve,\n","        \"se_curve\": se_curve,\n","        \"beta_selected_all\": beta_arr,\n","        \"beta_selected_mean\": beta_mean,\n","        \"beta_selected_se\": beta_se,\n","        \"run_tag\": run_tag\n","    }\n"]},{"cell_type":"code","execution_count":34,"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"executionInfo":{"elapsed":2648816,"status":"ok","timestamp":1769546733253,"user":{"displayName":"D K","userId":"02556183042422178006"},"user_tz":300},"id":"4CQzwTAN6GIv","outputId":"b7477af6-39e5-4192-dd11-49ce674feefe"},"outputs":[{"name":"stdout","output_type":"stream","text":["Approximating ground truth h*(t)...\n","\n","🚀 Running 100 simulations for 'Ours' Joint-Kernel (noise_std=1.0)...\n","--- Run 1/100 (seed=0) ---\n","   MISE: 0.033938 | Selected beta: 0.1\n","--- Run 2/100 (seed=1) ---\n","   MISE: 0.075741 | Selected beta: 0.8\n","--- Run 3/100 (seed=2) ---\n","   MISE: 0.027655 | Selected beta: 3.2\n","--- Run 4/100 (seed=3) ---\n","   MISE: 0.063475 | Selected beta: 0.4\n","--- Run 5/100 (seed=4) ---\n","   MISE: 0.045009 | Selected beta: 0.1\n","--- Run 6/100 (seed=5) ---\n","   MISE: 0.057727 | Selected beta: 3.2\n","--- Run 7/100 (seed=6) ---\n","   MISE: 0.162596 | Selected beta: 25.6\n","--- Run 8/100 (seed=7) ---\n","   MISE: 0.044209 | Selected beta: 0.8\n","--- Run 9/100 (seed=8) ---\n","   MISE: 0.066313 | Selected beta: 0.4\n","--- Run 10/100 (seed=9) ---\n","   MISE: 0.022681 | Selected beta: 1.6\n","--- Run 11/100 (seed=10) ---\n","   MISE: 0.148829 | Selected beta: 0.1\n","--- Run 12/100 (seed=11) ---\n","   MISE: 0.054744 | Selected beta: 3.2\n","--- Run 13/100 (seed=12) ---\n","   MISE: 0.058346 | Selected beta: 0.1\n","--- Run 14/100 (seed=13) ---\n","   MISE: 0.031362 | Selected beta: 0.1\n","--- Run 15/100 (seed=14) ---\n","   MISE: 0.079973 | Selected beta: 0.1\n","--- Run 16/100 (seed=15) ---\n","   MISE: 0.061617 | Selected beta: 0.1\n","--- Run 17/100 (seed=16) ---\n","   MISE: 0.015896 | Selected beta: 0.1\n","--- Run 18/100 (seed=17) ---\n","   MISE: 0.028729 | Selected beta: 0.8\n","--- Run 19/100 (seed=18) ---\n","   MISE: 0.019903 | Selected beta: 0.1\n","--- Run 20/100 (seed=19) ---\n","   MISE: 0.074821 | Selected beta: 1.6\n","--- Run 21/100 (seed=20) ---\n","   MISE: 0.085900 | Selected beta: 0.1\n","--- Run 22/100 (seed=21) ---\n","   MISE: 0.047101 | Selected beta: 0.1\n","--- Run 23/100 (seed=22) ---\n","   MISE: 0.050348 | Selected beta: 0.8\n","--- Run 24/100 (seed=23) ---\n","   MISE: 0.102624 | Selected beta: 25.6\n","--- Run 25/100 (seed=24) ---\n","   MISE: 0.025742 | Selected beta: 0.1\n","--- Run 26/100 (seed=25) ---\n","   MISE: 0.112177 | Selected beta: 0.1\n","--- Run 27/100 (seed=26) ---\n","   MISE: 0.046084 | Selected beta: 1.6\n","--- Run 28/100 (seed=27) ---\n","   MISE: 0.031564 | Selected beta: 3.2\n","--- Run 29/100 (seed=28) ---\n","   MISE: 0.038811 | Selected beta: 0.1\n","--- Run 30/100 (seed=29) ---\n","   MISE: 0.083219 | Selected beta: 0.1\n","--- Run 31/100 (seed=30) ---\n","   MISE: 0.092795 | Selected beta: 0.1\n","--- Run 32/100 (seed=31) ---\n","   MISE: 0.115078 | Selected beta: 12.8\n","--- Run 33/100 (seed=32) ---\n","   MISE: 0.043568 | Selected beta: 6.4\n","--- Run 34/100 (seed=33) ---\n","   MISE: 0.039876 | Selected beta: 0.1\n","--- Run 35/100 (seed=34) ---\n","   MISE: 0.100062 | Selected beta: 3.2\n","--- Run 36/100 (seed=35) ---\n","   MISE: 0.058648 | Selected beta: 0.8\n","--- Run 37/100 (seed=36) ---\n","   MISE: 0.144112 | Selected beta: 6.4\n","--- Run 38/100 (seed=37) ---\n","   MISE: 0.165237 | Selected beta: 0.1\n","--- Run 39/100 (seed=38) ---\n","   MISE: 0.022492 | Selected beta: 0.4\n","--- Run 40/100 (seed=39) ---\n","   MISE: 0.083191 | Selected beta: 0.2\n","--- Run 41/100 (seed=40) ---\n","   MISE: 0.107103 | Selected beta: 3.2\n","--- Run 42/100 (seed=41) ---\n","   MISE: 0.032310 | Selected beta: 0.1\n","--- Run 43/100 (seed=42) ---\n","   MISE: 0.037016 | Selected beta: 0.4\n","--- Run 44/100 (seed=43) ---\n","   MISE: 0.086112 | Selected beta: 3.2\n","--- Run 45/100 (seed=44) ---\n","   MISE: 0.089923 | Selected beta: 0.1\n","--- Run 46/100 (seed=45) ---\n","   MISE: 0.051432 | Selected beta: 1.6\n","--- Run 47/100 (seed=46) ---\n","   MISE: 0.037551 | Selected beta: 1.6\n","--- Run 48/100 (seed=47) ---\n","   MISE: 0.057178 | Selected beta: 0.2\n","--- Run 49/100 (seed=48) ---\n","   MISE: 0.050169 | Selected beta: 1.6\n","--- Run 50/100 (seed=49) ---\n","   MISE: 0.063500 | Selected beta: 0.1\n","--- Run 51/100 (seed=50) ---\n","   MISE: 0.049596 | Selected beta: 0.1\n","--- Run 52/100 (seed=51) ---\n","   MISE: 0.056329 | Selected beta: 3.2\n","--- Run 53/100 (seed=52) ---\n","   MISE: 0.048419 | Selected beta: 0.8\n","--- Run 54/100 (seed=53) ---\n","   MISE: 0.025669 | Selected beta: 3.2\n","--- Run 55/100 (seed=54) ---\n","   MISE: 0.010484 | Selected beta: 0.1\n","--- Run 56/100 (seed=55) ---\n","   MISE: 0.048505 | Selected beta: 0.8\n","--- Run 57/100 (seed=56) ---\n","   MISE: 0.028020 | Selected beta: 3.2\n","--- Run 58/100 (seed=57) ---\n","   MISE: 0.069305 | Selected beta: 0.1\n","--- Run 59/100 (seed=58) ---\n","   MISE: 0.073877 | Selected beta: 0.1\n","--- Run 60/100 (seed=59) ---\n","   MISE: 0.052179 | Selected beta: 3.2\n","--- Run 61/100 (seed=60) ---\n","   MISE: 0.123144 | Selected beta: 3.2\n","--- Run 62/100 (seed=61) ---\n","   MISE: 0.042296 | Selected beta: 0.4\n","--- Run 63/100 (seed=62) ---\n","   MISE: 0.018810 | Selected beta: 0.4\n","--- Run 64/100 (seed=63) ---\n","   MISE: 0.078321 | Selected beta: 6.4\n","--- Run 65/100 (seed=64) ---\n","   MISE: 0.080812 | Selected beta: 3.2\n","--- Run 66/100 (seed=65) ---\n","   MISE: 0.087711 | Selected beta: 0.1\n","--- Run 67/100 (seed=66) ---\n","   MISE: 0.038737 | Selected beta: 0.1\n","--- Run 68/100 (seed=67) ---\n","   MISE: 0.056047 | Selected beta: 0.2\n","--- Run 69/100 (seed=68) ---\n","   MISE: 0.028043 | Selected beta: 0.2\n","--- Run 70/100 (seed=69) ---\n","   MISE: 0.077035 | Selected beta: 0.1\n","--- Run 71/100 (seed=70) ---\n","   MISE: 0.023009 | Selected beta: 0.1\n","--- Run 72/100 (seed=71) ---\n","   MISE: 0.116184 | Selected beta: 12.8\n","--- Run 73/100 (seed=72) ---\n","   MISE: 0.099163 | Selected beta: 6.4\n","--- Run 74/100 (seed=73) ---\n","   MISE: 0.049190 | Selected beta: 0.1\n","--- Run 75/100 (seed=74) ---\n","   MISE: 0.017192 | Selected beta: 0.1\n","--- Run 76/100 (seed=75) ---\n","   MISE: 0.048224 | Selected beta: 0.1\n","--- Run 77/100 (seed=76) ---\n","   MISE: 0.049594 | Selected beta: 0.1\n","--- Run 78/100 (seed=77) ---\n","   MISE: 0.065803 | Selected beta: 0.8\n","--- Run 79/100 (seed=78) ---\n","   MISE: 0.053744 | Selected beta: 6.4\n","--- Run 80/100 (seed=79) ---\n","   MISE: 0.053581 | Selected beta: 0.2\n","--- Run 81/100 (seed=80) ---\n","   MISE: 0.035487 | Selected beta: 3.2\n","--- Run 82/100 (seed=81) ---\n","   MISE: 0.063432 | Selected beta: 0.1\n","--- Run 83/100 (seed=82) ---\n","   MISE: 0.098127 | Selected beta: 25.6\n","--- Run 84/100 (seed=83) ---\n","   MISE: 0.049418 | Selected beta: 0.8\n","--- Run 85/100 (seed=84) ---\n","   MISE: 0.059516 | Selected beta: 0.1\n","--- Run 86/100 (seed=85) ---\n","   MISE: 0.049644 | Selected beta: 3.2\n","--- Run 87/100 (seed=86) ---\n","   MISE: 0.023378 | Selected beta: 0.2\n","--- Run 88/100 (seed=87) ---\n","   MISE: 0.089567 | Selected beta: 6.4\n","--- Run 89/100 (seed=88) ---\n","   MISE: 0.043211 | Selected beta: 0.8\n","--- Run 90/100 (seed=89) ---\n","   MISE: 0.064869 | Selected beta: 0.1\n","--- Run 91/100 (seed=90) ---\n","   MISE: 0.096872 | Selected beta: 0.4\n","--- Run 92/100 (seed=91) ---\n","   MISE: 0.069015 | Selected beta: 0.1\n","--- Run 93/100 (seed=92) ---\n","   MISE: 0.084967 | Selected beta: 0.8\n","--- Run 94/100 (seed=93) ---\n","   MISE: 0.038691 | Selected beta: 0.1\n","--- Run 95/100 (seed=94) ---\n","   MISE: 0.047008 | Selected beta: 0.1\n","--- Run 96/100 (seed=95) ---\n","   MISE: 0.036180 | Selected beta: 0.4\n","--- Run 97/100 (seed=96) ---\n","   MISE: 0.084047 | Selected beta: 0.1\n","--- Run 98/100 (seed=97) ---\n","   MISE: 0.036567 | Selected beta: 0.1\n","--- Run 99/100 (seed=98) ---\n","   MISE: 0.069279 | Selected beta: 3.2\n","--- Run 100/100 (seed=99) ---\n","   MISE: 0.138972 | Selected beta: 25.6\n","\n","========================================\n","Simulation Summary (K=100)\n","========================================\n","Mean MISE : 0.062218\n","Std MISE  : 0.032922\n","SE MISE   : 0.003292\n","----------------------------------------\n","Mean selected beta : 2.445\n","Std selected beta  : 5.30413\n","SE selected beta   : 0.530413\n"]},{"name":"stderr","output_type":"stream","text":["/tmp/ipython-input-1735511343.py:22: PerformanceWarning: DataFrame is highly fragmented.  This is usually the result of calling `frame.insert` many times, which has poor performance.  Consider joining all columns at once using pd.concat(axis=1) instead. To get a de-fragmented frame, use `newframe = frame.copy()`\n","  df_out[f\"h_hat_seed_{seed}\"] = results[\"curves_mat\"][i, :]\n","/tmp/ipython-input-1735511343.py:24: PerformanceWarning: DataFrame is highly fragmented.  This is usually the result of calling `frame.insert` many times, which has poor performance.  Consider joining all columns at once using pd.concat(axis=1) instead. To get a de-fragmented frame, use `newframe = frame.copy()`\n","  df_out[\"mean_h_hat\"] = results[\"mean_curve\"]\n","/tmp/ipython-input-1735511343.py:25: PerformanceWarning: DataFrame is highly fragmented.  This is usually the result of calling `frame.insert` many times, which has poor performance.  Consider joining all columns at once using pd.concat(axis=1) instead. To get a de-fragmented frame, use `newframe = frame.copy()`\n","  df_out[\"se_h_hat\"] = results[\"se_curve\"]\n"]},{"name":"stdout","output_type":"stream","text":["\n","Results saved to: /content/drive/MyDrive/Colab Notebooks/CTE_Baseline/KRR_methods/Results/ours_synth_sample_1000_noise_1_seeds_0-99.csv\n"]}],"source":["# ======================================================================\n","# Main Execution\n","# ======================================================================\n","NOISE_STD = 1.0\n","N_SAMPLES = 1000\n","\n","if __name__ == \"__main__\":\n","    # 1. Run the simulation\n","    results = run_simulation(\n","        K=K_RUNS,\n","        noise_std=NOISE_STD,\n","        first_seed=FIRST_SEED\n","    )\n","\n","    # 3. Save Results to CSV\n","    output_csv_name = f\"{results['run_tag']}.csv\"\n","    output_path = RESULTS_DIR / output_csv_name\n","\n","    # Construct DataFrame\n","    df_out = pd.DataFrame({\"t\": results[\"t_grid\"]})\n","    for i, seed in enumerate(results[\"seeds\"]):\n","        df_out[f\"h_hat_seed_{seed}\"] = results[\"curves_mat\"][i, :]\n","\n","    df_out[\"mean_h_hat\"] = results[\"mean_curve\"]\n","    df_out[\"se_h_hat\"] = results[\"se_curve\"]\n","\n","    df_out.to_csv(output_path, index=False)\n","    print(f\"\\nResults saved to: {output_path}\")"]},{"cell_type":"code","execution_count":null,"metadata":{"colab":{"background_save":true,"base_uri":"https://localhost:8080/"},"id":"81XP0aqdqBZg"},"outputs":[{"name":"stdout","output_type":"stream","text":["Approximating ground truth h*(t)...\n","\n","🚀 Running 100 simulations for 'Ours' Joint-Kernel (noise_std=1.0)...\n","--- Run 1/100 (seed=0) ---\n","   MISE: 0.093234 | Selected beta: 0.8\n","--- Run 2/100 (seed=1) ---\n","   MISE: 0.099722 | Selected beta: 6.4\n","--- Run 3/100 (seed=2) ---\n","   MISE: 0.020565 | Selected beta: 0.1\n","--- Run 4/100 (seed=3) ---\n","   MISE: 0.057681 | Selected beta: 0.1\n","--- Run 5/100 (seed=4) ---\n","   MISE: 0.072101 | Selected beta: 0.2\n","--- Run 6/100 (seed=5) ---\n","   MISE: 0.060965 | Selected beta: 0.1\n","--- Run 7/100 (seed=6) ---\n","   MISE: 0.059579 | Selected beta: 0.1\n","--- Run 8/100 (seed=7) ---\n","   MISE: 0.049335 | Selected beta: 0.2\n","--- Run 9/100 (seed=8) ---\n","   MISE: 0.123438 | Selected beta: 6.4\n","--- Run 10/100 (seed=9) ---\n","   MISE: 0.124123 | Selected beta: 0.1\n","--- Run 11/100 (seed=10) ---\n","   MISE: 0.108606 | Selected beta: 6.4\n","--- Run 12/100 (seed=11) ---\n","   MISE: 0.108563 | Selected beta: 25.6\n","--- Run 13/100 (seed=12) ---\n","   MISE: 0.109561 | Selected beta: 12.8\n","--- Run 14/100 (seed=13) ---\n","   MISE: 0.069688 | Selected beta: 0.1\n","--- Run 15/100 (seed=14) ---\n","   MISE: 0.071141 | Selected beta: 0.1\n","--- Run 16/100 (seed=15) ---\n","   MISE: 0.109052 | Selected beta: 0.1\n","--- Run 17/100 (seed=16) ---\n","   MISE: 0.103823 | Selected beta: 1.6\n","--- Run 18/100 (seed=17) ---\n","   MISE: 0.083790 | Selected beta: 0.1\n","--- Run 19/100 (seed=18) ---\n","   MISE: 0.051372 | Selected beta: 0.8\n","--- Run 20/100 (seed=19) ---\n","   MISE: 0.054921 | Selected beta: 0.1\n","--- Run 21/100 (seed=20) ---\n","   MISE: 0.059626 | Selected beta: 0.2\n","--- Run 22/100 (seed=21) ---\n","   MISE: 0.083353 | Selected beta: 0.1\n","--- Run 23/100 (seed=22) ---\n","   MISE: 0.015275 | Selected beta: 0.8\n","--- Run 24/100 (seed=23) ---\n","   MISE: 0.067243 | Selected beta: 0.8\n","--- Run 25/100 (seed=24) ---\n","   MISE: 0.035123 | Selected beta: 0.2\n","--- Run 26/100 (seed=25) ---\n","   MISE: 0.154299 | Selected beta: 25.6\n","--- Run 27/100 (seed=26) ---\n","   MISE: 0.055002 | Selected beta: 1.6\n","--- Run 28/100 (seed=27) ---\n","   MISE: 0.131304 | Selected beta: 0.1\n","--- Run 29/100 (seed=28) ---\n","   MISE: 0.060638 | Selected beta: 0.1\n","--- Run 30/100 (seed=29) ---\n","   MISE: 0.130478 | Selected beta: 0.1\n","--- Run 31/100 (seed=30) ---\n","   MISE: 0.074966 | Selected beta: 0.1\n","--- Run 32/100 (seed=31) ---\n","   MISE: 0.038794 | Selected beta: 0.8\n","--- Run 33/100 (seed=32) ---\n","   MISE: 0.060044 | Selected beta: 6.4\n","--- Run 34/100 (seed=33) ---\n","   MISE: 0.147324 | Selected beta: 25.6\n","--- Run 35/100 (seed=34) ---\n","   MISE: 0.071930 | Selected beta: 0.1\n","--- Run 36/100 (seed=35) ---\n","   MISE: 0.079786 | Selected beta: 0.1\n","--- Run 37/100 (seed=36) ---\n","   MISE: 0.034134 | Selected beta: 0.1\n","--- Run 38/100 (seed=37) ---\n","   MISE: 0.086227 | Selected beta: 0.1\n","--- Run 39/100 (seed=38) ---\n","   MISE: 0.159095 | Selected beta: 12.8\n","--- Run 40/100 (seed=39) ---\n","   MISE: 0.111965 | Selected beta: 25.6\n","--- Run 41/100 (seed=40) ---\n","   MISE: 0.072180 | Selected beta: 0.1\n","--- Run 42/100 (seed=41) ---\n","   MISE: 0.097533 | Selected beta: 25.6\n","--- Run 43/100 (seed=42) ---\n","   MISE: 0.096347 | Selected beta: 1.6\n","--- Run 44/100 (seed=43) ---\n","   MISE: 0.098423 | Selected beta: 0.4\n","--- Run 45/100 (seed=44) ---\n","   MISE: 0.039378 | Selected beta: 0.1\n","--- Run 46/100 (seed=45) ---\n","   MISE: 0.185031 | Selected beta: 0.4\n","--- Run 47/100 (seed=46) ---\n","   MISE: 0.112973 | Selected beta: 6.4\n","--- Run 48/100 (seed=47) ---\n","   MISE: 0.133866 | Selected beta: 25.6\n","--- Run 49/100 (seed=48) ---\n","   MISE: 0.156048 | Selected beta: 25.6\n","--- Run 50/100 (seed=49) ---\n","   MISE: 0.045703 | Selected beta: 0.8\n","--- Run 51/100 (seed=50) ---\n","   MISE: 0.052048 | Selected beta: 0.2\n","--- Run 52/100 (seed=51) ---\n","   MISE: 0.075399 | Selected beta: 0.4\n","--- Run 53/100 (seed=52) ---\n","   MISE: 0.072217 | Selected beta: 6.4\n","--- Run 54/100 (seed=53) ---\n","   MISE: 0.087687 | Selected beta: 0.1\n","--- Run 55/100 (seed=54) ---\n","   MISE: 0.056935 | Selected beta: 1.6\n","--- Run 56/100 (seed=55) ---\n","   MISE: 0.072907 | Selected beta: 0.1\n","--- Run 57/100 (seed=56) ---\n","   MISE: 0.054040 | Selected beta: 0.1\n","--- Run 58/100 (seed=57) ---\n","   MISE: 0.120251 | Selected beta: 25.6\n","--- Run 59/100 (seed=58) ---\n","   MISE: 0.039164 | Selected beta: 0.8\n","--- Run 60/100 (seed=59) ---\n","   MISE: 0.089110 | Selected beta: 12.8\n","--- Run 61/100 (seed=60) ---\n","   MISE: 0.050903 | Selected beta: 1.6\n","--- Run 62/100 (seed=61) ---\n","   MISE: 0.053018 | Selected beta: 0.1\n","--- Run 63/100 (seed=62) ---\n","   MISE: 0.069146 | Selected beta: 0.8\n","--- Run 64/100 (seed=63) ---\n","   MISE: 0.126426 | Selected beta: 0.2\n","--- Run 65/100 (seed=64) ---\n","   MISE: 0.048291 | Selected beta: 3.2\n","--- Run 66/100 (seed=65) ---\n","   MISE: 0.094452 | Selected beta: 1.6\n","--- Run 67/100 (seed=66) ---\n","   MISE: 0.138339 | Selected beta: 0.1\n","--- Run 68/100 (seed=67) ---\n","   MISE: 0.090389 | Selected beta: 1.6\n","--- Run 69/100 (seed=68) ---\n","   MISE: 0.024730 | Selected beta: 0.1\n","--- Run 70/100 (seed=69) ---\n","   MISE: 0.062247 | Selected beta: 1.6\n","--- Run 71/100 (seed=70) ---\n","   MISE: 0.158272 | Selected beta: 0.2\n","--- Run 72/100 (seed=71) ---\n","   MISE: 0.053618 | Selected beta: 0.1\n","--- Run 73/100 (seed=72) ---\n","   MISE: 0.142125 | Selected beta: 25.6\n","--- Run 74/100 (seed=73) ---\n","   MISE: 0.130403 | Selected beta: 0.1\n","--- Run 75/100 (seed=74) ---\n","   MISE: 0.049564 | Selected beta: 0.1\n","--- Run 76/100 (seed=75) ---\n","   MISE: 0.042146 | Selected beta: 0.1\n","--- Run 77/100 (seed=76) ---\n","   MISE: 0.114381 | Selected beta: 0.1\n","--- Run 78/100 (seed=77) ---\n","   MISE: 0.047765 | Selected beta: 3.2\n","--- Run 79/100 (seed=78) ---\n","   MISE: 0.067690 | Selected beta: 0.2\n","--- Run 80/100 (seed=79) ---\n","   MISE: 0.076118 | Selected beta: 1.6\n","--- Run 81/100 (seed=80) ---\n","   MISE: 0.136732 | Selected beta: 25.6\n","--- Run 82/100 (seed=81) ---\n","   MISE: 0.112766 | Selected beta: 12.8\n","--- Run 83/100 (seed=82) ---\n","   MISE: 0.083056 | Selected beta: 12.8\n","--- Run 84/100 (seed=83) ---\n","   MISE: 0.065569 | Selected beta: 0.4\n","--- Run 85/100 (seed=84) ---\n","   MISE: 0.044233 | Selected beta: 3.2\n","--- Run 86/100 (seed=85) ---\n","   MISE: 0.077181 | Selected beta: 3.2\n","--- Run 87/100 (seed=86) ---\n","   MISE: 0.116517 | Selected beta: 25.6\n","--- Run 88/100 (seed=87) ---\n","   MISE: 0.096656 | Selected beta: 0.1\n","--- Run 89/100 (seed=88) ---\n","   MISE: 0.068464 | Selected beta: 0.1\n","--- Run 90/100 (seed=89) ---\n","   MISE: 0.099502 | Selected beta: 0.1\n","--- Run 91/100 (seed=90) ---\n","   MISE: 0.067585 | Selected beta: 3.2\n","--- Run 92/100 (seed=91) ---\n","   MISE: 0.095248 | Selected beta: 6.4\n","--- Run 93/100 (seed=92) ---\n","   MISE: 0.092244 | Selected beta: 0.4\n","--- Run 94/100 (seed=93) ---\n","   MISE: 0.098444 | Selected beta: 6.4\n","--- Run 95/100 (seed=94) ---\n","   MISE: 0.040496 | Selected beta: 0.1\n","--- Run 96/100 (seed=95) ---\n","   MISE: 0.066508 | Selected beta: 0.1\n","--- Run 97/100 (seed=96) ---\n","   MISE: 0.063325 | Selected beta: 0.8\n","--- Run 98/100 (seed=97) ---\n","   MISE: 0.105102 | Selected beta: 25.6\n","--- Run 99/100 (seed=98) ---\n","   MISE: 0.071878 | Selected beta: 0.1\n","--- Run 100/100 (seed=99) ---\n","   MISE: 0.120167 | Selected beta: 0.1\n","\n","========================================\n","Simulation Summary (K=100)\n","========================================\n","Mean MISE : 0.084068\n","Std MISE  : 0.035097\n","SE MISE   : 0.003510\n","----------------------------------------\n","Mean selected beta : 4.675\n","Std selected beta  : 8.35279\n","SE selected beta   : 0.835279\n"]},{"name":"stderr","output_type":"stream","text":["/tmp/ipython-input-823062166.py:22: PerformanceWarning: DataFrame is highly fragmented.  This is usually the result of calling `frame.insert` many times, which has poor performance.  Consider joining all columns at once using pd.concat(axis=1) instead. To get a de-fragmented frame, use `newframe = frame.copy()`\n","  df_out[f\"h_hat_seed_{seed}\"] = results[\"curves_mat\"][i, :]\n","/tmp/ipython-input-823062166.py:24: PerformanceWarning: DataFrame is highly fragmented.  This is usually the result of calling `frame.insert` many times, which has poor performance.  Consider joining all columns at once using pd.concat(axis=1) instead. To get a de-fragmented frame, use `newframe = frame.copy()`\n","  df_out[\"mean_h_hat\"] = results[\"mean_curve\"]\n","/tmp/ipython-input-823062166.py:25: PerformanceWarning: DataFrame is highly fragmented.  This is usually the result of calling `frame.insert` many times, which has poor performance.  Consider joining all columns at once using pd.concat(axis=1) instead. To get a de-fragmented frame, use `newframe = frame.copy()`\n","  df_out[\"se_h_hat\"] = results[\"se_curve\"]\n"]},{"name":"stdout","output_type":"stream","text":["\n","Results saved to: /content/drive/MyDrive/Colab Notebooks/CTE_Baseline/KRR_methods/Results/ours_synth_sample_500_noise_1_seeds_0-99.csv\n"]}],"source":["# ======================================================================\n","# Main Execution\n","# ======================================================================\n","NOISE_STD = 1.0\n","N_SAMPLES =  500\n","\n","if __name__ == \"__main__\":\n","    # 1. Run the simulation\n","    results = run_simulation(\n","        K=K_RUNS,\n","        noise_std=NOISE_STD,\n","        first_seed=FIRST_SEED\n","    )\n","\n","    # 3. Save Results to CSV\n","    output_csv_name = f\"{results['run_tag']}.csv\"\n","    output_path = RESULTS_DIR / output_csv_name\n","\n","    # Construct DataFrame\n","    df_out = pd.DataFrame({\"t\": results[\"t_grid\"]})\n","    for i, seed in enumerate(results[\"seeds\"]):\n","        df_out[f\"h_hat_seed_{seed}\"] = results[\"curves_mat\"][i, :]\n","\n","    df_out[\"mean_h_hat\"] = results[\"mean_curve\"]\n","    df_out[\"se_h_hat\"] = results[\"se_curve\"]\n","\n","    df_out.to_csv(output_path, index=False)\n","    print(f\"\\nResults saved to: {output_path}\")"]}],"metadata":{"colab":{"authorship_tag":"ABX9TyNJW89WzHBvUDhw6N+QyPVb","name":"","version":""},"kernelspec":{"display_name":"Python 3","name":"python3"},"language_info":{"name":"python"}},"nbformat":4,"nbformat_minor":0}