{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bd804a71-2745-484b-8226-13a2232b5a65",
   "metadata": {},
   "outputs": [],
   "source": [
    "# ============================================================\n",
    "# Our method: CNB on simulation data — Full Conformal case\n",
    "# Conditional DP mixture with joint densities\n",
    "# ============================================================\n",
    "\n",
    "import time\n",
    "from dataclasses import dataclass\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import matplotlib.pyplot as plt\n",
    "from numpy.random import default_rng\n",
    "from scipy.special import logsumexp\n",
    "import random\n",
    "\n",
    "\n",
    "\n",
    "@dataclass\n",
    "class GenParams:\n",
    "    a0: float; a1: float; a2: float\n",
    "    b0: float; b1: float; b2: float\n",
    "    c0: float; c1: float; c2: float\n",
    "    sigmas: np.ndarray  \n",
    "    weights: np.ndarray \n",
    "\n",
    "def draw_generator_params(seed: int = 2025) -> GenParams:\n",
    "    rng = default_rng(seed)\n",
    "    a0 = rng.uniform(-4.0, -2.0); a1 = rng.uniform(0.8, 1.5); a2 = rng.uniform(0.5, 1.0)\n",
    "    b0 = rng.uniform(-0.5, 0.5);  b1 = rng.uniform(0.5, 1.2); b2 = rng.uniform(0.4, 1.0)\n",
    "    c0 = rng.uniform( 2.0, 4.0);  c1 = rng.uniform(-1.0,-0.3); c2 = rng.uniform(0.8, 1.5)\n",
    "    sigmas = np.array([0.5, 0.5, 0.5], float)\n",
    "    weights = np.array([1/3, 1/3, 1/3], float)\n",
    "    return GenParams(a0,a1,a2,b0,b1,b2,c0,c1,c2,sigmas,weights)\n",
    "\n",
    "def sample_from_generator(n: int, P: GenParams, seed: int) -> pd.DataFrame:\n",
    "    rng = default_rng(seed)\n",
    "    X1 = rng.uniform(-1, 2, size=n)\n",
    "    X2 = rng.normal(1.0, 0.5, size=n)\n",
    "    comp = rng.choice(3, size=n, p=P.weights)\n",
    "    mu1 = P.a0 + P.a1*X1 + P.a2*(X2**2)\n",
    "    mu2 = P.b0 + P.b1*(X1**2) + P.b2*(X1*X2)\n",
    "    mu3 = P.c0 + P.c1*(X1**2) + P.c2*np.cos(2*X2)\n",
    "    MU  = np.stack([mu1, mu2, mu3], axis=1)\n",
    "    Y = rng.normal(loc=MU[np.arange(n), comp], scale=P.sigmas[comp])\n",
    "    return pd.DataFrame({\"X1\": X1, \"X2\": X2, \"Y\": Y})\n",
    "\n",
    "\n",
    "\n",
    "def normal_logpdf(y, mean, var, eps=1e-32):\n",
    "    var = np.maximum(var, eps)\n",
    "    return -0.5*(np.log(2*np.pi*var) + (y-mean)**2/var)\n",
    "\n",
    "def invgamma_sample(rng, a, b, size=None):\n",
    "    g = rng.gamma(shape=a, scale=1.0/np.maximum(b, 1e-30), size=size)\n",
    "    return 1.0/np.maximum(g, 1e-30)\n",
    "\n",
    "def add_intercept(X):\n",
    "    return np.concatenate([np.ones((X.shape[0], 1)), X], axis=1)\n",
    "\n",
    "def effective_K(pi, thr=1e-3):\n",
    "    return int(np.sum(pi > thr))\n",
    "\n",
    "\n",
    "\n",
    "class DPLRMixGibbs:\n",
    "    def __init__(self, K_cap=20, alpha=1.0,\n",
    "                 a0=2.0, b0=2.0,\n",
    "                 m0=None, V0=None,\n",
    "                 n_iter=800, burn=300, thin=5,\n",
    "                 random_state=0):\n",
    "        self.K_cap = int(K_cap)\n",
    "        self.alpha = float(alpha)\n",
    "        self.a0 = float(a0)\n",
    "        self.b0 = float(b0)\n",
    "        self.m0 = m0\n",
    "        self.V0 = V0\n",
    "        self.n_iter = int(n_iter)\n",
    "        self.burn = int(burn)\n",
    "        self.thin = int(thin)\n",
    "        self.rng = default_rng(random_state)\n",
    "        self.samples_ = []\n",
    "        self.elapsed_sec_ = 0.0\n",
    "\n",
    "    def _init_from_data(self, Xtil, y):\n",
    "        n, p = Xtil.shape\n",
    "        if self.m0 is None:\n",
    "            self.m0 = np.zeros(p)\n",
    "        if self.V0 is None:\n",
    "            self.V0 = np.eye(p) * 10.0\n",
    "        z = self.rng.integers(0, self.K_cap, size=n, dtype=np.intp)\n",
    "        betas = self.rng.normal(0, 1, size=(self.K_cap, p))\n",
    "        sig2  = np.ones(self.K_cap)\n",
    "        v = np.clip(self.rng.beta(1.0, self.alpha, size=self.K_cap-1), 1e-6, 1-1e-6)\n",
    "        v = np.concatenate([v, [1.0]])\n",
    "        pi = self._stick_to_weights(v)\n",
    "        return z, betas, sig2, v, pi\n",
    "\n",
    "    @staticmethod\n",
    "    def _stick_to_weights(v):\n",
    "        K = v.shape[0]\n",
    "        pi = np.empty(K)\n",
    "        prod = 1.0\n",
    "        for k in range(K):\n",
    "            pi[k] = v[k]*prod\n",
    "            if k < K-1:\n",
    "                prod *= (1.0 - v[k])\n",
    "        pi = np.clip(pi, 1e-300, 1.0)\n",
    "        return pi / pi.sum()\n",
    "\n",
    "    def _post_beta_sigma2(self, Xk, yk):\n",
    "        p = self.m0.shape[0]\n",
    "        a_n = self.a0 + 0.5*len(yk)\n",
    "        if len(yk) > 0:\n",
    "            XtX = Xk.T @ Xk\n",
    "            Vn_inv = np.linalg.inv(self.V0) + XtX\n",
    "            Vn = np.linalg.inv(Vn_inv)\n",
    "            mn = Vn @ (np.linalg.inv(self.V0) @ self.m0 + Xk.T @ yk)\n",
    "            resid_term = yk @ yk + self.m0 @ np.linalg.inv(self.V0) @ self.m0 - mn @ Vn_inv @ mn\n",
    "            b_n = self.b0 + 0.5*resid_term\n",
    "        else:\n",
    "            Vn = self.V0.copy()\n",
    "            mn = self.m0.copy()\n",
    "            b_n = self.b0\n",
    "        sig2 = invgamma_sample(self.rng, a_n, b_n)\n",
    "        beta = self.rng.multivariate_normal(mn, sig2 * Vn)\n",
    "        return beta, sig2\n",
    "\n",
    "    def _sample_params(self, Xtil, y, z):\n",
    "        K = self.K_cap\n",
    "        p = Xtil.shape[1]\n",
    "        betas = np.zeros((K, p))\n",
    "        sig2  = np.zeros(K)\n",
    "        for k in range(K):\n",
    "            Ik = (z == k)\n",
    "            Xk = Xtil[Ik]\n",
    "            yk = y[Ik]\n",
    "            beta_k, sig2_k = self._post_beta_sigma2(Xk, yk)\n",
    "            betas[k], sig2[k] = beta_k, sig2_k\n",
    "        return betas, sig2\n",
    "\n",
    "    def _sample_sticks(self, z):\n",
    "        K = self.K_cap\n",
    "        counts = np.bincount(z, minlength=K)\n",
    "        tail = np.flip(np.cumsum(np.flip(counts)))[1:]\n",
    "        v = np.empty(K)\n",
    "        for k in range(K-1):\n",
    "            a = 1.0 + counts[k]\n",
    "            b = self.alpha + tail[k]\n",
    "            v[k] = self.rng.beta(a, b)\n",
    "            v[k] = np.clip(v[k], 1e-6, 1-1e-6)\n",
    "        v[K-1] = 1.0\n",
    "        pi = self._stick_to_weights(v)\n",
    "        return v, pi\n",
    "\n",
    "    def _sample_allocs(self, Xtil, y, betas, sig2, pi):\n",
    "        n = Xtil.shape[0]\n",
    "        K = self.K_cap\n",
    "        means = Xtil @ betas.T\n",
    "        loglik = np.log(pi + 1e-300)[None,:] + normal_logpdf(y[:,None], means, sig2[None,:])\n",
    "        loglik -= logsumexp(loglik, axis=1, keepdims=True)\n",
    "        P = np.exp(loglik)\n",
    "        cum = np.cumsum(P, axis=1)\n",
    "        u = self.rng.random(n)[:, None]\n",
    "        z = (cum < u).sum(axis=1)\n",
    "        return z.astype(np.intp, copy=False)\n",
    "\n",
    "    def fit(self, X, y):\n",
    "        t0 = time.time()\n",
    "        Xtil = add_intercept(X)\n",
    "        z, betas, sig2, v, pi = self._init_from_data(Xtil, y)\n",
    "        self.samples_.clear()\n",
    "\n",
    "        for it in range(self.n_iter):\n",
    "            betas, sig2 = self._sample_params(Xtil, y, z)\n",
    "            v, pi = self._sample_sticks(z)\n",
    "            z = self._sample_allocs(Xtil, y, betas, sig2, pi)\n",
    "            if it >= self.burn and ((it - self.burn) % self.thin == 0):\n",
    "                self.samples_.append({\"betas\": betas.copy(), \"sig2\": sig2.copy(), \"pi\": pi.copy()})\n",
    "\n",
    "        self.elapsed_sec_ = time.time() - t0\n",
    "        return self\n",
    "\n",
    "\n",
    "\n",
    "def log_f_y_given_x_draw(draw, X, y):\n",
    "    Xtil = add_intercept(X)\n",
    "    betas = draw[\"betas\"]; sig2 = draw[\"sig2\"]; pi = draw[\"pi\"]\n",
    "    means = Xtil @ betas.T\n",
    "    logcomp = np.log(pi + 1e-300)[None,:] + normal_logpdf(y[:,None], means, sig2[None,:])\n",
    "    return logsumexp(logcomp, axis=1)\n",
    "\n",
    "def log_f_ygrid_given_x_draw(draw, x_star, y_grid):\n",
    "    x_aug = np.array([1.0, *x_star])  # (p,)\n",
    "    betas = draw[\"betas\"]; sig2 = draw[\"sig2\"]; pi = draw[\"pi\"]\n",
    "    means = betas @ x_aug\n",
    "    logcomp = np.log(pi + 1e-300)[:,None] + normal_logpdf(y_grid[None,:], means[:,None], sig2[:,None])\n",
    "    return logsumexp(logcomp, axis=0)\n",
    "\n",
    "def aoi_scores_from_logs_conditional(log_f_data_cond, log_f_star_cond, tau=1.0):\n",
    "    if tau <= 0: raise ValueError(\"tau must be > 0\")\n",
    "    T, n = log_f_data_cond.shape\n",
    "    T2, Ny = log_f_star_cond.shape\n",
    "    assert T == T2\n",
    "    logW = (log_f_star_cond / tau) - logsumexp(log_f_star_cond / tau, axis=0, keepdims=True)\n",
    "    s_train_log = np.empty((n, Ny))\n",
    "    for j in range(Ny):\n",
    "        s_train_log[:, j] = logsumexp(log_f_data_cond + logW[:, j][:,None], axis=0)\n",
    "    s_star_log = logsumexp(logW + log_f_star_cond, axis=0)\n",
    "    return s_train_log, s_star_log\n",
    "\n",
    "def randomized_full_conformal_mask_from_LOG(s_train_log, s_star_log, alpha_conf, rng):\n",
    "    n, Ny = s_train_log.shape\n",
    "    mask = np.zeros(Ny, dtype=bool)\n",
    "    for j in range(Ny):\n",
    "        sj = s_train_log[:, j]\n",
    "        sst = s_star_log[j]\n",
    "        lt = np.sum(sj <  sst)\n",
    "        eq = np.sum(sj == sst)\n",
    "        u  = rng.uniform()\n",
    "        pval = (1.0 + lt + u*eq) / (n + 1.0)\n",
    "        mask[j] = (pval > alpha_conf)\n",
    "    return mask\n",
    "\n",
    "\n",
    "\n",
    "def run_experiment_sim_conditional(\n",
    "    n_list=(100, 300, 600, 1000),\n",
    "    E=10, alpha=0.2,\n",
    "    grid_size=600, m_test=100,\n",
    "    K_cap=25, alpha_dp=1.0,\n",
    "    a0=2.0, b0=2.0,\n",
    "    n_iter=1200, burn=600, thin=3,\n",
    "    tau=1.0,\n",
    "    seed0=2025\n",
    "):\n",
    "    rng = default_rng(seed0)\n",
    "    P = draw_generator_params(seed0)\n",
    "\n",
    "    \n",
    "    pilot = sample_from_generator(20000, P, seed0+7)\n",
    "    y_grid = np.linspace(pilot[\"Y\"].min()-1.0, pilot[\"Y\"].max()+1.0, grid_size)\n",
    "    dy = y_grid[1]-y_grid[0]\n",
    "\n",
    "    rows = []; t0 = time.time()\n",
    "\n",
    "    for n in n_list:\n",
    "        covs, lens, K_effs, fit_t, conf_t = [], [], [], [], []\n",
    "        for rep in range(E):\n",
    "            seed_rep = int(rng.integers(0, 2**31-1))\n",
    "            Dtr = sample_from_generator(n, P, seed_rep)\n",
    "            Xtr = Dtr[[\"X1\",\"X2\"]].to_numpy(); ytr = Dtr[\"Y\"].to_numpy()\n",
    "\n",
    "            model = DPLRMixGibbs(\n",
    "                K_cap=K_cap, alpha=alpha_dp,\n",
    "                a0=a0, b0=b0, m0=None, V0=None,\n",
    "                n_iter=n_iter, burn=burn, thin=thin,\n",
    "                random_state=seed_rep\n",
    "            )\n",
    "            t_fit0 = time.time()\n",
    "            model.fit(Xtr, ytr)\n",
    "            fit_t.append(time.time()-t_fit0)\n",
    "\n",
    "            Tsave = len(model.samples_)\n",
    "            log_f_data = np.empty((Tsave, n))\n",
    "            for t_idx, draw in enumerate(model.samples_):\n",
    "                log_f_data[t_idx] = log_f_y_given_x_draw(draw, Xtr, ytr)\n",
    "\n",
    "            Dte = sample_from_generator(m_test, P, seed_rep+3)\n",
    "            Xte = Dte[[\"X1\",\"X2\"]].to_numpy(); yte = Dte[\"Y\"].to_numpy()\n",
    "\n",
    "            t_conf0 = time.time()\n",
    "            rng_rep = default_rng(seed_rep + 123)\n",
    "            for j in range(m_test):\n",
    "                x_star = Xte[j]; y_true = yte[j]\n",
    "                log_f_star = np.empty((Tsave, grid_size))\n",
    "                for t_idx, draw in enumerate(model.samples_):\n",
    "                    log_f_star[t_idx] = log_f_ygrid_given_x_draw(draw, x_star, y_grid)\n",
    "\n",
    "                s_train_log, s_star_log = aoi_scores_from_logs_conditional(log_f_data, log_f_star, tau=tau)\n",
    "                mask = randomized_full_conformal_mask_from_LOG(s_train_log, s_star_log, alpha, rng_rep)\n",
    "\n",
    "                idx_true = int(np.clip(np.searchsorted(y_grid, y_true), 0, grid_size-1))\n",
    "                covs.append(bool(mask[idx_true]))\n",
    "                lens.append(mask.sum()*dy)\n",
    "\n",
    "            conf_t.append(time.time()-t_conf0)\n",
    "            K_effs.append(np.mean([effective_K(d[\"pi\"]) for d in model.samples_]))\n",
    "\n",
    "        m = len(covs)\n",
    "        rows.append({\n",
    "            \"n\": n,\n",
    "            \"coverage_mean\": float(np.mean(covs)),\n",
    "            \"coverage_se\": float(np.std(covs, ddof=1)/np.sqrt(E)),\n",
    "            \"length_mean\": float(np.mean(lens)),\n",
    "            \"length_se\": float(np.std(lens, ddof=1)/np.sqrt(E)),\n",
    "            \"K\": float(np.mean(K_effs)),\n",
    "            \"time\": float(np.mean(fit_t)),\n",
    "            \"time_mean\": float(np.mean(conf_t)),\n",
    "            \"all_n\": round(time.time()-t0, 2)\n",
    "        })\n",
    "        print(f\"[n={n}] cov={rows[-1]['coverage_mean']:.3f}±{rows[-1]['coverage_se']:.3f} | \"\n",
    "              f\"len={rows[-1]['length_mean']:.3f}±{rows[-1]['length_se']:.3f} | \"\n",
    "              f\"K_eff~{rows[-1]['K']:.1f} | fit~{rows[-1]['time']:.1f}s | conf~{rows[-1]['time_mean']:.1f}s\")\n",
    "\n",
    "    res = pd.DataFrame(rows).sort_values(\"n\").reset_index(drop=True)\n",
    "\n",
    "    # Plots\n",
    "    plt.figure()\n",
    "    plt.errorbar(res[\"n\"], res[\"coverage_mean\"], yerr=res[\"coverage_se\"], marker=\"o\", capsize=3)\n",
    "    plt.axhline(0.8, linestyle=\"--\")\n",
    "    plt.title(\"CNB on Simulation Data Average Coverage vs n\")\n",
    "    plt.xlabel(\"n\"); plt.ylabel(\"Coverage Rate\"); plt.show()\n",
    "\n",
    "    plt.figure()\n",
    "    plt.errorbar(res[\"n\"], res[\"length_mean\"], yerr=res[\"length_se\"], marker=\"o\", capsize=3)\n",
    "    plt.title(\"CNB on Simulation Data Average Set Length vs n\")\n",
    "    plt.xlabel(\"n\"); plt.ylabel(\"Average Set Length\"); plt.show()\n",
    "\n",
    "    return res, y_grid, time.time()-t0\n",
    "\n",
    "\n",
    "# run it here\n",
    "\n",
    "if __name__ == \"__main__\":\n",
    "    res, y_grid, total_time = run_experiment_sim_conditional(\n",
    "        n_list=(100, 300, 600, 1000),\n",
    "        E=10, alpha=0.2,\n",
    "        grid_size=600, m_test=100,\n",
    "        K_cap=35, alpha_dp=1.0,\n",
    "        a0=2.0, b0=2.0,\n",
    "        n_iter=2000, burn=600, thin=3,\n",
    "        tau=1.0,\n",
    "        seed0=random.randint(1,1000000)\n",
    "    )\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e160cb4d-5cb4-415c-84a2-0c0a0f51cfa7",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
