{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "15eccf5d-5aff-4dda-9613-e48c48f040ab",
   "metadata": {},
   "outputs": [],
   "source": [
    "# ============================================================\n",
    "# Our method: CNB-split on California Housing — split Conformal case\n",
    "# ============================================================\n",
    "\n",
    "import time\n",
    "import math\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import matplotlib.pyplot as plt\n",
    "import random\n",
    "\n",
    "from sklearn.datasets import fetch_california_housing\n",
    "from sklearn.model_selection import train_test_split\n",
    "from sklearn.mixture import BayesianGaussianMixture\n",
    "from scipy.special import logsumexp\n",
    "\n",
    "\n",
    "\n",
    "def zscore(a, axis=0, eps=1e-12):\n",
    "    m = np.mean(a, axis=axis, keepdims=True)\n",
    "    s = np.std(a, axis=axis, keepdims=True)\n",
    "    return (a - m) / (s + eps), m, s\n",
    "\n",
    "def mvn_logpdf(x, mean, cov, cov_inv=None, logdet=None):\n",
    "    \n",
    "    x = np.atleast_2d(x)\n",
    "    d = mean.shape[0]\n",
    "    if cov_inv is None:\n",
    "        cov_inv = np.linalg.inv(cov)\n",
    "    if logdet is None:\n",
    "        sign, ld = np.linalg.slogdet(cov)\n",
    "        if sign <= 0:\n",
    "            # \n",
    "            ld = np.linalg.slogdet(cov + 1e-10*np.eye(d))[1]\n",
    "        logdet = ld\n",
    "    xc = x - mean\n",
    "    quad = np.einsum('ni,ij,nj->n', xc, cov_inv, xc)\n",
    "    return -0.5 * (d*np.log(2*np.pi) + logdet + quad)\n",
    "\n",
    "def normal_logpdf(y, mean, var):\n",
    "    \n",
    "    return -0.5*(np.log(2*np.pi*var) + (y - mean)**2 / var)\n",
    "\n",
    "\n",
    "\n",
    "class JointDPGMM:\n",
    "    \n",
    "    def __init__(self, K_cap=30, covariance_type=\"full\", max_iter=500, tol=1e-3, random_state=0, reg_covar=1e-6):\n",
    "        self.K_cap = K_cap\n",
    "        self.covariance_type = covariance_type\n",
    "        self.max_iter = max_iter\n",
    "        self.tol = tol\n",
    "        self.random_state = random_state\n",
    "        self.reg_covar = reg_covar\n",
    "        self.model = BayesianGaussianMixture(\n",
    "            n_components=K_cap,\n",
    "            covariance_type=covariance_type,\n",
    "            max_iter=max_iter,\n",
    "            tol=tol,\n",
    "            reg_covar=reg_covar,\n",
    "            weight_concentration_prior_type=\"dirichlet_process\",\n",
    "            random_state=random_state,\n",
    "            init_params=\"kmeans\",\n",
    "            n_init=1\n",
    "        )\n",
    "        self.fitted_ = False\n",
    "\n",
    "    def fit(self, y_train, X_train):\n",
    "        y_train = y_train.reshape(-1, 1)\n",
    "        Z = np.concatenate([y_train, X_train], axis=1)\n",
    "        self.model.fit(Z)\n",
    "        self._precompute_blocks()\n",
    "        self.fitted_ = True\n",
    "        return self\n",
    "\n",
    "    def _precompute_blocks(self):\n",
    "        \n",
    "        K = self.model.weights_.shape[0]\n",
    "        means = self.model.means_          \n",
    "        covs  = self.model.covariances_    \n",
    "\n",
    "        self.pi = self.model.weights_ + 0.0\n",
    "        self.mu_y = means[:, 0]           \n",
    "        self.mu_x = means[:, 1:]           \n",
    "\n",
    "        self.inv_Sxx = []\n",
    "        self.logdet_Sxx = []\n",
    "        self.B = []         \n",
    "        self.s2 = []        \n",
    "\n",
    "        for k in range(K):\n",
    "            S = covs[k]     \n",
    "            Syy = S[0, 0]\n",
    "            Syx = S[0, 1:].reshape(1, -1)    \n",
    "            Sxy = S[1:, 0].reshape(-1, 1)    \n",
    "            Sxx = S[1:, 1:]                  \n",
    "\n",
    "            \n",
    "            Sxx = Sxx + 1e-12*np.eye(Sxx.shape[0])\n",
    "\n",
    "            invSxx = np.linalg.inv(Sxx)\n",
    "            sign, ld = np.linalg.slogdet(Sxx)\n",
    "            if sign <= 0:\n",
    "                ld = np.linalg.slogdet(Sxx + 1e-10*np.eye(Sxx.shape[0]))[1]\n",
    "            Bk = Syx @ invSxx          \n",
    "            s2k = Syy - Syx @ invSxx @ Sxy  \n",
    "\n",
    "            self.inv_Sxx.append(invSxx)\n",
    "            self.logdet_Sxx.append(ld)\n",
    "            self.B.append(Bk.reshape(-1))                \n",
    "            self.s2.append(float(np.maximum(s2k, 1e-12)))\n",
    "\n",
    "        self.inv_Sxx = np.array(self.inv_Sxx, dtype=float)           \n",
    "        self.logdet_Sxx = np.array(self.logdet_Sxx, dtype=float)      \n",
    "        self.B = np.array(self.B, dtype=float)                        \n",
    "        self.s2 = np.array(self.s2, dtype=float)                      \n",
    "        \n",
    "        self.K_eff_ = int(np.sum(self.pi > 1e-3))\n",
    "\n",
    "    def log_pred_density(self, X, y):\n",
    "        \n",
    "        assert self.fitted_\n",
    "        X = np.atleast_2d(X)\n",
    "        y = np.asarray(y).reshape(-1)\n",
    "        n, d = X.shape\n",
    "        K = self.pi.shape[0]\n",
    "\n",
    "        \n",
    "        logw_x = np.empty((n, K))\n",
    "        for k in range(K):\n",
    "            logw_x[:, k] = np.log(self.pi[k] + 1e-300) + mvn_logpdf(\n",
    "                X, self.mu_x[k], cov=None, cov_inv=self.inv_Sxx[k], logdet=self.logdet_Sxx[k]\n",
    "            )\n",
    "        \n",
    "        logw_x_norm = logw_x - logsumexp(logw_x, axis=1, keepdims=True)\n",
    "\n",
    "        \n",
    "        logpdf_yk = np.empty((n, K))\n",
    "        for k in range(K):\n",
    "            mu_cond = self.mu_y[k] + (X - self.mu_x[k]) @ self.B[k]    \n",
    "            logpdf_yk[:, k] = normal_logpdf(y, mu_cond, self.s2[k])\n",
    "\n",
    "        \n",
    "        return logsumexp(logw_x_norm + logpdf_yk, axis=1)\n",
    "\n",
    "    def log_pred_density_grid(self, x_star, y_grid):\n",
    "        \n",
    "        assert self.fitted_\n",
    "        x = x_star.reshape(1, -1)\n",
    "        Ny = len(y_grid)\n",
    "        K = self.pi.shape[0]\n",
    "\n",
    "        \n",
    "        logw = np.empty(K)\n",
    "        for k in range(K):\n",
    "            logw[k] = np.log(self.pi[k] + 1e-300) + mvn_logpdf(\n",
    "                x, self.mu_x[k], cov=None, cov_inv=self.inv_Sxx[k], logdet=self.logdet_Sxx[k]\n",
    "            )[0]\n",
    "        logw_norm = logw - logsumexp(logw)\n",
    "\n",
    "        \n",
    "        comp = np.empty((K, Ny))\n",
    "        for k in range(K):\n",
    "            mu_cond = self.mu_y[k] + (x - self.mu_x[k]) @ self.B[k]   \n",
    "            comp[k, :] = normal_logpdf(y_grid, float(mu_cond), self.s2[k])\n",
    "        return logsumexp(logw_norm[:, None] + comp, axis=0)\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "def randomized_pvalue_from_calibration(s_star, s_cal, rng):\n",
    "    \n",
    "    s_cal = np.asarray(s_cal)\n",
    "    m = s_cal.size\n",
    "    lt = np.sum(s_cal < s_star)\n",
    "    eq = np.sum(s_cal == s_star)\n",
    "    u = rng.uniform()\n",
    "    return (1.0 + lt + u*eq) / (m + 1.0)\n",
    "\n",
    "def conformal_region_mask_grid(logp_star_grid, logp_cal, alpha, rng):\n",
    "    \n",
    "    Ny = len(logp_star_grid)\n",
    "    mask = np.zeros(Ny, dtype=bool)\n",
    "    for j in range(Ny):\n",
    "        p = randomized_pvalue_from_calibration(logp_star_grid[j], logp_cal, rng)\n",
    "        mask[j] = (p > alpha)\n",
    "    return mask\n",
    "\n",
    "\n",
    "\n",
    "def load_california_standardized(random_state=0):\n",
    "    X, y = fetch_california_housing(return_X_y=True)\n",
    "    # \n",
    "    Xz, Xm, Xs = zscore(X, axis=0)\n",
    "    yz, ym, ys = zscore(y, axis=0)\n",
    "    return Xz, yz, (Xm, Xs, ym, ys)\n",
    "\n",
    "def sample_splits(X, y, n_total, m_test, train_frac=0.7, seed=0):\n",
    "    \n",
    "    rng = np.random.default_rng(seed)\n",
    "    N = X.shape[0]\n",
    "    idx = rng.choice(N, size=n_total + m_test, replace=False)\n",
    "    idx_traincal = idx[:n_total]\n",
    "    idx_test = idx[n_total:]\n",
    "\n",
    "    X_traincal = X[idx_traincal]\n",
    "    y_traincal = y[idx_traincal]\n",
    "    X_test = X[idx_test]\n",
    "    y_test = y[idx_test]\n",
    "\n",
    "    n_train = int(round(train_frac * n_total))\n",
    "    \n",
    "    perm = rng.permutation(n_total)\n",
    "    tr_idx = perm[:n_train]\n",
    "    cal_idx = perm[n_train:]\n",
    "\n",
    "    X_train = X_traincal[tr_idx]\n",
    "    y_train = y_traincal[tr_idx]\n",
    "    X_cal   = X_traincal[cal_idx]\n",
    "    y_cal   = y_traincal[cal_idx]\n",
    "\n",
    "    return X_train, y_train, X_cal, y_cal, X_test, y_test\n",
    "\n",
    "\n",
    "\n",
    "def run_experiment1_california(\n",
    "    n_list=(100, 200, 300, 600, 1000),\n",
    "    E=10,\n",
    "    alpha=0.2,\n",
    "    grid_size=500,\n",
    "    m_test=200,\n",
    "    K_cap=30,\n",
    "    covariance_type=\"full\",\n",
    "    max_iter=500,\n",
    "    train_frac=0.7,\n",
    "    random_seed=2025\n",
    "):\n",
    "   \n",
    "    rng_master = np.random.default_rng(random_seed)\n",
    "    X, y, _sc = load_california_standardized()\n",
    "    d = X.shape[1]\n",
    "\n",
    "    records = []\n",
    "    total_t0 = time.time()\n",
    "\n",
    "    for n in n_list:\n",
    "        t_n0 = time.time()\n",
    "        cov_list = []\n",
    "        len_list = []\n",
    "        Keff_list = []\n",
    "\n",
    "        for rep in range(E):\n",
    "            seed_rep = int(rng_master.integers(0, 2**31-1))\n",
    "            \n",
    "            Xtr, ytr, Xcal, ycal, Xte, yte = sample_splits(\n",
    "                X, y, n_total=n, m_test=m_test, train_frac=train_frac, seed=seed_rep\n",
    "            )\n",
    "\n",
    "            \n",
    "            y_all = np.concatenate([ytr, ycal])\n",
    "            y_grid = np.linspace(y_all.min()-1.0, y_all.max()+1.0, grid_size)\n",
    "            dy = y_grid[1] - y_grid[0]\n",
    "\n",
    "            \n",
    "            model = JointDPGMM(\n",
    "                K_cap=K_cap,\n",
    "                covariance_type=covariance_type,\n",
    "                max_iter=max_iter,\n",
    "                tol=1e-3,\n",
    "                random_state=seed_rep\n",
    "            ).fit(ytr, Xtr)\n",
    "            Keff_list.append(model.K_eff_)\n",
    "\n",
    "            \n",
    "            logp_cal = model.log_pred_density(Xcal, ycal)\n",
    "\n",
    "            \n",
    "            cov_flags = []\n",
    "            lengths = []\n",
    "            rng_rep = np.random.default_rng(seed_rep + 123)\n",
    "\n",
    "            for j in range(len(yte)):\n",
    "                x_star = Xte[j]\n",
    "                y_true = yte[j]\n",
    "\n",
    "                \n",
    "                logp_star_grid = model.log_pred_density_grid(x_star, y_grid)\n",
    "\n",
    "                #\n",
    "                mask = conformal_region_mask_grid(logp_star_grid, logp_cal, alpha, rng_rep)\n",
    "\n",
    "                \n",
    "                idx_true = int(np.clip(np.searchsorted(y_grid, y_true), 0, grid_size-1))\n",
    "                coverage = bool(mask[idx_true])\n",
    "                cov_flags.append(coverage)\n",
    "                lengths.append(mask.sum() * dy)\n",
    "\n",
    "            cov_list.append(np.mean(cov_flags))\n",
    "            len_list.append(np.mean(lengths))\n",
    "\n",
    "        elapsed_n = time.time() - t_n0\n",
    "        cov_arr = np.array(cov_list)\n",
    "        len_arr = np.array(len_list)\n",
    "        Keff_arr = np.array(Keff_list)\n",
    "\n",
    "        records.append({\n",
    "            \"n\": n,\n",
    "            \"coverage_average\": cov_arr.mean(),\n",
    "            \"coverage_se\": cov_arr.std(ddof=1)/np.sqrt(E),\n",
    "            \"length_average\": len_arr.mean(),\n",
    "            \"length_se\": len_arr.std(ddof=1)/np.sqrt(E),\n",
    "            \"K\": Keff_arr.mean(),\n",
    "            \"time\": round(elapsed_n, 2)\n",
    "        })\n",
    "\n",
    "        print(f\"[n={n}] cov={cov_arr.mean():.3f}±{cov_arr.std(ddof=1)/np.sqrt(E):.3f} | \"\n",
    "              f\"len={len_arr.mean():.3f}±{len_arr.std(ddof=1)/np.sqrt(E):.3f} | \"\n",
    "              f\"K_eff~{Keff_arr.mean():.1f} | time={elapsed_n:.1f}s\")\n",
    "\n",
    "    total_elapsed = time.time() - total_t0\n",
    "    res = pd.DataFrame(records).sort_values(\"n\").reset_index(drop=True)\n",
    " \n",
    "\n",
    "    \n",
    "    plt.figure()\n",
    "    plt.errorbar(res[\"n\"], res[\"coverage_mean\"], yerr=res[\"coverage_se\"],\n",
    "                 marker=\"o\", capsize=4)\n",
    "    plt.axhline(0.8, linestyle=\"--\")\n",
    "    plt.title(\"CNB Split on California Data Average Coverage vs n\")\n",
    "    plt.xlabel(\"n\")\n",
    "    plt.ylabel(\"Coverage Rate\")\n",
    "    plt.tight_layout()\n",
    "    plt.show()\n",
    "\n",
    "    plt.figure()\n",
    "    plt.errorbar(res[\"n\"], res[\"length_mean\"], yerr=res[\"length_se\"],\n",
    "                 marker=\"o\", capsize=4)\n",
    "    plt.title(\"CNB Split on Simulation Data Average Set length vs n\")\n",
    "    plt.xlabel(\"n\")\n",
    "    plt.ylabel(\"Average Set Length\")\n",
    "    plt.tight_layout()\n",
    "    plt.show()\n",
    "\n",
    "    return res\n",
    "\n",
    "\n",
    "\n",
    "if __name__ == \"__main__\":\n",
    "    _ = run_experiment1_california(\n",
    "        n_list=(100, 300, 600, 1000),\n",
    "        E=10,\n",
    "        alpha=0.2,\n",
    "        grid_size=600,\n",
    "        m_test=100,\n",
    "        K_cap=30,\n",
    "        covariance_type=\"full\",\n",
    "        max_iter=2000,\n",
    "        train_frac=0.7,\n",
    "        random_seed=random.randint(1,1000000)\n",
    "    )\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "aa09412c-53aa-464f-b6f1-d773a36b1800",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f655fa7b-329c-41a6-87b2-9d4ecfb0df8c",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
