{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "provenance": []
    },
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3"
    },
    "language_info": {
      "name": "python"
    }
  },
  "cells": [
    {
      "cell_type": "code",
      "execution_count": 4,
      "metadata": {
        "id": "qo2V4Q-R6-au"
      },
      "outputs": [],
      "source": [
        "import numpy as np\n",
        "import pandas as pd\n",
        "from sklearn.datasets import load_diabetes\n",
        "from sklearn.preprocessing import StandardScaler\n",
        "from sklearn.linear_model import Ridge\n",
        "from sklearn.metrics import mean_squared_error\n",
        "import matplotlib.pyplot as plt\n",
        "\n",
        "import torch\n",
        "import torch.nn as nn\n",
        "import torch.optim as optim\n",
        "from sklearn.model_selection import train_test_split\n",
        "import cvxpy as cp\n",
        "import os\n",
        "from IPython.display import display\n",
        "from torch.optim import LBFGS\n",
        "import copy\n",
        "import re\n",
        "import networkx as nx\n",
        "import random\n",
        "from sklearn.datasets import fetch_california_housing\n",
        "from google.colab import drive\n",
        "from scipy.stats import wishart\n",
        "\n",
        "import pymc as pm\n",
        "import arviz as az\n",
        "import pytensor.tensor as at\n",
        "import time"
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "dataset_list = ['load_diabetes', 'asia_bif', 'cali_housing', ]\n",
        "chosen_dataset = dataset_list[0]\n",
        "\n",
        "if chosen_dataset == 'load_diabetes':\n",
        "    X, y = load_diabetes(return_X_y=True)\n",
        "\n",
        "elif chosen_dataset == 'asia_bif':\n",
        "\n",
        "    drive.mount('/content/drive')\n",
        "    # Load the uploaded asia.bif file\n",
        "    bif_path = \"/content/drive/My Drive/ICLR2026/asia.bif\"\n",
        "\n",
        "    # Parse the BIF file to extract the graph structure\n",
        "    with open(bif_path, 'r') as f:\n",
        "        bif_text = f.read()\n",
        "\n",
        "    variables = re.findall(r'variable\\s+(\\w+)', bif_text)\n",
        "    parents_matches = re.findall(r'probability\\s+\\(\\s*(\\w+)\\s*\\|\\s*([^)]+)\\)', bif_text)\n",
        "    root_matches = re.findall(r'probability\\s+\\(\\s*(\\w+)\\s*\\)', bif_text)\n",
        "\n",
        "    # Initialize graph\n",
        "    G = nx.DiGraph()\n",
        "    G.add_nodes_from(variables)\n",
        "\n",
        "    # Add edges from parent-child relationships\n",
        "    for child, parents in parents_matches:\n",
        "        parent_list = [p.strip() for p in parents.split(',')]\n",
        "        for parent in parent_list:\n",
        "            G.add_edge(parent, child)\n",
        "\n",
        "    # Add root nodes (no parents)\n",
        "    child_nodes_with_parents = set(child for child, _ in parents_matches)\n",
        "    root_nodes = [var for var in root_matches if var not in child_nodes_with_parents]\n",
        "    G.add_nodes_from(root_nodes)\n",
        "\n",
        "    # Ensure it's a DAG and sort topologically\n",
        "    assert nx.is_directed_acyclic_graph(G)\n",
        "    topo_order = list(nx.topological_sort(G))\n",
        "\n",
        "    # Get adjacency matrix in topological order\n",
        "    A = nx.to_numpy_array(G, nodelist=topo_order)\n",
        "\n",
        "    # Compute out-degree Laplacian\n",
        "    D_out = np.diag(np.sum(A, axis=1))\n",
        "    # L = D_out - A\n",
        "    import copy\n",
        "    # L_orig = copy.deepcopy(L)\n",
        "\n",
        "    # Avoid division by zero\n",
        "    with np.errstate(divide='ignore'):\n",
        "        D_out_inv = np.diag(1.0 / np.sum(A, axis=1))\n",
        "        D_out_inv[np.isinf(D_out_inv)] = 0.0\n",
        "\n",
        "    # Random walk Laplacian\n",
        "    L = np.eye(A.shape[0]) - D_out_inv @ A\n",
        "    L_orig = copy.deepcopy(L)\n",
        "\n",
        "    # Prepare matrices for display\n",
        "    adjacency_df = pd.DataFrame(A, index=topo_order, columns=topo_order)\n",
        "    laplacian_df = pd.DataFrame(L, index=topo_order, columns=topo_order)\n",
        "\n",
        "    # Return eigenvalues for insight\n",
        "    eigenvalues = np.linalg.eigvals(L)\n",
        "    eigenvalues.sort()\n",
        "    eigenvalues.real.round(3)\n",
        "\n",
        "    # Define the topological order of variables in ASIA\n",
        "    variables = topo_order\n",
        "\n",
        "    # Create a dictionary to hold generated data\n",
        "    data = {var: [] for var in variables}\n",
        "    num_samples = 1000\n",
        "\n",
        "    for _ in range(num_samples):\n",
        "        sample = {}\n",
        "\n",
        "        # asia: has tuberculosis if visited Asia (Bernoulli)\n",
        "        sample[\"asia\"] = int(random.random() < 0.05)\n",
        "\n",
        "        # smoke: independent\n",
        "        sample[\"smoke\"] = int(random.random() < 0.3)\n",
        "\n",
        "        # tub: depends on asia\n",
        "        sample[\"tub\"] = int(sample[\"asia\"] and random.random() < 0.6)\n",
        "\n",
        "        # lung: depends on smoke\n",
        "        sample[\"lung\"] = int(sample[\"smoke\"] and random.random() < 0.5)\n",
        "\n",
        "        # bronc: depends on smoke\n",
        "        sample[\"bronc\"] = int(sample[\"smoke\"] and random.random() < 0.4)\n",
        "\n",
        "        # either: OR of tub and lung\n",
        "        sample[\"either\"] = int(sample[\"tub\"] or sample[\"lung\"])\n",
        "\n",
        "        # xray: depends on either\n",
        "        sample[\"xray\"] = int(sample[\"either\"] and random.random() < 0.9 or random.random() < 0.05)\n",
        "\n",
        "        # dysp: depends on bronc and either\n",
        "        sample[\"dysp\"] = int((sample[\"bronc\"] and random.random() < 0.7) or\n",
        "                            (sample[\"either\"] and random.random() < 0.6))\n",
        "\n",
        "        for var in variables:\n",
        "            data[var].append(sample[var])\n",
        "\n",
        "    # Create a DataFrame\n",
        "    df = pd.DataFrame(data)\n",
        "\n",
        "    target_protein = \"either\"\n",
        "    y = df[target_protein].values\n",
        "    X = df.drop(columns=[target_protein]).values\n",
        "    feature_names = df.drop(columns=[target_protein]).columns.tolist()\n",
        "\n",
        "elif chosen_dataset == 'cali_housing':\n",
        "    data = fetch_california_housing(as_frame=False)\n",
        "    X, y = data.data, data.target"
      ],
      "metadata": {
        "id": "5PK-s-XF7cau"
      },
      "execution_count": 5,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "X_train, X_test, y_train, y_test = train_test_split(\n",
        "    X, y, test_size=0.2, random_state=0)\n",
        "\n",
        "scaler_X = StandardScaler().fit(X_train)\n",
        "X_train = scaler_X.transform(X_train)\n",
        "X_test  = scaler_X.transform(X_test)\n",
        "\n",
        "scaler_y = StandardScaler().fit(y_train.reshape(-1,1))\n",
        "y_train = scaler_y.transform(y_train.reshape(-1,1)).ravel()\n",
        "y_test  = scaler_y.transform(y_test.reshape(-1,1)).ravel()\n",
        "\n",
        "n, d = X_train.shape\n",
        "print(n)\n",
        "print(d)"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "ZpY5JLQ09d8d",
        "outputId": "99665c3b-8099-444d-d578-055af11b1d77"
      },
      "execution_count": 6,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "353\n",
            "10\n"
          ]
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "# -------------------------\n",
        "# 2) Fixed random Tikhonov components\n",
        "# -------------------------\n",
        "rng = np.random.default_rng(12345)\n",
        "L = rng.normal(size=(d, d))          # square for simplicity; any (p x d) works\n",
        "lam = 5e-2                           # Tikhonov strength\n",
        "nu0 = d + 2                         # dof; must be >= d\n",
        "kappa0 = 1.0                        # strength on mean prior\n",
        "m0 = np.zeros(d)                    # prior mean\n",
        "\n",
        "# Wishart scale chosen so E[Λ] = lam * (L^T L)\n",
        "Sigma0 = (lam / nu0) * (L.T @ L)\n",
        "\n",
        "# -------------------------\n",
        "# Sample precision (Tikhonov) from Wishart\n",
        "# -------------------------\n",
        "Lambda = wishart.rvs(df=nu0, scale=Sigma0, random_state=rng)  # Λ ~ W(ν0, Σ0)\n",
        "M_true = Lambda                                               # Tikhonov matrix (prior precision)\n",
        "\n",
        "# -------------------------\n",
        "# Optionally sample μ | Λ ~ N(m0, (κ0 Λ)^(-1))\n",
        "# -------------------------\n",
        "# Draw using a Cholesky of (κ0 Λ)^(-1)\n",
        "chol_cov = np.linalg.cholesky(np.linalg.inv(kappa0 * Lambda))\n",
        "mu_true = m0 + chol_cov @ rng.normal(size=d)\n",
        "\n",
        "# Torch tensors\n",
        "M_true_t = torch.tensor(M_true, dtype=torch.float32)\n",
        "mu_true_t = torch.tensor(mu_true, dtype=torch.float32)\n",
        "\n",
        "# -------------------------\n",
        "# 3.a) Forward Tikhonov (closed form, sum-of-squares)\n",
        "#     (X^T X + M_true) θ = X^T y + M_true μ\n",
        "#     (X^T X + lam L^T L) theta = X^T y + lam L^T L mu\n",
        "# -------------------------\n",
        "XtX = X_train.T @ X_train\n",
        "Xty = X_train.T @ y_train\n",
        "A = XtX + M_true\n",
        "rhs = Xty + M_true @ mu_true\n",
        "A += 1e-10 * np.eye(d)  # small jitter\n",
        "\n",
        "theta_fwd = np.linalg.solve(A, rhs)\n",
        "# print(f\"θ* = {theta_fwd}\")\n",
        "\n",
        "# Residual b = - X^T (X θ - y)  [sum of squares convention]\n",
        "# KKT stationarity: X^T(Xθ - y) + M_true(θ - μ_true) = 0\n",
        "b_fwd = - X_train.T @ (X_train @ theta_fwd - y_train)\n",
        "\n",
        "# Verify KKT using the ground-truth M_true\n",
        "kkt_residual = np.linalg.norm(X_train.T @ (X_train @ theta_fwd - y_train) + M_true @ (theta_fwd - mu_true))\n",
        "print(f\"KKT residual with ground-truth M_true: {kkt_residual:.3e}\")\n",
        "\n",
        "# -------------------------\n",
        "# 3.b) Forward Tikhonov (ML model solution)\n",
        "# -------------------------\n",
        "X_train_t = torch.from_numpy(X_train).float()\n",
        "y_train_t = torch.from_numpy(y_train).float().unsqueeze(1)\n",
        "\n",
        "# 4) Define a single-layer linear model (no bias for simplicity)\n",
        "model = nn.Linear(d, 1, bias=False)\n",
        "\n",
        "# 5) Optimize with weight decay = ridge λ\n",
        "lr = 1e-2\n",
        "ridge_lambda = 1e-2\n",
        "optimizer = optim.SGD(model.parameters(), lr=lr, weight_decay=ridge_lambda)\n",
        "loss_fn   = nn.MSELoss()\n",
        "\n",
        "for epoch in range(1000):\n",
        "    optimizer.zero_grad()\n",
        "    preds = model(X_train_t)\n",
        "    w = model.weight.view(-1)\n",
        "    data_ss = torch.sum((preds.squeeze() - torch.from_numpy(y_train).float())**2)\n",
        "    loss = data_ss + lam * torch.sum((torch.from_numpy(L.T).float() @ (w - torch.from_numpy(mu_true).float()))**2)\n",
        "    # loss = loss_fn(preds, y_train_t) * n + lam * torch.sum((L.T @ (w - mu_true))**2)\n",
        "    loss.backward()\n",
        "    optimizer.step()\n",
        "\n",
        "# 6) Extract θ* from the trained model\n",
        "theta_ml = model.weight.detach().numpy().reshape(-1)  # shape (d,)\n",
        "# print(f\"θ* = {theta_ml}\")\n",
        "# val_mse_torch = np.mean((X_val @ theta_torch - y_val)**2)\n",
        "\n",
        "# 7) Compute b = -½ ∇‖Xθ - y‖² at θ*\n",
        "#    ∇f = 2 Xᵀ(Xθ* - y)  ⇒  b = -0.5 * ∇f = -Xᵀ(Xθ* - y)\n",
        "grad = 2 * X_train.T @ (X_train @ theta_ml - y_train)\n",
        "b_ml    = -0.5 * grad\n",
        "\n",
        "mode = \"ml\"\n",
        "# mode = \"fwd\"\n",
        "if mode == \"ml\":\n",
        "    theta_star = theta_ml\n",
        "    b_star = b_ml\n",
        "else:\n",
        "    theta_star = theta_fwd\n",
        "    b_star = b_fwd\n",
        "\n",
        "b_true = M_true @ (theta_star - mu_true) + rng.normal(0, 0.5, size=d)"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "AmEJ9bF9AK_t",
        "outputId": "a8a96b00-d043-4e94-d8b4-213f69691324"
      },
      "execution_count": 7,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "KKT residual with ground-truth M_true: 7.649e-11\n"
          ]
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "# -------------------------\n",
        "# 5) Metrics\n",
        "# -------------------------\n",
        "def metrics(mu_hat, mu_true, M_hat, M_true, theta_star, y_test, X_test):\n",
        "\n",
        "    def frob(A): return np.linalg.norm(A, 'fro')\n",
        "\n",
        "    err = M_hat - M_true\n",
        "    rel_frob = np.linalg.norm(err, 'fro') / (np.linalg.norm(M_true, 'fro') + 1e-12)\n",
        "    spec_err = np.linalg.norm(err, 2) / (np.linalg.norm(M_true, 2) + 1e-12)\n",
        "    corr = np.corrcoef(M_true.ravel(), M_hat.ravel())[0, 1]\n",
        "\n",
        "    M_sym = 0.5 * (M_hat + M_hat.T)\n",
        "    w, U = np.linalg.eigh(M_sym)\n",
        "    w_clipped = np.clip(w, 0, None)\n",
        "    M_psd = U @ np.diag(w_clipped) @ U.T\n",
        "    M_hat = copy.deepcopy(M_psd)\n",
        "\n",
        "    # PSD checks\n",
        "    mineig_true = np.linalg.eigvalsh(M_true).min()\n",
        "    mineig_hat = np.linalg.eigvalsh((M_hat + M_hat.T)/2).min()  # symmetrize for eigs\n",
        "\n",
        "    # # Constraint residuals\n",
        "    # viol = max(np.linalg.norm(M_psd @ (theta_list[j] - mu_list[j]) - b_list[j]) for j in range(J))\n",
        "\n",
        "    def _as_array_of_mus(mu):\n",
        "        # Accepts: np.ndarray (d,), list/tuple of (d,) arrays, or (J,d) array\n",
        "        if isinstance(mu, (list, tuple)):\n",
        "            return np.stack(mu, axis=0)  # (J,d)\n",
        "        mu = np.asarray(mu)\n",
        "        return mu if mu.ndim == 2 else mu[None, :]  # (1,d) if single\n",
        "\n",
        "    def rmse(a, b):\n",
        "        a = np.asarray(a); b = np.asarray(b)\n",
        "        return float(np.sqrt(np.mean((a - b) ** 2)))\n",
        "\n",
        "    def symmetrize(M):\n",
        "        return 0.5 * (M + M.T)\n",
        "\n",
        "    def safe_pinv_psd(M, eps=1e-9):\n",
        "        # Symmetrize + jitter, then Moore–Penrose\n",
        "        Ms = symmetrize(M)\n",
        "        d = Ms.shape[0]\n",
        "        return np.linalg.pinv(Ms + eps * np.eye(d))\n",
        "\n",
        "    # --- RMSE(mu) ---\n",
        "    # Expect variables: mu_hat (estimated) and mu_true (ground truth).\n",
        "    # If you instead have multiple priors {mu_j}, pass lists for both.\n",
        "    rmse_mu = None\n",
        "    try:\n",
        "        mu_hat_arr  = _as_array_of_mus(mu_hat)   # shape (J,d) or (1,d)\n",
        "        mu_true_arr = _as_array_of_mus(mu_true)  # shape (J,d) or (1,d)\n",
        "        # If counts differ but one is (1,d), broadcast it across J:\n",
        "        if mu_hat_arr.shape[0] != mu_true_arr.shape[0]:\n",
        "            if mu_hat_arr.shape[0] == 1:\n",
        "                mu_hat_arr  = np.repeat(mu_hat_arr,  mu_true_arr.shape[0], axis=0)\n",
        "            elif mu_true_arr.shape[0] == 1:\n",
        "                mu_true_arr = np.repeat(mu_true_arr, mu_hat_arr.shape[0],  axis=0)\n",
        "        rmse_mu = rmse(mu_hat_arr, mu_true_arr)\n",
        "    except NameError:\n",
        "        # mu_hat or mu_true not defined; skip metric gracefully\n",
        "        rmse_mu = np.nan\n",
        "\n",
        "    # --- RFB for covariances (relative Frobenius between Σ_hat and Σ_true) ---\n",
        "    Sigma_true = safe_pinv_psd(M_true)\n",
        "    Sigma_hat  = safe_pinv_psd(M_hat)\n",
        "    rfb_cov = np.linalg.norm(Sigma_hat - Sigma_true, 'fro') / (np.linalg.norm(Sigma_true, 'fro') + 1e-12)\n",
        "\n",
        "    # Optionally: spectral relative error on covariances\n",
        "    spec_cov_err = np.linalg.norm(Sigma_hat - Sigma_true, 2) / (np.linalg.norm(Sigma_true, 2) + 1e-12)\n",
        "\n",
        "    # if mode == \"ml\":\n",
        "    #     print(\"=== Multi-constraint inverse (using ML-derived θ*’s) ===\")\n",
        "    # else:\n",
        "    #     print(\"=== Multi-constraint inverse (using closed-form solution) ===\")\n",
        "\n",
        "    metrics_df = pd.DataFrame({\n",
        "        \"metric\": [\n",
        "            \"Relative Frobenius error ‖M_hat - M_true‖_F / ‖M_true‖_F\",\n",
        "            \"Relative spectral error  ‖·‖_2\",\n",
        "            \"Entrywise correlation corr(M_true, M_hat)\",\n",
        "            \"Min eigenvalue M_true\",\n",
        "            \"Min eigenvalue sym(M_hat)\",\n",
        "            \"Validation MSE (θ_fwd)\",\n",
        "            # \"Max constraint residual\",\n",
        "            \"RMSE (μ̂ vs μ_true)\",\n",
        "            \"RFB on covariances ‖Σ̂ - Σ_true‖_F / ‖Σ_true‖_F\",\n",
        "            \"Relative spectral error on covariances ‖·‖_2\"\n",
        "        ],\n",
        "        \"value\": [\n",
        "            rel_frob,\n",
        "            spec_err,\n",
        "            corr,\n",
        "            mineig_true,\n",
        "            mineig_hat,\n",
        "            float(np.mean((X_test @ theta_star - y_test)**2)),\n",
        "            # viol,\n",
        "            rmse_mu,\n",
        "            rfb_cov,\n",
        "            spec_cov_err\n",
        "        ]\n",
        "    })\n",
        "\n",
        "    display(metrics_df)"
      ],
      "metadata": {
        "id": "fiKkwYaveDM8"
      },
      "execution_count": 8,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "# -------------------------\n",
        "# 4.a) Inverse recovery from a single (theta, b): alternating (M, mu)\n",
        "      #  min tr(M) s.t. M (theta - mu) = b, M PSD; then mu <- theta - M^+ b\n",
        "# -------------------------\n",
        "\n",
        "# given: theta_star, b_star (1D arrays of same length)\n",
        "\n",
        "d_theta = theta_star.shape[0]\n",
        "mu = np.zeros(d_theta)\n",
        "tol = 1e-6\n",
        "max_iters = 50\n",
        "feas_margin = 1e-8  # ensure (theta - mu)^T b >= this\n",
        "\n",
        "def enforce_feasibility(mu, theta_star, b_star, margin):\n",
        "    v = theta_star - mu\n",
        "    vb = float(v @ b_star)\n",
        "    if vb < margin:\n",
        "        denom = float(b_star @ b_star)\n",
        "        if denom == 0.0:\n",
        "            # if b=0, feasible with M=0; keep mu\n",
        "            return mu\n",
        "        # minimal shift along -b to make v^T b = margin\n",
        "        t = (margin - vb) / denom\n",
        "        mu = mu - t * b_star\n",
        "    return mu\n",
        "\n",
        "start = time.perf_counter()\n",
        "\n",
        "M_value = None\n",
        "for k in range(max_iters):\n",
        "    # --- feasibility guard ---\n",
        "    mu = enforce_feasibility(mu, theta_star, b_star, feas_margin)\n",
        "    v = theta_star - mu\n",
        "\n",
        "    # --- M-step (SDP, PSD) ---\n",
        "    M = cp.Variable((d_theta, d_theta), PSD=True)\n",
        "    constraints = [M @ v == b_star]\n",
        "    obj = cp.Minimize(cp.trace(M))\n",
        "    prob = cp.Problem(obj, constraints)\n",
        "    prob.solve(solver=cp.SCS, eps=1e-6, max_iters=200000, verbose=False)\n",
        "\n",
        "    M_value = M.value\n",
        "    if M_value is None:\n",
        "        raise RuntimeError(f\"SCS failed at iter {k}: M.value is None (likely infeasible or numerical failure).\")\n",
        "\n",
        "    # symmetrize + tiny ridge for stability\n",
        "    M_value = 0.5 * (M_value + M_value.T) + 1e-12 * np.eye(d_theta)\n",
        "\n",
        "    # --- mu-step (pseudoinverse) ---\n",
        "    M_pinv = np.linalg.pinv(M_value, rcond=1e-12)\n",
        "    mu_new = theta_star - M_pinv @ b_star\n",
        "\n",
        "    # --- convergence ---\n",
        "    if np.linalg.norm(mu_new - mu) < tol:\n",
        "        mu = mu_new\n",
        "        print(\"Converged at k =\", k)\n",
        "        break\n",
        "    mu = mu_new\n",
        "else:\n",
        "    print(\"Max iterations reached without convergence.\")\n",
        "\n",
        "end = time.perf_counter()\n",
        "print(f\"Elapsed time: {end - start:.2f} seconds\")\n",
        "\n",
        "# after convergence:\n",
        "M_opt = M_value\n",
        "mu_opt = mu\n",
        "\n",
        "# recover L (PSD factor)\n",
        "eigvals, eigvecs = np.linalg.eigh(0.5 * (M_opt + M_opt.T))\n",
        "eigvals = np.clip(eigvals, 0.0, None)\n",
        "L_new = eigvecs @ np.diag(np.sqrt(eigvals)) @ eigvecs.T\n",
        "\n",
        "# nice prints\n",
        "np.set_printoptions(precision=4, suppress=True)\n",
        "residual = np.linalg.norm(M_opt @ (theta_star - mu_opt) - b_star)\n",
        "print(\"Inferred prior-mean μ:\\n\", mu_opt)\n",
        "print(\"Trace(M):\", float(np.trace(M_opt)))\n",
        "print(\"Constraint residual ‖M(θ*−μ)−b‖₂:\", residual)\n",
        "print(\"θ*:\\n\", theta_star)\n",
        "print(\"\\nb (aka b_star):\\n\", b_star)\n",
        "print(\"\\nM @ (θ*−μ) (should ≈ b):\\n\", (M_opt @ (theta_star - mu_opt)).round(4))\n",
        "print(\"\\nRecovered L (one factor of M):\\n\", L_new)\n",
        "print(f\"\\nReconstruction error ‖M - LᵀL‖_F = {np.linalg.norm(M_opt - (L_new.T @ L_new)):.4e}\")"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "Lg-0jAVEQVLa",
        "outputId": "25f07c64-1d72-455e-b39a-fead55cfa245"
      },
      "execution_count": 9,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stderr",
          "text": [
            "/usr/local/lib/python3.12/dist-packages/cvxpy/problems/problem.py:1510: UserWarning: Solution may be inaccurate. Try another solver, adjusting the solver settings, or solve with verbose=True for more information.\n",
            "  warnings.warn(\n"
          ]
        },
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Max iterations reached without convergence.\n",
            "Elapsed time: 5.53 seconds\n",
            "Inferred prior-mean μ:\n",
            " [-0.0224 -0.1502  0.3505  0.184  -0.3897  0.1832  0.0131  0.1112  0.4553\n",
            "  0.0272]\n",
            "Trace(M): 40673.13303658821\n",
            "Constraint residual ‖M(θ*−μ)−b‖₂: 2.3672957613561345e-06\n",
            "θ*:\n",
            " [-0.0224 -0.1504  0.3498  0.1838 -0.389   0.1828  0.0128  0.1115  0.4541\n",
            "  0.0274]\n",
            "\n",
            "b (aka b_star):\n",
            " [ 0.1619  0.5038  0.1079 -0.0024 -0.1356  0.1074  0.6    -0.8617 -0.7409\n",
            " -0.2634]\n",
            "\n",
            "M @ (θ*−μ) (should ≈ b):\n",
            " [ 0.1619  0.5038  0.1079 -0.0024 -0.1356  0.1074  0.6    -0.8617 -0.7409\n",
            " -0.2634]\n",
            "\n",
            "Recovered L (one factor of M):\n",
            " [[  2.5897   8.0565   1.7259  -0.0392  -2.1691   1.7167   9.5947 -13.7796\n",
            "  -11.8472  -4.2118]\n",
            " [  8.0565  25.0632   5.3691  -0.1219  -6.748    5.3406  29.8485 -42.8674\n",
            "  -36.8555 -13.1026]\n",
            " [  1.7259   5.3691   1.1503  -0.0261  -1.4457   1.1441   6.3942  -9.1832\n",
            "   -7.8964  -2.8069]\n",
            " [ -0.0392  -0.1219  -0.0261   0.0006   0.0328  -0.026   -0.1452   0.2086\n",
            "    0.1789   0.0637]\n",
            " [ -2.1691  -6.748   -1.4457   0.0328   1.8169  -1.4379  -8.0364  11.5416\n",
            "    9.9242   3.5277]\n",
            " [  1.7167   5.3406   1.1441  -0.026   -1.4379   1.138    6.3602  -9.1344\n",
            "   -7.8539  -2.792 ]\n",
            " [  9.5947  29.8485   6.3942  -0.1452  -8.0364   6.3602  35.5474 -51.052\n",
            "  -43.8924 -15.6042]\n",
            " [-13.7796 -42.8674  -9.1832   0.2086  11.5416  -9.1344 -51.052   73.3191\n",
            "   63.0365  22.4102]\n",
            " [-11.8472 -36.8555  -7.8964   0.1789   9.9242  -7.8539 -43.8924  63.0365\n",
            "   54.2128  19.2675]\n",
            " [ -4.2118 -13.1026  -2.8069   0.0637   3.5277  -2.792  -15.6042  22.4102\n",
            "   19.2675   6.8498]]\n",
            "\n",
            "Reconstruction error ‖M - LᵀL‖_F = 1.5734e-04\n"
          ]
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "metrics(mu_opt, mu_true, M_opt, M_true, theta_star, y_test, X_test)"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 332
        },
        "id": "7G198Y01RSYx",
        "outputId": "0c9a57e8-97b3-4a67-d76e-e9915e119aa8"
      },
      "execution_count": 10,
      "outputs": [
        {
          "output_type": "display_data",
          "data": {
            "text/plain": [
              "                                              metric         value\n",
              "0  Relative Frobenius error ‖M_hat - M_true‖_F / ...  1.474931e+04\n",
              "1                     Relative spectral error  ‖·‖_2  1.786834e+04\n",
              "2          Entrywise correlation corr(M_true, M_hat)  4.816039e-01\n",
              "3                              Min eigenvalue M_true  4.276565e-04\n",
              "4                          Min eigenvalue sym(M_hat) -1.000250e-11\n",
              "5                             Validation MSE (θ_fwd)  5.593741e-01\n",
              "6                                RMSE (μ̂ vs μ_true)  9.245819e+00\n",
              "7    RFB on covariances ‖Σ̂ - Σ_true‖_F / ‖Σ_true‖_F  1.200445e+06\n",
              "8       Relative spectral error on covariances ‖·‖_2  4.280892e+05"
            ],
            "text/html": [
              "\n",
              "  <div id=\"df-6bf33a1d-6bfb-4d61-9ddb-22b9289c4970\" class=\"colab-df-container\">\n",
              "    <div>\n",
              "<style scoped>\n",
              "    .dataframe tbody tr th:only-of-type {\n",
              "        vertical-align: middle;\n",
              "    }\n",
              "\n",
              "    .dataframe tbody tr th {\n",
              "        vertical-align: top;\n",
              "    }\n",
              "\n",
              "    .dataframe thead th {\n",
              "        text-align: right;\n",
              "    }\n",
              "</style>\n",
              "<table border=\"1\" class=\"dataframe\">\n",
              "  <thead>\n",
              "    <tr style=\"text-align: right;\">\n",
              "      <th></th>\n",
              "      <th>metric</th>\n",
              "      <th>value</th>\n",
              "    </tr>\n",
              "  </thead>\n",
              "  <tbody>\n",
              "    <tr>\n",
              "      <th>0</th>\n",
              "      <td>Relative Frobenius error ‖M_hat - M_true‖_F / ...</td>\n",
              "      <td>1.474931e+04</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>1</th>\n",
              "      <td>Relative spectral error  ‖·‖_2</td>\n",
              "      <td>1.786834e+04</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>2</th>\n",
              "      <td>Entrywise correlation corr(M_true, M_hat)</td>\n",
              "      <td>4.816039e-01</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>3</th>\n",
              "      <td>Min eigenvalue M_true</td>\n",
              "      <td>4.276565e-04</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>4</th>\n",
              "      <td>Min eigenvalue sym(M_hat)</td>\n",
              "      <td>-1.000250e-11</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>5</th>\n",
              "      <td>Validation MSE (θ_fwd)</td>\n",
              "      <td>5.593741e-01</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>6</th>\n",
              "      <td>RMSE (μ̂ vs μ_true)</td>\n",
              "      <td>9.245819e+00</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>7</th>\n",
              "      <td>RFB on covariances ‖Σ̂ - Σ_true‖_F / ‖Σ_true‖_F</td>\n",
              "      <td>1.200445e+06</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>8</th>\n",
              "      <td>Relative spectral error on covariances ‖·‖_2</td>\n",
              "      <td>4.280892e+05</td>\n",
              "    </tr>\n",
              "  </tbody>\n",
              "</table>\n",
              "</div>\n",
              "    <div class=\"colab-df-buttons\">\n",
              "\n",
              "  <div class=\"colab-df-container\">\n",
              "    <button class=\"colab-df-convert\" onclick=\"convertToInteractive('df-6bf33a1d-6bfb-4d61-9ddb-22b9289c4970')\"\n",
              "            title=\"Convert this dataframe to an interactive table.\"\n",
              "            style=\"display:none;\">\n",
              "\n",
              "  <svg xmlns=\"http://www.w3.org/2000/svg\" height=\"24px\" viewBox=\"0 -960 960 960\">\n",
              "    <path d=\"M120-120v-720h720v720H120Zm60-500h600v-160H180v160Zm220 220h160v-160H400v160Zm0 220h160v-160H400v160ZM180-400h160v-160H180v160Zm440 0h160v-160H620v160ZM180-180h160v-160H180v160Zm440 0h160v-160H620v160Z\"/>\n",
              "  </svg>\n",
              "    </button>\n",
              "\n",
              "  <style>\n",
              "    .colab-df-container {\n",
              "      display:flex;\n",
              "      gap: 12px;\n",
              "    }\n",
              "\n",
              "    .colab-df-convert {\n",
              "      background-color: #E8F0FE;\n",
              "      border: none;\n",
              "      border-radius: 50%;\n",
              "      cursor: pointer;\n",
              "      display: none;\n",
              "      fill: #1967D2;\n",
              "      height: 32px;\n",
              "      padding: 0 0 0 0;\n",
              "      width: 32px;\n",
              "    }\n",
              "\n",
              "    .colab-df-convert:hover {\n",
              "      background-color: #E2EBFA;\n",
              "      box-shadow: 0px 1px 2px rgba(60, 64, 67, 0.3), 0px 1px 3px 1px rgba(60, 64, 67, 0.15);\n",
              "      fill: #174EA6;\n",
              "    }\n",
              "\n",
              "    .colab-df-buttons div {\n",
              "      margin-bottom: 4px;\n",
              "    }\n",
              "\n",
              "    [theme=dark] .colab-df-convert {\n",
              "      background-color: #3B4455;\n",
              "      fill: #D2E3FC;\n",
              "    }\n",
              "\n",
              "    [theme=dark] .colab-df-convert:hover {\n",
              "      background-color: #434B5C;\n",
              "      box-shadow: 0px 1px 3px 1px rgba(0, 0, 0, 0.15);\n",
              "      filter: drop-shadow(0px 1px 2px rgba(0, 0, 0, 0.3));\n",
              "      fill: #FFFFFF;\n",
              "    }\n",
              "  </style>\n",
              "\n",
              "    <script>\n",
              "      const buttonEl =\n",
              "        document.querySelector('#df-6bf33a1d-6bfb-4d61-9ddb-22b9289c4970 button.colab-df-convert');\n",
              "      buttonEl.style.display =\n",
              "        google.colab.kernel.accessAllowed ? 'block' : 'none';\n",
              "\n",
              "      async function convertToInteractive(key) {\n",
              "        const element = document.querySelector('#df-6bf33a1d-6bfb-4d61-9ddb-22b9289c4970');\n",
              "        const dataTable =\n",
              "          await google.colab.kernel.invokeFunction('convertToInteractive',\n",
              "                                                    [key], {});\n",
              "        if (!dataTable) return;\n",
              "\n",
              "        const docLinkHtml = 'Like what you see? Visit the ' +\n",
              "          '<a target=\"_blank\" href=https://colab.research.google.com/notebooks/data_table.ipynb>data table notebook</a>'\n",
              "          + ' to learn more about interactive tables.';\n",
              "        element.innerHTML = '';\n",
              "        dataTable['output_type'] = 'display_data';\n",
              "        await google.colab.output.renderOutput(dataTable, element);\n",
              "        const docLink = document.createElement('div');\n",
              "        docLink.innerHTML = docLinkHtml;\n",
              "        element.appendChild(docLink);\n",
              "      }\n",
              "    </script>\n",
              "  </div>\n",
              "\n",
              "\n",
              "    <div id=\"df-483ab49a-5254-46b7-8f3e-f7d697a982fa\">\n",
              "      <button class=\"colab-df-quickchart\" onclick=\"quickchart('df-483ab49a-5254-46b7-8f3e-f7d697a982fa')\"\n",
              "                title=\"Suggest charts\"\n",
              "                style=\"display:none;\">\n",
              "\n",
              "<svg xmlns=\"http://www.w3.org/2000/svg\" height=\"24px\"viewBox=\"0 0 24 24\"\n",
              "     width=\"24px\">\n",
              "    <g>\n",
              "        <path d=\"M19 3H5c-1.1 0-2 .9-2 2v14c0 1.1.9 2 2 2h14c1.1 0 2-.9 2-2V5c0-1.1-.9-2-2-2zM9 17H7v-7h2v7zm4 0h-2V7h2v10zm4 0h-2v-4h2v4z\"/>\n",
              "    </g>\n",
              "</svg>\n",
              "      </button>\n",
              "\n",
              "<style>\n",
              "  .colab-df-quickchart {\n",
              "      --bg-color: #E8F0FE;\n",
              "      --fill-color: #1967D2;\n",
              "      --hover-bg-color: #E2EBFA;\n",
              "      --hover-fill-color: #174EA6;\n",
              "      --disabled-fill-color: #AAA;\n",
              "      --disabled-bg-color: #DDD;\n",
              "  }\n",
              "\n",
              "  [theme=dark] .colab-df-quickchart {\n",
              "      --bg-color: #3B4455;\n",
              "      --fill-color: #D2E3FC;\n",
              "      --hover-bg-color: #434B5C;\n",
              "      --hover-fill-color: #FFFFFF;\n",
              "      --disabled-bg-color: #3B4455;\n",
              "      --disabled-fill-color: #666;\n",
              "  }\n",
              "\n",
              "  .colab-df-quickchart {\n",
              "    background-color: var(--bg-color);\n",
              "    border: none;\n",
              "    border-radius: 50%;\n",
              "    cursor: pointer;\n",
              "    display: none;\n",
              "    fill: var(--fill-color);\n",
              "    height: 32px;\n",
              "    padding: 0;\n",
              "    width: 32px;\n",
              "  }\n",
              "\n",
              "  .colab-df-quickchart:hover {\n",
              "    background-color: var(--hover-bg-color);\n",
              "    box-shadow: 0 1px 2px rgba(60, 64, 67, 0.3), 0 1px 3px 1px rgba(60, 64, 67, 0.15);\n",
              "    fill: var(--button-hover-fill-color);\n",
              "  }\n",
              "\n",
              "  .colab-df-quickchart-complete:disabled,\n",
              "  .colab-df-quickchart-complete:disabled:hover {\n",
              "    background-color: var(--disabled-bg-color);\n",
              "    fill: var(--disabled-fill-color);\n",
              "    box-shadow: none;\n",
              "  }\n",
              "\n",
              "  .colab-df-spinner {\n",
              "    border: 2px solid var(--fill-color);\n",
              "    border-color: transparent;\n",
              "    border-bottom-color: var(--fill-color);\n",
              "    animation:\n",
              "      spin 1s steps(1) infinite;\n",
              "  }\n",
              "\n",
              "  @keyframes spin {\n",
              "    0% {\n",
              "      border-color: transparent;\n",
              "      border-bottom-color: var(--fill-color);\n",
              "      border-left-color: var(--fill-color);\n",
              "    }\n",
              "    20% {\n",
              "      border-color: transparent;\n",
              "      border-left-color: var(--fill-color);\n",
              "      border-top-color: var(--fill-color);\n",
              "    }\n",
              "    30% {\n",
              "      border-color: transparent;\n",
              "      border-left-color: var(--fill-color);\n",
              "      border-top-color: var(--fill-color);\n",
              "      border-right-color: var(--fill-color);\n",
              "    }\n",
              "    40% {\n",
              "      border-color: transparent;\n",
              "      border-right-color: var(--fill-color);\n",
              "      border-top-color: var(--fill-color);\n",
              "    }\n",
              "    60% {\n",
              "      border-color: transparent;\n",
              "      border-right-color: var(--fill-color);\n",
              "    }\n",
              "    80% {\n",
              "      border-color: transparent;\n",
              "      border-right-color: var(--fill-color);\n",
              "      border-bottom-color: var(--fill-color);\n",
              "    }\n",
              "    90% {\n",
              "      border-color: transparent;\n",
              "      border-bottom-color: var(--fill-color);\n",
              "    }\n",
              "  }\n",
              "</style>\n",
              "\n",
              "      <script>\n",
              "        async function quickchart(key) {\n",
              "          const quickchartButtonEl =\n",
              "            document.querySelector('#' + key + ' button');\n",
              "          quickchartButtonEl.disabled = true;  // To prevent multiple clicks.\n",
              "          quickchartButtonEl.classList.add('colab-df-spinner');\n",
              "          try {\n",
              "            const charts = await google.colab.kernel.invokeFunction(\n",
              "                'suggestCharts', [key], {});\n",
              "          } catch (error) {\n",
              "            console.error('Error during call to suggestCharts:', error);\n",
              "          }\n",
              "          quickchartButtonEl.classList.remove('colab-df-spinner');\n",
              "          quickchartButtonEl.classList.add('colab-df-quickchart-complete');\n",
              "        }\n",
              "        (() => {\n",
              "          let quickchartButtonEl =\n",
              "            document.querySelector('#df-483ab49a-5254-46b7-8f3e-f7d697a982fa button');\n",
              "          quickchartButtonEl.style.display =\n",
              "            google.colab.kernel.accessAllowed ? 'block' : 'none';\n",
              "        })();\n",
              "      </script>\n",
              "    </div>\n",
              "\n",
              "    </div>\n",
              "  </div>\n"
            ],
            "application/vnd.google.colaboratory.intrinsic+json": {
              "type": "dataframe",
              "summary": "{\n  \"name\": \"metrics(mu_opt, mu_true, M_opt, M_true, theta_star, y_test, X_test)\",\n  \"rows\": 9,\n  \"fields\": [\n    {\n      \"column\": \"metric\",\n      \"properties\": {\n        \"dtype\": \"string\",\n        \"num_unique_values\": 9,\n        \"samples\": [\n          \"RFB on covariances \\u2016\\u03a3\\u0302 - \\u03a3_true\\u2016_F / \\u2016\\u03a3_true\\u2016_F\",\n          \"Relative spectral error  \\u2016\\u00b7\\u2016_2\",\n          \"Validation MSE (\\u03b8_fwd)\"\n        ],\n        \"semantic_type\": \"\",\n        \"description\": \"\"\n      }\n    },\n    {\n      \"column\": \"value\",\n      \"properties\": {\n        \"dtype\": \"number\",\n        \"std\": 405933.85369229584,\n        \"min\": -1.0002504853943694e-11,\n        \"max\": 1200444.7938222177,\n        \"num_unique_values\": 9,\n        \"samples\": [\n          1200444.7938222177,\n          17868.33868595419,\n          0.5593740541201775\n        ],\n        \"semantic_type\": \"\",\n        \"description\": \"\"\n      }\n    }\n  ]\n}"
            }
          },
          "metadata": {}
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "def fit_theta_star(mu_np, max_iter=500):\n",
        "    mu_t = torch.tensor(mu_np, dtype=torch.float32)\n",
        "    y_train_t = torch.from_numpy(y_train).float() # Convert y_train to tensor\n",
        "\n",
        "    model = nn.Linear(d, 1, bias=False)\n",
        "    # LBFGS converges fast to the exact quadratic optimum\n",
        "    opt = LBFGS(model.parameters(), lr=1.0, max_iter=max_iter, history_size=50, line_search_fn=\"strong_wolfe\")\n",
        "\n",
        "    def closure():\n",
        "        opt.zero_grad()\n",
        "        # Convert X_train.T to a torch tensor and transpose back to (n, d)\n",
        "        X_train_t = torch.from_numpy(X_train).float()\n",
        "        pred = model(X_train_t).squeeze()           # (n,)\n",
        "        data_ss = torch.sum((pred - y_train_t)**2)  # sum of squares, use tensor y_train_t\n",
        "        w = model.weight.view(-1)               # (d,)\n",
        "        diff = (w - mu_t)\n",
        "        reg = diff @ (M_true_t @ diff)          # (θ-μ)^T M_true (θ-μ)\n",
        "        loss = data_ss + reg\n",
        "        loss.backward()\n",
        "        return loss\n",
        "\n",
        "    opt.step(closure)\n",
        "    with torch.no_grad():\n",
        "        theta = model.weight.view(-1).detach().cpu().numpy()\n",
        "    return theta"
      ],
      "metadata": {
        "id": "8qGtjE_XCJWD"
      },
      "execution_count": 11,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "# -------------------------\n",
        "# 4.b.1) Inverse recovery (multi-constraint least squares with closed form)\n",
        "#     Build J >= d constraints with diverse priors μ_j and solve M V = B\n",
        "#     (exact when constraints are full rank)\n",
        "#     Generate {mu_j}, solve forward for {theta_j}, compute {b_j}, then solve:\n",
        "#     find M PSD s.t. M(θ_j - μ_j) = b_j for all j\n",
        "# -------------------------\n",
        "\n",
        "# -------------------------\n",
        "# 4.b.2) Inverse recovery (multi-constraint least squares with ML model)\n",
        "#    Build J >= d constraints with diverse priors μ_j, train to get θ*_j and b_j\n",
        "#    Loss = ||Xθ - y||^2  +  (θ - μ)^T M_true (θ - μ)\n",
        "#    (sum of squares, to match stationarity form)\n",
        "# -------------------------\n",
        "\n",
        "J = d\n",
        "mu_list = []\n",
        "theta_list = []\n",
        "b_list = []\n",
        "\n",
        "start = time.perf_counter()\n",
        "\n",
        "for j in range(J):\n",
        "    mu_j = np.zeros(d)\n",
        "    mu_j[j] = 1.0  # canonical basis ensures V is typically full-row-rank\n",
        "    mu_list.append(mu_j)\n",
        "\n",
        "    if mode == \"ml\":\n",
        "        theta_j = fit_theta_star(mu_j)\n",
        "    else:\n",
        "        rhs_j = Xty + M_true @ mu_j\n",
        "        theta_j = np.linalg.solve(A, rhs_j)  # same A for all; only rhs changes\n",
        "    theta_list.append(theta_j)\n",
        "\n",
        "    # b_j = -X^T (X θ_j - y)   (sum-of-squares convention)\n",
        "    b_j = - X_train.T @ (X_train @ theta_j - y_train)\n",
        "\n",
        "    b_list.append(b_j)\n",
        "\n",
        "end = time.perf_counter()\n",
        "print(\"time:\", end - start)\n",
        "\n",
        "V = np.column_stack([theta_list[j] - mu_list[j] for j in range(J)])  # d x J\n",
        "B = np.column_stack(b_list)                                          # d x J\n",
        "\n",
        "# Inverse recovery: Solve M_hat = B V^+  (right pseudoinverse); exact if V is full row rank\n",
        "M_hat = B @ np.linalg.pinv(V)\n",
        "\n",
        "# Symmetrize and PSD-project (nearest PSD in Frobenius norm) via eigenvalue clipping\n",
        "M_sym = 0.5 * (M_hat + M_hat.T)\n",
        "w, U = np.linalg.eigh(M_sym)\n",
        "w_clipped = np.clip(w, 0, None)\n",
        "M_psd = U @ np.diag(w_clipped) @ U.T\n",
        "M_hat = copy.deepcopy(M_psd)"
      ],
      "metadata": {
        "id": "yjaAp79pCJ9z",
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "outputId": "ff884fb3-f134-4091-8ff0-ed2375bbaf86"
      },
      "execution_count": 13,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "time: 0.0015552180000213411\n"
          ]
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "# ---------- helpers ----------\n",
        "\n",
        "def symmetrize(M: np.ndarray) -> np.ndarray:\n",
        "    \"\"\"Return 0.5*(M + M^T). Safe before linear solves.\"\"\"\n",
        "    return 0.5 * (M + M.T)\n",
        "\n",
        "def compute_b_sse(X: np.ndarray, theta: np.ndarray, y: np.ndarray) -> np.ndarray:\n",
        "    \"\"\"\n",
        "    b = - X^T (X θ - y), using SUM-OF-SQUARES convention (not mean).\n",
        "    Shapes: X:(n,d), theta:(d,), y:(n,)\n",
        "    \"\"\"\n",
        "    resid = X @ theta - y\n",
        "    return - X.T @ resid\n",
        "\n",
        "def stationarity_residual(M: np.ndarray, theta: np.ndarray, mu: np.ndarray, b: np.ndarray) -> float:\n",
        "    \"\"\"‖M(θ - μ) - b‖₂ — should be tiny if everything is consistent.\"\"\"\n",
        "    return np.linalg.norm(M @ (theta - mu) - b)\n",
        "\n",
        "# ---------- μ estimators ----------\n",
        "\n",
        "def infer_mu_single(M_hat: np.ndarray,\n",
        "                    theta: np.ndarray,\n",
        "                    b: np.ndarray,\n",
        "                    ridge: float = 1e-8,\n",
        "                    return_parts: bool = False):\n",
        "    \"\"\"\n",
        "    Estimate μ from one (θ, b) pair given M_hat.\n",
        "    - If M_hat ≻ 0, solution is μ = θ - M^{-1} b.\n",
        "    - For PSD/ill-conditioned M, solve (M^2 + ridge I) μ = M^2 θ - M b (minimum-norm limit as ridge→0).\n",
        "    Set return_parts=True to also split μ into identifiable vs. nullspace components.\n",
        "    \"\"\"\n",
        "    M = symmetrize(M_hat)\n",
        "    d = M.shape[0]\n",
        "\n",
        "    # SPD fast path; fall back to ridge-stabilized normal equations if needed\n",
        "    if ridge <= 0:\n",
        "        try:\n",
        "            mu = theta - np.linalg.solve(M, b)\n",
        "        except np.linalg.LinAlgError:\n",
        "            ridge = 1e-8\n",
        "\n",
        "    if ridge > 0:\n",
        "        A = M @ M + ridge * np.eye(d)\n",
        "        rhs = M @ (M @ theta - b)\n",
        "        mu = np.linalg.solve(A, rhs)\n",
        "\n",
        "    if not return_parts:\n",
        "        return mu\n",
        "\n",
        "    # Decompose into identifiable (Range(M)) vs ambiguous (Null(M)) parts\n",
        "    w, U = np.linalg.eigh(M)\n",
        "    tol = 1e-10\n",
        "    Ur = U[:, w > tol]   # range basis\n",
        "    Un = U[:, w <= tol]  # nullspace basis\n",
        "    mu_ident = Ur @ (Ur.T @ mu)\n",
        "    mu_ambig = Un @ (Un.T @ mu)\n",
        "    return mu, mu_ident, mu_ambig\n",
        "\n",
        "b_star_new = compute_b_sse(X_train, theta_star, y_train)\n",
        "mu_hat = infer_mu_single(M_hat, theta_star, b_star_new, ridge=1e-8)"
      ],
      "metadata": {
        "id": "Ciw-vBcjfmFX"
      },
      "execution_count": 16,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "metrics(mu_hat, mu_true, M_hat, M_true, theta_star, y_test, X_test)"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 332
        },
        "id": "kAw7P9BcCO3g",
        "outputId": "61056057-c5b1-4404-b7da-b41b5717a8b4"
      },
      "execution_count": 17,
      "outputs": [
        {
          "output_type": "display_data",
          "data": {
            "text/plain": [
              "                                              metric         value\n",
              "0  Relative Frobenius error ‖M_hat - M_true‖_F / ...  2.826092e-10\n",
              "1                     Relative spectral error  ‖·‖_2  3.038032e-10\n",
              "2          Entrywise correlation corr(M_true, M_hat)  1.000000e+00\n",
              "3                              Min eigenvalue M_true  4.276565e-04\n",
              "4                          Min eigenvalue sym(M_hat)  4.276561e-04\n",
              "5                             Validation MSE (θ_fwd)  5.593741e-01\n",
              "6                                RMSE (μ̂ vs μ_true)  4.524047e-01\n",
              "7    RFB on covariances ‖Σ̂ - Σ_true‖_F / ‖Σ_true‖_F  9.371235e-07\n",
              "8       Relative spectral error on covariances ‖·‖_2  9.440251e-07"
            ],
            "text/html": [
              "\n",
              "  <div id=\"df-a2d65c11-c740-4ab5-bdbe-a4bef66f45c0\" class=\"colab-df-container\">\n",
              "    <div>\n",
              "<style scoped>\n",
              "    .dataframe tbody tr th:only-of-type {\n",
              "        vertical-align: middle;\n",
              "    }\n",
              "\n",
              "    .dataframe tbody tr th {\n",
              "        vertical-align: top;\n",
              "    }\n",
              "\n",
              "    .dataframe thead th {\n",
              "        text-align: right;\n",
              "    }\n",
              "</style>\n",
              "<table border=\"1\" class=\"dataframe\">\n",
              "  <thead>\n",
              "    <tr style=\"text-align: right;\">\n",
              "      <th></th>\n",
              "      <th>metric</th>\n",
              "      <th>value</th>\n",
              "    </tr>\n",
              "  </thead>\n",
              "  <tbody>\n",
              "    <tr>\n",
              "      <th>0</th>\n",
              "      <td>Relative Frobenius error ‖M_hat - M_true‖_F / ...</td>\n",
              "      <td>2.826092e-10</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>1</th>\n",
              "      <td>Relative spectral error  ‖·‖_2</td>\n",
              "      <td>3.038032e-10</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>2</th>\n",
              "      <td>Entrywise correlation corr(M_true, M_hat)</td>\n",
              "      <td>1.000000e+00</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>3</th>\n",
              "      <td>Min eigenvalue M_true</td>\n",
              "      <td>4.276565e-04</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>4</th>\n",
              "      <td>Min eigenvalue sym(M_hat)</td>\n",
              "      <td>4.276561e-04</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>5</th>\n",
              "      <td>Validation MSE (θ_fwd)</td>\n",
              "      <td>5.593741e-01</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>6</th>\n",
              "      <td>RMSE (μ̂ vs μ_true)</td>\n",
              "      <td>4.524047e-01</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>7</th>\n",
              "      <td>RFB on covariances ‖Σ̂ - Σ_true‖_F / ‖Σ_true‖_F</td>\n",
              "      <td>9.371235e-07</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>8</th>\n",
              "      <td>Relative spectral error on covariances ‖·‖_2</td>\n",
              "      <td>9.440251e-07</td>\n",
              "    </tr>\n",
              "  </tbody>\n",
              "</table>\n",
              "</div>\n",
              "    <div class=\"colab-df-buttons\">\n",
              "\n",
              "  <div class=\"colab-df-container\">\n",
              "    <button class=\"colab-df-convert\" onclick=\"convertToInteractive('df-a2d65c11-c740-4ab5-bdbe-a4bef66f45c0')\"\n",
              "            title=\"Convert this dataframe to an interactive table.\"\n",
              "            style=\"display:none;\">\n",
              "\n",
              "  <svg xmlns=\"http://www.w3.org/2000/svg\" height=\"24px\" viewBox=\"0 -960 960 960\">\n",
              "    <path d=\"M120-120v-720h720v720H120Zm60-500h600v-160H180v160Zm220 220h160v-160H400v160Zm0 220h160v-160H400v160ZM180-400h160v-160H180v160Zm440 0h160v-160H620v160ZM180-180h160v-160H180v160Zm440 0h160v-160H620v160Z\"/>\n",
              "  </svg>\n",
              "    </button>\n",
              "\n",
              "  <style>\n",
              "    .colab-df-container {\n",
              "      display:flex;\n",
              "      gap: 12px;\n",
              "    }\n",
              "\n",
              "    .colab-df-convert {\n",
              "      background-color: #E8F0FE;\n",
              "      border: none;\n",
              "      border-radius: 50%;\n",
              "      cursor: pointer;\n",
              "      display: none;\n",
              "      fill: #1967D2;\n",
              "      height: 32px;\n",
              "      padding: 0 0 0 0;\n",
              "      width: 32px;\n",
              "    }\n",
              "\n",
              "    .colab-df-convert:hover {\n",
              "      background-color: #E2EBFA;\n",
              "      box-shadow: 0px 1px 2px rgba(60, 64, 67, 0.3), 0px 1px 3px 1px rgba(60, 64, 67, 0.15);\n",
              "      fill: #174EA6;\n",
              "    }\n",
              "\n",
              "    .colab-df-buttons div {\n",
              "      margin-bottom: 4px;\n",
              "    }\n",
              "\n",
              "    [theme=dark] .colab-df-convert {\n",
              "      background-color: #3B4455;\n",
              "      fill: #D2E3FC;\n",
              "    }\n",
              "\n",
              "    [theme=dark] .colab-df-convert:hover {\n",
              "      background-color: #434B5C;\n",
              "      box-shadow: 0px 1px 3px 1px rgba(0, 0, 0, 0.15);\n",
              "      filter: drop-shadow(0px 1px 2px rgba(0, 0, 0, 0.3));\n",
              "      fill: #FFFFFF;\n",
              "    }\n",
              "  </style>\n",
              "\n",
              "    <script>\n",
              "      const buttonEl =\n",
              "        document.querySelector('#df-a2d65c11-c740-4ab5-bdbe-a4bef66f45c0 button.colab-df-convert');\n",
              "      buttonEl.style.display =\n",
              "        google.colab.kernel.accessAllowed ? 'block' : 'none';\n",
              "\n",
              "      async function convertToInteractive(key) {\n",
              "        const element = document.querySelector('#df-a2d65c11-c740-4ab5-bdbe-a4bef66f45c0');\n",
              "        const dataTable =\n",
              "          await google.colab.kernel.invokeFunction('convertToInteractive',\n",
              "                                                    [key], {});\n",
              "        if (!dataTable) return;\n",
              "\n",
              "        const docLinkHtml = 'Like what you see? Visit the ' +\n",
              "          '<a target=\"_blank\" href=https://colab.research.google.com/notebooks/data_table.ipynb>data table notebook</a>'\n",
              "          + ' to learn more about interactive tables.';\n",
              "        element.innerHTML = '';\n",
              "        dataTable['output_type'] = 'display_data';\n",
              "        await google.colab.output.renderOutput(dataTable, element);\n",
              "        const docLink = document.createElement('div');\n",
              "        docLink.innerHTML = docLinkHtml;\n",
              "        element.appendChild(docLink);\n",
              "      }\n",
              "    </script>\n",
              "  </div>\n",
              "\n",
              "\n",
              "    <div id=\"df-09ba841a-c92b-4cd1-9f4b-410f119d0a5e\">\n",
              "      <button class=\"colab-df-quickchart\" onclick=\"quickchart('df-09ba841a-c92b-4cd1-9f4b-410f119d0a5e')\"\n",
              "                title=\"Suggest charts\"\n",
              "                style=\"display:none;\">\n",
              "\n",
              "<svg xmlns=\"http://www.w3.org/2000/svg\" height=\"24px\"viewBox=\"0 0 24 24\"\n",
              "     width=\"24px\">\n",
              "    <g>\n",
              "        <path d=\"M19 3H5c-1.1 0-2 .9-2 2v14c0 1.1.9 2 2 2h14c1.1 0 2-.9 2-2V5c0-1.1-.9-2-2-2zM9 17H7v-7h2v7zm4 0h-2V7h2v10zm4 0h-2v-4h2v4z\"/>\n",
              "    </g>\n",
              "</svg>\n",
              "      </button>\n",
              "\n",
              "<style>\n",
              "  .colab-df-quickchart {\n",
              "      --bg-color: #E8F0FE;\n",
              "      --fill-color: #1967D2;\n",
              "      --hover-bg-color: #E2EBFA;\n",
              "      --hover-fill-color: #174EA6;\n",
              "      --disabled-fill-color: #AAA;\n",
              "      --disabled-bg-color: #DDD;\n",
              "  }\n",
              "\n",
              "  [theme=dark] .colab-df-quickchart {\n",
              "      --bg-color: #3B4455;\n",
              "      --fill-color: #D2E3FC;\n",
              "      --hover-bg-color: #434B5C;\n",
              "      --hover-fill-color: #FFFFFF;\n",
              "      --disabled-bg-color: #3B4455;\n",
              "      --disabled-fill-color: #666;\n",
              "  }\n",
              "\n",
              "  .colab-df-quickchart {\n",
              "    background-color: var(--bg-color);\n",
              "    border: none;\n",
              "    border-radius: 50%;\n",
              "    cursor: pointer;\n",
              "    display: none;\n",
              "    fill: var(--fill-color);\n",
              "    height: 32px;\n",
              "    padding: 0;\n",
              "    width: 32px;\n",
              "  }\n",
              "\n",
              "  .colab-df-quickchart:hover {\n",
              "    background-color: var(--hover-bg-color);\n",
              "    box-shadow: 0 1px 2px rgba(60, 64, 67, 0.3), 0 1px 3px 1px rgba(60, 64, 67, 0.15);\n",
              "    fill: var(--button-hover-fill-color);\n",
              "  }\n",
              "\n",
              "  .colab-df-quickchart-complete:disabled,\n",
              "  .colab-df-quickchart-complete:disabled:hover {\n",
              "    background-color: var(--disabled-bg-color);\n",
              "    fill: var(--disabled-fill-color);\n",
              "    box-shadow: none;\n",
              "  }\n",
              "\n",
              "  .colab-df-spinner {\n",
              "    border: 2px solid var(--fill-color);\n",
              "    border-color: transparent;\n",
              "    border-bottom-color: var(--fill-color);\n",
              "    animation:\n",
              "      spin 1s steps(1) infinite;\n",
              "  }\n",
              "\n",
              "  @keyframes spin {\n",
              "    0% {\n",
              "      border-color: transparent;\n",
              "      border-bottom-color: var(--fill-color);\n",
              "      border-left-color: var(--fill-color);\n",
              "    }\n",
              "    20% {\n",
              "      border-color: transparent;\n",
              "      border-left-color: var(--fill-color);\n",
              "      border-top-color: var(--fill-color);\n",
              "    }\n",
              "    30% {\n",
              "      border-color: transparent;\n",
              "      border-left-color: var(--fill-color);\n",
              "      border-top-color: var(--fill-color);\n",
              "      border-right-color: var(--fill-color);\n",
              "    }\n",
              "    40% {\n",
              "      border-color: transparent;\n",
              "      border-right-color: var(--fill-color);\n",
              "      border-top-color: var(--fill-color);\n",
              "    }\n",
              "    60% {\n",
              "      border-color: transparent;\n",
              "      border-right-color: var(--fill-color);\n",
              "    }\n",
              "    80% {\n",
              "      border-color: transparent;\n",
              "      border-right-color: var(--fill-color);\n",
              "      border-bottom-color: var(--fill-color);\n",
              "    }\n",
              "    90% {\n",
              "      border-color: transparent;\n",
              "      border-bottom-color: var(--fill-color);\n",
              "    }\n",
              "  }\n",
              "</style>\n",
              "\n",
              "      <script>\n",
              "        async function quickchart(key) {\n",
              "          const quickchartButtonEl =\n",
              "            document.querySelector('#' + key + ' button');\n",
              "          quickchartButtonEl.disabled = true;  // To prevent multiple clicks.\n",
              "          quickchartButtonEl.classList.add('colab-df-spinner');\n",
              "          try {\n",
              "            const charts = await google.colab.kernel.invokeFunction(\n",
              "                'suggestCharts', [key], {});\n",
              "          } catch (error) {\n",
              "            console.error('Error during call to suggestCharts:', error);\n",
              "          }\n",
              "          quickchartButtonEl.classList.remove('colab-df-spinner');\n",
              "          quickchartButtonEl.classList.add('colab-df-quickchart-complete');\n",
              "        }\n",
              "        (() => {\n",
              "          let quickchartButtonEl =\n",
              "            document.querySelector('#df-09ba841a-c92b-4cd1-9f4b-410f119d0a5e button');\n",
              "          quickchartButtonEl.style.display =\n",
              "            google.colab.kernel.accessAllowed ? 'block' : 'none';\n",
              "        })();\n",
              "      </script>\n",
              "    </div>\n",
              "\n",
              "    </div>\n",
              "  </div>\n"
            ],
            "application/vnd.google.colaboratory.intrinsic+json": {
              "type": "dataframe",
              "summary": "{\n  \"name\": \"metrics(mu_hat, mu_true, M_hat, M_true, theta_star, y_test, X_test)\",\n  \"rows\": 9,\n  \"fields\": [\n    {\n      \"column\": \"metric\",\n      \"properties\": {\n        \"dtype\": \"string\",\n        \"num_unique_values\": 9,\n        \"samples\": [\n          \"RFB on covariances \\u2016\\u03a3\\u0302 - \\u03a3_true\\u2016_F / \\u2016\\u03a3_true\\u2016_F\",\n          \"Relative spectral error  \\u2016\\u00b7\\u2016_2\",\n          \"Validation MSE (\\u03b8_fwd)\"\n        ],\n        \"semantic_type\": \"\",\n        \"description\": \"\"\n      }\n    },\n    {\n      \"column\": \"value\",\n      \"properties\": {\n        \"dtype\": \"number\",\n        \"std\": 0.3652895358833866,\n        \"min\": 2.826092482149628e-10,\n        \"max\": 0.9999999999999998,\n        \"num_unique_values\": 9,\n        \"samples\": [\n          9.371235224993417e-07,\n          3.038031608162033e-10,\n          0.5593740541201775\n        ],\n        \"semantic_type\": \"\",\n        \"description\": \"\"\n      }\n    }\n  ]\n}"
            }
          },
          "metadata": {}
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "def _sym(M): return 0.5*(M + M.T)\n",
        "def _psd_project(M):\n",
        "    w, U = np.linalg.eigh(_sym(M))\n",
        "    w = np.clip(w, 0, None)\n",
        "    return (U * w) @ U.T\n",
        "\n",
        "def constraint_map(theta_star, b,\n",
        "                   mu0=None, Lambda0=None,          # μ ~ N(mu0, Λ0^{-1})\n",
        "                   M0=None, lam=1e-3,               # ridge around M0: (λ/2)||M-M0||_F^2\n",
        "                   tau=0.1, max_iter=300, tol=1e-8):\n",
        "    \"\"\"\n",
        "    MAP for: r = M(θ* - μ) - b, with r ~ N(0, τ^2 I).\n",
        "    Unknowns: μ, M (PSD). Priors: μ ~ N(mu0, Λ0^{-1}), ridge prior on M.\n",
        "    \"\"\"\n",
        "    d = theta_star.size\n",
        "    I = np.eye(d)\n",
        "    if mu0 is None:    mu0 = np.zeros(d)\n",
        "    if Lambda0 is None: Lambda0 = 1e-3 * I\n",
        "    M = _psd_project(I if M0 is None else M0.copy())\n",
        "    mu = mu0.copy()\n",
        "\n",
        "    prev = np.inf\n",
        "    for _ in range(max_iter):\n",
        "        # μ | M\n",
        "        A = (M.T @ M) / (tau**2) + Lambda0\n",
        "        rhs = (M.T @ (M @ theta_star - b)) / (tau**2) + Lambda0 @ mu0\n",
        "        mu = np.linalg.solve(A, rhs)\n",
        "\n",
        "        # M | μ  (single-vector ridge; then PSD-projection)\n",
        "        v = theta_star - mu\n",
        "        # Add a small epsilon to the diagonal of G for numerical stability\n",
        "        epsilon = 1e-9  # or a similar small value\n",
        "        G = v[:, None] @ v[None, :] + (lam * tau**2) * I + epsilon * I\n",
        "        M_uncon = (b[:, None] @ v[None, :]) @ np.linalg.inv(G)\n",
        "        M = _psd_project(_sym(M_uncon))\n",
        "\n",
        "        # objective (for convergence)\n",
        "        resid = M @ (theta_star - mu) - b\n",
        "        obj = 0.5*np.linalg.norm(resid)**2/(tau**2) \\\n",
        "              + 0.5*(mu-mu0) @ (Lambda0 @ (mu-mu0)) \\\n",
        "              + 0.5*lam*np.linalg.norm(M - (np.zeros_like(M) if M0 is None else M0))**2\n",
        "        if abs(prev - obj) < tol: break\n",
        "        prev = obj\n",
        "\n",
        "    return M, mu, {\"resid_norm\": float(np.linalg.norm(M @ (theta_star - mu) - b)),\n",
        "                   \"min_eig_M\": float(np.linalg.eigvalsh(M).min())}\n",
        "\n",
        "# constraint_map(theta_star, b_true)"
      ],
      "metadata": {
        "id": "YXkz87veCSPg"
      },
      "execution_count": 18,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "# pip install pymc arviz numpy\n",
        "import numpy as np\n",
        "import pymc as pm\n",
        "import arviz as az\n",
        "\n",
        "def constraint_mcmc_safe(\n",
        "    theta_star, b, *,\n",
        "    eta=2.0,          # LKJ concentration (larger -> closer to identity)\n",
        "    mu_scale=5.0,     # prior SD for μ\n",
        "    tau_prior=0.5,    # HalfNormal prior scale for τ\n",
        "    jitter=1e-6,      # keeps Σ SPD & invertible\n",
        "    # draws=2000, tune=1000, chains=4, seed=42,\n",
        "    draws=500, tune=100, chains=4, seed=42,\n",
        "):\n",
        "    \"\"\"\n",
        "    Single-constraint Bayesian model:\n",
        "        r = M(θ* - μ) - b,  with  r ~ N(0, τ^2 I),  M = Σ^{-1}.\n",
        "    No aesara/pytensor imports. Adds jitter & input validation to avoid NaNs/Infs.\n",
        "    \"\"\"\n",
        "    theta_star = np.asarray(theta_star, dtype=float).reshape(-1)\n",
        "    b          = np.asarray(b,          dtype=float).reshape(-1)\n",
        "    d = theta_star.size\n",
        "    assert b.size == d, \"theta_star and b must have the same length\"\n",
        "\n",
        "    # --- input validation ---\n",
        "    if not np.isfinite(theta_star).all():\n",
        "        raise ValueError(\"theta_star contains NaN/Inf.\")\n",
        "    if not np.isfinite(b).all():\n",
        "        raise ValueError(\"b contains NaN/Inf.\")\n",
        "\n",
        "    with pm.Model() as model:\n",
        "        # Prior on μ\n",
        "        mu = pm.Normal(\"mu\", 0.0, sigma=mu_scale, shape=d, initval=np.clip(theta_star - b, -3*mu_scale, 3*mu_scale))\n",
        "\n",
        "        # LKJ prior over covariance Σ (via Cholesky)\n",
        "        sd_dist  = pm.HalfNormal.dist(sigma=1.0, shape=d)\n",
        "        packed_L = pm.LKJCholeskyCov(\"packed_L\", n=d, eta=eta, sd_dist=sd_dist, compute_corr=False)\n",
        "        L        = pm.expand_packed_triangular(d, packed_L, lower=True)\n",
        "\n",
        "        # Σ = L Lᵀ + jitter·I  (keeps Σ strictly PD -> stable inverse)\n",
        "        Sigma = pm.Deterministic(\"Sigma\", L @ L.T + jitter * np.eye(d))\n",
        "\n",
        "        # Residual scale τ\n",
        "        tau = pm.HalfNormal(\"tau\", tau_prior)\n",
        "\n",
        "        # Residual r = Σ^{-1}(θ* - μ) - b\n",
        "        v         = theta_star - mu\n",
        "        Sigma_inv = pm.math.matrix_inverse(Sigma)\n",
        "        prec_v    = pm.math.dot(Sigma_inv, v)\n",
        "        resid     = prec_v - b\n",
        "\n",
        "        # Likelihood: resid ~ N(0, τ^2 I)  -> use zero-data trick\n",
        "        pm.Normal(\"stationarity\", mu=resid, sigma=tau, observed=np.zeros(d))\n",
        "\n",
        "        # Also keep M in the trace for convenience\n",
        "        pm.Deterministic(\"M\", Sigma_inv)\n",
        "\n",
        "        idata = pm.sample(\n",
        "            draws=draws, tune=tune, chains=chains,\n",
        "            init=\"adapt_diag\",            # more stable init\n",
        "            target_accept=0.95,           # fewer divergences\n",
        "            random_seed=seed,\n",
        "            progressbar=True,\n",
        "        )\n",
        "\n",
        "    # Posterior means\n",
        "    mu_post    = idata.posterior[\"mu\"].mean(dim=(\"chain\",\"draw\")).values\n",
        "    Sigma_post = idata.posterior[\"Sigma\"].mean(dim=(\"chain\",\"draw\")).values\n",
        "    M_post     = idata.posterior[\"M\"].mean(dim=(\"chain\",\"draw\")).values\n",
        "\n",
        "    return {\"mu\": mu_post, \"Sigma\": Sigma_post, \"M\": M_post, \"idata\": idata}\n",
        "\n",
        "# constraint_mcmc_safe(theta_star, b_true)"
      ],
      "metadata": {
        "id": "r1E04csDXVBA"
      },
      "execution_count": 19,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "def constraint_vi(\n",
        "    theta_star, b, *,\n",
        "    steps=10_000,           # ADVI steps\n",
        "    eta=2.0,                # LKJ concentration (larger -> closer to identity)\n",
        "    mu_scale=5.0,           # prior SD for μ\n",
        "    tau_prior=0.5,          # HalfNormal prior scale for τ\n",
        "    jitter=1e-6,            # ensures Σ is strictly PD\n",
        "    seed=123\n",
        "):\n",
        "    \"\"\"\n",
        "    Variational Inference for single-constraint model:\n",
        "        r = M (θ* - μ) - b,  with  r ~ N(0, τ^2 I),  M = Σ^{-1}.\n",
        "    No aesara/pytensor; guards against LKJ tuple return; stable with jitter.\n",
        "    Returns posterior means and the InferenceData.\n",
        "    \"\"\"\n",
        "    theta_star = np.asarray(theta_star, dtype=float).reshape(-1)\n",
        "    b          = np.asarray(b,          dtype=float).reshape(-1)\n",
        "    d = theta_star.size\n",
        "    if b.size != d:\n",
        "        raise ValueError(\"theta_star and b must have same length\")\n",
        "    if not (np.isfinite(theta_star).all() and np.isfinite(b).all()):\n",
        "        raise ValueError(\"theta_star or b contains NaN/Inf\")\n",
        "\n",
        "    with pm.Model() as model:\n",
        "        # μ prior\n",
        "        mu = pm.Normal(\"mu\", 0.0, sigma=mu_scale, shape=d,\n",
        "                       initval=np.clip(theta_star - b, -3*mu_scale, 3*mu_scale))\n",
        "\n",
        "        # Σ prior via LKJ-Cholesky; force non-tuple return\n",
        "        sd_dist  = pm.HalfNormal.dist(sigma=1.0, shape=d)\n",
        "        lkj_obj  = pm.LKJCholeskyCov(\"lkj_chol\", n=d, eta=eta,\n",
        "                                     sd_dist=sd_dist, compute_corr=False)\n",
        "        packed_L = lkj_obj[0] if isinstance(lkj_obj, tuple) else lkj_obj\n",
        "        L        = pm.expand_packed_triangular(d, packed_L, lower=True)\n",
        "\n",
        "        # Σ = L Lᵀ + jitter·I  (keeps Σ strictly PD)\n",
        "        Sigma = pm.Deterministic(\"Sigma\", L @ L.T + jitter * np.eye(d))\n",
        "        Sigma_inv = pm.Deterministic(\"M\", pm.math.matrix_inverse(Sigma))  # M = Σ^{-1}\n",
        "\n",
        "        # Residual noise\n",
        "        tau = pm.HalfNormal(\"tau\", tau_prior)\n",
        "\n",
        "        # Residual r = Σ^{-1}(θ* - μ) - b\n",
        "        v     = theta_star - mu                   # symbolic (d,)\n",
        "        precv = pm.math.dot(Sigma_inv, v)         # Σ^{-1} v\n",
        "        resid = precv - b                         # symbolic (d,)\n",
        "\n",
        "        # Likelihood via zero-data trick: 0 ~ N(resid, τ^2 I)\n",
        "        pm.Normal(\"stationarity\", mu=resid, sigma=tau, observed=np.zeros(d))\n",
        "\n",
        "        # Fit VI (ADVI)\n",
        "        approx = pm.fit(steps, method=\"advi\", obj_n_mc=5, random_seed=seed)\n",
        "        idata  = approx.sample(2000, random_seed=seed)\n",
        "\n",
        "    # Posterior means\n",
        "    mu_post    = idata.posterior[\"mu\"].mean(dim=(\"chain\",\"draw\")).values\n",
        "    Sigma_post = idata.posterior[\"Sigma\"].mean(dim=(\"chain\",\"draw\")).values\n",
        "    M_post     = idata.posterior[\"M\"].mean(dim=(\"chain\",\"draw\")).values\n",
        "\n",
        "    return {\"mu\": mu_post, \"Sigma\": Sigma_post, \"M\": M_post, \"idata\": idata}\n",
        "\n",
        "# constraint_vi(theta_star, b_true)"
      ],
      "metadata": {
        "collapsed": true,
        "id": "Nsk1Wv_wagOU"
      },
      "execution_count": 20,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "# --- MAP ---\n",
        "M_map, mu_map, info_map = constraint_map(theta_star, b_star, lam=1e-3, tau=0.1)\n",
        "print(\"MAP  resid norm:\", info_map[\"resid_norm\"])\n",
        "metrics(mu_map, mu_true, M_map, M_true, theta_star, y_test, X_test)\n",
        "\n",
        "# --- VI ---\n",
        "vi_out = constraint_vi(theta_star, b_star, steps=30000)\n",
        "print(\"VI   resid norm:\",\n",
        "      np.linalg.norm(vi_out[\"M\"] @ (theta_star - vi_out[\"mu\"]) - b_star))\n",
        "metrics(vi_out[\"mu\"], mu_true, vi_out[\"M\"], M_true, theta_star, y_test, X_test)"
      ],
      "metadata": {
        "id": "6RJVYW6dbcMa"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "# --- MCMC ---\n",
        "mcmc_out = constraint_mcmc_safe(theta_star, b_star, draws=1500, tune=800)\n",
        "print(\"MCMC resid mean:\",\n",
        "      np.linalg.norm(mcmc_out[\"M\"] @ (theta_star - mcmc_out[\"mu\"]) - b_star))\n",
        "metrics(mcmc_out[\"mu\"], mu_true, mcmc_out[\"M\"], M_true, theta_star, y_test, X_test)"
      ],
      "metadata": {
        "id": "EcgEUD5-yWmu"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [],
      "metadata": {
        "id": "Wvz1HLJloHR7"
      },
      "execution_count": null,
      "outputs": []
    }
  ]
}