{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "provenance": []
    },
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3"
    },
    "language_info": {
      "name": "python"
    }
  },
  "cells": [
    {
      "cell_type": "code",
      "execution_count": 1,
      "metadata": {
        "id": "qo2V4Q-R6-au"
      },
      "outputs": [],
      "source": [
        "import numpy as np\n",
        "import pandas as pd\n",
        "from sklearn.datasets import load_diabetes\n",
        "from sklearn.preprocessing import StandardScaler\n",
        "from sklearn.linear_model import Ridge\n",
        "from sklearn.metrics import mean_squared_error\n",
        "import matplotlib.pyplot as plt\n",
        "\n",
        "import torch\n",
        "import torch.nn as nn\n",
        "import torch.optim as optim\n",
        "from sklearn.model_selection import train_test_split\n",
        "import cvxpy as cp\n",
        "import os\n",
        "from IPython.display import display\n",
        "from torch.optim import LBFGS\n",
        "import copy\n",
        "import re\n",
        "import networkx as nx\n",
        "import random\n",
        "from sklearn.datasets import fetch_california_housing\n",
        "from google.colab import drive\n",
        "from scipy.stats import wishart\n",
        "\n",
        "# !pip install pymc arviz\n",
        "import pymc as pm\n",
        "import arviz as az\n",
        "import pytensor.tensor as at"
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "dataset_list = ['load_diabetes', 'asia_bif', 'cali_housing', ]\n",
        "chosen_dataset = dataset_list[1]\n",
        "\n",
        "if chosen_dataset == 'load_diabetes':\n",
        "    X, y = load_diabetes(return_X_y=True)\n",
        "\n",
        "elif chosen_dataset == 'asia_bif':\n",
        "\n",
        "    drive.mount('/content/drive')\n",
        "    # Load the uploaded asia.bif file\n",
        "    bif_path = \"/content/drive/My Drive/ICLR2026/asia.bif\"\n",
        "\n",
        "    # Parse the BIF file to extract the graph structure\n",
        "    with open(bif_path, 'r') as f:\n",
        "        bif_text = f.read()\n",
        "\n",
        "    variables = re.findall(r'variable\\s+(\\w+)', bif_text)\n",
        "    parents_matches = re.findall(r'probability\\s+\\(\\s*(\\w+)\\s*\\|\\s*([^)]+)\\)', bif_text)\n",
        "    root_matches = re.findall(r'probability\\s+\\(\\s*(\\w+)\\s*\\)', bif_text)\n",
        "\n",
        "    # Initialize graph\n",
        "    G = nx.DiGraph()\n",
        "    G.add_nodes_from(variables)\n",
        "\n",
        "    # Add edges from parent-child relationships\n",
        "    for child, parents in parents_matches:\n",
        "        parent_list = [p.strip() for p in parents.split(',')]\n",
        "        for parent in parent_list:\n",
        "            G.add_edge(parent, child)\n",
        "\n",
        "    # Add root nodes (no parents)\n",
        "    child_nodes_with_parents = set(child for child, _ in parents_matches)\n",
        "    root_nodes = [var for var in root_matches if var not in child_nodes_with_parents]\n",
        "    G.add_nodes_from(root_nodes)\n",
        "\n",
        "    # Ensure it's a DAG and sort topologically\n",
        "    assert nx.is_directed_acyclic_graph(G)\n",
        "    topo_order = list(nx.topological_sort(G))\n",
        "\n",
        "    # Get adjacency matrix in topological order\n",
        "    A = nx.to_numpy_array(G, nodelist=topo_order)\n",
        "\n",
        "    # Compute out-degree Laplacian\n",
        "    D_out = np.diag(np.sum(A, axis=1))\n",
        "    # L = D_out - A\n",
        "    import copy\n",
        "    # L_orig = copy.deepcopy(L)\n",
        "\n",
        "    # Avoid division by zero\n",
        "    with np.errstate(divide='ignore'):\n",
        "        D_out_inv = np.diag(1.0 / np.sum(A, axis=1))\n",
        "        D_out_inv[np.isinf(D_out_inv)] = 0.0\n",
        "\n",
        "    # Random walk Laplacian\n",
        "    L = np.eye(A.shape[0]) - D_out_inv @ A\n",
        "    L_orig = copy.deepcopy(L)\n",
        "\n",
        "    # Prepare matrices for display\n",
        "    adjacency_df = pd.DataFrame(A, index=topo_order, columns=topo_order)\n",
        "    laplacian_df = pd.DataFrame(L, index=topo_order, columns=topo_order)\n",
        "\n",
        "    # Return eigenvalues for insight\n",
        "    eigenvalues = np.linalg.eigvals(L)\n",
        "    eigenvalues.sort()\n",
        "    eigenvalues.real.round(3)\n",
        "\n",
        "    # Define the topological order of variables in ASIA\n",
        "    variables = topo_order\n",
        "\n",
        "    # Create a dictionary to hold generated data\n",
        "    data = {var: [] for var in variables}\n",
        "    num_samples = 1000\n",
        "\n",
        "    # Define synthetic conditional logic per node\n",
        "    # This is loosely inspired by the structure, not actual CPDs\n",
        "\n",
        "    for _ in range(num_samples):\n",
        "        sample = {}\n",
        "\n",
        "        # asia: has tuberculosis if visited Asia (Bernoulli)\n",
        "        sample[\"asia\"] = int(random.random() < 0.05)\n",
        "\n",
        "        # smoke: independent\n",
        "        sample[\"smoke\"] = int(random.random() < 0.3)\n",
        "\n",
        "        # tub: depends on asia\n",
        "        sample[\"tub\"] = int(sample[\"asia\"] and random.random() < 0.6)\n",
        "\n",
        "        # lung: depends on smoke\n",
        "        sample[\"lung\"] = int(sample[\"smoke\"] and random.random() < 0.5)\n",
        "\n",
        "        # bronc: depends on smoke\n",
        "        sample[\"bronc\"] = int(sample[\"smoke\"] and random.random() < 0.4)\n",
        "\n",
        "        # either: OR of tub and lung\n",
        "        sample[\"either\"] = int(sample[\"tub\"] or sample[\"lung\"])\n",
        "\n",
        "        # xray: depends on either\n",
        "        sample[\"xray\"] = int(sample[\"either\"] and random.random() < 0.9 or random.random() < 0.05)\n",
        "\n",
        "        # dysp: depends on bronc and either\n",
        "        sample[\"dysp\"] = int((sample[\"bronc\"] and random.random() < 0.7) or\n",
        "                            (sample[\"either\"] and random.random() < 0.6))\n",
        "\n",
        "        for var in variables:\n",
        "            data[var].append(sample[var])\n",
        "\n",
        "    # Create a DataFrame\n",
        "    df = pd.DataFrame(data)\n",
        "\n",
        "    # Display a preview of the synthetic dataset\n",
        "    # import ace_tools as tools\n",
        "    # display(\"Synthetic ASIA Dataset (1000 Records)\", df)\n",
        "\n",
        "    target_protein = \"either\"\n",
        "    y = df[target_protein].values\n",
        "    X = df.drop(columns=[target_protein]).values\n",
        "    feature_names = df.drop(columns=[target_protein]).columns.tolist()\n",
        "\n",
        "elif chosen_dataset == 'cali_housing':\n",
        "    data = fetch_california_housing(as_frame=False)\n",
        "    X, y = data.data, data.target"
      ],
      "metadata": {
        "id": "5PK-s-XF7cau",
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "outputId": "b70899f1-94d9-410d-de4c-9d98c7b41f4e"
      },
      "execution_count": 2,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Mounted at /content/drive\n"
          ]
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "X_train, X_test, y_train, y_test = train_test_split(\n",
        "    X, y, test_size=0.2, random_state=0)\n",
        "\n",
        "scaler_X = StandardScaler().fit(X_train)\n",
        "X_train = scaler_X.transform(X_train)\n",
        "X_test  = scaler_X.transform(X_test)\n",
        "\n",
        "scaler_y = StandardScaler().fit(y_train.reshape(-1,1))\n",
        "y_train = scaler_y.transform(y_train.reshape(-1,1)).ravel()\n",
        "y_test  = scaler_y.transform(y_test.reshape(-1,1)).ravel()\n",
        "\n",
        "n, d = X_train.shape"
      ],
      "metadata": {
        "id": "ZpY5JLQ09d8d"
      },
      "execution_count": 3,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "# -------------------------\n",
        "# 2) Fixed random Tikhonov components\n",
        "# -------------------------\n",
        "rng = np.random.default_rng(12345)\n",
        "L = rng.normal(size=(d, d))          # square for simplicity; any (p x d) works\n",
        "lam = 5e-2                           # Tikhonov strength\n",
        "nu0 = d + 2                         # dof; must be >= d\n",
        "kappa0 = 1.0                        # strength on mean prior\n",
        "m0 = np.zeros(d)                    # prior mean\n",
        "\n",
        "# Wishart scale chosen so E[Λ] = lam * (L^T L)\n",
        "Sigma0 = (lam / nu0) * (L.T @ L)\n",
        "\n",
        "# -------------------------\n",
        "# Sample precision (Tikhonov) from Wishart\n",
        "# -------------------------\n",
        "Lambda = wishart.rvs(df=nu0, scale=Sigma0, random_state=rng)  # Λ ~ W(ν0, Σ0)\n",
        "M_true = Lambda                                               # Tikhonov matrix (prior precision)\n",
        "\n",
        "# -------------------------\n",
        "# Optionally sample μ | Λ ~ N(m0, (κ0 Λ)^(-1))\n",
        "# -------------------------\n",
        "# Draw using a Cholesky of (κ0 Λ)^(-1)\n",
        "chol_cov = np.linalg.cholesky(np.linalg.inv(kappa0 * Lambda))\n",
        "mu_true = m0 + chol_cov @ rng.normal(size=d)\n",
        "\n",
        "# Torch tensors\n",
        "M_true_t = torch.tensor(M_true, dtype=torch.float32)\n",
        "mu_true_t = torch.tensor(mu_true, dtype=torch.float32)\n",
        "\n",
        "# -------------------------\n",
        "# 3.a) Forward Tikhonov (closed form, sum-of-squares)\n",
        "#     (X^T X + M_true) θ = X^T y + M_true μ\n",
        "#     (X^T X + lam L^T L) theta = X^T y + lam L^T L mu\n",
        "# -------------------------\n",
        "XtX = X_train.T @ X_train\n",
        "Xty = X_train.T @ y_train\n",
        "A = XtX + M_true\n",
        "rhs = Xty + M_true @ mu_true\n",
        "A += 1e-10 * np.eye(d)  # small jitter\n",
        "\n",
        "theta_fwd = np.linalg.solve(A, rhs)\n",
        "# print(f\"θ* = {theta_fwd}\")\n",
        "\n",
        "# Residual b = - X^T (X θ - y)  [sum of squares convention]\n",
        "# KKT stationarity: X^T(Xθ - y) + M_true(θ - μ_true) = 0\n",
        "b_fwd = - X_train.T @ (X_train @ theta_fwd - y_train)\n",
        "\n",
        "# Verify KKT using the ground-truth M_true\n",
        "kkt_residual = np.linalg.norm(X_train.T @ (X_train @ theta_fwd - y_train) + M_true @ (theta_fwd - mu_true))\n",
        "print(f\"KKT residual with ground-truth M_true: {kkt_residual:.3e}\")\n",
        "\n",
        "# -------------------------\n",
        "# 3.b) Forward Tikhonov (ML model solution)\n",
        "# -------------------------\n",
        "X_train_t = torch.from_numpy(X_train).float()\n",
        "y_train_t = torch.from_numpy(y_train).float().unsqueeze(1)\n",
        "\n",
        "# 4) Define a single-layer linear model (no bias for simplicity)\n",
        "model = nn.Linear(d, 1, bias=False)\n",
        "\n",
        "# 5) Optimize with weight decay = ridge λ\n",
        "lr = 1e-2\n",
        "ridge_lambda = 1e-2\n",
        "optimizer = optim.SGD(model.parameters(), lr=lr, weight_decay=ridge_lambda)\n",
        "loss_fn   = nn.MSELoss()\n",
        "\n",
        "for epoch in range(1000):\n",
        "    optimizer.zero_grad()\n",
        "    preds = model(X_train_t)\n",
        "    w = model.weight.view(-1)\n",
        "    data_ss = torch.sum((preds.squeeze() - torch.from_numpy(y_train).float())**2)\n",
        "    loss = data_ss + lam * torch.sum((torch.from_numpy(L.T).float() @ (w - torch.from_numpy(mu_true).float()))**2)\n",
        "    # loss = loss_fn(preds, y_train_t) * n + lam * torch.sum((L.T @ (w - mu_true))**2)\n",
        "    loss.backward()\n",
        "    optimizer.step()\n",
        "\n",
        "# 6) Extract θ* from the trained model\n",
        "theta_ml = model.weight.detach().numpy().reshape(-1)  # shape (d,)\n",
        "# print(f\"θ* = {theta_ml}\")\n",
        "# val_mse_torch = np.mean((X_val @ theta_torch - y_val)**2)\n",
        "\n",
        "# 7) Compute b = -½ ∇‖Xθ - y‖² at θ*\n",
        "#    ∇f = 2 Xᵀ(Xθ* - y)  ⇒  b = -0.5 * ∇f = -Xᵀ(Xθ* - y)\n",
        "grad = 2 * X_train.T @ (X_train @ theta_ml - y_train)\n",
        "b_ml    = -0.5 * grad\n",
        "\n",
        "# mode = \"ml\"\n",
        "mode = \"fwd\"\n",
        "if mode == \"ml\":\n",
        "    theta_star = theta_ml\n",
        "    b_star = b_ml\n",
        "else:\n",
        "    theta_star = theta_fwd\n",
        "    b_star = b_fwd\n",
        "\n",
        "b_true = M_true @ (theta_star - mu_true) + rng.normal(0, 0.5, size=d)"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "AmEJ9bF9AK_t",
        "outputId": "9fdc710f-03f2-4b73-8151-26501d9553cc"
      },
      "execution_count": 4,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "KKT residual with ground-truth M_true: 8.920e-11\n"
          ]
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "# -------------------------\n",
        "# 5) Metrics\n",
        "# -------------------------\n",
        "def metrics(mu_hat, mu_true, M_hat, M_true):\n",
        "\n",
        "    def frob(A): return np.linalg.norm(A, 'fro')\n",
        "\n",
        "    err = M_hat - M_true\n",
        "    rel_frob = np.linalg.norm(err, 'fro') / (np.linalg.norm(M_true, 'fro') + 1e-12)\n",
        "    spec_err = np.linalg.norm(err, 2) / (np.linalg.norm(M_true, 2) + 1e-12)\n",
        "    corr = np.corrcoef(M_true.ravel(), M_hat.ravel())[0, 1]\n",
        "\n",
        "    M_sym = 0.5 * (M_hat + M_hat.T)\n",
        "    w, U = np.linalg.eigh(M_sym)\n",
        "    w_clipped = np.clip(w, 0, None)\n",
        "    M_psd = U @ np.diag(w_clipped) @ U.T\n",
        "    M_hat = copy.deepcopy(M_psd)\n",
        "\n",
        "    # PSD checks\n",
        "    mineig_true = np.linalg.eigvalsh(M_true).min()\n",
        "    mineig_hat = np.linalg.eigvalsh((M_hat + M_hat.T)/2).min()  # symmetrize for eigs\n",
        "\n",
        "    # J = d\n",
        "    # mu_list = []\n",
        "    # theta_list = []\n",
        "    # b_list = []\n",
        "\n",
        "    # # Constraint residuals\n",
        "    # viol = max(np.linalg.norm(M_psd @ (theta_list[j] - mu_list[j]) - b_list[j]) for j in range(J))\n",
        "\n",
        "    def _as_array_of_mus(mu):\n",
        "        # Accepts: np.ndarray (d,), list/tuple of (d,) arrays, or (J,d) array\n",
        "        if isinstance(mu, (list, tuple)):\n",
        "            return np.stack(mu, axis=0)  # (J,d)\n",
        "        mu = np.asarray(mu)\n",
        "        return mu if mu.ndim == 2 else mu[None, :]  # (1,d) if single\n",
        "\n",
        "    def rmse(a, b):\n",
        "        a = np.asarray(a); b = np.asarray(b)\n",
        "        return float(np.sqrt(np.mean((a - b) ** 2)))\n",
        "\n",
        "    def symmetrize(M):\n",
        "        return 0.5 * (M + M.T)\n",
        "\n",
        "    def safe_pinv_psd(M, eps=1e-9):\n",
        "        # Symmetrize + jitter, then Moore–Penrose\n",
        "        Ms = symmetrize(M)\n",
        "        d = Ms.shape[0]\n",
        "        return np.linalg.pinv(Ms + eps * np.eye(d))\n",
        "\n",
        "    # ---- Add these safe helpers near metrics() ----\n",
        "    def residual_norm(theta_star, mu_hat, M_hat, b):\n",
        "        return float(np.linalg.norm(sym(M_hat) @ (theta_star - mu_hat) - b))\n",
        "\n",
        "    def cov_metrics_from_precision(M_hat, M_true):\n",
        "        Sig_hat = safe_pinv_psd(M_hat)\n",
        "        Sig_true = safe_pinv_psd(M_true)\n",
        "        rfb_cov = np.linalg.norm(Sig_hat - Sig_true, 'fro') / (np.linalg.norm(Sig_true, 'fro') + 1e-12)\n",
        "        spec_cov_err = np.linalg.norm(Sig_hat - Sig_true, 2) / (np.linalg.norm(Sig_true, 2) + 1e-12)\n",
        "        return rfb_cov, spec_cov_err\n",
        "\n",
        "    # --- RMSE(mu) ---\n",
        "    # Expect variables: mu_hat (estimated) and mu_true (ground truth).\n",
        "    # If you instead have multiple priors {mu_j}, pass lists for both.\n",
        "    rmse_mu = None\n",
        "    try:\n",
        "        mu_hat_arr  = _as_array_of_mus(mu_hat)   # shape (J,d) or (1,d)\n",
        "        mu_true_arr = _as_array_of_mus(mu_true)  # shape (J,d) or (1,d)\n",
        "        # If counts differ but one is (1,d), broadcast it across J:\n",
        "        if mu_hat_arr.shape[0] != mu_true_arr.shape[0]:\n",
        "            if mu_hat_arr.shape[0] == 1:\n",
        "                mu_hat_arr  = np.repeat(mu_hat_arr,  mu_true_arr.shape[0], axis=0)\n",
        "            elif mu_true_arr.shape[0] == 1:\n",
        "                mu_true_arr = np.repeat(mu_true_arr, mu_hat_arr.shape[0],  axis=0)\n",
        "        rmse_mu = rmse(mu_hat_arr, mu_true_arr)\n",
        "    except NameError:\n",
        "        # mu_hat or mu_true not defined; skip metric gracefully\n",
        "        rmse_mu = np.nan\n",
        "\n",
        "    # --- RFB for covariances (relative Frobenius between Σ_hat and Σ_true) ---\n",
        "    Sigma_true = safe_pinv_psd(M_true)\n",
        "    Sigma_hat  = safe_pinv_psd(M_hat)\n",
        "    rfb_cov = np.linalg.norm(Sigma_hat - Sigma_true, 'fro') / (np.linalg.norm(Sigma_true, 'fro') + 1e-12)\n",
        "\n",
        "    # Optionally: spectral relative error on covariances\n",
        "    spec_cov_err = np.linalg.norm(Sigma_hat - Sigma_true, 2) / (np.linalg.norm(Sigma_true, 2) + 1e-12)\n",
        "\n",
        "    # Residual metric (if theta_star and b are available in scope)\n",
        "    try:\n",
        "        resid = residual_norm(theta_star, mu_hat, M_hat, b)\n",
        "    except NameError:\n",
        "        resid = np.nan  # not available in this context\n",
        "\n",
        "    # Covariance metrics (derived from precision)\n",
        "    rfb_cov, spec_cov_err = cov_metrics_from_precision(M_hat, M_true)\n",
        "\n",
        "        # if mode == \"ml\":\n",
        "        #     print(\"=== Multi-constraint inverse (using ML-derived θ*’s) ===\")\n",
        "        # else:\n",
        "        #     print(\"=== Multi-constraint inverse (using closed-form solution) ===\")\n",
        "\n",
        "    metrics_df = pd.DataFrame({\n",
        "        \"metric\": [\n",
        "            \"Relative Frobenius error ‖M_hat - M_true‖_F / ‖M_true‖_F\",\n",
        "            # \"Relative spectral error  ‖·‖_2\",\n",
        "            # \"Entrywise correlation corr(M_true, M_hat)\",\n",
        "            # \"Min eigenvalue M_true\",\n",
        "            # \"Min eigenvalue sym(M_hat)\",\n",
        "            # \"Validation MSE (θ_fwd)\",\n",
        "            ###### \"Max constraint residual\",\n",
        "            \"RMSE (μ̂ vs μ_true)\",\n",
        "            \"RFB on covariances ‖Σ̂ - Σ_true‖_F / ‖Σ_true‖_F\",\n",
        "            # \"Relative spectral error on covariances ‖·‖_2\",\n",
        "            \"Constraint residual ‖M̂(θ*−μ̂)−b‖₂\"\n",
        "        ],\n",
        "        \"value\": [\n",
        "            rel_frob,\n",
        "            # spec_err,\n",
        "            # corr,\n",
        "            # mineig_true,\n",
        "            # mineig_hat,\n",
        "            # float(np.mean((X_test @ theta_star - y_test)**2)),\n",
        "            ## viol,\n",
        "            rmse_mu,\n",
        "            rfb_cov,\n",
        "            # spec_cov_err,\n",
        "            resid\n",
        "        ]\n",
        "    })\n",
        "\n",
        "    # display(metrics_df)\n",
        "    return metrics_df"
      ],
      "metadata": {
        "id": "fiKkwYaveDM8"
      },
      "execution_count": 6,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "import numpy as np\n",
        "\n",
        "def sym(M):\n",
        "    return 0.5 * (M + M.T)\n",
        "\n",
        "def psd_project(M):\n",
        "    Ms = sym(M)\n",
        "    w, U = np.linalg.eigh(Ms)\n",
        "    w = np.clip(w, 0, None)\n",
        "    return (U * w) @ U.T\n",
        "\n",
        "def safe_pinv_psd(M, eps=1e-9):\n",
        "    return np.linalg.pinv(sym(M) + eps * np.eye(M.shape[0]))\n",
        "\n",
        "def ls_M_given_mu(theta_star, b, mu):\n",
        "    v = theta_star - mu\n",
        "    denom = float(v @ v) + 1e-12\n",
        "    M = np.outer(b, v) / denom                # rank-1\n",
        "    return M\n",
        "\n",
        "def ls_psd(theta_star, b):\n",
        "    # simple heuristic for μ (e.g., 0 or tune on val)\n",
        "    mu0 = np.zeros_like(b)\n",
        "    M  = ls_M_given_mu(theta_star, b, mu0)\n",
        "    M  = psd_project(M)\n",
        "    # refine μ one step (closed-form): μ = θ* - M^+ b\n",
        "    Mp = np.linalg.pinv(M + 1e-9*np.eye(len(b)))\n",
        "    mu = theta_star - Mp @ b\n",
        "    # (optional) re-fit M and re-project\n",
        "    M  = psd_project(ls_M_given_mu(theta_star, b, mu))\n",
        "    return mu, M\n"
      ],
      "metadata": {
        "id": "KUwrTaPjQ-3O"
      },
      "execution_count": 7,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "def baseline_identity_ridge(theta_star, b, alphas):\n",
        "    best = None\n",
        "    for a in alphas:  # e.g., np.logspace(-3, 3, 50)\n",
        "        mu_hat = theta_star - b / a\n",
        "        M_hat  = psd_project(a * np.eye(len(b)))  # already PSD\n",
        "        res    = np.linalg.norm(M_hat @ (theta_star - mu_hat) - b)\n",
        "        item   = (res, mu_hat, M_hat, {\"alpha\": a, \"residual\": res})\n",
        "        best   = item if best is None else (item if item[0] < best[0] else best)\n",
        "    return best[1], best[2], best[3]\n",
        "\n",
        "def baseline_diagonal_ridge(theta_star, b, w_grid):\n",
        "    d = len(b)\n",
        "    w_best = np.zeros(d)\n",
        "    mu_hat = np.zeros(d)\n",
        "    for i in range(d):\n",
        "        best = None\n",
        "        for w in w_grid:  # e.g., np.logspace(-3, 3, 60)\n",
        "            mu_i = theta_star[i] - b[i] / w\n",
        "            res_i = abs(w * (theta_star[i] - mu_i) - b[i])  # 0 ideally\n",
        "            cand = (res_i, mu_i, w)\n",
        "            best = cand if best is None else (cand if cand[0] < best[0] else best)\n",
        "        mu_hat[i], w_best[i] = best[1], best[2]\n",
        "    M_hat = psd_project(np.diag(w_best))\n",
        "    res   = np.linalg.norm(M_hat @ (theta_star - mu_hat) - b)\n",
        "    return mu_hat, M_hat, {\"residual\": res}\n",
        "\n",
        "def baseline_ls_psd(theta_star, b, mu_init=None):\n",
        "    v  = theta_star - (np.zeros_like(b) if mu_init is None else mu_init)\n",
        "    dn = float(v @ v) + 1e-12\n",
        "    M  = np.outer(b, v) / dn\n",
        "    M  = psd_project(M)\n",
        "    Mp = safe_pinv_psd(M)\n",
        "    mu = theta_star - Mp @ b\n",
        "    # optional one refinement of M with updated mu\n",
        "    v2 = theta_star - mu\n",
        "    dn2 = float(v2 @ v2) + 1e-12\n",
        "    M  = psd_project(np.outer(b, v2) / dn2)\n",
        "    res = np.linalg.norm(M @ (theta_star - mu) - b)\n",
        "    return mu, M, {\"residual\": res}\n",
        "\n",
        "def baseline_ridge_ls_psd(theta_star, b, lam=1e-3, mu=None):\n",
        "    v  = theta_star - (np.zeros_like(b) if mu is None else mu)\n",
        "    VVT = np.outer(v, v) + lam * np.eye(len(b))\n",
        "    M   = np.outer(b, v) @ np.linalg.inv(VVT)\n",
        "    M   = psd_project(M)\n",
        "    Mp  = safe_pinv_psd(M)\n",
        "    mu_hat = theta_star - Mp @ b\n",
        "    res = np.linalg.norm(M @ (theta_star - mu_hat) - b)\n",
        "    return mu_hat, M, {\"lambda\": lam, \"residual\": res}\n",
        "\n",
        "def baseline_zero_mean_ridge(theta_star, b, lam=1e-3):\n",
        "    v  = theta_star\n",
        "    VVT = np.outer(v, v) + lam * np.eye(len(b))\n",
        "    M   = np.outer(b, v) @ np.linalg.inv(VVT)\n",
        "    M   = psd_project(M)\n",
        "    mu_hat = np.zeros_like(b)\n",
        "    res = np.linalg.norm(M @ (theta_star - mu_hat) - b)\n",
        "    return mu_hat, M, {\"lambda\": lam, \"residual\": res}\n",
        "\n",
        "def baseline_laplacian(theta_star, b, W, alphas, eps=1e-3):\n",
        "    d = len(b)\n",
        "    D = np.diag(W.sum(axis=1))\n",
        "    L = D - W\n",
        "    best = None\n",
        "    for a in alphas:\n",
        "        M = psd_project(a * L + eps * np.eye(d))\n",
        "        mu = theta_star - safe_pinv_psd(M) @ b\n",
        "        res = np.linalg.norm(M @ (theta_star - mu) - b)\n",
        "        cand = (res, mu, M, {\"alpha\": a, \"eps\": eps, \"residual\": res})\n",
        "        best = cand if best is None else (cand if cand[0] < best[0] else best)\n",
        "    return best[1], best[2], best[3]\n",
        "\n",
        "def random_spd_baseline(theta_star, b, d, k=2, ntrials=50):\n",
        "    best = None\n",
        "    for _ in range(ntrials):\n",
        "        A = np.random.normal(size=(d, k))\n",
        "        M = A @ A.T + 0.1*np.eye(d)\n",
        "        mu = theta_star - np.linalg.pinv(M) @ b\n",
        "        res = np.linalg.norm(M @ (theta_star - mu) - b)\n",
        "        info = {\"residual\": res, \"k\": k}\n",
        "        cand = (res, mu, M, info)\n",
        "        if best is None or res < best[0]:\n",
        "            best = cand\n",
        "    return best[1], best[2], best[3]\n",
        "\n",
        "\n",
        "def shrinkage_precision(Theta, lambdas):\n",
        "    # Theta: (J, d)\n",
        "    mu = Theta.mean(axis=0)\n",
        "    S  = np.cov(Theta.T, bias=False)  # sample covariance\n",
        "    best = None\n",
        "    for lam in lambdas:\n",
        "        Sig = (1-lam)*S + lam*np.eye(S.shape[0])\n",
        "        M = np.linalg.pinv(Sig)\n",
        "        score = np.linalg.cond(Sig)  # placeholder; or validation loss\n",
        "        best = min(best, (score, mu, M), key=lambda x: x[0]) if best else (score, mu, M)\n",
        "    return best[1], best[2], best[3]"
      ],
      "metadata": {
        "id": "qRU3pMiT-MLP"
      },
      "execution_count": 8,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "def fit_baseline(method, theta_star, b, **kwargs):\n",
        "    \"\"\"\n",
        "    Returns: mu_hat, M_hat, info  (M_hat is symmetrized & PSD-projected)\n",
        "    method: 'identity_ridge' | 'diagonal_ridge' | 'ls_psd' | 'ridge_ls_psd' | 'zero_mean_ridge' | 'laplacian'\n",
        "    \"\"\"\n",
        "    method = method.lower()\n",
        "    if method == \"identity_ridge\":\n",
        "        mu, M, info = baseline_identity_ridge(theta_star, b_star, kwargs.get(\"alphas\"))\n",
        "    elif method == \"diagonal_ridge\":\n",
        "        mu, M, info = baseline_diagonal_ridge(theta_star, b_star, kwargs.get(\"w_grid\"))\n",
        "    elif method == \"ls_psd\":\n",
        "        mu, M, info = baseline_ls_psd(theta_star, b_star, kwargs.get(\"mu_init\"))\n",
        "    elif method == \"ridge_ls_psd\":\n",
        "        mu, M, info = baseline_ridge_ls_psd(theta_star, b_star, kwargs.get(\"lam\", 1e-3), kwargs.get(\"mu\"))\n",
        "    elif method == \"zero_mean_ridge\":\n",
        "        mu, M, info = baseline_zero_mean_ridge(theta_star, b_star, kwargs.get(\"lam\", 1e-3))\n",
        "    # elif method == \"laplacian\":\n",
        "    #     mu, M, info = baseline_laplacian(theta_star, b_star, kwargs[\"W\"], kwargs.get(\"alphas\"), kwargs.get(\"eps\", 1e-3))\n",
        "    elif method == \"random_spd\":\n",
        "        mu, M, info = random_spd_baseline(theta_star, b_star, kwargs.get(\"d\"), kwargs.get(\"k\", 2), kwargs.get(\"ntrials\", 50))\n",
        "    # elif method == \"shrinkage_precision\":\n",
        "    #     mu, M, info = shrinkage_precision(theta_star, kwargs.get(\"lambdas\"))\n",
        "    else:\n",
        "        raise ValueError(f\"Unknown baseline method: {method}\")\n",
        "\n",
        "    # Final safety: symmetrize & PSD-project once more\n",
        "    M = psd_project(M)\n",
        "    return mu, M, info"
      ],
      "metadata": {
        "id": "3vtYEmWN-MQi"
      },
      "execution_count": 9,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "# List of baselines to run (add/remove as you like)\n",
        "baseline_methods = [\n",
        "    (\"identity_ridge\", {\"alphas\": np.logspace(-2, 3, 20)}),\n",
        "    (\"diagonal_ridge\", {\"w_grid\": np.logspace(-2, 3, 20)}),\n",
        "    (\"ls_psd\",         {}),   # can pass mu_init if desired\n",
        "    (\"ridge_ls_psd\",   {\"lam\": 1e-3}),\n",
        "    (\"zero_mean_ridge\",{\"lam\": 1e-3}),\n",
        "    # For Laplacian baseline you need W\n",
        "    # (\"laplacian\", {\"W\": W, \"alphas\": np.logspace(-2, 3, 10), \"eps\": 1e-3}),\n",
        "    (\"random_spd\", {\"d\": d, \"k\": 2, \"ntrials\": 50}),\n",
        "    # (\"shrinkage_precision\", {\"lambdas\": np.logspace(-2, 3, 10)}),\n",
        "]\n",
        "\n",
        "import time\n",
        "\n",
        "all_results = []\n",
        "avg_duration = 0\n",
        "\n",
        "for method, kwargs in baseline_methods:\n",
        "\n",
        "    start = time.perf_counter()\n",
        "    mu_hat, M_hat, info = fit_baseline(method, theta_star, b_star, **kwargs)\n",
        "    end = time.perf_counter()\n",
        "    # print(f\"Duration: {end - start:.6f} seconds\")\n",
        "    avg_duration += end - start\n",
        "\n",
        "    # Call your metrics function (assumes it returns/display metrics_df)\n",
        "    metrics_df = metrics(mu_hat, mu_true, M_hat, M_true)\n",
        "\n",
        "    # Tag results with method name\n",
        "    metrics_df.insert(0, \"baseline\", method)\n",
        "    all_results.append(metrics_df)\n",
        "\n",
        "    # print(metrics_df.to_markdown())\n",
        "\n",
        "# Combine all into one big DataFrame\n",
        "all_metrics_df = pd.concat(all_results, ignore_index=True)\n",
        "\n",
        "display(all_metrics_df)\n",
        "# print(avg_duration/6)"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 802
        },
        "id": "-CoCDpg5N9cS",
        "outputId": "57264b6b-f03e-40f4-8869-d142dc695d03"
      },
      "execution_count": 10,
      "outputs": [
        {
          "output_type": "display_data",
          "data": {
            "text/plain": [
              "           baseline                                             metric  \\\n",
              "0    identity_ridge  Relative Frobenius error ‖M_hat - M_true‖_F / ...   \n",
              "1    identity_ridge                                RMSE (μ̂ vs μ_true)   \n",
              "2    identity_ridge    RFB on covariances ‖Σ̂ - Σ_true‖_F / ‖Σ_true‖_F   \n",
              "3    identity_ridge                 Constraint residual ‖M̂(θ*−μ̂)−b‖₂   \n",
              "4    diagonal_ridge  Relative Frobenius error ‖M_hat - M_true‖_F / ...   \n",
              "5    diagonal_ridge                                RMSE (μ̂ vs μ_true)   \n",
              "6    diagonal_ridge    RFB on covariances ‖Σ̂ - Σ_true‖_F / ‖Σ_true‖_F   \n",
              "7    diagonal_ridge                 Constraint residual ‖M̂(θ*−μ̂)−b‖₂   \n",
              "8            ls_psd  Relative Frobenius error ‖M_hat - M_true‖_F / ...   \n",
              "9            ls_psd                                RMSE (μ̂ vs μ_true)   \n",
              "10           ls_psd    RFB on covariances ‖Σ̂ - Σ_true‖_F / ‖Σ_true‖_F   \n",
              "11           ls_psd                 Constraint residual ‖M̂(θ*−μ̂)−b‖₂   \n",
              "12     ridge_ls_psd  Relative Frobenius error ‖M_hat - M_true‖_F / ...   \n",
              "13     ridge_ls_psd                                RMSE (μ̂ vs μ_true)   \n",
              "14     ridge_ls_psd    RFB on covariances ‖Σ̂ - Σ_true‖_F / ‖Σ_true‖_F   \n",
              "15     ridge_ls_psd                 Constraint residual ‖M̂(θ*−μ̂)−b‖₂   \n",
              "16  zero_mean_ridge  Relative Frobenius error ‖M_hat - M_true‖_F / ...   \n",
              "17  zero_mean_ridge                                RMSE (μ̂ vs μ_true)   \n",
              "18  zero_mean_ridge    RFB on covariances ‖Σ̂ - Σ_true‖_F / ‖Σ_true‖_F   \n",
              "19  zero_mean_ridge                 Constraint residual ‖M̂(θ*−μ̂)−b‖₂   \n",
              "20       random_spd  Relative Frobenius error ‖M_hat - M_true‖_F / ...   \n",
              "21       random_spd                                RMSE (μ̂ vs μ_true)   \n",
              "22       random_spd    RFB on covariances ‖Σ̂ - Σ_true‖_F / ‖Σ_true‖_F   \n",
              "23       random_spd                 Constraint residual ‖M̂(θ*−μ̂)−b‖₂   \n",
              "\n",
              "           value  \n",
              "0   9.876823e-01  \n",
              "1   1.398009e+03  \n",
              "2   9.999975e-01  \n",
              "3            NaN  \n",
              "4   9.876823e-01  \n",
              "5   1.398009e+03  \n",
              "6   9.999975e-01  \n",
              "7            NaN  \n",
              "8   1.000000e+00  \n",
              "9   2.015820e+08  \n",
              "10  6.164866e+01  \n",
              "11           NaN  \n",
              "12  8.166929e-01  \n",
              "13  2.015820e+08  \n",
              "14  6.099366e+01  \n",
              "15           NaN  \n",
              "16  8.166929e-01  \n",
              "17  1.397070e+03  \n",
              "18  6.099366e+01  \n",
              "19           NaN  \n",
              "20  4.009430e+00  \n",
              "21  1.396882e+03  \n",
              "22  9.999998e-01  \n",
              "23           NaN  "
            ],
            "text/html": [
              "\n",
              "  <div id=\"df-cff67a82-1b85-4891-a1ae-f7213d7dfa2e\" class=\"colab-df-container\">\n",
              "    <div>\n",
              "<style scoped>\n",
              "    .dataframe tbody tr th:only-of-type {\n",
              "        vertical-align: middle;\n",
              "    }\n",
              "\n",
              "    .dataframe tbody tr th {\n",
              "        vertical-align: top;\n",
              "    }\n",
              "\n",
              "    .dataframe thead th {\n",
              "        text-align: right;\n",
              "    }\n",
              "</style>\n",
              "<table border=\"1\" class=\"dataframe\">\n",
              "  <thead>\n",
              "    <tr style=\"text-align: right;\">\n",
              "      <th></th>\n",
              "      <th>baseline</th>\n",
              "      <th>metric</th>\n",
              "      <th>value</th>\n",
              "    </tr>\n",
              "  </thead>\n",
              "  <tbody>\n",
              "    <tr>\n",
              "      <th>0</th>\n",
              "      <td>identity_ridge</td>\n",
              "      <td>Relative Frobenius error ‖M_hat - M_true‖_F / ...</td>\n",
              "      <td>9.876823e-01</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>1</th>\n",
              "      <td>identity_ridge</td>\n",
              "      <td>RMSE (μ̂ vs μ_true)</td>\n",
              "      <td>1.398009e+03</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>2</th>\n",
              "      <td>identity_ridge</td>\n",
              "      <td>RFB on covariances ‖Σ̂ - Σ_true‖_F / ‖Σ_true‖_F</td>\n",
              "      <td>9.999975e-01</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>3</th>\n",
              "      <td>identity_ridge</td>\n",
              "      <td>Constraint residual ‖M̂(θ*−μ̂)−b‖₂</td>\n",
              "      <td>NaN</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>4</th>\n",
              "      <td>diagonal_ridge</td>\n",
              "      <td>Relative Frobenius error ‖M_hat - M_true‖_F / ...</td>\n",
              "      <td>9.876823e-01</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>5</th>\n",
              "      <td>diagonal_ridge</td>\n",
              "      <td>RMSE (μ̂ vs μ_true)</td>\n",
              "      <td>1.398009e+03</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>6</th>\n",
              "      <td>diagonal_ridge</td>\n",
              "      <td>RFB on covariances ‖Σ̂ - Σ_true‖_F / ‖Σ_true‖_F</td>\n",
              "      <td>9.999975e-01</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>7</th>\n",
              "      <td>diagonal_ridge</td>\n",
              "      <td>Constraint residual ‖M̂(θ*−μ̂)−b‖₂</td>\n",
              "      <td>NaN</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>8</th>\n",
              "      <td>ls_psd</td>\n",
              "      <td>Relative Frobenius error ‖M_hat - M_true‖_F / ...</td>\n",
              "      <td>1.000000e+00</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>9</th>\n",
              "      <td>ls_psd</td>\n",
              "      <td>RMSE (μ̂ vs μ_true)</td>\n",
              "      <td>2.015820e+08</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>10</th>\n",
              "      <td>ls_psd</td>\n",
              "      <td>RFB on covariances ‖Σ̂ - Σ_true‖_F / ‖Σ_true‖_F</td>\n",
              "      <td>6.164866e+01</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>11</th>\n",
              "      <td>ls_psd</td>\n",
              "      <td>Constraint residual ‖M̂(θ*−μ̂)−b‖₂</td>\n",
              "      <td>NaN</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>12</th>\n",
              "      <td>ridge_ls_psd</td>\n",
              "      <td>Relative Frobenius error ‖M_hat - M_true‖_F / ...</td>\n",
              "      <td>8.166929e-01</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>13</th>\n",
              "      <td>ridge_ls_psd</td>\n",
              "      <td>RMSE (μ̂ vs μ_true)</td>\n",
              "      <td>2.015820e+08</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>14</th>\n",
              "      <td>ridge_ls_psd</td>\n",
              "      <td>RFB on covariances ‖Σ̂ - Σ_true‖_F / ‖Σ_true‖_F</td>\n",
              "      <td>6.099366e+01</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>15</th>\n",
              "      <td>ridge_ls_psd</td>\n",
              "      <td>Constraint residual ‖M̂(θ*−μ̂)−b‖₂</td>\n",
              "      <td>NaN</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>16</th>\n",
              "      <td>zero_mean_ridge</td>\n",
              "      <td>Relative Frobenius error ‖M_hat - M_true‖_F / ...</td>\n",
              "      <td>8.166929e-01</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>17</th>\n",
              "      <td>zero_mean_ridge</td>\n",
              "      <td>RMSE (μ̂ vs μ_true)</td>\n",
              "      <td>1.397070e+03</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>18</th>\n",
              "      <td>zero_mean_ridge</td>\n",
              "      <td>RFB on covariances ‖Σ̂ - Σ_true‖_F / ‖Σ_true‖_F</td>\n",
              "      <td>6.099366e+01</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>19</th>\n",
              "      <td>zero_mean_ridge</td>\n",
              "      <td>Constraint residual ‖M̂(θ*−μ̂)−b‖₂</td>\n",
              "      <td>NaN</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>20</th>\n",
              "      <td>random_spd</td>\n",
              "      <td>Relative Frobenius error ‖M_hat - M_true‖_F / ...</td>\n",
              "      <td>4.009430e+00</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>21</th>\n",
              "      <td>random_spd</td>\n",
              "      <td>RMSE (μ̂ vs μ_true)</td>\n",
              "      <td>1.396882e+03</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>22</th>\n",
              "      <td>random_spd</td>\n",
              "      <td>RFB on covariances ‖Σ̂ - Σ_true‖_F / ‖Σ_true‖_F</td>\n",
              "      <td>9.999998e-01</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>23</th>\n",
              "      <td>random_spd</td>\n",
              "      <td>Constraint residual ‖M̂(θ*−μ̂)−b‖₂</td>\n",
              "      <td>NaN</td>\n",
              "    </tr>\n",
              "  </tbody>\n",
              "</table>\n",
              "</div>\n",
              "    <div class=\"colab-df-buttons\">\n",
              "\n",
              "  <div class=\"colab-df-container\">\n",
              "    <button class=\"colab-df-convert\" onclick=\"convertToInteractive('df-cff67a82-1b85-4891-a1ae-f7213d7dfa2e')\"\n",
              "            title=\"Convert this dataframe to an interactive table.\"\n",
              "            style=\"display:none;\">\n",
              "\n",
              "  <svg xmlns=\"http://www.w3.org/2000/svg\" height=\"24px\" viewBox=\"0 -960 960 960\">\n",
              "    <path d=\"M120-120v-720h720v720H120Zm60-500h600v-160H180v160Zm220 220h160v-160H400v160Zm0 220h160v-160H400v160ZM180-400h160v-160H180v160Zm440 0h160v-160H620v160ZM180-180h160v-160H180v160Zm440 0h160v-160H620v160Z\"/>\n",
              "  </svg>\n",
              "    </button>\n",
              "\n",
              "  <style>\n",
              "    .colab-df-container {\n",
              "      display:flex;\n",
              "      gap: 12px;\n",
              "    }\n",
              "\n",
              "    .colab-df-convert {\n",
              "      background-color: #E8F0FE;\n",
              "      border: none;\n",
              "      border-radius: 50%;\n",
              "      cursor: pointer;\n",
              "      display: none;\n",
              "      fill: #1967D2;\n",
              "      height: 32px;\n",
              "      padding: 0 0 0 0;\n",
              "      width: 32px;\n",
              "    }\n",
              "\n",
              "    .colab-df-convert:hover {\n",
              "      background-color: #E2EBFA;\n",
              "      box-shadow: 0px 1px 2px rgba(60, 64, 67, 0.3), 0px 1px 3px 1px rgba(60, 64, 67, 0.15);\n",
              "      fill: #174EA6;\n",
              "    }\n",
              "\n",
              "    .colab-df-buttons div {\n",
              "      margin-bottom: 4px;\n",
              "    }\n",
              "\n",
              "    [theme=dark] .colab-df-convert {\n",
              "      background-color: #3B4455;\n",
              "      fill: #D2E3FC;\n",
              "    }\n",
              "\n",
              "    [theme=dark] .colab-df-convert:hover {\n",
              "      background-color: #434B5C;\n",
              "      box-shadow: 0px 1px 3px 1px rgba(0, 0, 0, 0.15);\n",
              "      filter: drop-shadow(0px 1px 2px rgba(0, 0, 0, 0.3));\n",
              "      fill: #FFFFFF;\n",
              "    }\n",
              "  </style>\n",
              "\n",
              "    <script>\n",
              "      const buttonEl =\n",
              "        document.querySelector('#df-cff67a82-1b85-4891-a1ae-f7213d7dfa2e button.colab-df-convert');\n",
              "      buttonEl.style.display =\n",
              "        google.colab.kernel.accessAllowed ? 'block' : 'none';\n",
              "\n",
              "      async function convertToInteractive(key) {\n",
              "        const element = document.querySelector('#df-cff67a82-1b85-4891-a1ae-f7213d7dfa2e');\n",
              "        const dataTable =\n",
              "          await google.colab.kernel.invokeFunction('convertToInteractive',\n",
              "                                                    [key], {});\n",
              "        if (!dataTable) return;\n",
              "\n",
              "        const docLinkHtml = 'Like what you see? Visit the ' +\n",
              "          '<a target=\"_blank\" href=https://colab.research.google.com/notebooks/data_table.ipynb>data table notebook</a>'\n",
              "          + ' to learn more about interactive tables.';\n",
              "        element.innerHTML = '';\n",
              "        dataTable['output_type'] = 'display_data';\n",
              "        await google.colab.output.renderOutput(dataTable, element);\n",
              "        const docLink = document.createElement('div');\n",
              "        docLink.innerHTML = docLinkHtml;\n",
              "        element.appendChild(docLink);\n",
              "      }\n",
              "    </script>\n",
              "  </div>\n",
              "\n",
              "\n",
              "    <div id=\"df-1e306307-9adb-4ad0-8ccd-db1ae6a862d3\">\n",
              "      <button class=\"colab-df-quickchart\" onclick=\"quickchart('df-1e306307-9adb-4ad0-8ccd-db1ae6a862d3')\"\n",
              "                title=\"Suggest charts\"\n",
              "                style=\"display:none;\">\n",
              "\n",
              "<svg xmlns=\"http://www.w3.org/2000/svg\" height=\"24px\"viewBox=\"0 0 24 24\"\n",
              "     width=\"24px\">\n",
              "    <g>\n",
              "        <path d=\"M19 3H5c-1.1 0-2 .9-2 2v14c0 1.1.9 2 2 2h14c1.1 0 2-.9 2-2V5c0-1.1-.9-2-2-2zM9 17H7v-7h2v7zm4 0h-2V7h2v10zm4 0h-2v-4h2v4z\"/>\n",
              "    </g>\n",
              "</svg>\n",
              "      </button>\n",
              "\n",
              "<style>\n",
              "  .colab-df-quickchart {\n",
              "      --bg-color: #E8F0FE;\n",
              "      --fill-color: #1967D2;\n",
              "      --hover-bg-color: #E2EBFA;\n",
              "      --hover-fill-color: #174EA6;\n",
              "      --disabled-fill-color: #AAA;\n",
              "      --disabled-bg-color: #DDD;\n",
              "  }\n",
              "\n",
              "  [theme=dark] .colab-df-quickchart {\n",
              "      --bg-color: #3B4455;\n",
              "      --fill-color: #D2E3FC;\n",
              "      --hover-bg-color: #434B5C;\n",
              "      --hover-fill-color: #FFFFFF;\n",
              "      --disabled-bg-color: #3B4455;\n",
              "      --disabled-fill-color: #666;\n",
              "  }\n",
              "\n",
              "  .colab-df-quickchart {\n",
              "    background-color: var(--bg-color);\n",
              "    border: none;\n",
              "    border-radius: 50%;\n",
              "    cursor: pointer;\n",
              "    display: none;\n",
              "    fill: var(--fill-color);\n",
              "    height: 32px;\n",
              "    padding: 0;\n",
              "    width: 32px;\n",
              "  }\n",
              "\n",
              "  .colab-df-quickchart:hover {\n",
              "    background-color: var(--hover-bg-color);\n",
              "    box-shadow: 0 1px 2px rgba(60, 64, 67, 0.3), 0 1px 3px 1px rgba(60, 64, 67, 0.15);\n",
              "    fill: var(--button-hover-fill-color);\n",
              "  }\n",
              "\n",
              "  .colab-df-quickchart-complete:disabled,\n",
              "  .colab-df-quickchart-complete:disabled:hover {\n",
              "    background-color: var(--disabled-bg-color);\n",
              "    fill: var(--disabled-fill-color);\n",
              "    box-shadow: none;\n",
              "  }\n",
              "\n",
              "  .colab-df-spinner {\n",
              "    border: 2px solid var(--fill-color);\n",
              "    border-color: transparent;\n",
              "    border-bottom-color: var(--fill-color);\n",
              "    animation:\n",
              "      spin 1s steps(1) infinite;\n",
              "  }\n",
              "\n",
              "  @keyframes spin {\n",
              "    0% {\n",
              "      border-color: transparent;\n",
              "      border-bottom-color: var(--fill-color);\n",
              "      border-left-color: var(--fill-color);\n",
              "    }\n",
              "    20% {\n",
              "      border-color: transparent;\n",
              "      border-left-color: var(--fill-color);\n",
              "      border-top-color: var(--fill-color);\n",
              "    }\n",
              "    30% {\n",
              "      border-color: transparent;\n",
              "      border-left-color: var(--fill-color);\n",
              "      border-top-color: var(--fill-color);\n",
              "      border-right-color: var(--fill-color);\n",
              "    }\n",
              "    40% {\n",
              "      border-color: transparent;\n",
              "      border-right-color: var(--fill-color);\n",
              "      border-top-color: var(--fill-color);\n",
              "    }\n",
              "    60% {\n",
              "      border-color: transparent;\n",
              "      border-right-color: var(--fill-color);\n",
              "    }\n",
              "    80% {\n",
              "      border-color: transparent;\n",
              "      border-right-color: var(--fill-color);\n",
              "      border-bottom-color: var(--fill-color);\n",
              "    }\n",
              "    90% {\n",
              "      border-color: transparent;\n",
              "      border-bottom-color: var(--fill-color);\n",
              "    }\n",
              "  }\n",
              "</style>\n",
              "\n",
              "      <script>\n",
              "        async function quickchart(key) {\n",
              "          const quickchartButtonEl =\n",
              "            document.querySelector('#' + key + ' button');\n",
              "          quickchartButtonEl.disabled = true;  // To prevent multiple clicks.\n",
              "          quickchartButtonEl.classList.add('colab-df-spinner');\n",
              "          try {\n",
              "            const charts = await google.colab.kernel.invokeFunction(\n",
              "                'suggestCharts', [key], {});\n",
              "          } catch (error) {\n",
              "            console.error('Error during call to suggestCharts:', error);\n",
              "          }\n",
              "          quickchartButtonEl.classList.remove('colab-df-spinner');\n",
              "          quickchartButtonEl.classList.add('colab-df-quickchart-complete');\n",
              "        }\n",
              "        (() => {\n",
              "          let quickchartButtonEl =\n",
              "            document.querySelector('#df-1e306307-9adb-4ad0-8ccd-db1ae6a862d3 button');\n",
              "          quickchartButtonEl.style.display =\n",
              "            google.colab.kernel.accessAllowed ? 'block' : 'none';\n",
              "        })();\n",
              "      </script>\n",
              "    </div>\n",
              "\n",
              "  <div id=\"id_6289aa33-0aae-4100-96fb-bc732fb7b2ed\">\n",
              "    <style>\n",
              "      .colab-df-generate {\n",
              "        background-color: #E8F0FE;\n",
              "        border: none;\n",
              "        border-radius: 50%;\n",
              "        cursor: pointer;\n",
              "        display: none;\n",
              "        fill: #1967D2;\n",
              "        height: 32px;\n",
              "        padding: 0 0 0 0;\n",
              "        width: 32px;\n",
              "      }\n",
              "\n",
              "      .colab-df-generate:hover {\n",
              "        background-color: #E2EBFA;\n",
              "        box-shadow: 0px 1px 2px rgba(60, 64, 67, 0.3), 0px 1px 3px 1px rgba(60, 64, 67, 0.15);\n",
              "        fill: #174EA6;\n",
              "      }\n",
              "\n",
              "      [theme=dark] .colab-df-generate {\n",
              "        background-color: #3B4455;\n",
              "        fill: #D2E3FC;\n",
              "      }\n",
              "\n",
              "      [theme=dark] .colab-df-generate:hover {\n",
              "        background-color: #434B5C;\n",
              "        box-shadow: 0px 1px 3px 1px rgba(0, 0, 0, 0.15);\n",
              "        filter: drop-shadow(0px 1px 2px rgba(0, 0, 0, 0.3));\n",
              "        fill: #FFFFFF;\n",
              "      }\n",
              "    </style>\n",
              "    <button class=\"colab-df-generate\" onclick=\"generateWithVariable('all_metrics_df')\"\n",
              "            title=\"Generate code using this dataframe.\"\n",
              "            style=\"display:none;\">\n",
              "\n",
              "  <svg xmlns=\"http://www.w3.org/2000/svg\" height=\"24px\"viewBox=\"0 0 24 24\"\n",
              "       width=\"24px\">\n",
              "    <path d=\"M7,19H8.4L18.45,9,17,7.55,7,17.6ZM5,21V16.75L18.45,3.32a2,2,0,0,1,2.83,0l1.4,1.43a1.91,1.91,0,0,1,.58,1.4,1.91,1.91,0,0,1-.58,1.4L9.25,21ZM18.45,9,17,7.55Zm-12,3A5.31,5.31,0,0,0,4.9,8.1,5.31,5.31,0,0,0,1,6.5,5.31,5.31,0,0,0,4.9,4.9,5.31,5.31,0,0,0,6.5,1,5.31,5.31,0,0,0,8.1,4.9,5.31,5.31,0,0,0,12,6.5,5.46,5.46,0,0,0,6.5,12Z\"/>\n",
              "  </svg>\n",
              "    </button>\n",
              "    <script>\n",
              "      (() => {\n",
              "      const buttonEl =\n",
              "        document.querySelector('#id_6289aa33-0aae-4100-96fb-bc732fb7b2ed button.colab-df-generate');\n",
              "      buttonEl.style.display =\n",
              "        google.colab.kernel.accessAllowed ? 'block' : 'none';\n",
              "\n",
              "      buttonEl.onclick = () => {\n",
              "        google.colab.notebook.generateWithVariable('all_metrics_df');\n",
              "      }\n",
              "      })();\n",
              "    </script>\n",
              "  </div>\n",
              "\n",
              "    </div>\n",
              "  </div>\n"
            ],
            "application/vnd.google.colaboratory.intrinsic+json": {
              "type": "dataframe",
              "variable_name": "all_metrics_df",
              "summary": "{\n  \"name\": \"all_metrics_df\",\n  \"rows\": 24,\n  \"fields\": [\n    {\n      \"column\": \"baseline\",\n      \"properties\": {\n        \"dtype\": \"category\",\n        \"num_unique_values\": 6,\n        \"samples\": [\n          \"identity_ridge\",\n          \"diagonal_ridge\",\n          \"random_spd\"\n        ],\n        \"semantic_type\": \"\",\n        \"description\": \"\"\n      }\n    },\n    {\n      \"column\": \"metric\",\n      \"properties\": {\n        \"dtype\": \"category\",\n        \"num_unique_values\": 4,\n        \"samples\": [\n          \"RMSE (\\u03bc\\u0302 vs \\u03bc_true)\",\n          \"Constraint residual \\u2016M\\u0302(\\u03b8*\\u2212\\u03bc\\u0302)\\u2212b\\u2016\\u2082\",\n          \"Relative Frobenius error \\u2016M_hat - M_true\\u2016_F / \\u2016M_true\\u2016_F\"\n        ],\n        \"semantic_type\": \"\",\n        \"description\": \"\"\n      }\n    },\n    {\n      \"column\": \"value\",\n      \"properties\": {\n        \"dtype\": \"number\",\n        \"std\": 65187625.42044317,\n        \"min\": 0.8166929301716152,\n        \"max\": 201581965.22997212,\n        \"num_unique_values\": 13,\n        \"samples\": [\n          1396.8821964749814,\n          1397.0700382967955,\n          0.9876823079617933\n        ],\n        \"semantic_type\": \"\",\n        \"description\": \"\"\n      }\n    }\n  ]\n}"
            }
          },
          "metadata": {}
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [],
      "metadata": {
        "id": "bLYdBJ8haGqv"
      },
      "execution_count": null,
      "outputs": []
    }
  ]
}