{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8a67417f-3bfa-4a61-a8ba-8a7066cb504c",
   "metadata": {},
   "outputs": [],
   "source": [
    "# ============================================================\n",
    "# CBMA from Bhagwa 2025 paper on california housing data\n",
    "#  We also include BMA credible interval as extra result\n",
    "# ============================================================\n",
    "\n",
    "import time\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import matplotlib.pyplot as plt\n",
    "import random\n",
    "\n",
    "from scipy.special import logsumexp, gammaln\n",
    "from scipy.stats import norm\n",
    "from sklearn.datasets import fetch_california_housing\n",
    "from sklearn.preprocessing import PolynomialFeatures\n",
    "\n",
    "\n",
    "# Define some functions\n",
    "\n",
    "\n",
    "def zscore(a, axis=0, eps=1e-12):\n",
    "    m = np.mean(a, axis=axis, keepdims=True)\n",
    "    s = np.std(a, axis=axis, keepdims=True)\n",
    "    return (a - m) / (s + eps), m, s\n",
    "\n",
    "def normal_logpdf(y, mean, var):\n",
    "    \n",
    "    return -0.5 * (np.log(2*np.pi*var) + (y - mean)**2 / var)\n",
    "\n",
    "def invgamma_sample(rng, alpha, beta, size=None):\n",
    "    \n",
    "    g = rng.gamma(shape=alpha, scale=1.0/beta, size=size)\n",
    "    return 1.0 / g\n",
    "\n",
    "\n",
    "# Data loading \n",
    "\n",
    "\n",
    "def load_california_standardized(random_state=0):\n",
    "    X, y = fetch_california_housing(return_X_y=True)\n",
    "    Xz, Xm, Xs = zscore(X, axis=0)\n",
    "    yz, ym, ys = zscore(y, axis=0)\n",
    "    return Xz, yz, (Xm, Xs, ym, ys)\n",
    "\n",
    "def sample_train_test(X, y, n_train, m_test, seed):\n",
    "    rng = np.random.default_rng(seed)\n",
    "    N = X.shape[0]\n",
    "    idx = rng.choice(N, size=n_train + m_test, replace=False)\n",
    "    tr_idx = idx[:n_train]\n",
    "    te_idx = idx[n_train:]\n",
    "    return X[tr_idx], y[tr_idx], X[te_idx], y[te_idx]\n",
    "\n",
    "\n",
    "# candidate models from their code\n",
    "\n",
    "\n",
    "def design_full(X):\n",
    "    \n",
    "    n = X.shape[0]\n",
    "    return np.column_stack([np.ones(n), X])\n",
    "\n",
    "def design_small01(X):\n",
    "    \n",
    "    n = X.shape[0]\n",
    "    return np.column_stack([np.ones(n), X[:, [0, 1]]])\n",
    "\n",
    "_poly = PolynomialFeatures(degree=2, include_bias=True)  \n",
    "def design_quadratic(X):\n",
    "    # \n",
    "    return _poly.fit_transform(X)\n",
    "\n",
    "def design_pair23(X):\n",
    "    n = X.shape[0]\n",
    "    return np.column_stack([np.ones(n), X[:, [2, 3]]])\n",
    "\n",
    "def design_pair45(X):\n",
    "    n = X.shape[0]\n",
    "    return np.column_stack([np.ones(n), X[:, [4, 5]]])\n",
    "\n",
    "\n",
    "\n",
    "def posterior_params_conjugate(X, y, m0, V0, a0, b0):\n",
    "    n, p = X.shape\n",
    "    V0_inv = np.linalg.inv(V0)\n",
    "    XtX = X.T @ X\n",
    "    Xty = X.T @ y\n",
    "    Vn_inv = V0_inv + XtX\n",
    "    Vn = np.linalg.inv(Vn_inv)\n",
    "    mn = Vn @ (V0_inv @ m0 + Xty)\n",
    "    an = a0 + 0.5 * n\n",
    "    term = y @ y + m0.T @ V0_inv @ m0 - mn.T @ Vn_inv @ mn\n",
    "    bn = b0 + 0.5 * float(term)\n",
    "    return mn, Vn, an, bn\n",
    "\n",
    "def log_marginal_likelihood_conjugate(X, y, m0, V0, a0, b0):\n",
    "    n, p = X.shape\n",
    "    mn, Vn, an, bn = posterior_params_conjugate(X, y, m0, V0, a0, b0)\n",
    "    \n",
    "    sign0, logdetV0 = np.linalg.slogdet(V0)\n",
    "    signn, logdetVn = np.linalg.slogdet(Vn)\n",
    "    if sign0 <= 0 or signn <= 0:\n",
    "        \n",
    "        sign0, logdetV0 = np.linalg.slogdet(V0 + 1e-12*np.eye(p))\n",
    "        signn, logdetVn = np.linalg.slogdet(Vn + 1e-12*np.eye(p))\n",
    "    \n",
    "    logml = (\n",
    "        -0.5 * n * np.log(2*np.pi)\n",
    "        - 0.5 * logdetV0\n",
    "        + 0.5 * logdetVn\n",
    "        + a0 * np.log(b0)\n",
    "        - an * np.log(bn)\n",
    "        + gammaln(an) - gammaln(a0)\n",
    "    )\n",
    "    return float(logml)\n",
    "\n",
    "def sample_posterior_linear_conjugate(rng, X, y, m0=None, V0=None, a0=1.0, b0=1.0, T=200):\n",
    "    n, p = X.shape\n",
    "    if m0 is None: m0 = np.zeros(p)\n",
    "    if V0 is None: V0 = np.eye(p) * 10.0  \n",
    "    mn, Vn, an, bn = posterior_params_conjugate(X, y, m0, V0, a0, b0)\n",
    "    sig2 = invgamma_sample(rng, an, bn, size=T)       \n",
    "    L = np.linalg.cholesky(Vn)\n",
    "    zn = rng.standard_normal(size=(T, p))\n",
    "    betas = mn[None, :] + (np.sqrt(sig2)[:, None]) * (zn @ L.T)  \n",
    "    return betas, sig2\n",
    "\n",
    "\n",
    "# AOI from their paper\n",
    "\n",
    "\n",
    "def aoi_scores_per_model(betas, sig2, X_train, y_train, x_star, y_grid):\n",
    "   \n",
    "    T = betas.shape[0]\n",
    "    n = X_train.shape[0]\n",
    "    Ny = len(y_grid)\n",
    "\n",
    "    \n",
    "    mu_train = (X_train @ betas.T).T                     \n",
    "    loglik_train = normal_logpdf(y_train[None, :], mu_train, sig2[:, None])  \n",
    "    E_train = np.exp(np.clip(loglik_train, -700, 50))    \n",
    "\n",
    "    \n",
    "    mu_star = x_star @ betas.T                          \n",
    "    loglik_star = normal_logpdf(y_grid[None, :], mu_star[:, None], sig2[:, None])  \n",
    "    f_star = np.exp(np.clip(loglik_star, -700, 50))      \n",
    "\n",
    "    \n",
    "    Z = np.sum(f_star, axis=0, keepdims=True) + 1e-300   \n",
    "    W = f_star / Z                                       \n",
    "\n",
    "    \n",
    "    s_train_mat = (E_train.T @ W)                        \n",
    "\n",
    "    \n",
    "    s_star_vec = (np.sum(f_star * f_star, axis=0) / (Z[0]))  \n",
    "\n",
    "    \n",
    "    log_p_k_y = logsumexp(loglik_star, axis=0) - np.log(T)   \n",
    "\n",
    "    return s_train_mat, s_star_vec, log_p_k_y\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "def randomized_full_conformal_mask_from_logs(log_s_train_CBMA, log_s_star_CBMA, alpha, rng):\n",
    "\n",
    "    n, Ny = log_s_train_CBMA.shape\n",
    "    mask = np.zeros(Ny, dtype=bool)\n",
    "    for j in range(Ny):\n",
    "        sj = log_s_train_CBMA[:, j]\n",
    "        sstar = log_s_star_CBMA[j]\n",
    "        # \n",
    "        lt = np.sum(sj < sstar)\n",
    "        eq = np.sum(sj == sstar)\n",
    "        u = rng.uniform()\n",
    "        pval = (1.0 + lt + u*eq) / (n + 1.0)\n",
    "        mask[j] = (pval > alpha)\n",
    "    return mask\n",
    "\n",
    "\n",
    "# BMA credible intervl also computed here\n",
    "\n",
    "\n",
    "def bma_credible_band(y_grid, cdf_list_per_model, log_model_post):\n",
    "   \n",
    "    pass  \n",
    "\n",
    "\n",
    "\n",
    "def run_experiment3_california(\n",
    "    n_list=(100, 300, 600, 1000),\n",
    "    E=5,\n",
    "    alpha=0.2,\n",
    "    grid_size=400,\n",
    "    m_test=100,\n",
    "    T_samples=200,\n",
    "    seed0=2025\n",
    "):\n",
    "    \n",
    "    rng_master = np.random.default_rng(seed0)\n",
    "    X, y, _ = load_california_standardized()\n",
    "\n",
    "    \n",
    "    CBMA_pools = {\n",
    "        \"CBMA1\": [(\"Full\", design_full), (\"Small01\", design_small01)],\n",
    "        \"CBMA2\": [(\"Full\", design_full), (\"Quadratic\", design_quadratic), (\"Small01\", design_small01)],\n",
    "        \"CBMA3\": [(\"Small01\", design_small01), (\"Pair23\", design_pair23), (\"Pair45\", design_pair45)],\n",
    "    }\n",
    "\n",
    "    rows = []\n",
    "    t0 = time.time()\n",
    "\n",
    "    for n in n_list:\n",
    "        t_n0 = time.time()\n",
    "\n",
    "        for pool_name, models in CBMA_pools.items():\n",
    "            coverages_cbma = []\n",
    "            lengths_cbma = []\n",
    "            coverages_bma = []\n",
    "            lengths_bma = []\n",
    "            times_cbma_rep = []\n",
    "            times_bma_rep = []\n",
    "\n",
    "            for rep in range(E):\n",
    "                seed_rep = int(rng_master.integers(0, 2**31-1))\n",
    "                Xtr_raw, ytr, Xte_raw, yte = sample_train_test(X, y, n_train=n, m_test=m_test, seed=seed_rep)\n",
    "\n",
    "                \n",
    "                y_grid = np.linspace(ytr.min() - 1.0, ytr.max() + 1.0, grid_size)\n",
    "                dy = y_grid[1] - y_grid[0]\n",
    "                rng_rep = np.random.default_rng(seed_rep + 77)\n",
    "\n",
    "                \n",
    "                model_objs = [] \n",
    "                logml_list = []\n",
    "                for (mname, design_fn) in models:\n",
    "                    Xtr = design_fn(Xtr_raw)\n",
    "                    p = Xtr.shape[1]\n",
    "                    m0 = np.zeros(p)\n",
    "                    V0 = np.eye(p) * 10.0\n",
    "                    a0, b0 = 1.0, 1.0\n",
    "\n",
    "                    \n",
    "                    betas, sig2 = sample_posterior_linear_conjugate(rng_rep, Xtr, ytr, m0=m0, V0=V0, a0=a0, b0=b0, T=T_samples)\n",
    "\n",
    "                    \n",
    "                    logml = log_marginal_likelihood_conjugate(Xtr, ytr, m0, V0, a0, b0)\n",
    "                    logml_list.append(logml)\n",
    "\n",
    "                    \n",
    "                    model_objs.append({\n",
    "                        \"name\": mname,\n",
    "                        \"design_fn\": design_fn,\n",
    "                        \"Xtr\": Xtr,\n",
    "                        \"betas\": betas,\n",
    "                        \"sig2\": sig2\n",
    "                    })\n",
    "\n",
    "                logml_arr = np.array(logml_list)  \n",
    "                log_pMk = logml_arr - logsumexp(logml_arr)  \n",
    "\n",
    "                \n",
    "                cov_flags_cbma = []\n",
    "                len_vals_cbma = []\n",
    "                cov_flags_bma  = []\n",
    "                len_vals_bma   = []\n",
    "\n",
    "                t_rep_cbma0 = time.time()\n",
    "                for j in range(m_test):\n",
    "                    x_raw = Xte_raw[j]\n",
    "                    y_true = yte[j]\n",
    "\n",
    "                    \n",
    "                    per_model = []\n",
    "                    cdf_models = []  \n",
    "                    for mod in model_objs:\n",
    "                        Xtr = mod[\"Xtr\"]\n",
    "                        betas = mod[\"betas\"]\n",
    "                        sig2 = mod[\"sig2\"]\n",
    "                        \n",
    "                        x_star = mod[\"design_fn\"](x_raw.reshape(1, -1)).reshape(-1)  \n",
    "\n",
    "                        \n",
    "                        s_train_mat, s_star_vec, log_p_k_y = aoi_scores_per_model(\n",
    "                            betas, sig2, Xtr, ytr, x_star, y_grid\n",
    "                        )\n",
    "                        per_model.append({\n",
    "                            \"s_train_mat\": s_train_mat,      \n",
    "                            \"s_star_vec\":  s_star_vec,       \n",
    "                            \"log_p_k_y\":   log_p_k_y         \n",
    "                        })\n",
    "\n",
    "                        \n",
    "                        mu_star = x_star @ betas.T               \n",
    "                        cdf_k = norm.cdf(y_grid[None, :], loc=mu_star[:, None], scale=np.sqrt(sig2)[:, None]).mean(axis=0)  \n",
    "                        cdf_models.append(cdf_k)\n",
    "\n",
    "                    \n",
    "                    K = len(per_model)\n",
    "                    Ny = len(y_grid)\n",
    "                    log_q = np.empty((K, Ny))\n",
    "                    for k in range(K):\n",
    "                        log_q[k, :] = log_pMk[k] + per_model[k][\"log_p_k_y\"] \n",
    "                    log_norm = logsumexp(log_q, axis=0, keepdims=True)        \n",
    "                    log_q -= log_norm                                          \n",
    "\n",
    "                    \n",
    "                    n_train = per_model[0][\"s_train_mat\"].shape[0]\n",
    "                    log_s_train_CBMA = np.empty((n_train, Ny))\n",
    "                    log_s_star_CBMA  = np.empty(Ny)\n",
    "\n",
    "                    \n",
    "                    log_s_train_k = [np.log(m[\"s_train_mat\"] + 1e-300) for m in per_model]  \n",
    "                    log_s_star_k  = [np.log(m[\"s_star_vec\"]  + 1e-300) for m in per_model]  \n",
    "\n",
    "                    for jj in range(Ny):\n",
    "                        # \n",
    "                        stack_train = np.stack([log_q[k, jj] + log_s_train_k[k][:, jj] for k in range(K)], axis=0)  \n",
    "                        log_s_train_CBMA[:, jj] = logsumexp(stack_train, axis=0)  \n",
    "\n",
    "                        stack_star = np.array([log_q[k, jj] + log_s_star_k[k][jj] for k in range(K)])  \n",
    "                        log_s_star_CBMA[jj] = logsumexp(stack_star)\n",
    "\n",
    "                    \n",
    "                    mask_cbma = randomized_full_conformal_mask_from_logs(\n",
    "                        log_s_train_CBMA, log_s_star_CBMA, alpha, rng_rep\n",
    "                    )\n",
    "\n",
    "                    \n",
    "                    idx_true = int(np.clip(np.searchsorted(y_grid, y_true), 0, Ny-1))\n",
    "                    cov_flags_cbma.append(bool(mask_cbma[idx_true]))\n",
    "                    len_vals_cbma.append(mask_cbma.sum() * dy)\n",
    "\n",
    "                    \n",
    "                    mix_cdf = np.sum(np.exp(log_pMk)[:, None] * np.vstack(cdf_models), axis=0)  \n",
    "                   \n",
    "                    lo = y_grid[np.argmin(np.abs(mix_cdf - alpha/2))]\n",
    "                    hi = y_grid[np.argmin(np.abs(mix_cdf - (1 - alpha/2)))]\n",
    "                    cov_flags_bma.append((y_true >= lo) and (y_true <= hi))\n",
    "                    len_vals_bma.append(abs(hi - lo))\n",
    "\n",
    "                times_cbma_rep.append(time.time() - t_rep_cbma0)\n",
    "                times_bma_rep.append(0.0)  \n",
    "\n",
    "            \n",
    "            cov_cb = np.array(cov_flags_cbma).reshape(E, -1).mean(axis=1)\n",
    "            len_cb = np.array(len_vals_cbma).reshape(E, -1).mean(axis=1)\n",
    "            cov_bm = np.array(cov_flags_bma ).reshape(E, -1).mean(axis=1)\n",
    "            len_bm = np.array(len_vals_bma ).reshape(E, -1).mean(axis=1)\n",
    "\n",
    "            rows.append({\n",
    "                \"cbma\": pool_name,\n",
    "                \"n\": n,\n",
    "                \"coverage_cbma_mean\": cov_cb.mean(),\n",
    "                \"coverage_cbma_se\":   cov_cb.std(ddof=1)/np.sqrt(E),\n",
    "                \"length_cbma_mean\":   len_cb.mean(),\n",
    "                \"length_cbma_se\":     len_cb.std(ddof=1)/np.sqrt(E),\n",
    "                \"coverage_bma_mean\":  cov_bm.mean(),\n",
    "                \"coverage_bma_se\":    cov_bm.std(ddof=1)/np.sqrt(E),\n",
    "                \"length_bma_mean\":    len_bm.mean(),\n",
    "                \"length_bma_se\":      len_bm.std(ddof=1)/np.sqrt(E),\n",
    "                \"elapsed_sec_cbma_sum\": round(np.sum(times_cbma_rep), 2)\n",
    "            })\n",
    "\n",
    "           \n",
    "                  f\"CBMA cov={cov_cb.mean():.3f}±{cov_cb.std(ddof=1)/np.sqrt(E):.3f} | \"\n",
    "                  f\"len={len_cb.mean():.3f}±{len_cb.std(ddof=1)/np.sqrt(E):.3f} || \"\n",
    "                  f\"BMA cov={cov_bm.mean():.3f}±{cov_bm.std(ddof=1)/np.sqrt(E):.3f} | \"\n",
    "                  f\"len={len_bm.mean():.3f}±{len_bm.std(ddof=1)/np.sqrt(E):.3f}\")\n",
    "\n",
    "    \n",
    "    plt.figure()\n",
    "    for pool in sorted(res[\"cbma\"].unique()):\n",
    "        dfp = res[res[\"cbma\"]==pool].sort_values(\"n\")\n",
    "        plt.errorbar(dfp[\"n\"], dfp[\"coverage_cbma_mean\"], yerr=dfp[\"coverage_cbma_se\"],\n",
    "                     marker=\"o\", capsize=4, label=f\"{pool} (CBMA)\")\n",
    "    plt.axhline(0.8, linestyle=\"--\", alpha=0.7)\n",
    "    plt.title(\"CBMA on California Data Average Coverage vs n\")\n",
    "    plt.xlabel(\"n\")\n",
    "    plt.ylabel(\"Coverage Rate\")\n",
    "    plt.legend()\n",
    "    plt.tight_layout()\n",
    "    plt.show()\n",
    "\n",
    "    \n",
    "    plt.figure()\n",
    "    for pool in sorted(res[\"cbma\"].unique()):\n",
    "        dfp = res[res[\"cbma\"]==pool].sort_values(\"n\")\n",
    "        plt.errorbar(dfp[\"n\"], dfp[\"length_cbma_mean\"], yerr=dfp[\"length_cbma_se\"],\n",
    "                     marker=\"o\", capsize=4, label=f\"{pool} (CBMA)\")\n",
    "    plt.title(\"CBMA on California Data Average Set Length vs n\")\n",
    "    plt.xlabel(\"n\")\n",
    "    plt.ylabel(\"Average Set Length\")\n",
    "    plt.legend()\n",
    "    plt.tight_layout()\n",
    "    plt.show()\n",
    "\n",
    "    \n",
    "    plt.figure()\n",
    "    for pool in sorted(res[\"cbma\"].unique()):\n",
    "        dfp = res[res[\"cbma\"]==pool].sort_values(\"n\")\n",
    "        plt.errorbar(dfp[\"n\"], dfp[\"coverage_bma_mean\"], yerr=dfp[\"coverage_bma_se\"],\n",
    "                     marker=\"s\", capsize=4, label=f\"{pool} (BMA)\")\n",
    "    plt.axhline(0.8, linestyle=\"--\", alpha=0.7)\n",
    "    plt.title(\"BMA on California Data Average Coverage vs n\")\n",
    "    plt.xlabel(\"n\")\n",
    "    plt.ylabel(\"Coverage Rate\")\n",
    "    plt.legend()\n",
    "    plt.tight_layout()\n",
    "    plt.show()\n",
    "\n",
    "    \n",
    "    plt.figure()\n",
    "    for pool in sorted(res[\"cbma\"].unique()):\n",
    "        dfp = res[res[\"cbma\"]==pool].sort_values(\"n\")\n",
    "        plt.errorbar(dfp[\"n\"], dfp[\"length_bma_mean\"], yerr=dfp[\"length_bma_se\"],\n",
    "                     marker=\"s\", capsize=4, label=f\"{pool} (BMA)\")\n",
    "    plt.title(\"BMA on California Data Average Set Length vs n\")\n",
    "    plt.xlabel(\"n\")\n",
    "    plt.ylabel(\"Average Set Length\")\n",
    "    plt.legend()\n",
    "    plt.tight_layout()\n",
    "    plt.show()\n",
    "\n",
    "    return res\n",
    "\n",
    "\n",
    "# Run\n",
    "\n",
    "\n",
    "if __name__ == \"__main__\":\n",
    "    _ = run_experiment3_california(\n",
    "        n_list=(100,300, 600, 1000),  \n",
    "        E=10,                      \n",
    "        alpha=0.2,                \n",
    "        grid_size=500,            \n",
    "        m_test=100,              \n",
    "        T_samples=2000,            \n",
    "        seed0=random.randint(1,1000000)\n",
    "    )\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
