{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "23a982f1-f606-4aa4-b6ab-8849547c7f6d",
   "metadata": {},
   "outputs": [],
   "source": [
    "# ============================================================\n",
    "# Our method: CNB on California Housing — Full Conformal case\n",
    "# Conditional DP mixture with joint densities\n",
    "# ============================================================\n",
    "\n",
    "import time\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import matplotlib.pyplot as plt\n",
    "from numpy.random import default_rng\n",
    "from scipy.special import logsumexp\n",
    "import random\n",
    "\n",
    "\n",
    "\n",
    "def zscore(a, axis=0, eps=1e-12):\n",
    "    m = np.mean(a, axis=axis, keepdims=True)\n",
    "    s = np.std(a, axis=axis, keepdims=True)\n",
    "    return (a - m) / (s + eps), m, s\n",
    "\n",
    "def normal_logpdf(y, mean, var, eps=1e-32):\n",
    "    var = np.maximum(var, eps)\n",
    "    return -0.5*(np.log(2*np.pi*var) + (y-mean)**2/var)\n",
    "\n",
    "def invgamma_sample(rng, a, b, size=None):\n",
    "    #\n",
    "    g = rng.gamma(shape=a, scale=1.0/np.maximum(b, 1e-30), size=size)\n",
    "    return 1.0/np.maximum(g, 1e-30)\n",
    "\n",
    "def add_intercept(X):\n",
    "    return np.concatenate([np.ones((X.shape[0], 1)), X], axis=1)\n",
    "\n",
    "def effective_K(pi, thr=1e-3):\n",
    "    return int(np.sum(pi > thr))\n",
    "\n",
    "# data load\n",
    "\n",
    "def load_california_standardized():\n",
    "    from sklearn.datasets import fetch_california_housing\n",
    "    X, y = fetch_california_housing(return_X_y=True)\n",
    "    Xz, Xm, Xs = zscore(X, axis=0)\n",
    "    yz, ym, ys = zscore(y, axis=0)\n",
    "    return Xz, yz, (Xm, Xs, ym, ys)\n",
    "\n",
    "def sample_train_test(X, y, n_train, m_test, seed):\n",
    "    rng = default_rng(seed)\n",
    "    N = X.shape[0]\n",
    "    idx = rng.choice(N, size=n_train + m_test, replace=False)\n",
    "    tr_idx = idx[:n_train]\n",
    "    te_idx = idx[n_train:]\n",
    "    return X[tr_idx], y[tr_idx], X[te_idx], y[te_idx]\n",
    "\n",
    "\n",
    "\n",
    "class DPLRMixGibbs:\n",
    "    \n",
    "    def __init__(self, K_cap=20, alpha=1.0,\n",
    "                 a0=2.0, b0=2.0,              \n",
    "                 m0=None, V0=None,            \n",
    "                 n_iter=800, burn=300, thin=5,\n",
    "                 random_state=0):\n",
    "        self.K_cap = int(K_cap)\n",
    "        self.alpha = float(alpha)\n",
    "        self.a0 = float(a0)\n",
    "        self.b0 = float(b0)\n",
    "        self.m0 = m0     \n",
    "        self.V0 = V0     \n",
    "        self.n_iter = int(n_iter)\n",
    "        self.burn = int(burn)\n",
    "        self.thin = int(thin)\n",
    "        self.rng = default_rng(random_state)\n",
    "        self.samples_ = []\n",
    "        self.elapsed_sec_ = 0.0\n",
    "\n",
    "    def _init_from_data(self, Xtil, y):\n",
    "        n, p = Xtil.shape\n",
    "        if self.m0 is None:\n",
    "            self.m0 = np.zeros(p)\n",
    "        if self.V0 is None:\n",
    "            self.V0 = np.eye(p) * 10.0      \n",
    "        # \n",
    "        z = self.rng.integers(0, self.K_cap, size=n, dtype=np.intp)\n",
    "       \n",
    "        betas = self.rng.normal(0, 1, size=(self.K_cap, p))\n",
    "        sig2  = np.ones(self.K_cap)\n",
    "        \n",
    "        v = np.clip(self.rng.beta(1.0, self.alpha, size=self.K_cap-1), 1e-6, 1-1e-6)\n",
    "        v = np.concatenate([v, [1.0]])\n",
    "        pi = self._stick_to_weights(v)\n",
    "        return z, betas, sig2, v, pi\n",
    "\n",
    "    @staticmethod\n",
    "    def _stick_to_weights(v):\n",
    "        K = v.shape[0]\n",
    "        pi = np.empty(K)\n",
    "        prod = 1.0\n",
    "        for k in range(K):\n",
    "            pi[k] = v[k]*prod\n",
    "            if k < K-1:\n",
    "                prod *= (1.0 - v[k])\n",
    "        pi = np.clip(pi, 1e-300, 1.0)\n",
    "        return pi / pi.sum()\n",
    "\n",
    "    def _post_beta_sigma2(self, Xk, yk):\n",
    "        # do posterior\n",
    "        p = self.m0.shape[0]\n",
    "        a_n = self.a0 + 0.5*len(yk)\n",
    "        if len(yk) > 0:\n",
    "            XtX = Xk.T @ Xk\n",
    "            Vn_inv = np.linalg.inv(self.V0) + XtX\n",
    "            Vn = np.linalg.inv(Vn_inv)\n",
    "            mn = Vn @ (np.linalg.inv(self.V0) @ self.m0 + Xk.T @ yk)\n",
    "            resid_term = yk @ yk + self.m0 @ np.linalg.inv(self.V0) @ self.m0 - mn @ Vn_inv @ mn\n",
    "            b_n = self.b0 + 0.5*resid_term\n",
    "        else:\n",
    "            Vn = self.V0.copy()\n",
    "            mn = self.m0.copy()\n",
    "            b_n = self.b0\n",
    "        sig2 = invgamma_sample(self.rng, a_n, b_n)\n",
    "        beta = self.rng.multivariate_normal(mn, sig2 * Vn)\n",
    "        return beta, sig2\n",
    "\n",
    "    def _sample_params(self, Xtil, y, z):\n",
    "        K = self.K_cap\n",
    "        p = Xtil.shape[1]\n",
    "        betas = np.zeros((K, p))\n",
    "        sig2  = np.zeros(K)\n",
    "        for k in range(K):\n",
    "            Ik = (z == k)\n",
    "            Xk = Xtil[Ik]\n",
    "            yk = y[Ik]\n",
    "            beta_k, sig2_k = self._post_beta_sigma2(Xk, yk)\n",
    "            betas[k], sig2[k] = beta_k, sig2_k\n",
    "        return betas, sig2\n",
    "\n",
    "    def _sample_sticks(self, z):\n",
    "        K = self.K_cap\n",
    "        counts = np.bincount(z, minlength=K)\n",
    "        tail = np.flip(np.cumsum(np.flip(counts)))[1:]\n",
    "        v = np.empty(K)\n",
    "        for k in range(K-1):\n",
    "            a = 1.0 + counts[k]\n",
    "            b = self.alpha + tail[k]\n",
    "            v[k] = self.rng.beta(a, b)\n",
    "            v[k] = np.clip(v[k], 1e-6, 1-1e-6)\n",
    "        v[K-1] = 1.0\n",
    "        pi = self._stick_to_weights(v)\n",
    "        return v, pi\n",
    "\n",
    "    def _sample_allocs(self, Xtil, y, betas, sig2, pi):\n",
    "        n = Xtil.shape[0]\n",
    "        K = self.K_cap\n",
    "        \n",
    "        means = Xtil @ betas.T                 \n",
    "        loglik = np.log(pi + 1e-300)[None, :] + normal_logpdf(y[:,None], means, sig2[None,:])\n",
    "        loglik -= logsumexp(loglik, axis=1, keepdims=True)\n",
    "        P = np.exp(loglik)\n",
    "        cum = np.cumsum(P, axis=1)\n",
    "        u = self.rng.random(n)[:, None]\n",
    "        z = (cum < u).sum(axis=1)\n",
    "        return z.astype(np.intp, copy=False)\n",
    "\n",
    "    def fit(self, X, y):\n",
    "        t0 = time.time()\n",
    "        Xtil = add_intercept(X)\n",
    "        z, betas, sig2, v, pi = self._init_from_data(Xtil, y)\n",
    "        self.samples_.clear()\n",
    "\n",
    "        for it in range(self.n_iter):\n",
    "            betas, sig2 = self._sample_params(Xtil, y, z)\n",
    "            v, pi = self._sample_sticks(z)\n",
    "            z = self._sample_allocs(Xtil, y, betas, sig2, pi)\n",
    "\n",
    "            if it >= self.burn and ((it - self.burn) % self.thin == 0):\n",
    "                self.samples_.append({\"betas\": betas.copy(), \"sig2\": sig2.copy(), \"pi\": pi.copy()})\n",
    "\n",
    "        self.elapsed_sec_ = time.time() - t0\n",
    "        return self\n",
    "\n",
    "\n",
    "\n",
    "def log_f_y_given_x_draw(draw, X, y):\n",
    "    \n",
    "    Xtil = add_intercept(X)\n",
    "    betas = draw[\"betas\"]   \n",
    "    sig2  = draw[\"sig2\"]    \n",
    "    pi    = draw[\"pi\"]      \n",
    "    means = Xtil @ betas.T                    \n",
    "    logcomp = np.log(pi + 1e-300)[None,:] + normal_logpdf(y[:,None], means, sig2[None,:])\n",
    "    return logsumexp(logcomp, axis=1)        \n",
    "\n",
    "def log_f_ygrid_given_x_draw(draw, x_star, y_grid):\n",
    "    \n",
    "    x_aug = np.concatenate([[1.0], x_star])  \n",
    "    betas = draw[\"betas\"]\n",
    "    sig2  = draw[\"sig2\"]\n",
    "    pi    = draw[\"pi\"]\n",
    "    means = betas @ x_aug                     \n",
    "    logcomp = np.log(pi + 1e-300)[:,None] + normal_logpdf(y_grid[None,:], means[:,None], sig2[:,None])\n",
    "    return logsumexp(logcomp, axis=0)        \n",
    "\n",
    "# AOI \n",
    "\n",
    "def aoi_scores_from_logs_conditional(log_f_data_cond, log_f_star_cond, tau=1.0):\n",
    "    \n",
    "    if tau <= 0: raise ValueError(\"tau must be > 0\")\n",
    "    T, n = log_f_data_cond.shape\n",
    "    T2, Ny = log_f_star_cond.shape\n",
    "    assert T == T2\n",
    "\n",
    "    \n",
    "    logW = (log_f_star_cond / tau) - logsumexp(log_f_star_cond / tau, axis=0, keepdims=True)  \n",
    "\n",
    "    \n",
    "    s_train_log = np.empty((n, Ny))\n",
    "    for j in range(Ny):\n",
    "        s_train_log[:, j] = logsumexp(log_f_data_cond + logW[:, j][:,None], axis=0)\n",
    "\n",
    "    \n",
    "    s_star_log = logsumexp(logW + log_f_star_cond, axis=0)\n",
    "    return s_train_log, s_star_log\n",
    "\n",
    "def randomized_full_conformal_mask_from_LOG(s_train_log, s_star_log, alpha_conf, rng):\n",
    "    n, Ny = s_train_log.shape\n",
    "    mask = np.zeros(Ny, dtype=bool)\n",
    "    for j in range(Ny):\n",
    "        sj = s_train_log[:, j]\n",
    "        sst = s_star_log[j]\n",
    "        lt = np.sum(sj <  sst)\n",
    "        eq = np.sum(sj == sst)\n",
    "        u  = rng.uniform()\n",
    "        pval = (1.0 + lt + u*eq) / (n + 1.0)\n",
    "        mask[j] = (pval > alpha_conf)\n",
    "    return mask\n",
    "\n",
    "\n",
    "\n",
    "def run_experiment_california_conditional(\n",
    "    n_list=(300, 400, 500),\n",
    "    E=3,\n",
    "    alpha=0.2,\n",
    "    grid_size=400,\n",
    "    m_test=100,\n",
    "    K_cap=25,\n",
    "    alpha_dp=1.0,\n",
    "    a0=2.0, b0=2.0,            \n",
    "    n_iter=1200, burn=600, thin=3,\n",
    "    tau=1.0,\n",
    "    random_seed=2025\n",
    "):\n",
    "    rng_master = default_rng(random_seed)\n",
    "    X, y, _sc = load_california_standardized()\n",
    "    records = []\n",
    "    total_t0 = time.time()\n",
    "\n",
    "    for n in n_list:\n",
    "        cov_list, len_list = [], []\n",
    "        keff_list = []; fit_t = []; conf_t = []\n",
    "        for rep in range(E):\n",
    "            seed_rep = int(rng_master.integers(0, 2**31-1))\n",
    "            Xtr, ytr, Xte, yte = sample_train_test(X, y, n_train=n, m_test=m_test, seed=seed_rep)\n",
    "\n",
    "            \n",
    "            y_grid = np.linspace(ytr.min()-1.0, ytr.max()+1.0, grid_size)\n",
    "            dy = y_grid[1]-y_grid[0]\n",
    "\n",
    "            \n",
    "            model = DPLRMixGibbs(\n",
    "                K_cap=K_cap, alpha=alpha_dp,\n",
    "                a0=a0, b0=b0, m0=None, V0=None,\n",
    "                n_iter=n_iter, burn=burn, thin=thin,\n",
    "                random_state=seed_rep\n",
    "            )\n",
    "            t0 = time.time()\n",
    "            model.fit(Xtr, ytr)\n",
    "            fit_t.append(time.time()-t0)\n",
    "\n",
    "           \n",
    "            Tsave = len(model.samples_)\n",
    "            log_f_data = np.empty((Tsave, n))\n",
    "            for t_idx, draw in enumerate(model.samples_):\n",
    "                log_f_data[t_idx] = log_f_y_given_x_draw(draw, Xtr, ytr)\n",
    "\n",
    "            \n",
    "            t1 = time.time()\n",
    "            rng_rep = default_rng(seed_rep + 123)\n",
    "            cov_flags = []; len_vals = []\n",
    "            for j in range(m_test):\n",
    "                x_star = Xte[j]; y_true = yte[j]\n",
    "                log_f_star = np.empty((Tsave, grid_size))\n",
    "                for t_idx, draw in enumerate(model.samples_):\n",
    "                    log_f_star[t_idx] = log_f_ygrid_given_x_draw(draw, x_star, y_grid)\n",
    "\n",
    "                s_train_log, s_star_log = aoi_scores_from_logs_conditional(log_f_data, log_f_star, tau=tau)\n",
    "                mask = randomized_full_conformal_mask_from_LOG(s_train_log, s_star_log, alpha, rng_rep)\n",
    "\n",
    "                \n",
    "                idx_true = int(np.clip(np.searchsorted(y_grid, y_true), 0, grid_size-1))\n",
    "                cov_flags.append(bool(mask[idx_true]))\n",
    "                len_vals.append(mask.sum() * dy)\n",
    "\n",
    "            conf_t.append(time.time()-t1)\n",
    "\n",
    "            cov_list.append(np.mean(cov_flags))\n",
    "            len_list.append(np.mean(len_vals))\n",
    "            keff_list.append(np.mean([effective_K(d[\"pi\"]) for d in model.samples_]))\n",
    "\n",
    "        rec = {\n",
    "            \"n\": n,\n",
    "            \"coverage_mean\": float(np.mean(cov_list)),\n",
    "            \"coverage_se\": float(np.std(cov_list, ddof=1)/np.sqrt(E)),\n",
    "            \"length_mean\": float(np.mean(len_list)),\n",
    "            \"length_se\": float(np.std(len_list, ddof=1)/np.sqrt(E)),\n",
    "            \"K\": float(np.mean(keff_list)),\n",
    "            \"fit_mean\": float(np.mean(fit_t)),\n",
    "            \"conf_mean\": float(np.mean(conf_t)),\n",
    "            \"elapsed_n\": round(time.time()-total_t0, 2)\n",
    "        }\n",
    "        records.append(rec)\n",
    "        print(f\"[n={n}] cov={rec['coverage_mean']:.3f}±{rec['coverage_se']:.3f} | \"\n",
    "              f\"len={rec['length_mean']:.3f}±{rec['length_se']:.3f} | \"\n",
    "              f\"K_eff~{rec['K']:.1f} | fit~{rec['fit_mean']:.1f}s | conf~{rec['conf_mean']:.1f}s\")\n",
    "\n",
    "    res = pd.DataFrame(records).sort_values(\"n\").reset_index(drop=True)\n",
    "\n",
    "    # Plots\n",
    "    plt.figure()\n",
    "    plt.errorbar(res[\"n\"], res[\"coverage_mean\"], yerr=res[\"coverage_se\"], marker=\"o\", capsize=4)\n",
    "    plt.axhline(0.8, linestyle=\"--\", alpha=0.7)\n",
    "    plt.title(\"CNB on California Data Average Coverage vs n\")\n",
    "    plt.xlabel(\"n\"); plt.ylabel(\"Coverage Rate\")\n",
    "    plt.tight_layout(); plt.show()\n",
    "\n",
    "    plt.figure()\n",
    "    plt.errorbar(res[\"n\"], res[\"length_mean\"], yerr=res[\"length_se\"], marker=\"o\", capsize=4)\n",
    "    plt.title(\"CNB on California Data Average Set Length vs n\")\n",
    "    plt.xlabel(\"n\"); plt.ylabel(\"Average Set Length\")\n",
    "    plt.tight_layout(); plt.show()\n",
    "\n",
    "    return res\n",
    "\n",
    "\n",
    "#run it\n",
    "\n",
    "if __name__ == \"__main__\":\n",
    "    _ = run_experiment_california_conditional(\n",
    "        n_list=(100,300, 600, 1000),\n",
    "        E=10,\n",
    "        alpha=0.2,\n",
    "        grid_size=600,\n",
    "        m_test=100,\n",
    "        K_cap=25,\n",
    "        alpha_dp=1.0,\n",
    "        a0=2.0, b0=2.0,\n",
    "        n_iter=2000, burn=600, thin=3,\n",
    "        tau=1.0,\n",
    "        random_seed=random.randint(1,1000000)\n",
    "    )\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
