{"nbformat":4,"nbformat_minor":0,"metadata":{"colab":{"provenance":[],"authorship_tag":"ABX9TyOrhcDtjgHaFGTSmN9pBAGw"},"kernelspec":{"name":"python3","display_name":"Python 3"},"language_info":{"name":"python"}},"cells":[{"cell_type":"markdown","source":["# Length-Scale Selection for Synthetic Data\n","\n","This notebook performs hyperparameter tuning for Kernel Ridge Regression (KRR) on synthetically generated data.\n","\n","**Workflow:**\n","1. **Heuristic Range Estimation:** Calculates the median pairwise distance on a sample batch to estimate reasonable length-scale ($\\ell$) ranges for specific correlation levels ($\\rho$).\n","2. **Data Generation:** Generates a unified synthetic dataset ($X, T, Y$) based on the project's DGP (Data Generating Process).\n","3. **Hyperparameter Tuning:** Runs a grid search with Leave-One-Out Cross-Validation (LOOCV) to jointly select the best length-scale ($\\ell$) and regularization strength ($\\beta$)."],"metadata":{"id":"LqO3jw41rfln"}},{"cell_type":"code","execution_count":1,"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"IanFSwIMra7S","executionInfo":{"status":"ok","timestamp":1769540007832,"user_tz":300,"elapsed":19354,"user":{"displayName":"D K","userId":"02556183042422178006"}},"outputId":"ad0fb8c1-0585-49f8-facd-26dfe8706f65"},"outputs":[{"output_type":"stream","name":"stdout","text":["Mounted at /content/drive\n","Working Directory: /content/drive/MyDrive/Colab Notebooks/CTE_Baseline\n"]}],"source":["import sys\n","import pathlib\n","import numpy as np\n","import pandas as pd\n","from typing import Dict, List, Tuple, Union, Optional\n","\n","# --- Environment Setup ---\n","# Mount Google Drive if running in Colab, otherwise assume local execution\n","try:\n","    from google.colab import drive\n","    drive.mount('/content/drive')\n","    BASE_DIR = pathlib.Path(\"/content/drive/MyDrive/Colab Notebooks/CTE_Baseline\")\n","except ImportError:\n","    BASE_DIR = pathlib.Path(\".\").resolve()\n","\n","# Add project root to path\n","sys.path.append(str(BASE_DIR))\n","\n","# Project specific imports\n","# Note: Ensure the 'KRR_methods' package is accessible in the path\n","from KRR_methods.synthetic_dgps import generate_unified_data\n","from KRR_methods.algorithms.length_selection import krr_length_selection_loocv_joint\n","\n","print(f\"Working Directory: {BASE_DIR}\")"]},{"cell_type":"code","source":["# ============================================================\n","# Utility Functions for Median-based Length Scale Estimation\n","# ============================================================\n","\n","def pairwise_dist_median(Z: np.ndarray, max_pairs: int = 200_000, seed: int = 123) -> float:\n","    \"\"\"\n","    Calculates the median of pairwise Euclidean distances.\n","    Uses subsampling if the number of pairs exceeds max_pairs.\n","    \"\"\"\n","    n = Z.shape[0]\n","    rng = np.random.default_rng(seed)\n","    total_pairs = n * (n - 1) // 2\n","\n","    if total_pairs <= max_pairs:\n","        # Exact calculation\n","        G = Z @ Z.T\n","        sq = np.sum(Z * Z, axis=1, keepdims=True)\n","        # Numerical stability: clamp minimum to 0\n","        D2 = np.maximum(sq + sq.T - 2.0 * G, 0.0)\n","        iu = np.triu_indices(n, k=1)\n","        dists = np.sqrt(D2[iu], dtype=Z.dtype)\n","        return float(np.median(dists))\n","    else:\n","        # Subsampling\n","        m = max_pairs\n","        i = rng.integers(0, n, size=m)\n","        j = rng.integers(0, n, size=m)\n","        same = (i == j)\n","        if np.any(same):\n","            j[same] = (j[same] + 1) % n\n","        dists = np.linalg.norm(Z[i] - Z[j], axis=1)\n","        return float(np.median(dists))\n","\n","def solve_z_for_tau(rho: float, nu: float = 1.5, tol: float = 1e-12, max_iter: int = 200) -> float:\n","    \"\"\"\n","    Solves for z such that correlation(z) = rho for the Matérn kernel.\n","    \"\"\"\n","    import math\n","\n","    # Special case: nu = 0.5 (Exponential/Laplace kernel)\n","    if abs(nu - 0.5) < 1e-9:\n","        return -math.log(rho)\n","\n","    # Polynomial part of the Matérn correlation function\n","    def P(z):\n","        if abs(nu - 1.5) < 1e-9:\n","            return 1.0 + z\n","        elif abs(nu - 2.5) < 1e-9:\n","            return 1.0 + z + (z**2)/3.0\n","        else:\n","            raise ValueError(\"nu must be 0.5, 1.5 or 2.5\")\n","\n","    def f(z):\n","        return P(z) * math.exp(-z) - rho\n","\n","    # Binary search\n","    z_lo, z_hi = 0.0, 50.0\n","    if f(z_lo) < 0: return z_lo\n","\n","    while f(z_hi) > 0 and z_hi < 1e6:\n","        z_hi *= 2.0\n","\n","    for _ in range(max_iter):\n","        z_mid = 0.5 * (z_lo + z_hi)\n","        if f(z_mid) > 0:\n","            z_lo = z_mid\n","        else:\n","            z_hi = z_mid\n","        if (z_hi - z_lo) < tol:\n","            break\n","    return 0.5 * (z_lo + z_hi)\n","\n","def matern_c(nu: float) -> float:\n","    \"\"\"Returns scaling constant c for Matérn kernel: z = c * r / ell\"\"\"\n","    if abs(nu - 0.5) < 1e-9: return 1.0\n","    if abs(nu - 1.5) < 1e-9: return np.sqrt(3.0)\n","    if abs(nu - 2.5) < 1e-9: return np.sqrt(5.0)\n","    raise ValueError(\"nu must be 0.5, 1.5 or 2.5\")\n","\n","def estimate_ell_for_rhos(\n","    n=4000, rhos=(0.5, 0.8), nu=1.5, seed=42, max_pairs=200_000, noise_std=1.0\n","):\n","    \"\"\"\n","    Generates a temporary batch of data to estimate length-scale ranges\n","    corresponding to specific correlation values (rho).\n","    \"\"\"\n","    # Set seed locally to ensure heuristic reproducibility\n","    np.random.seed(seed)\n","\n","    X_local, A_local, _ = generate_unified_data(n_samples=n, noise_std=noise_std)\n","\n","    # Combine Covariates X and Treatment A for distance calculation\n","    Z = np.hstack([X_local, A_local.reshape(-1, 1)])\n","\n","    r_med = pairwise_dist_median(Z, max_pairs=max_pairs, seed=seed)\n","    c_val = matern_c(nu)\n","\n","    out = {\"r_median\": r_med, \"nu\": nu, \"results\": []}\n","    for rho in rhos:\n","        z_rho = solve_z_for_tau(rho, nu=nu)\n","        ell = (c_val * r_med) / z_rho\n","        out[\"results\"].append({\"rho\": rho, \"z_rho\": z_rho, \"ell_rho\": ell})\n","    return out"],"metadata":{"id":"YGborykoriTC","executionInfo":{"status":"ok","timestamp":1769540007835,"user_tz":300,"elapsed":1,"user":{"displayName":"D K","userId":"02556183042422178006"}}},"execution_count":2,"outputs":[]},{"cell_type":"code","source":["# ============================================================\n","# Step 1: Estimate Median-based Length Parameters\n","# ============================================================\n","# This step helps determine the 'ell_list' grid used in the next step.\n","\n","heuristic_res = estimate_ell_for_rhos(\n","    n=500,\n","    rhos=(0.15, 0.5, 0.85),\n","    nu=1.5,\n","    seed=42,\n",")\n","\n","print(f\"r_median = {heuristic_res['r_median']:.6f}, nu = {heuristic_res['nu']}\")\n","for r in heuristic_res[\"results\"]:\n","    print(f\"rho={r['rho']}: z_rho={r['z_rho']:.6f}, ell_rho={r['ell_rho']:.6f}\")"],"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"X_TJol9lrldT","executionInfo":{"status":"ok","timestamp":1769541146224,"user_tz":300,"elapsed":5,"user":{"displayName":"D K","userId":"02556183042422178006"}},"outputId":"9012b63d-d571-4a50-effb-23c684bb6b42"},"execution_count":17,"outputs":[{"output_type":"stream","name":"stdout","text":["r_median = 3.684261, nu = 1.5\n","rho=0.15: z_rho=3.372442, ell_rho=1.892198\n","rho=0.5: z_rho=1.678347, ell_rho=3.802150\n","rho=0.85: z_rho=0.683239, ell_rho=9.339821\n"]}]},{"cell_type":"code","source":["# ============================================================\n","# Step 2: Generate Data and Run CV Grid Search\n","# ============================================================\n","\n","# 1. Generate the actual dataset for the experiment\n","# Using seed=0 as per original script\n","X, A, Y = generate_unified_data(\n","    n_samples=500,\n","    seed=42,\n","    noise_std=1\n",")\n","\n","print(f\"Data Generated. Shapes -> X: {X.shape}, A: {A.shape}, Y: {Y.shape}\")\n","\n","# 2. Run LOOCV Grid Search\n","# Searching for optimal length-scale (ell) and regularization (beta)\n","res = krr_length_selection_loocv_joint(\n","    Xs=X,\n","    Ts=A,\n","    Ys=Y,\n","    nu_list=[1.5],                       # Matérn smoothness 1.5\n","    ell_list=[2, 3, 4, 5, 6, 7, 8, 9], # Grid derived from heuristics\n","    beta_bounds=(1e-4, 1e2),             # Search range for Ridge parameter\n","    kernel_type=\"matern\",\n",")\n","\n","print(\"\\nBest Global Parameter Set:\")\n","print(res[\"best_global\"])"],"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"reqo-EEBrm7P","executionInfo":{"status":"ok","timestamp":1769541155498,"user_tz":300,"elapsed":1909,"user":{"displayName":"D K","userId":"02556183042422178006"}},"outputId":"0405f295-163e-4666-99e5-6f21c58628bc"},"execution_count":18,"outputs":[{"output_type":"stream","name":"stdout","text":["Data Generated. Shapes -> X: (500, 10), A: (500,), Y: (500,)\n","\n","Best Global Parameter Set:\n","{'kernel': 'matern', 'nu': 1.5, 'ell': 3, 'beta_star': 5.362741335545741, 'loocv_mse': 1.2982073751393195}\n"]}]},{"cell_type":"markdown","source":["# A-to-Y Regression Length Parameter Selection"],"metadata":{"id":"O1rkW_DGLLqf"}},{"cell_type":"code","source":["import numpy as np\n","\n","def estimate_ell_for_rhos_A_only(\n","    n: int = 1000,\n","    rhos: tuple = (0.15, 0.85),\n","    nu: float = 1.5,\n","    seed: int = 42,\n","    max_pairs: int = 200_000,\n","    noise_std: float = 1.0,\n","):\n","    \"\"\"\n","    Estimate length-scale values implied by target correlations (rhos) using\n","    the median pairwise distance computed from A only.\n","\n","    This mirrors the existing heuristic in your script, except it uses Z = [A]\n","    rather than Z = [X, A].\n","    \"\"\"\n","    np.random.seed(seed)\n","\n","    # Assumes generate_unified_data is available (as in your existing script).\n","    _, A_local, _ = generate_unified_data(n_samples=n, noise_std=noise_std)\n","\n","    Z = A_local.reshape(-1, 1)\n","    r_med = pairwise_dist_median(Z, max_pairs=max_pairs, seed=seed)\n","    c_val = matern_c(nu)\n","\n","    out = {\"r_median\": float(r_med), \"nu\": float(nu), \"results\": []}\n","    for rho in rhos:\n","        z_rho = solve_z_for_tau(rho, nu=nu)\n","        ell = (c_val * r_med) / z_rho\n","        out[\"results\"].append({\"rho\": float(rho), \"z_rho\": float(z_rho), \"ell_rho\": float(ell)})\n","\n","    return out\n","\n","\n","# Example run\n","heur_A = estimate_ell_for_rhos_A_only(\n","    n=500,\n","    rhos=(0.15, 0.85),\n","    nu=1.5,\n","    seed=42,\n","    noise_std=1.0,\n",")\n","\n","print(f\"[A-only heuristic] r_median(A) = {heur_A['r_median']:.6f}, nu = {heur_A['nu']}\")\n","for r in heur_A[\"results\"]:\n","    print(f\"rho={r['rho']}: z_rho={r['z_rho']:.6f}, ell_rho={r['ell_rho']:.6f}\")\n","\n","ell_low = [r[\"ell_rho\"] for r in heur_A[\"results\"] if abs(r[\"rho\"] - 0.85) < 1e-12][0]\n","ell_high = [r[\"ell_rho\"] for r in heur_A[\"results\"] if abs(r[\"rho\"] - 0.15) < 1e-12][0]\n","print(f\"Suggested ell range (A-only, rho in [0.15, 0.85]): [{ell_low:.6f}, {ell_high:.6f}]\")\n"],"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"3vjU3-loLOWB","executionInfo":{"status":"ok","timestamp":1769541164963,"user_tz":300,"elapsed":15,"user":{"displayName":"D K","userId":"02556183042422178006"}},"outputId":"57fff568-cd5d-4dcf-b235-49f4b2dc8fcb"},"execution_count":19,"outputs":[{"output_type":"stream","name":"stdout","text":["[A-only heuristic] r_median(A) = 2.563962, nu = 1.5\n","rho=0.15: z_rho=3.372442, ell_rho=1.316824\n","rho=0.85: z_rho=0.683239, ell_rho=6.499797\n","Suggested ell range (A-only, rho in [0.15, 0.85]): [6.499797, 1.316824]\n"]}]},{"cell_type":"code","source":["import numpy as np\n","\n","# Assumes generate_unified_data is available (as in your existing script).\n","X, A, Y = generate_unified_data(\n","    n_samples=500,\n","    seed=0,\n","    noise_std=1\n",")\n","\n","A = np.asarray(A).reshape(-1, 1)\n","Y = np.asarray(Y).reshape(-1)\n","\n","print(f\"Data generated. Shapes -> X: {np.asarray(X).shape}, A: {A.shape}, Y: {Y.shape}\")\n"],"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"xrmqzn2qMHSN","executionInfo":{"status":"ok","timestamp":1769541170144,"user_tz":300,"elapsed":9,"user":{"displayName":"D K","userId":"02556183042422178006"}},"outputId":"189d9df2-ba22-46a9-def2-b8ba1ec6d025"},"execution_count":20,"outputs":[{"output_type":"stream","name":"stdout","text":["Data generated. Shapes -> X: (500, 10), A: (500, 1), Y: (500,)\n"]}]},{"cell_type":"code","source":["import numpy as np\n","from typing import Dict, List, Tuple\n","\n","def matern32_kernel_1d(A: np.ndarray, ell: float) -> np.ndarray:\n","    \"\"\"\n","    Matérn kernel with nu = 1.5 (Matérn 3/2) for 1D inputs.\n","\n","    k(r) = (1 + sqrt(3) * r / ell) * exp(-sqrt(3) * r / ell)\n","    \"\"\"\n","    A = np.asarray(A, dtype=float).reshape(-1, 1)\n","    r = np.abs(A - A.T)\n","    s = (np.sqrt(3.0) * r) / float(ell)\n","    return (1.0 + s) * np.exp(-s)\n","\n","\n","def loocv_mse_krr_from_kernel(K: np.ndarray, y: np.ndarray, beta: float) -> float:\n","    \"\"\"\n","    Efficient LOOCV MSE for kernel ridge regression with a fixed kernel matrix K.\n","\n","    Let M = K + beta * I.\n","    alpha = M^{-1} y\n","    C = M^{-1}\n","    LOOCV residual_i = alpha_i / C_ii\n","    LOOCV MSE = mean(residual_i^2)\n","\n","    Uses Cholesky for numerical stability.\n","    \"\"\"\n","    y = np.asarray(y, dtype=float).reshape(-1)\n","    n = K.shape[0]\n","    M = K + float(beta) * np.eye(n, dtype=float)\n","\n","    L = np.linalg.cholesky(M)\n","\n","    tmp = np.linalg.solve(L, y)\n","    alpha = np.linalg.solve(L.T, tmp)\n","\n","    V = np.linalg.solve(L, np.eye(n, dtype=float))\n","    diagC = np.sum(V * V, axis=0)\n","\n","    residual = alpha / diagC\n","    return float(np.mean(residual ** 2))\n","\n","\n","def make_beta_grid(beta_bounds: Tuple[float, float], num_beta: int) -> np.ndarray:\n","    \"\"\"\n","    Create a log-spaced beta grid within [beta_min, beta_max].\n","    \"\"\"\n","    beta_min, beta_max = beta_bounds\n","    beta_min = float(beta_min)\n","    beta_max = float(beta_max)\n","    if beta_min <= 0 or beta_max <= 0 or beta_min >= beta_max:\n","        raise ValueError(\"beta_bounds must satisfy 0 < beta_min < beta_max.\")\n","    return np.logspace(np.log10(beta_min), np.log10(beta_max), num=num_beta)\n","\n","\n","def select_ell_by_loocv_with_inner_beta_search(\n","    A: np.ndarray,\n","    Y: np.ndarray,\n","    ell_list: List[float],\n","    beta_bounds: Tuple[float, float] = (1e-4, 1e2),\n","    num_beta: int = 30,\n","    verbose: bool = True,\n",") -> Dict:\n","    \"\"\"\n","    Outer loop over ell; inner loop over beta.\n","    For each ell, pick beta minimizing LOOCV MSE.\n","    Then pick ell minimizing the per-ell best LOOCV MSE.\n","    \"\"\"\n","    A = np.asarray(A, dtype=float).reshape(-1, 1)\n","    Y = np.asarray(Y, dtype=float).reshape(-1)\n","\n","    beta_grid = make_beta_grid(beta_bounds, num_beta=num_beta)\n","\n","    per_ell_results = []\n","    best_global = {\"ell\": None, \"beta\": None, \"loocv_mse\": np.inf}\n","\n","    for ell in ell_list:\n","        K = matern32_kernel_1d(A, ell=float(ell))\n","\n","        best_beta_for_ell = None\n","        best_mse_for_ell = np.inf\n","\n","        for beta in beta_grid:\n","            mse = loocv_mse_krr_from_kernel(K, Y, beta=float(beta))\n","            if mse < best_mse_for_ell:\n","                best_mse_for_ell = mse\n","                best_beta_for_ell = float(beta)\n","\n","        per_ell_results.append(\n","            {\"ell\": float(ell), \"best_beta\": float(best_beta_for_ell), \"best_loocv_mse\": float(best_mse_for_ell)}\n","        )\n","\n","        if verbose:\n","            print(f\"ell={float(ell):>8.4f} | best_beta={best_beta_for_ell:.4e} | best_LOOCV_MSE={best_mse_for_ell:.6e}\")\n","\n","        if best_mse_for_ell < best_global[\"loocv_mse\"]:\n","            best_global = {\"ell\": float(ell), \"beta\": float(best_beta_for_ell), \"loocv_mse\": float(best_mse_for_ell)}\n","\n","    return {\n","        \"best_global\": best_global,\n","        \"per_ell_results\": per_ell_results,\n","        \"beta_bounds\": tuple(map(float, beta_bounds)),\n","        \"num_beta\": int(num_beta),\n","        \"ell_list\": [float(e) for e in ell_list],\n","    }\n","\n","\n","# Example run: your requested initial ell grid\n","ell_grid = [1, 2, 3, 4, 5,6,7]\n","\n","res_A_only = select_ell_by_loocv_with_inner_beta_search(\n","    A=A,\n","    Y=Y,\n","    ell_list=ell_grid,\n","    beta_bounds=(1e-4, 1e2),  # Matches the style used in your existing script call\n","    num_beta=30,\n","    verbose=True,\n",")\n","\n","print(\"\\nBest (A-only) global parameters:\")\n","print(res_A_only[\"best_global\"])\n"],"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"s6EmncF9MKrI","executionInfo":{"status":"ok","timestamp":1769541185066,"user_tz":300,"elapsed":12614,"user":{"displayName":"D K","userId":"02556183042422178006"}},"outputId":"70439e16-0424-4746-9ff5-4479e0d0cc47"},"execution_count":21,"outputs":[{"output_type":"stream","name":"stdout","text":["ell=  1.0000 | best_beta=2.2122e+00 | best_LOOCV_MSE=9.963566e-01\n","ell=  2.0000 | best_beta=5.2983e-01 | best_LOOCV_MSE=9.961879e-01\n","ell=  3.0000 | best_beta=1.2690e-01 | best_LOOCV_MSE=9.963176e-01\n","ell=  4.0000 | best_beta=7.8805e-02 | best_LOOCV_MSE=9.964268e-01\n","ell=  5.0000 | best_beta=3.0392e-02 | best_LOOCV_MSE=9.964353e-01\n","ell=  6.0000 | best_beta=1.8874e-02 | best_LOOCV_MSE=9.964684e-01\n","ell=  7.0000 | best_beta=1.1721e-02 | best_LOOCV_MSE=9.964966e-01\n","\n","Best (A-only) global parameters:\n","{'ell': 2.0, 'beta': 0.5298316906283708, 'loocv_mse': 0.9961879016900578}\n"]}]}]}