{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1bd3d2cd-c286-42c1-b007-34fe41895727",
   "metadata": {},
   "outputs": [],
   "source": [
    "# ============================================================\n",
    "# CBMA from Bhagwa 2025 paper on simulation data\n",
    "#  We also include BMA credible interval as extra result\n",
    "# ============================================================\n",
    "\n",
    "import time\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import matplotlib.pyplot as plt\n",
    "from dataclasses import dataclass\n",
    "from numpy.random import default_rng\n",
    "import random\n",
    "\n",
    "# Generation of data\n",
    "\n",
    "@dataclass\n",
    "class GenParams:\n",
    "    a0: float; a1: float; a2: float\n",
    "    b0: float; b1: float; b2: float\n",
    "    c0: float; c1: float; c2: float\n",
    "    sigmas: np.ndarray\n",
    "    weights: np.ndarray\n",
    "\n",
    "def draw_generator_params(seed: int = 20250814) -> GenParams:\n",
    "    rng = default_rng(seed)\n",
    "    a0 = rng.uniform(-4.0, -2.0); a1 = rng.uniform(0.8, 1.5); a2 = rng.uniform(0.5, 1.0)\n",
    "    b0 = rng.uniform(-0.5, 0.5);  b1 = rng.uniform(0.5, 1.2); b2 = rng.uniform(0.4, 1.0)\n",
    "    c0 = rng.uniform( 2.0, 4.0);  c1 = rng.uniform(-1.0,-0.3); c2 = rng.uniform(0.8, 1.5)\n",
    "    sigmas = np.array([0.5, 0.5, 0.5], float)\n",
    "    weights = np.array([1/3, 1/3, 1/3], float)\n",
    "    return GenParams(a0,a1,a2,b0,b1,b2,c0,c1,c2,sigmas,weights)\n",
    "\n",
    "def sample_from_generator(n: int, P: GenParams, seed: int) -> pd.DataFrame:\n",
    "    rng = default_rng(seed)\n",
    "    X1 = rng.uniform(-1, 2, size=n)\n",
    "    X2 = rng.normal(1.0, 0.5, size=n)\n",
    "    comp = rng.choice(3, size=n, p=P.weights)\n",
    "    mu1 = P.a0 + P.a1*X1 + P.a2*(X2**2)\n",
    "    mu2 = P.b0 + P.b1*(X1**2) + P.b2*(X1*X2)\n",
    "    mu3 = P.c0 + P.c1*(X1**2) + P.c2*np.cos(2*X2)\n",
    "    MU  = np.stack([mu1, mu2, mu3], axis=1)\n",
    "    Y = rng.normal(loc=MU[np.arange(n), comp], scale=P.sigmas[comp])\n",
    "    return pd.DataFrame({\"X1\": X1, \"X2\": X2, \"Y\": Y})\n",
    "\n",
    "\n",
    "\n",
    "# True models (means)\n",
    "def phi1(x1, x2): return np.array([1.0, x1, x2**2], float)\n",
    "def phi2(x1, x2): return np.array([1.0, x1**2, x1*x2], float)\n",
    "def phi3(x1, x2): return np.array([1.0, x1**2, np.cos(2*x2)], float)\n",
    "\n",
    "def design_phi1(X1, X2): return np.column_stack([np.ones_like(X1), X1, X2**2])\n",
    "def design_phi2(X1, X2): return np.column_stack([np.ones_like(X1), X1**2, X1*X2])\n",
    "def design_phi3(X1, X2): return np.column_stack([np.ones_like(X1), X1**2, np.cos(2*X2)])\n",
    "\n",
    "\n",
    "def phi_wrong_linear(x1,x2): return np.array([1.0, x1, x2], float)\n",
    "def design_wrong_linear(X1,X2): return np.column_stack([np.ones_like(X1), X1, X2])\n",
    "\n",
    "def phi_wrong_quad(x1,x2): return np.array([1.0, x1, x1**2], float)\n",
    "def design_wrong_quad(X1,X2): return np.column_stack([np.ones_like(X1), X1, X1**2])\n",
    "\n",
    "\n",
    "def phi_sum12(x1,x2): return np.concatenate([phi1(x1,x2), phi2(x1,x2)])\n",
    "def design_sum12(X1,X2):\n",
    "    return np.column_stack([np.ones_like(X1), X1, X2**2,\n",
    "                            np.ones_like(X1), X1**2, X1*X2])\n",
    "\n",
    "\n",
    "\n",
    "def phi_wrong_c2x1(x1,x2): return np.array([2.0, x1], float)\n",
    "def design_wrong_c2x1(X1,X2): return np.column_stack([2.0*np.ones_like(X1), X1])\n",
    "\n",
    "\n",
    "def phi_wrong_x1_x2sq(x1,x2): return np.array([x1, x2**2], float)\n",
    "def design_wrong_x1_x2sq(X1,X2): return np.column_stack([X1, X2**2])\n",
    "\n",
    "\n",
    "def phi_wrong_1_x1_cos3x2(x1,x2): return np.array([1.0, x1, np.cos(3*x2)], float)\n",
    "def design_wrong_1_x1_cos3x2(X1,X2): return np.column_stack([np.ones_like(X1), X1, np.cos(3*X2)])\n",
    "\n",
    "def log_norm_pdf(y, mean, sd):\n",
    "    z = (y - mean) / sd\n",
    "    return -0.5*np.log(2*np.pi) - np.log(sd) - 0.5*z*z\n",
    "\n",
    "\n",
    "\n",
    "def true_prior_means_vars():\n",
    "    m1 = np.array([-3.0, 1.15, 0.75], float)\n",
    "    v1 = np.array([(2.0**2)/12, (0.7**2)/12, (0.5**2)/12], float)\n",
    "    m2 = np.array([ 0.0, 0.85, 0.70], float)\n",
    "    v2 = np.array([(1.0**2)/12, (0.7**2)/12, (0.6**2)/12], float)\n",
    "    m3 = np.array([ 3.0, -0.65, 1.15], float)\n",
    "    v3 = np.array([(2.0**2)/12, (0.7**2)/12, (0.7**2)/12], float)\n",
    "    return (m1,v1), (m2,v2), (m3,v3)\n",
    "\n",
    "def broad_prior(p, var=5.0):\n",
    "    m0 = np.zeros(p, float)\n",
    "    v0 = np.ones(p, float)*var\n",
    "    return m0, v0\n",
    "\n",
    "\n",
    "\n",
    "def conjugate_posterior(X, y, m0, V0_diag, sigma):\n",
    "    s2 = sigma**2\n",
    "    V0_inv = np.diag(1.0/np.asarray(V0_diag))\n",
    "    XtX = X.T @ X\n",
    "    Xty = X.T @ y\n",
    "    V_inv = V0_inv + XtX / s2\n",
    "    V = np.linalg.inv(V_inv)\n",
    "    m = V @ (V0_inv @ m0 + Xty / s2)\n",
    "    return m, V\n",
    "\n",
    "def sample_posterior_linear(X, y, m0, V0_diag, sigma, T=200, burn=0, rng=None):\n",
    "    rng = default_rng() if rng is None else rng\n",
    "    m, V = conjugate_posterior(X, y, m0, V0_diag, sigma)\n",
    "    return rng.multivariate_normal(m, V, size=T)\n",
    "\n",
    "def log_marginal_linear_conjugate(X, y, m0, V0_diag, sigma):\n",
    "    V0 = np.diag(V0_diag)\n",
    "    n = len(y)\n",
    "    Sigma = sigma**2 * np.eye(n) + X @ V0 @ X.T\n",
    "    sign, logdet = np.linalg.slogdet(Sigma)\n",
    "    if sign <= 0:\n",
    "        logdet = np.log(np.abs(np.linalg.det(Sigma))+1e-300)\n",
    "    r = y - X @ m0\n",
    "    sol = np.linalg.solve(Sigma, r)\n",
    "    return -0.5*(n*np.log(2*np.pi) + logdet + r.T @ sol)\n",
    "\n",
    "\n",
    "\n",
    "def gibbs_mix_regression(Phi_list, y, K, m0_list, v0_list, alpha_dir, sigma_list,\n",
    "                         rng=None, n_burn=500, n_samp=200, thin=1):\n",
    "    rng = default_rng() if rng is None else rng\n",
    "    n = len(y)\n",
    "    z = rng.integers(0, K, size=n)\n",
    "    V0_inv_list = [np.diag(1.0/np.asarray(v0)) for v0 in v0_list]\n",
    "    s2 = [sig**2 for sig in sigma_list]\n",
    "\n",
    "    beta_list = [rng.multivariate_normal(m0_list[k], np.diag(v0_list[k])) for k in range(K)]\n",
    "    counts = np.bincount(z, minlength=K)\n",
    "    pi = rng.dirichlet(alpha_dir + counts)\n",
    "\n",
    "    samples = []\n",
    "    total = n_burn + n_samp*thin\n",
    "    for it in range(total):\n",
    "        for k in range(K):\n",
    "            idx = np.where(z == k)[0]\n",
    "            if len(idx) == 0:\n",
    "                V = np.linalg.inv(V0_inv_list[k])\n",
    "                m = V @ (V0_inv_list[k] @ m0_list[k])\n",
    "                beta_list[k] = rng.multivariate_normal(m, V)\n",
    "            else:\n",
    "                Phi_k = Phi_list[k][idx]; y_k = y[idx]\n",
    "                XtX = Phi_k.T @ Phi_k\n",
    "                Xty = Phi_k.T @ y_k\n",
    "                V_inv = V0_inv_list[k] + XtX / s2[k]\n",
    "                V = np.linalg.inv(V_inv)\n",
    "                m = V @ (V0_inv_list[k] @ m0_list[k] + Xty / s2[k])\n",
    "                beta_list[k] = rng.multivariate_normal(m, V)\n",
    "        counts = np.bincount(z, minlength=K)\n",
    "        pi = rng.dirichlet(alpha_dir + counts)\n",
    "        for i in range(n):\n",
    "            logps = np.empty(K)\n",
    "            for k in range(K):\n",
    "                mu = Phi_list[k][i] @ beta_list[k]\n",
    "                logps[k] = np.log(pi[k]+1e-300) + log_norm_pdf(y[i], mu, np.sqrt(s2[k]))\n",
    "            mlog = np.max(logps); p = np.exp(logps-mlog); p /= p.sum()\n",
    "            z[i] = rng.choice(K, p=p)\n",
    "        if it >= n_burn and (it - n_burn) % thin == 0:\n",
    "            samples.append((np.array(beta_list, float), pi.copy()))\n",
    "    return samples  \n",
    "\n",
    "\n",
    "\n",
    "def em_mix_regression(Phi_list, y, K, sigma_list, max_iter=200, tol=1e-6, seed=0):\n",
    "    rng = default_rng(seed)\n",
    "    n = len(y)\n",
    "    pi = np.ones(K)/K\n",
    "    p_dims = [Phi.shape[1] for Phi in Phi_list]\n",
    "    beta_list = [np.zeros(p, float) for p in p_dims]\n",
    "    s2 = [sig**2 for sig in sigma_list]\n",
    "\n",
    "    def comp_loglik(i, k):\n",
    "        mu = Phi_list[k][i] @ beta_list[k]\n",
    "        return log_norm_pdf(y[i], mu, np.sqrt(s2[k]))\n",
    "\n",
    "    logL_prev = -np.inf\n",
    "    for it in range(max_iter):\n",
    "        log_r = np.empty((n,K))\n",
    "        for i in range(n):\n",
    "            for k in range(K):\n",
    "                log_r[i,k] = np.log(pi[k]+1e-300) + comp_loglik(i,k)\n",
    "        m = np.max(log_r, axis=1, keepdims=True)\n",
    "        r = np.exp(log_r - m); r /= r.sum(axis=1, keepdims=True)\n",
    "\n",
    "        Nk = r.sum(axis=0) + 1e-12\n",
    "        pi = Nk / n\n",
    "\n",
    "        for k in range(K):\n",
    "            W = np.diag(r[:,k])\n",
    "            Phi = Phi_list[k]\n",
    "            A = Phi.T @ W @ Phi\n",
    "            b = Phi.T @ (r[:,k]*y)\n",
    "            beta_list[k] = np.linalg.solve(A + 1e-8*np.eye(A.shape[0]), b)\n",
    "\n",
    "        ll = 0.0\n",
    "        for i in range(n):\n",
    "            ll_i = 0.0\n",
    "            for k in range(K):\n",
    "                mu = Phi_list[k][i] @ beta_list[k]\n",
    "                ll_i += pi[k]*np.exp(log_norm_pdf(y[i], mu, np.sqrt(s2[k])))\n",
    "            ll += np.log(ll_i + 1e-300)\n",
    "        if np.abs(ll - logL_prev) < tol:\n",
    "            break\n",
    "        logL_prev = ll\n",
    "\n",
    "    p = sum(p_dims) + (K-1)\n",
    "    return ll, p\n",
    "\n",
    "\n",
    "\n",
    "def log_pred_dataset_linear(Y, X, betas, sigma):\n",
    "    T = len(betas); n = len(Y)\n",
    "    out = np.empty((T,n))\n",
    "    for t in range(T):\n",
    "        mu = X @ betas[t]\n",
    "        out[t] = log_norm_pdf(Y, mu, sigma)\n",
    "    return out\n",
    "\n",
    "def log_f_star_grid_linear(y_grid, x_vec, betas, sigma):\n",
    "    T = len(betas); Ny = len(y_grid)\n",
    "    mu_t = betas @ x_vec  \n",
    "    return (-(0.5*np.log(2*np.pi) + np.log(sigma))\n",
    "            - 0.5*((y_grid[None,:]-mu_t[:,None])**2)/(sigma**2))\n",
    "\n",
    "def log_pred_dataset_mixture(Y, Phi_list, samp, sigma_list):\n",
    "    T = len(samp); n = len(Y); K = len(Phi_list)\n",
    "    out = np.empty((T,n))\n",
    "    for t in range(T):\n",
    "        betas_t, pi_t = samp[t]\n",
    "        logs = []\n",
    "        for k in range(K):\n",
    "            mu = Phi_list[k] @ betas_t[k]\n",
    "            logs.append(np.log(pi_t[k]+1e-300) + log_norm_pdf(Y, mu, sigma_list[k]))\n",
    "        logs = np.stack(logs, axis=0)\n",
    "        m = np.max(logs, axis=0)\n",
    "        out[t] = m + np.log(np.sum(np.exp(logs - m), axis=0))\n",
    "    return out\n",
    "\n",
    "def log_f_star_grid_mixture(y_grid, x_phis, samp, sigma_list):\n",
    "    T = len(samp); Ny = len(y_grid); K = len(x_phis)\n",
    "    out = np.empty((T,Ny))\n",
    "    for t in range(T):\n",
    "        betas_t, pi_t = samp[t]\n",
    "        comps = []\n",
    "        for k in range(K):\n",
    "            mu = x_phis[k] @ betas_t[k]\n",
    "            comps.append(np.log(pi_t[k]+1e-300) + log_norm_pdf(y_grid, mu, sigma_list[k]))\n",
    "        comps = np.stack(comps, axis=0)\n",
    "        m = np.max(comps, axis=0)\n",
    "        out[t] = m + np.log(np.sum(np.exp(comps - m), axis=0))\n",
    "    return out\n",
    "\n",
    "\n",
    "\n",
    "def randomized_full_conformal_mask_CBMA(logf_dataset_list, logf_star_list, log_model_weights, alpha=0.2, rng=None):\n",
    "    rng = default_rng() if rng is None else rng\n",
    "    K = len(logf_dataset_list)\n",
    "    n = logf_dataset_list[0].shape[1]\n",
    "    Ny = logf_star_list[0].shape[1]\n",
    "    mask = np.zeros(Ny, dtype=bool)\n",
    "\n",
    "    for j in range(Ny):\n",
    "        log_s_i_k = []\n",
    "        log_s_star_k = []\n",
    "        log_p_y_k = []\n",
    "        for k in range(K):\n",
    "            a = logf_star_list[k][:, j] \n",
    "            mA = np.max(a); logZ = np.log(np.sum(np.exp(a - mA))) + mA\n",
    "            log_w = a - logZ\n",
    "            M = log_w[:, None] + logf_dataset_list[k] \n",
    "            mcol = np.max(M, axis=0)\n",
    "            log_s_i = mcol + np.log(np.sum(np.exp(M - mcol), axis=0))\n",
    "            log_s_star = mA + np.log(np.sum(np.exp((a - logZ) + (a - mA))))\n",
    "            log_s_i_k.append(log_s_i)\n",
    "            log_s_star_k.append(log_s_star)\n",
    "            log_p_y_k.append(logZ - np.log(len(a)))  \n",
    "\n",
    "        log_qk = log_model_weights + np.array(log_p_y_k)\n",
    "        m = np.max(log_qk); q = np.exp(log_qk - m); q = q / q.sum()\n",
    "        log_q = np.log(q + 1e-300)\n",
    "\n",
    "        \n",
    "        log_s_i_CBMA = None\n",
    "        for k in range(K):\n",
    "            term = log_q[k] + log_s_i_k[k]\n",
    "            log_s_i_CBMA = term if (log_s_i_CBMA is None) else \\\n",
    "                           np.maximum(log_s_i_CBMA, term) + np.log1p(np.exp(np.minimum(term, log_s_i_CBMA)-np.maximum(term, log_s_i_CBMA)))\n",
    "        terms_star = np.array([log_q[k] + log_s_star_k[k] for k in range(K)])\n",
    "        ms = np.max(terms_star); log_s_star_CBMA = ms + np.log(np.sum(np.exp(terms_star - ms)))\n",
    "\n",
    "        lt = np.sum(log_s_i_CBMA <  log_s_star_CBMA)\n",
    "        eq = np.sum(log_s_i_CBMA == log_s_star_CBMA)\n",
    "        u = rng.random()\n",
    "        pval = (1.0 + lt + u*eq) / (n + 1.0)\n",
    "        mask[j] = (pval > alpha)\n",
    "    return mask\n",
    "\n",
    "def pvalue_at_y_true_CBMA(logf_dataset_list, logf_star_true_list, log_model_weights, alpha=0.2, rng=None):\n",
    "    rng = default_rng() if rng is None else rng\n",
    "    K = len(logf_dataset_list)\n",
    "    n = logf_dataset_list[0].shape[1]\n",
    "\n",
    "    log_s_i_k = []\n",
    "    log_s_star_k = []\n",
    "    log_p_y_k = []\n",
    "    for k in range(K):\n",
    "        a = logf_star_true_list[k]  \n",
    "        mA = np.max(a); logZ = np.log(np.sum(np.exp(a - mA))) + mA\n",
    "        log_w = a - logZ\n",
    "        M = log_w[:, None] + logf_dataset_list[k]\n",
    "        mcol = np.max(M, axis=0)\n",
    "        log_s_i = mcol + np.log(np.sum(np.exp(M - mcol), axis=0))\n",
    "        log_s_star = mA + np.log(np.sum(np.exp((a - logZ) + (a - mA))))\n",
    "        log_s_i_k.append(log_s_i)\n",
    "        log_s_star_k.append(log_s_star)\n",
    "        log_p_y_k.append(logZ - np.log(len(a)))\n",
    "    log_qk = log_model_weights + np.array(log_p_y_k)\n",
    "    m = np.max(log_qk); q = np.exp(log_qk - m); q = q/q.sum()\n",
    "    log_q = np.log(q + 1e-300)\n",
    "\n",
    "    log_s_i_CBMA = None\n",
    "    for k in range(K):\n",
    "        term = log_q[k] + log_s_i_k[k]\n",
    "        log_s_i_CBMA = term if (log_s_i_CBMA is None) else \\\n",
    "                       np.maximum(log_s_i_CBMA, term) + np.log1p(np.exp(np.minimum(term, log_s_i_CBMA)-np.maximum(term, log_s_i_CBMA)))\n",
    "    terms_star = np.array([log_q[k] + log_s_star_k[k] for k in range(K)])\n",
    "    ms = np.max(terms_star); log_s_star_CBMA = ms + np.log(np.sum(np.exp(terms_star - ms)))\n",
    "    lt = np.sum(log_s_i_CBMA <  log_s_star_CBMA)\n",
    "    eq = np.sum(log_s_i_CBMA == log_s_star_CBMA)\n",
    "    u = rng.random()\n",
    "    return (1.0 + lt + u*eq) / (n + 1.0)\n",
    "\n",
    "\n",
    "\n",
    "def run_experiment3_CBMA(\n",
    "    n_list=(100,300, 600, 1000),\n",
    "    E=10, alpha=0.2, grid_size=500, m_test=100,\n",
    "    T_samples=2000, burn_mix=500,\n",
    "    seed0=random.randint(1,1000000)\n",
    "):\n",
    "    rng = default_rng(seed0)\n",
    "    sigma_true = 0.5\n",
    "\n",
    "    \n",
    "    P = draw_generator_params(seed0)\n",
    "\n",
    "    \n",
    "    pilot = sample_from_generator(3000, P, seed0+7)\n",
    "    y_grid = np.linspace(pilot[\"Y\"].min()-1.0, pilot[\"Y\"].max()+1.0, grid_size)\n",
    "    dy = y_grid[1] - y_grid[0]\n",
    "\n",
    "    \n",
    "    (m1,v1),(m2,v2),(m3,v3) = true_prior_means_vars()\n",
    "    alpha3 = np.array([50.0,50.0,50.0])\n",
    "    \n",
    "    mW3, vW3 = broad_prior(3, var=5.0)\n",
    "    mW2, vW2 = broad_prior(2, var=5.0)\n",
    "    \n",
    "    m12 = np.concatenate([m1,m2])\n",
    "    v12 = np.concatenate([v1,v2])\n",
    "\n",
    "    records = []\n",
    "    t0 = time.time()\n",
    "\n",
    "    for n in n_list:\n",
    "        t_n = time.time()\n",
    "        stats = { \"CBMA1\": {\"cov\":[], \"len\":[]},\n",
    "                  \"CBMA2\": {\"cov\":[], \"len\":[]},\n",
    "                  \"CBMA3\": {\"cov\":[], \"len\":[]} }\n",
    "\n",
    "        for rep in range(E):\n",
    "            seed_base = (n*97 + rep*313 + seed0) % (2**32 - 1)\n",
    "            D = sample_from_generator(n, P, seed_base)\n",
    "            X1, X2, Y = D[\"X1\"].to_numpy(), D[\"X2\"].to_numpy(), D[\"Y\"].to_numpy()\n",
    "\n",
    "            \n",
    "            Phi1 = design_phi1(X1,X2)\n",
    "            Phi2 = design_phi2(X1,X2)\n",
    "            Phi3 = design_phi3(X1,X2)\n",
    "            X_lin  = design_wrong_linear(X1,X2)\n",
    "            X_quad = design_wrong_quad(X1,X2)\n",
    "            X_12   = design_sum12(X1,X2)\n",
    "\n",
    "            \n",
    "            X_c2x1     = design_wrong_c2x1(X1,X2)         \n",
    "            X_x1_x2sq  = design_wrong_x1_x2sq(X1,X2)      \n",
    "            X_1_x1_c3  = design_wrong_1_x1_cos3x2(X1,X2)  \n",
    "\n",
    "           \n",
    "            samp_true3 = gibbs_mix_regression(\n",
    "                [Phi1,Phi2,Phi3], Y, K=3,\n",
    "                m0_list=[m1,m2,m3], v0_list=[v1,v2,v3],\n",
    "                alpha_dir=alpha3, sigma_list=[sigma_true]*3,\n",
    "                rng=default_rng(seed_base+11),\n",
    "                n_burn=burn_mix, n_samp=T_samples, thin=1\n",
    "            )\n",
    "           \n",
    "            betas_Wlin = sample_posterior_linear(X_lin, Y, mW3, vW3, sigma_true,\n",
    "                                                 T=T_samples, rng=default_rng(seed_base+12))\n",
    "            \n",
    "            betas_Wquad = sample_posterior_linear(X_quad, Y, mW3, vW3, sigma_true,\n",
    "                                                  T=T_samples, rng=default_rng(seed_base+13))\n",
    "            \n",
    "            betas_12 = sample_posterior_linear(X_12, Y, m12, v12, sigma_true,\n",
    "                                               T=T_samples, rng=default_rng(seed_base+15))\n",
    "\n",
    "            \n",
    "            betas_c2x1    = sample_posterior_linear(X_c2x1,    Y, *broad_prior(2, var=5.0), sigma_true,\n",
    "                                                    T=T_samples, rng=default_rng(seed_base+16))\n",
    "            betas_x1_x2sq = sample_posterior_linear(X_x1_x2sq, Y, *broad_prior(2, var=5.0), sigma_true,\n",
    "                                                    T=T_samples, rng=default_rng(seed_base+17))\n",
    "            betas_1_x1_c3 = sample_posterior_linear(X_1_x1_c3, Y, mW3, vW3, sigma_true,\n",
    "                                                    T=T_samples, rng=default_rng(seed_base+18))\n",
    "\n",
    "            \n",
    "            logf_true3 = log_pred_dataset_mixture(Y, [Phi1,Phi2,Phi3], samp_true3, [sigma_true]*3)\n",
    "            logf_Wlin  = log_pred_dataset_linear(Y, X_lin,  betas_Wlin,  sigma_true)\n",
    "            logf_Wquad = log_pred_dataset_linear(Y, X_quad, betas_Wquad, sigma_true)\n",
    "            logf_12    = log_pred_dataset_linear(Y, X_12,   betas_12,    sigma_true)\n",
    "\n",
    "            logf_c2x1    = log_pred_dataset_linear(Y, X_c2x1,    betas_c2x1,    sigma_true)\n",
    "            logf_x1_x2sq = log_pred_dataset_linear(Y, X_x1_x2sq, betas_x1_x2sq, sigma_true)\n",
    "            logf_1_x1_c3 = log_pred_dataset_linear(Y, X_1_x1_c3, betas_1_x1_c3, sigma_true)\n",
    "\n",
    "            \n",
    "            logm_Wlin  = log_marginal_linear_conjugate(X_lin,  Y, *broad_prior(3, var=5.0), sigma_true)\n",
    "            logm_Wquad = log_marginal_linear_conjugate(X_quad, Y, *broad_prior(3, var=5.0), sigma_true)\n",
    "            logm_12    = log_marginal_linear_conjugate(X_12,   Y, m12, v12, sigma_true)\n",
    "\n",
    "            logm_c2x1    = log_marginal_linear_conjugate(X_c2x1,    Y, *broad_prior(2, var=5.0), sigma_true)\n",
    "            logm_x1_x2sq = log_marginal_linear_conjugate(X_x1_x2sq, Y, *broad_prior(2, var=5.0), sigma_true)\n",
    "            logm_1_x1_c3 = log_marginal_linear_conjugate(X_1_x1_c3, Y, *broad_prior(3, var=5.0), sigma_true)\n",
    "\n",
    "           \n",
    "            ll_mix, p_mix = em_mix_regression([Phi1,Phi2,Phi3], Y, 3, [sigma_true]*3,\n",
    "                                              max_iter=200, tol=1e-6, seed=seed_base+21)\n",
    "            logm_true3 = ll_mix - 0.5*p_mix*np.log(n)\n",
    "\n",
    "            \n",
    "            CBMA_sets = {\n",
    "                \"CBMA1\": {\n",
    "                    \"names\": [\"True3\", \"Wlin\"],\n",
    "                    \"logf_dataset_list\": [logf_true3, logf_Wlin],\n",
    "                    \"log_model_weights\": np.array([logm_true3, logm_Wlin], float)\n",
    "                },\n",
    "                \"CBMA2\": {\n",
    "                    \"names\": [\"Sum12\",\"Wlin\",\"Wquad\"],\n",
    "                    \"logf_dataset_list\": [logf_12, logf_Wlin, logf_Wquad],\n",
    "                    \"log_model_weights\": np.array([logm_12, logm_Wlin, logm_Wquad], float)\n",
    "                },\n",
    "                \"CBMA3\": {\n",
    "                    \"names\": [\"C2X1\",\"X1_X2sq\",\"OneX1Cos3\"],\n",
    "                    \"logf_dataset_list\": [logf_c2x1, logf_x1_x2sq, logf_1_x1_c3],\n",
    "                    \"log_model_weights\": np.array([logm_c2x1, logm_x1_x2sq, logm_1_x1_c3], float)\n",
    "                }\n",
    "            }\n",
    "\n",
    "            \n",
    "            Dtest = sample_from_generator(m_test, P, seed_base+3)\n",
    "            X1t, X2t, Yt = Dtest[\"X1\"].to_numpy(), Dtest[\"X2\"].to_numpy(), Dtest[\"Y\"].to_numpy()\n",
    "\n",
    "            for j in range(m_test):\n",
    "                x1s, x2s, y_true = X1t[j], X2t[j], Yt[j]\n",
    "\n",
    "                \n",
    "                logf_star_true3 = log_f_star_grid_mixture(y_grid,\n",
    "                    [phi1(x1s,x2s),phi2(x1s,x2s),phi3(x1s,x2s)], samp_true3, [sigma_true]*3)\n",
    "                logf_star_true3_ytrue = log_f_star_grid_mixture(np.array([y_true]),\n",
    "                    [phi1(x1s,x2s),phi2(x1s,x2s),phi3(x1s,x2s)], samp_true3, [sigma_true]*3)[:,0]\n",
    "\n",
    "                logf_star_Wlin  = log_f_star_grid_linear(y_grid, phi_wrong_linear(x1s,x2s), betas_Wlin,  sigma_true)\n",
    "                logf_star_Wlin_y = log_f_star_grid_linear(np.array([y_true]), phi_wrong_linear(x1s,x2s), betas_Wlin, sigma_true)[:,0]\n",
    "\n",
    "                logf_star_Wquad = log_f_star_grid_linear(y_grid, phi_wrong_quad(x1s,x2s),  betas_Wquad, sigma_true)\n",
    "                logf_star_Wquad_y = log_f_star_grid_linear(np.array([y_true]), phi_wrong_quad(x1s,x2s), betas_Wquad, sigma_true)[:,0]\n",
    "\n",
    "                logf_star_12    = log_f_star_grid_linear(y_grid, phi_sum12(x1s,x2s), betas_12, sigma_true)\n",
    "                logf_star_12_y  = log_f_star_grid_linear(np.array([y_true]), phi_sum12(x1s,x2s), betas_12, sigma_true)[:,0]\n",
    "\n",
    "                \n",
    "                logf_star_c2x1    = log_f_star_grid_linear(y_grid, phi_wrong_c2x1(x1s,x2s), betas_c2x1, sigma_true)\n",
    "                logf_star_c2x1_y  = log_f_star_grid_linear(np.array([y_true]), phi_wrong_c2x1(x1s,x2s), betas_c2x1, sigma_true)[:,0]\n",
    "\n",
    "                logf_star_x1_x2sq   = log_f_star_grid_linear(y_grid, phi_wrong_x1_x2sq(x1s,x2s), betas_x1_x2sq, sigma_true)\n",
    "                logf_star_x1_x2sq_y = log_f_star_grid_linear(np.array([y_true]), phi_wrong_x1_x2sq(x1s,x2s), betas_x1_x2sq, sigma_true)[:,0]\n",
    "\n",
    "                logf_star_1_x1_c3   = log_f_star_grid_linear(y_grid, phi_wrong_1_x1_cos3x2(x1s,x2s), betas_1_x1_c3, sigma_true)\n",
    "                logf_star_1_x1_c3_y = log_f_star_grid_linear(np.array([y_true]), phi_wrong_1_x1_cos3x2(x1s,x2s), betas_1_x1_c3, sigma_true)[:,0]\n",
    "\n",
    "                for key in [\"CBMA1\",\"CBMA2\",\"CBMA3\"]:\n",
    "                    info = CBMA_sets[key]\n",
    "                    # \n",
    "                    lmw = info[\"log_model_weights\"]\n",
    "                    mW = np.max(lmw); log_model_weights = lmw - (mW + np.log(np.sum(np.exp(lmw - mW))))\n",
    "\n",
    "                    if key==\"CBMA1\":\n",
    "                        logf_dataset_list = [logf_true3, logf_Wlin]\n",
    "                        logf_star_list    = [logf_star_true3, logf_star_Wlin]\n",
    "                        logf_star_true_list = [logf_star_true3_ytrue, logf_star_Wlin_y]\n",
    "                    elif key==\"CBMA2\":\n",
    "                        logf_dataset_list = [logf_12, logf_Wlin, logf_Wquad]\n",
    "                        logf_star_list    = [logf_star_12, logf_star_Wlin, logf_star_Wquad]\n",
    "                        logf_star_true_list = [logf_star_12_y, logf_star_Wlin_y, logf_star_Wquad_y]\n",
    "                    else:  \n",
    "                        logf_dataset_list = [logf_c2x1, logf_x1_x2sq, logf_1_x1_c3]\n",
    "                        logf_star_list    = [logf_star_c2x1, logf_star_x1_x2sq, logf_star_1_x1_c3]\n",
    "                        logf_star_true_list = [logf_star_c2x1_y, logf_star_x1_x2sq_y, logf_star_1_x1_c3_y]\n",
    "\n",
    "                    mask = randomized_full_conformal_mask_CBMA(logf_dataset_list, logf_star_list,\n",
    "                                                               log_model_weights, alpha=alpha,\n",
    "                                                               rng=default_rng(seed_base+100+j))\n",
    "                    stats[key][\"len\"].append(mask.sum()*dy)\n",
    "\n",
    "                    pval = pvalue_at_y_true_CBMA(logf_dataset_list, logf_star_true_list,\n",
    "                                                 log_model_weights, alpha=alpha,\n",
    "                                                 rng=default_rng(seed_base+200+j))\n",
    "                    stats[key][\"cov\"].append(pval > alpha)\n",
    "\n",
    "        \n",
    "        def summarize(covs, lens):\n",
    "            cov = np.array(covs, float); ln = np.array(lens, float)\n",
    "            m = len(cov)\n",
    "            return cov.mean(), cov.std(ddof=1)/np.sqrt(m), ln.mean(), ln.std(ddof=1)/np.sqrt(m)\n",
    "\n",
    "        for key in [\"CBMA1\",\"CBMA2\",\"CBMA3\"]:\n",
    "            c,se,l,sel = summarize(stats[key][\"cov\"], stats[key][\"len\"])\n",
    "            records.append({\"model\": key, \"n\": n, \"coverage_mean\": c, \"coverage_se\": se,\n",
    "                            \"length_mean\": l, \"length_se\": sel,\n",
    "                            \"elapsed_sec_this_n\": round(time.time()-t_n,2)})\n",
    "\n",
    "    total_elapsed = time.time()-t0\n",
    "    return pd.DataFrame(records), y_grid, total_elapsed\n",
    "\n",
    "\n",
    "\n",
    "if __name__ == \"__main__\":\n",
    "    res, y_grid, total_time = run_experiment3_CBMA(\n",
    "        n_list=(100,300, 600, 1000),\n",
    "        E=10, alpha=0.2, grid_size=500, m_test=100,\n",
    "        T_samples=2000, burn_mix=500,\n",
    "        seed0=random.randint(1,1000000)\n",
    "    )\n",
    "\n",
    "\n",
    "    plt.figure()\n",
    "    for model in [\"CBMA1\",\"CBMA2\",\"CBMA3\"]:\n",
    "        dfm = res[res[\"model\"]==model].sort_values(\"n\")\n",
    "        plt.errorbar(dfm[\"n\"], dfm[\"coverage_mean\"], yerr=dfm[\"coverage_se\"],\n",
    "                     marker=\"o\", capsize=3, label=model)\n",
    "    plt.axhline(0.8, linestyle=\"--\")\n",
    "    plt.title(\"CBMA on Simulation Data Average Coverage vs n\")\n",
    "    plt.xlabel(\"n\"); plt.ylabel(\"Coverage Rate\"); plt.legend(); plt.show()\n",
    "\n",
    "    plt.figure()\n",
    "    for model in [\"CBMA1\",\"CBMA2\",\"CBMA3\"]:\n",
    "        dfm = res[res[\"model\"]==model].sort_values(\"n\")\n",
    "        plt.errorbar(dfm[\"n\"], dfm[\"length_mean\"], yerr=dfm[\"length_se\"],\n",
    "                     marker=\"o\", capsize=3, label=model)\n",
    "    plt.title(\"CBMA on Simulation Data Average Set Length vs n\")\n",
    "    plt.xlabel(\"n\"); plt.ylabel(\"Average Set Length\"); plt.legend(); plt.show()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b9de5a52-af08-453e-95ba-7e9221b15e44",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
