{"cells":[{"cell_type":"markdown","metadata":{"id":"qGtu4Hlw6_P2"},"source":["# Synthetic Data Experiment (Plug-in Estimator)\n","\n","This notebook runs the synthetic data experiment using the **Joint Kernel Plug-in Estimator**.\n","\n","**Workflow:**\n","1. **Data Generation:** Generates synthetic data $(X, T, Y)$ based on the defined DGP (Data Generating Process).\n","2. **Simulation:** Runs $K$ Monte Carlo repetitions. In each run:\n","   - Generates a fresh dataset using a specific seed.\n","   - Fits the joint plug-in estimator on the $(X, T)$ space.\n","   - Evaluates the estimated dose-response curve $\\hat{h}(t)$ on a grid.\n","   - Computes the Mean Integrated Squared Error (MISE) against the true $h^*(t)$.\n","3. **Evaluation:** Aggregates results, plots the estimated curve against the ground truth, and saves the results."]},{"cell_type":"code","execution_count":35,"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"executionInfo":{"elapsed":660,"status":"ok","timestamp":1769544086361,"user":{"displayName":"D K","userId":"02556183042422178006"},"user_tz":300},"id":"n11P9Ebu6-A7","outputId":"580797f2-59f2-492c-9a03-4b41c300f414"},"outputs":[{"output_type":"stream","name":"stdout","text":["Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount(\"/content/drive\", force_remount=True).\n","Working Directory: /content/drive/MyDrive/Colab Notebooks/CTE_Baseline\n"]}],"source":["import sys\n","import pathlib\n","import numpy as np\n","import pandas as pd\n","import matplotlib.pyplot as plt\n","\n","# --- Environment Setup ---\n","try:\n","    from google.colab import drive\n","    drive.mount('/content/drive')\n","    BASE_DIR = pathlib.Path(\"/content/drive/MyDrive/Colab Notebooks/CTE_Baseline\")\n","except ImportError:\n","    BASE_DIR = pathlib.Path(\".\").resolve()\n","\n","# Add project root to path\n","sys.path.append(str(BASE_DIR))\n","sys.path.append(str(BASE_DIR / \"KRR_methods\")) # For synthetic_dgps import if needed\n","\n","# Project imports\n","from KRR_methods.synthetic_dgps import (\n","    generate_unified_data,\n","    split_data,\n","    approximate_h_star,\n",")\n","from KRR_methods.algorithms.estimators_plugin import run_plugin_on_original_grid_joint\n","\n","print(f\"Working Directory: {BASE_DIR}\")"]},{"cell_type":"code","execution_count":36,"metadata":{"executionInfo":{"elapsed":11,"status":"ok","timestamp":1769544086361,"user":{"displayName":"D K","userId":"02556183042422178006"},"user_tz":300},"id":"De3Nke4s7BVS"},"outputs":[],"source":["# ======================================================================\n","# Configuration & Hyperparameters\n","# ======================================================================\n","\n","# --- Paths ---\n","RESULTS_DIR = BASE_DIR / \"KRR_methods\" / \"Results\"\n","\n","# --- Data Generation Parameters ---\n","D_X = 10\n","NOISE_STD = 1.0\n","N_SAMPLES = 1000\n","N_SIMULATIONS = 100\n","FIRST_SEED = 1  # As per original script\n","\n","# --- Estimator Hyperparameters (Joint Kernel) ---\n","# Grid for Ridge Parameter (beta)\n","C_VAL = 0.1\n","BETA_GRID = np.array([C_VAL * (2**i) for i in range(0, 9)], dtype=float)\n","\n","# Kernel Parameters (Matern)\n","LENGTH_SCALE_ND = 3.0\n","NU_ND = 1.5\n","KERNEL_TYPE_ND = \"matern\"\n","\n","# Evaluation Grid\n","T_GRID = np.linspace(-np.pi, np.pi, 1000)"]},{"cell_type":"code","execution_count":37,"metadata":{"executionInfo":{"elapsed":6,"status":"ok","timestamp":1769544086365,"user":{"displayName":"D K","userId":"02556183042422178006"},"user_tz":300},"id":"sWXnEUy_7EIx"},"outputs":[],"source":["# ======================================================================\n","# Helper Functions\n","# ======================================================================\n","\n","def run_single_plugin_joint(X, T, Y, t_grid):\n","    \"\"\"\n","    Runs a single instance of the joint plug-in estimator.\n","\n","    1. Splits data into training/validation (D1/D2).\n","    2. Runs the joint kernel estimator (hold-out validation on D2).\n","    3. Evaluates the resulting function on the t_grid.\n","    4. Returns both h_hat(t_grid) and the selected beta from validation.\n","    \"\"\"\n","    # 1. Split data\n","    D1, D2 = split_data(X, T, Y)\n","\n","    # 2. Fit Joint Plug-in Estimator (joint kernel on (X, T))\n","    # This function selects the best beta by validation MSE on D2.\n","    out = run_plugin_on_original_grid_joint(\n","        D1=D1,\n","        D2=D2,\n","        beta_grid_for_f=BETA_GRID,\n","        nu_nd=NU_ND,\n","        length_scale_nd=LENGTH_SCALE_ND,\n","        kernel_type_nd=KERNEL_TYPE_ND,\n","    )\n","\n","    # 3. Evaluate on grid\n","    h_hat_vals = out[\"h_hat_joint\"](t_grid)\n","\n","    # 4. Extract selected beta\n","    beta_sel = float(out[\"best_beta\"])\n","\n","    return h_hat_vals, beta_sel, out\n"]},{"cell_type":"code","execution_count":38,"metadata":{"executionInfo":{"elapsed":86,"status":"ok","timestamp":1769544086452,"user":{"displayName":"D K","userId":"02556183042422178006"},"user_tz":300},"id":"l1EpGzEU7GDw"},"outputs":[],"source":["# ======================================================================\n","# Simulation Logic\n","# ======================================================================\n","\n","def run_simulation(K=100, noise_std=1.0, first_seed=100, prefix=\"plugin_synth\"):\n","    \"\"\"\n","    Runs the joint-kernel plug-in KRR simulation K times.\n","\n","    Additionally:\n","      - Stores the beta selected by hold-out validation on D2 for each run\n","      - Prints the mean / std / SE of the selected beta at the end\n","    \"\"\"\n","    RESULTS_DIR.mkdir(parents=True, exist_ok=True)\n","\n","    # 1. Compute True Curve h*(t) once\n","    h_star_vals = approximate_h_star(T_GRID)\n","\n","    mise_list = []\n","    curves_list = []\n","    beta_list = []\n","    seeds = []\n","\n","    print(f\"🚀 Running {K} simulations for joint plug-in KRR (noise_std={noise_std})...\")\n","\n","    for k in range(K):\n","        seed = first_seed + k\n","        seeds.append(seed)\n","\n","        # [Critical] Set Global Numpy Seed for reproducibility\n","        # This ensures exactly the same data generation and internal estimator behavior per seed\n","        np.random.seed(seed)\n","\n","        print(f\"--- Run {k + 1}/{K} (seed={seed}) ---\")\n","\n","        # 2. Generate Synthetic Data\n","        X, T, Y = generate_unified_data(N_SAMPLES, noise_std)\n","\n","        # 3. Fit Estimator (now also returns beta_sel)\n","        h_hat_vals, beta_sel, _ = run_single_plugin_joint(X, T, Y, T_GRID)\n","        curves_list.append(h_hat_vals)\n","        beta_list.append(beta_sel)\n","\n","        # 4. Compute MISE\n","        mise = np.mean((h_hat_vals - h_star_vals) ** 2)\n","        mise_list.append(mise)\n","\n","        print(f\"   MISE: {mise:.6f} | Selected beta: {beta_sel:.6g}\")\n","\n","    # --- Aggregation & Statistics ---\n","    mise_arr = np.array(mise_list, dtype=float)\n","    curves_mat = np.vstack(curves_list)  # Shape: (K, len(T_GRID))\n","\n","    mean_curve = curves_mat.mean(axis=0)\n","    std_curve = curves_mat.std(axis=0, ddof=1)\n","    se_curve = std_curve / np.sqrt(K)\n","\n","    beta_arr = np.array(beta_list, dtype=float)\n","    beta_mean = float(beta_arr.mean())\n","    beta_std = float(beta_arr.std(ddof=1))\n","    beta_se = float(beta_std / np.sqrt(K))\n","\n","    print(\"\\n\" + \"=\" * 30)\n","    print(f\"Simulation Summary (K={K})\")\n","    print(\"=\" * 30)\n","    print(f\"Mean MISE : {mise_arr.mean():.6f}\")\n","    print(f\"Std MISE  : {mise_arr.std(ddof=1):.6f}\")\n","    print(f\"SE MISE   : {mise_arr.std(ddof=1) / np.sqrt(K):.6f}\")\n","\n","    print(\"-\" * 30)\n","    print(f\"Mean selected beta : {beta_mean:.6g}\")\n","    print(f\"Std selected beta  : {beta_std:.6g}\")\n","    print(f\"SE selected beta   : {beta_se:.6g}\")\n","\n","    return {\n","        \"seeds\": seeds,\n","        \"t_grid\": T_GRID,\n","        \"h_star\": h_star_vals,\n","        \"mise_all\": mise_arr,\n","        \"curves_mat\": curves_mat,\n","        \"mean_curve\": mean_curve,\n","        \"se_curve\": se_curve,\n","        \"beta_selected_all\": beta_arr,\n","        \"beta_selected_mean\": beta_mean,\n","        \"beta_selected_se\": beta_se,\n","        \"run_tag\": f\"{prefix}_sample_{N_SAMPLES}_noise_{noise_std}_seeds_{seeds[0]}-{seeds[-1]}\",\n","    }\n"]},{"cell_type":"markdown","metadata":{"id":"BfwSMlBu3iem"},"source":["# n =1000"]},{"cell_type":"code","execution_count":39,"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"hXkOei7Q7Jcw","executionInfo":{"status":"ok","timestamp":1769546535242,"user_tz":300,"elapsed":2448788,"user":{"displayName":"D K","userId":"02556183042422178006"}},"outputId":"4e9d4e50-5651-4fbb-e9b0-fa018bbabd22"},"outputs":[{"output_type":"stream","name":"stdout","text":["🚀 Running 100 simulations for joint plug-in KRR (noise_std=1.0)...\n","--- Run 1/100 (seed=1) ---\n","   MISE: 0.139469 | Selected beta: 1.6\n","--- Run 2/100 (seed=2) ---\n","   MISE: 0.068129 | Selected beta: 1.6\n","--- Run 3/100 (seed=3) ---\n","   MISE: 0.091892 | Selected beta: 0.8\n","--- Run 4/100 (seed=4) ---\n","   MISE: 0.084461 | Selected beta: 1.6\n","--- Run 5/100 (seed=5) ---\n","   MISE: 0.075045 | Selected beta: 1.6\n","--- Run 6/100 (seed=6) ---\n","   MISE: 0.162954 | Selected beta: 6.4\n","--- Run 7/100 (seed=7) ---\n","   MISE: 0.129424 | Selected beta: 3.2\n","--- Run 8/100 (seed=8) ---\n","   MISE: 0.095975 | Selected beta: 3.2\n","--- Run 9/100 (seed=9) ---\n","   MISE: 0.024441 | Selected beta: 0.8\n","--- Run 10/100 (seed=10) ---\n","   MISE: 0.150849 | Selected beta: 3.2\n","--- Run 11/100 (seed=11) ---\n","   MISE: 0.072525 | Selected beta: 1.6\n","--- Run 12/100 (seed=12) ---\n","   MISE: 0.162478 | Selected beta: 12.8\n","--- Run 13/100 (seed=13) ---\n","   MISE: 0.091574 | Selected beta: 1.6\n","--- Run 14/100 (seed=14) ---\n","   MISE: 0.103608 | Selected beta: 0.8\n","--- Run 15/100 (seed=15) ---\n","   MISE: 0.101459 | Selected beta: 1.6\n","--- Run 16/100 (seed=16) ---\n","   MISE: 0.071132 | Selected beta: 1.6\n","--- Run 17/100 (seed=17) ---\n","   MISE: 0.064951 | Selected beta: 0.8\n","--- Run 18/100 (seed=18) ---\n","   MISE: 0.045375 | Selected beta: 0.8\n","--- Run 19/100 (seed=19) ---\n","   MISE: 0.106401 | Selected beta: 1.6\n","--- Run 20/100 (seed=20) ---\n","   MISE: 0.133123 | Selected beta: 6.4\n","--- Run 21/100 (seed=21) ---\n","   MISE: 0.121685 | Selected beta: 6.4\n","--- Run 22/100 (seed=22) ---\n","   MISE: 0.103533 | Selected beta: 1.6\n","--- Run 23/100 (seed=23) ---\n","   MISE: 0.105222 | Selected beta: 3.2\n","--- Run 24/100 (seed=24) ---\n","   MISE: 0.064475 | Selected beta: 0.8\n","--- Run 25/100 (seed=25) ---\n","   MISE: 0.135552 | Selected beta: 1.6\n","--- Run 26/100 (seed=26) ---\n","   MISE: 0.092723 | Selected beta: 1.6\n","--- Run 27/100 (seed=27) ---\n","   MISE: 0.056712 | Selected beta: 1.6\n","--- Run 28/100 (seed=28) ---\n","   MISE: 0.059463 | Selected beta: 0.8\n","--- Run 29/100 (seed=29) ---\n","   MISE: 0.133210 | Selected beta: 1.6\n","--- Run 30/100 (seed=30) ---\n","   MISE: 0.134021 | Selected beta: 6.4\n","--- Run 31/100 (seed=31) ---\n","   MISE: 0.115904 | Selected beta: 1.6\n","--- Run 32/100 (seed=32) ---\n","   MISE: 0.072761 | Selected beta: 1.6\n","--- Run 33/100 (seed=33) ---\n","   MISE: 0.084051 | Selected beta: 1.6\n","--- Run 34/100 (seed=34) ---\n","   MISE: 0.101830 | Selected beta: 1.6\n","--- Run 35/100 (seed=35) ---\n","   MISE: 0.097914 | Selected beta: 1.6\n","--- Run 36/100 (seed=36) ---\n","   MISE: 0.151528 | Selected beta: 3.2\n","--- Run 37/100 (seed=37) ---\n","   MISE: 0.183283 | Selected beta: 0.8\n","--- Run 38/100 (seed=38) ---\n","   MISE: 0.057395 | Selected beta: 0.8\n","--- Run 39/100 (seed=39) ---\n","   MISE: 0.110656 | Selected beta: 1.6\n","--- Run 40/100 (seed=40) ---\n","   MISE: 0.112257 | Selected beta: 1.6\n","--- Run 41/100 (seed=41) ---\n","   MISE: 0.055491 | Selected beta: 0.4\n","--- Run 42/100 (seed=42) ---\n","   MISE: 0.071217 | Selected beta: 1.6\n","--- Run 43/100 (seed=43) ---\n","   MISE: 0.109936 | Selected beta: 3.2\n","--- Run 44/100 (seed=44) ---\n","   MISE: 0.106591 | Selected beta: 1.6\n","--- Run 45/100 (seed=45) ---\n","   MISE: 0.129154 | Selected beta: 3.2\n","--- Run 46/100 (seed=46) ---\n","   MISE: 0.067174 | Selected beta: 1.6\n","--- Run 47/100 (seed=47) ---\n","   MISE: 0.097756 | Selected beta: 3.2\n","--- Run 48/100 (seed=48) ---\n","   MISE: 0.073957 | Selected beta: 0.8\n","--- Run 49/100 (seed=49) ---\n","   MISE: 0.097732 | Selected beta: 0.8\n","--- Run 50/100 (seed=50) ---\n","   MISE: 0.082843 | Selected beta: 0.8\n","--- Run 51/100 (seed=51) ---\n","   MISE: 0.094602 | Selected beta: 1.6\n","--- Run 52/100 (seed=52) ---\n","   MISE: 0.081546 | Selected beta: 1.6\n","--- Run 53/100 (seed=53) ---\n","   MISE: 0.073584 | Selected beta: 3.2\n","--- Run 54/100 (seed=54) ---\n","   MISE: 0.051184 | Selected beta: 0.8\n","--- Run 55/100 (seed=55) ---\n","   MISE: 0.075457 | Selected beta: 0.8\n","--- Run 56/100 (seed=56) ---\n","   MISE: 0.090961 | Selected beta: 3.2\n","--- Run 57/100 (seed=57) ---\n","   MISE: 0.164410 | Selected beta: 12.8\n","--- Run 58/100 (seed=58) ---\n","   MISE: 0.150470 | Selected beta: 6.4\n","--- Run 59/100 (seed=59) ---\n","   MISE: 0.096175 | Selected beta: 3.2\n","--- Run 60/100 (seed=60) ---\n","   MISE: 0.167175 | Selected beta: 6.4\n","--- Run 61/100 (seed=61) ---\n","   MISE: 0.125799 | Selected beta: 6.4\n","--- Run 62/100 (seed=62) ---\n","   MISE: 0.107906 | Selected beta: 3.2\n","--- Run 63/100 (seed=63) ---\n","   MISE: 0.115807 | Selected beta: 1.6\n","--- Run 64/100 (seed=64) ---\n","   MISE: 0.140800 | Selected beta: 3.2\n","--- Run 65/100 (seed=65) ---\n","   MISE: 0.111923 | Selected beta: 1.6\n","--- Run 66/100 (seed=66) ---\n","   MISE: 0.050067 | Selected beta: 0.4\n","--- Run 67/100 (seed=67) ---\n","   MISE: 0.078685 | Selected beta: 0.8\n","--- Run 68/100 (seed=68) ---\n","   MISE: 0.078533 | Selected beta: 1.6\n","--- Run 69/100 (seed=69) ---\n","   MISE: 0.090831 | Selected beta: 1.6\n","--- Run 70/100 (seed=70) ---\n","   MISE: 0.083266 | Selected beta: 1.6\n","--- Run 71/100 (seed=71) ---\n","   MISE: 0.143699 | Selected beta: 1.6\n","--- Run 72/100 (seed=72) ---\n","   MISE: 0.109768 | Selected beta: 1.6\n","--- Run 73/100 (seed=73) ---\n","   MISE: 0.074936 | Selected beta: 0.8\n","--- Run 74/100 (seed=74) ---\n","   MISE: 0.031919 | Selected beta: 0.4\n","--- Run 75/100 (seed=75) ---\n","   MISE: 0.105802 | Selected beta: 1.6\n","--- Run 76/100 (seed=76) ---\n","   MISE: 0.072554 | Selected beta: 0.8\n","--- Run 77/100 (seed=77) ---\n","   MISE: 0.093493 | Selected beta: 0.8\n","--- Run 78/100 (seed=78) ---\n","   MISE: 0.072803 | Selected beta: 1.6\n","--- Run 79/100 (seed=79) ---\n","   MISE: 0.081619 | Selected beta: 1.6\n","--- Run 80/100 (seed=80) ---\n","   MISE: 0.096917 | Selected beta: 3.2\n","--- Run 81/100 (seed=81) ---\n","   MISE: 0.094732 | Selected beta: 1.6\n","--- Run 82/100 (seed=82) ---\n","   MISE: 0.098326 | Selected beta: 0.8\n","--- Run 83/100 (seed=83) ---\n","   MISE: 0.103640 | Selected beta: 1.6\n","--- Run 84/100 (seed=84) ---\n","   MISE: 0.102099 | Selected beta: 0.8\n","--- Run 85/100 (seed=85) ---\n","   MISE: 0.067886 | Selected beta: 1.6\n","--- Run 86/100 (seed=86) ---\n","   MISE: 0.053090 | Selected beta: 0.8\n","--- Run 87/100 (seed=87) ---\n","   MISE: 0.098398 | Selected beta: 1.6\n","--- Run 88/100 (seed=88) ---\n","   MISE: 0.070665 | Selected beta: 0.8\n","--- Run 89/100 (seed=89) ---\n","   MISE: 0.110392 | Selected beta: 0.8\n","--- Run 90/100 (seed=90) ---\n","   MISE: 0.137140 | Selected beta: 3.2\n","--- Run 91/100 (seed=91) ---\n","   MISE: 0.135920 | Selected beta: 3.2\n","--- Run 92/100 (seed=92) ---\n","   MISE: 0.137389 | Selected beta: 3.2\n","--- Run 93/100 (seed=93) ---\n","   MISE: 0.086698 | Selected beta: 1.6\n","--- Run 94/100 (seed=94) ---\n","   MISE: 0.084771 | Selected beta: 1.6\n","--- Run 95/100 (seed=95) ---\n","   MISE: 0.067219 | Selected beta: 0.8\n","--- Run 96/100 (seed=96) ---\n","   MISE: 0.137892 | Selected beta: 3.2\n","--- Run 97/100 (seed=97) ---\n","   MISE: 0.060564 | Selected beta: 0.8\n","--- Run 98/100 (seed=98) ---\n","   MISE: 0.095583 | Selected beta: 1.6\n","--- Run 99/100 (seed=99) ---\n","   MISE: 0.140482 | Selected beta: 3.2\n","--- Run 100/100 (seed=100) ---\n","   MISE: 0.118376 | Selected beta: 3.2\n","\n","==============================\n","Simulation Summary (K=100)\n","==============================\n","Mean MISE : 0.098752\n","Std MISE  : 0.031934\n","SE MISE   : 0.003193\n","------------------------------\n","Mean selected beta : 2.244\n","Std selected beta  : 2.11346\n","SE selected beta   : 0.211346\n"]},{"output_type":"stream","name":"stderr","text":["/tmp/ipython-input-1426055032.py:20: PerformanceWarning: DataFrame is highly fragmented.  This is usually the result of calling `frame.insert` many times, which has poor performance.  Consider joining all columns at once using pd.concat(axis=1) instead. To get a de-fragmented frame, use `newframe = frame.copy()`\n","  df_out[f\"h_hat_seed_{seed}\"] = results[\"curves_mat\"][i, :]\n","/tmp/ipython-input-1426055032.py:22: PerformanceWarning: DataFrame is highly fragmented.  This is usually the result of calling `frame.insert` many times, which has poor performance.  Consider joining all columns at once using pd.concat(axis=1) instead. To get a de-fragmented frame, use `newframe = frame.copy()`\n","  df_out[\"mean_h_hat\"] = results[\"mean_curve\"]\n","/tmp/ipython-input-1426055032.py:23: PerformanceWarning: DataFrame is highly fragmented.  This is usually the result of calling `frame.insert` many times, which has poor performance.  Consider joining all columns at once using pd.concat(axis=1) instead. To get a de-fragmented frame, use `newframe = frame.copy()`\n","  df_out[\"se_h_hat\"] = results[\"se_curve\"]\n"]},{"output_type":"stream","name":"stdout","text":["\n","Results saved to: /content/drive/MyDrive/Colab Notebooks/CTE_Baseline/KRR_methods/Results/plugin_synth_sample_1000_noise_1.0_seeds_1-100.csv\n"]}],"source":["# ======================================================================\n","# Main Execution\n","# ======================================================================\n","NOISE_STD = 1.0\n","N_SAMPLES = 1000\n","# 1. Run the simulation\n","results = run_simulation(\n","    K=N_SIMULATIONS,\n","    noise_std=NOISE_STD,\n","    first_seed=FIRST_SEED\n",")\n","\n","# 3. Save Results to CSV\n","output_csv_name = f\"{results['run_tag']}.csv\"\n","output_path = RESULTS_DIR / output_csv_name\n","\n","# Construct DataFrame\n","df_out = pd.DataFrame({\"t\": T_GRID})\n","for i, seed in enumerate(results[\"seeds\"]):\n","    df_out[f\"h_hat_seed_{seed}\"] = results[\"curves_mat\"][i, :]\n","\n","df_out[\"mean_h_hat\"] = results[\"mean_curve\"]\n","df_out[\"se_h_hat\"] = results[\"se_curve\"]\n","\n","df_out.to_csv(output_path, index=False)\n","print(f\"\\nResults saved to: {output_path}\")"]},{"cell_type":"markdown","metadata":{"id":"d-NkGETi3XIu"},"source":["# n =500"]},{"cell_type":"code","execution_count":40,"metadata":{"id":"syILRUX03em4","executionInfo":{"status":"ok","timestamp":1769546535245,"user_tz":300,"elapsed":1,"user":{"displayName":"D K","userId":"02556183042422178006"}}},"outputs":[],"source":["NOISE_STD = 1.0\n","N_SAMPLES =  500"]},{"cell_type":"code","execution_count":41,"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"_uOkADDI3XYe","executionInfo":{"status":"ok","timestamp":1769547123303,"user_tz":300,"elapsed":588056,"user":{"displayName":"D K","userId":"02556183042422178006"}},"outputId":"10cf0279-c72d-489f-dc00-0d27681c3b51"},"outputs":[{"output_type":"stream","name":"stdout","text":["🚀 Running 100 simulations for joint plug-in KRR (noise_std=1.0)...\n","--- Run 1/100 (seed=1) ---\n","   MISE: 0.130149 | Selected beta: 6.4\n","--- Run 2/100 (seed=2) ---\n","   MISE: 0.081440 | Selected beta: 1.6\n","--- Run 3/100 (seed=3) ---\n","   MISE: 0.131045 | Selected beta: 6.4\n","--- Run 4/100 (seed=4) ---\n","   MISE: 0.111593 | Selected beta: 0.8\n","--- Run 5/100 (seed=5) ---\n","   MISE: 0.134249 | Selected beta: 3.2\n","--- Run 6/100 (seed=6) ---\n","   MISE: 0.084201 | Selected beta: 0.8\n","--- Run 7/100 (seed=7) ---\n","   MISE: 0.097417 | Selected beta: 0.8\n","--- Run 8/100 (seed=8) ---\n","   MISE: 0.143146 | Selected beta: 3.2\n","--- Run 9/100 (seed=9) ---\n","   MISE: 0.150465 | Selected beta: 6.4\n","--- Run 10/100 (seed=10) ---\n","   MISE: 0.146979 | Selected beta: 1.6\n","--- Run 11/100 (seed=11) ---\n","   MISE: 0.126138 | Selected beta: 6.4\n","--- Run 12/100 (seed=12) ---\n","   MISE: 0.124369 | Selected beta: 6.4\n","--- Run 13/100 (seed=13) ---\n","   MISE: 0.137668 | Selected beta: 1.6\n","--- Run 14/100 (seed=14) ---\n","   MISE: 0.120615 | Selected beta: 1.6\n","--- Run 15/100 (seed=15) ---\n","   MISE: 0.158270 | Selected beta: 0.8\n","--- Run 16/100 (seed=16) ---\n","   MISE: 0.121388 | Selected beta: 3.2\n","--- Run 17/100 (seed=17) ---\n","   MISE: 0.135812 | Selected beta: 3.2\n","--- Run 18/100 (seed=18) ---\n","   MISE: 0.135180 | Selected beta: 3.2\n","--- Run 19/100 (seed=19) ---\n","   MISE: 0.119272 | Selected beta: 3.2\n","--- Run 20/100 (seed=20) ---\n","   MISE: 0.130205 | Selected beta: 3.2\n","--- Run 21/100 (seed=21) ---\n","   MISE: 0.106999 | Selected beta: 0.4\n","--- Run 22/100 (seed=22) ---\n","   MISE: 0.051988 | Selected beta: 0.8\n","--- Run 23/100 (seed=23) ---\n","   MISE: 0.144086 | Selected beta: 3.2\n","--- Run 24/100 (seed=24) ---\n","   MISE: 0.099920 | Selected beta: 1.6\n","--- Run 25/100 (seed=25) ---\n","   MISE: 0.165407 | Selected beta: 3.2\n","--- Run 26/100 (seed=26) ---\n","   MISE: 0.133211 | Selected beta: 12.8\n","--- Run 27/100 (seed=27) ---\n","   MISE: 0.141327 | Selected beta: 3.2\n","--- Run 28/100 (seed=28) ---\n","   MISE: 0.106153 | Selected beta: 0.8\n","--- Run 29/100 (seed=29) ---\n","   MISE: 0.154784 | Selected beta: 0.8\n","--- Run 30/100 (seed=30) ---\n","   MISE: 0.100876 | Selected beta: 0.8\n","--- Run 31/100 (seed=31) ---\n","   MISE: 0.097593 | Selected beta: 1.6\n","--- Run 32/100 (seed=32) ---\n","   MISE: 0.058522 | Selected beta: 0.8\n","--- Run 33/100 (seed=33) ---\n","   MISE: 0.138806 | Selected beta: 3.2\n","--- Run 34/100 (seed=34) ---\n","   MISE: 0.092522 | Selected beta: 0.8\n","--- Run 35/100 (seed=35) ---\n","   MISE: 0.134148 | Selected beta: 6.4\n","--- Run 36/100 (seed=36) ---\n","   MISE: 0.056906 | Selected beta: 0.4\n","--- Run 37/100 (seed=37) ---\n","   MISE: 0.125337 | Selected beta: 3.2\n","--- Run 38/100 (seed=38) ---\n","   MISE: 0.161244 | Selected beta: 12.8\n","--- Run 39/100 (seed=39) ---\n","   MISE: 0.102676 | Selected beta: 0.8\n","--- Run 40/100 (seed=40) ---\n","   MISE: 0.133965 | Selected beta: 6.4\n","--- Run 41/100 (seed=41) ---\n","   MISE: 0.128366 | Selected beta: 6.4\n","--- Run 42/100 (seed=42) ---\n","   MISE: 0.140199 | Selected beta: 6.4\n","--- Run 43/100 (seed=43) ---\n","   MISE: 0.148051 | Selected beta: 3.2\n","--- Run 44/100 (seed=44) ---\n","   MISE: 0.059889 | Selected beta: 0.4\n","--- Run 45/100 (seed=45) ---\n","   MISE: 0.194554 | Selected beta: 3.2\n","--- Run 46/100 (seed=46) ---\n","   MISE: 0.144135 | Selected beta: 6.4\n","--- Run 47/100 (seed=47) ---\n","   MISE: 0.145258 | Selected beta: 3.2\n","--- Run 48/100 (seed=48) ---\n","   MISE: 0.145367 | Selected beta: 0.8\n","--- Run 49/100 (seed=49) ---\n","   MISE: 0.070099 | Selected beta: 0.8\n","--- Run 50/100 (seed=50) ---\n","   MISE: 0.070603 | Selected beta: 0.4\n","--- Run 51/100 (seed=51) ---\n","   MISE: 0.109745 | Selected beta: 0.8\n","--- Run 52/100 (seed=52) ---\n","   MISE: 0.123859 | Selected beta: 6.4\n","--- Run 53/100 (seed=53) ---\n","   MISE: 0.100638 | Selected beta: 1.6\n","--- Run 54/100 (seed=54) ---\n","   MISE: 0.133642 | Selected beta: 6.4\n","--- Run 55/100 (seed=55) ---\n","   MISE: 0.096004 | Selected beta: 0.4\n","--- Run 56/100 (seed=56) ---\n","   MISE: 0.109382 | Selected beta: 1.6\n","--- Run 57/100 (seed=57) ---\n","   MISE: 0.111229 | Selected beta: 3.2\n","--- Run 58/100 (seed=58) ---\n","   MISE: 0.103617 | Selected beta: 3.2\n","--- Run 59/100 (seed=59) ---\n","   MISE: 0.087826 | Selected beta: 0.8\n","--- Run 60/100 (seed=60) ---\n","   MISE: 0.114844 | Selected beta: 3.2\n","--- Run 61/100 (seed=61) ---\n","   MISE: 0.104478 | Selected beta: 1.6\n","--- Run 62/100 (seed=62) ---\n","   MISE: 0.130846 | Selected beta: 1.6\n","--- Run 63/100 (seed=63) ---\n","   MISE: 0.134796 | Selected beta: 1.6\n","--- Run 64/100 (seed=64) ---\n","   MISE: 0.110881 | Selected beta: 3.2\n","--- Run 65/100 (seed=65) ---\n","   MISE: 0.141867 | Selected beta: 6.4\n","--- Run 66/100 (seed=66) ---\n","   MISE: 0.177671 | Selected beta: 3.2\n","--- Run 67/100 (seed=67) ---\n","   MISE: 0.155656 | Selected beta: 6.4\n","--- Run 68/100 (seed=68) ---\n","   MISE: 0.088996 | Selected beta: 1.6\n","--- Run 69/100 (seed=69) ---\n","   MISE: 0.121245 | Selected beta: 3.2\n","--- Run 70/100 (seed=70) ---\n","   MISE: 0.164109 | Selected beta: 1.6\n","--- Run 71/100 (seed=71) ---\n","   MISE: 0.073595 | Selected beta: 0.4\n","--- Run 72/100 (seed=72) ---\n","   MISE: 0.135717 | Selected beta: 1.6\n","--- Run 73/100 (seed=73) ---\n","   MISE: 0.184297 | Selected beta: 1.6\n","--- Run 74/100 (seed=74) ---\n","   MISE: 0.091622 | Selected beta: 0.8\n","--- Run 75/100 (seed=75) ---\n","   MISE: 0.084709 | Selected beta: 0.8\n","--- Run 76/100 (seed=76) ---\n","   MISE: 0.126156 | Selected beta: 0.8\n","--- Run 77/100 (seed=77) ---\n","   MISE: 0.118392 | Selected beta: 3.2\n","--- Run 78/100 (seed=78) ---\n","   MISE: 0.110482 | Selected beta: 0.8\n","--- Run 79/100 (seed=79) ---\n","   MISE: 0.115246 | Selected beta: 1.6\n","--- Run 80/100 (seed=80) ---\n","   MISE: 0.149624 | Selected beta: 1.6\n","--- Run 81/100 (seed=81) ---\n","   MISE: 0.114205 | Selected beta: 1.6\n","--- Run 82/100 (seed=82) ---\n","   MISE: 0.113624 | Selected beta: 6.4\n","--- Run 83/100 (seed=83) ---\n","   MISE: 0.109510 | Selected beta: 0.8\n","--- Run 84/100 (seed=84) ---\n","   MISE: 0.054358 | Selected beta: 0.8\n","--- Run 85/100 (seed=85) ---\n","   MISE: 0.125864 | Selected beta: 6.4\n","--- Run 86/100 (seed=86) ---\n","   MISE: 0.116232 | Selected beta: 1.6\n","--- Run 87/100 (seed=87) ---\n","   MISE: 0.121584 | Selected beta: 0.8\n","--- Run 88/100 (seed=88) ---\n","   MISE: 0.109648 | Selected beta: 0.8\n","--- Run 89/100 (seed=89) ---\n","   MISE: 0.120751 | Selected beta: 0.4\n","--- Run 90/100 (seed=90) ---\n","   MISE: 0.111382 | Selected beta: 3.2\n","--- Run 91/100 (seed=91) ---\n","   MISE: 0.115664 | Selected beta: 3.2\n","--- Run 92/100 (seed=92) ---\n","   MISE: 0.139452 | Selected beta: 1.6\n","--- Run 93/100 (seed=93) ---\n","   MISE: 0.118929 | Selected beta: 0.8\n","--- Run 94/100 (seed=94) ---\n","   MISE: 0.077392 | Selected beta: 0.8\n","--- Run 95/100 (seed=95) ---\n","   MISE: 0.076174 | Selected beta: 0.2\n","--- Run 96/100 (seed=96) ---\n","   MISE: 0.100991 | Selected beta: 1.6\n","--- Run 97/100 (seed=97) ---\n","   MISE: 0.090579 | Selected beta: 1.6\n","--- Run 98/100 (seed=98) ---\n","   MISE: 0.079132 | Selected beta: 0.2\n","--- Run 99/100 (seed=99) ---\n","   MISE: 0.126499 | Selected beta: 0.4\n","--- Run 100/100 (seed=100) ---\n","   MISE: 0.106605 | Selected beta: 3.2\n","\n","==============================\n","Simulation Summary (K=100)\n","==============================\n","Mean MISE : 0.118123\n","Std MISE  : 0.028483\n","SE MISE   : 0.002848\n","------------------------------\n","Mean selected beta : 2.668\n","Std selected beta  : 2.46367\n","SE selected beta   : 0.246367\n"]},{"output_type":"stream","name":"stderr","text":["/tmp/ipython-input-2704349507.py:19: PerformanceWarning: DataFrame is highly fragmented.  This is usually the result of calling `frame.insert` many times, which has poor performance.  Consider joining all columns at once using pd.concat(axis=1) instead. To get a de-fragmented frame, use `newframe = frame.copy()`\n","  df_out[f\"h_hat_seed_{seed}\"] = results[\"curves_mat\"][i, :]\n","/tmp/ipython-input-2704349507.py:21: PerformanceWarning: DataFrame is highly fragmented.  This is usually the result of calling `frame.insert` many times, which has poor performance.  Consider joining all columns at once using pd.concat(axis=1) instead. To get a de-fragmented frame, use `newframe = frame.copy()`\n","  df_out[\"mean_h_hat\"] = results[\"mean_curve\"]\n","/tmp/ipython-input-2704349507.py:22: PerformanceWarning: DataFrame is highly fragmented.  This is usually the result of calling `frame.insert` many times, which has poor performance.  Consider joining all columns at once using pd.concat(axis=1) instead. To get a de-fragmented frame, use `newframe = frame.copy()`\n","  df_out[\"se_h_hat\"] = results[\"se_curve\"]\n"]},{"output_type":"stream","name":"stdout","text":["\n","Results saved to: /content/drive/MyDrive/Colab Notebooks/CTE_Baseline/KRR_methods/Results/plugin_synth_sample_500_noise_1.0_seeds_1-100.csv\n"]}],"source":["# ======================================================================\n","# Main Execution\n","# ======================================================================\n","\n","# 1. Run the simulation\n","results = run_simulation(\n","    K=N_SIMULATIONS,\n","    noise_std=NOISE_STD,\n","    first_seed=FIRST_SEED\n",")\n","\n","# 3. Save Results to CSV\n","output_csv_name = f\"{results['run_tag']}.csv\"\n","output_path = RESULTS_DIR / output_csv_name\n","\n","# Construct DataFrame\n","df_out = pd.DataFrame({\"t\": T_GRID})\n","for i, seed in enumerate(results[\"seeds\"]):\n","    df_out[f\"h_hat_seed_{seed}\"] = results[\"curves_mat\"][i, :]\n","\n","df_out[\"mean_h_hat\"] = results[\"mean_curve\"]\n","df_out[\"se_h_hat\"] = results[\"se_curve\"]\n","\n","df_out.to_csv(output_path, index=False)\n","print(f\"\\nResults saved to: {output_path}\")"]}],"metadata":{"colab":{"provenance":[],"authorship_tag":"ABX9TyOOyIeCbWIn6YLqTc7/JEn4"},"kernelspec":{"display_name":"Python 3","name":"python3"},"language_info":{"name":"python"}},"nbformat":4,"nbformat_minor":0}