{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "eadbd038-f716-44bc-ac5e-03bb7412978d",
   "metadata": {},
   "outputs": [],
   "source": [
    "# ============================================================\n",
    "# Our method: CNB-split on simulation data — split Conformal case\n",
    "# ============================================================\n",
    "\n",
    "import math\n",
    "import time\n",
    "import random\n",
    "from dataclasses import dataclass\n",
    "from pathlib import Path\n",
    "\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import matplotlib.pyplot as plt\n",
    "from sklearn.mixture import BayesianGaussianMixture\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "@dataclass\n",
    "class GenParams:\n",
    "    a0: float; a1: float; a2: float\n",
    "    b0: float; b1: float; b2: float\n",
    "    c0: float; c1: float; c2: float\n",
    "    sigmas: np.ndarray  \n",
    "    weights: np.ndarray \n",
    "\n",
    "def draw_generator_params(seed: int = 2025) -> GenParams:\n",
    "    rng = np.random.default_rng(seed)\n",
    "    # \n",
    "    a0 = rng.uniform(-4.0, -2.0); a1 = rng.uniform(0.8, 1.5); a2 = rng.uniform(0.5, 1.0)\n",
    "    b0 = rng.uniform(-0.5, 0.5);  b1 = rng.uniform(0.5, 1.2); b2 = rng.uniform(0.4, 1.0)\n",
    "    c0 = rng.uniform( 2.0, 4.0);  c1 = rng.uniform(-1.0,-0.3); c2 = rng.uniform(0.8, 1.5)\n",
    "    sigmas = np.array([0.5, 0.5, 0.5], float)\n",
    "    weights = np.array([1/3, 1/3, 1/3], float)\n",
    "    return GenParams(a0,a1,a2,b0,b1,b2,c0,c1,c2,sigmas,weights)\n",
    "\n",
    "def sample_from_generator(n: int, params: GenParams, seed: int) -> pd.DataFrame:\n",
    "    rng = np.random.default_rng(seed)\n",
    "    X1 = rng.uniform(-1, 2, size=n)\n",
    "    X2 = rng.normal(1.0, 0.5, size=n)\n",
    "    comp = rng.choice(3, size=n, p=params.weights)\n",
    "\n",
    "    mu1 = params.a0 + params.a1*X1 + params.a2*(X2**2)\n",
    "    mu2 = params.b0 + params.b1*(X1**2) + params.b2*(X1*X2)\n",
    "    mu3 = params.c0 + params.c1*(X1**2) + params.c2*np.cos(2*X2)\n",
    "    MU  = np.stack([mu1, mu2, mu3], axis=1)\n",
    "\n",
    "    Y = rng.normal(loc=MU[np.arange(n), comp], scale=params.sigmas[comp])\n",
    "    return pd.DataFrame({\"X1\": X1, \"X2\": X2, \"Y\": Y})\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "@dataclass\n",
    "class Standardizer:\n",
    "    mu: np.ndarray  \n",
    "    sd: np.ndarray  \n",
    "\n",
    "def fit_standardizer(Z_train: np.ndarray) -> Standardizer:\n",
    "    mu = Z_train.mean(axis=0)\n",
    "    sd = Z_train.std(axis=0, ddof=0)\n",
    "    sd = np.where(sd <= 1e-12, 1.0, sd)\n",
    "    return Standardizer(mu=mu, sd=sd)\n",
    "\n",
    "def transform_Z(Z: np.ndarray, S: Standardizer) -> np.ndarray:\n",
    "    return (Z - S.mu) / S.sd\n",
    "\n",
    "def transform_yx(y: float, x: np.ndarray, S: Standardizer):\n",
    "    y_s = (y - S.mu[0]) / S.sd[0]\n",
    "    x_s = (x - S.mu[1:]) / S.sd[1:]\n",
    "    return float(y_s), x_s\n",
    "\n",
    "def transform_ygrid(y_grid: np.ndarray, S: Standardizer) -> np.ndarray:\n",
    "    return (y_grid - S.mu[0]) / S.sd[0]\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "def precompute_cache(weights_: np.ndarray, means_: np.ndarray, covs_: np.ndarray):\n",
    "    # \n",
    "    w  = weights_.copy()              \n",
    "    mu = means_.copy()                \n",
    "    S  = covs_.copy()                 \n",
    "\n",
    "    mu_y = mu[:, 0]                   \n",
    "    mu_x = mu[:, 1:3]                 \n",
    "    S_yy = S[:, 0, 0]                 \n",
    "    S_yx = S[:, 0, 1:3]               \n",
    "    S_xy = S[:, 1:3, 0]               \n",
    "    S_xx = S[:, 1:3, 1:3]             \n",
    "\n",
    "    K = S_xx.shape[0]\n",
    "    inv_S_xx = np.empty_like(S_xx)\n",
    "    logdet   = np.empty(K, float)\n",
    "    for k in range(K):\n",
    "        Sxx = S_xx[k] + 1e-10 * np.eye(2)\n",
    "        inv_S_xx[k] = np.linalg.inv(Sxx)\n",
    "        sign, ld = np.linalg.slogdet(Sxx)\n",
    "        if sign <= 0:\n",
    "            Sxx = S_xx[k] + 1e-8 * np.eye(2)\n",
    "            inv_S_xx[k] = np.linalg.inv(Sxx)\n",
    "            _, ld = np.linalg.slogdet(Sxx)\n",
    "        logdet[k] = ld\n",
    "\n",
    "    \n",
    "    A = np.einsum(\"ki,kij->kj\", S_yx, inv_S_xx)                      \n",
    "    cond_var = S_yy - np.einsum(\"ki,kij,kj->k\", S_yx, inv_S_xx, S_xy) \n",
    "    cond_var = np.maximum(cond_var, 1e-10)\n",
    "    sd    = np.sqrt(cond_var)\n",
    "\n",
    "    \n",
    "    const_x = -0.5 * (2*np.log(2*np.pi) + logdet)\n",
    "\n",
    "    \n",
    "    log_coeff_y = -0.5*np.log(2*np.pi) - np.log(sd)\n",
    "\n",
    "    return {\n",
    "        \"w\": w, \"mu_y\": mu_y, \"mu_x\": mu_x, \"A\": A,\n",
    "        \"inv_S_xx\": inv_S_xx, \"const_x\": const_x,\n",
    "        \"cond_var\": cond_var, \"sd\": sd, \"log_coeff_y\": log_coeff_y\n",
    "    }\n",
    "\n",
    "def cond_means_x(x: np.ndarray, C) -> np.ndarray:\n",
    "    diff = x[None, :] - C[\"mu_x\"]            \n",
    "    return C[\"mu_y\"] + np.einsum(\"kj,kj->k\", C[\"A\"], diff)  \n",
    "\n",
    "def log_gamma_x(x: np.ndarray, C) -> np.ndarray:\n",
    "    diff = x[None, :] - C[\"mu_x\"]                                  \n",
    "    quad = np.einsum(\"ki,kij,kj->k\", diff, C[\"inv_S_xx\"], diff)    \n",
    "    return np.log(C[\"w\"] + 1e-300) + C[\"const_x\"] - 0.5*quad       \n",
    "\n",
    "def log_norm_y_given_x(y: float, x: np.ndarray, C) -> np.ndarray:\n",
    "    mks = cond_means_x(x, C)                                       \n",
    "    z = (y - mks) / C[\"sd\"]                                       \n",
    "    return C[\"log_coeff_y\"] - 0.5 * z * z                          \n",
    "\n",
    "def log_dens_point_y_given_x(y: float, x: np.ndarray, C) -> float:\n",
    "    lg = log_gamma_x(x, C)                       \n",
    "    ln = log_norm_y_given_x(y, x, C)            \n",
    "    a = lg + ln\n",
    "    m1 = np.max(a); m2 = np.max(lg)\n",
    "  \n",
    "    return float(np.log(np.sum(np.exp(a - m1))) + m1\n",
    "                 - (np.log(np.sum(np.exp(lg - m2))) + m2))\n",
    "\n",
    "def log_dens_grid_y_given_x(y_grid: np.ndarray, x: np.ndarray, C) -> np.ndarray:\n",
    "    lg = log_gamma_x(x, C)                       \n",
    "    m2 = np.max(lg); logZx = np.log(np.sum(np.exp(lg - m2))) + m2\n",
    "    mks = cond_means_x(x, C)                     \n",
    "    Z = (y_grid[None, :] - mks[:, None]) / C[\"sd\"][:, None]  \n",
    "    ln = C[\"log_coeff_y\"][:, None] - 0.5 * Z * Z            \n",
    "    a = lg[:, None] + ln\n",
    "    m1 = np.max(a, axis=0)\n",
    "    return (np.log(np.sum(np.exp(a - m1), axis=0)) + m1) - logZx   \n",
    "\n",
    "\n",
    "\n",
    "\n",
    "def run_experiment_from_scratch(\n",
    "    n_list=(100,300,600,1000,2000),\n",
    "    E=10, alpha=0.20, K_cap=30,\n",
    "    max_iter=1000, tol=2e-3, n_init=5, reg_covar=1e-5,\n",
    "    grid_size=500, seed0=2025,\n",
    "    train_fraction=0.65, test_m=200,\n",
    "    standardize=True\n",
    "):\n",
    "    \n",
    "    params = draw_generator_params(seed0)\n",
    "\n",
    "    \n",
    "    pilot = sample_from_generator(3000, params, seed0+7)\n",
    "    y_grid = np.linspace(pilot[\"Y\"].min()-1.0, pilot[\"Y\"].max()+1.0, grid_size)\n",
    "    dy = y_grid[1] - y_grid[0]  \n",
    "\n",
    "    rows = []\n",
    "    per_n_times = {}\n",
    "    t0_total = time.time()\n",
    "\n",
    "    for n in n_list:\n",
    "        t_n_start = time.time()\n",
    "        coverages, lengths, eff_Ks = [], [], []\n",
    "        for rep in range(E):\n",
    "            seed = (n*97 + rep*313 + seed0) % (2**32 - 1)\n",
    "\n",
    "            \n",
    "            D = sample_from_generator(n, params, seed)\n",
    "            rng = np.random.default_rng(seed+1)\n",
    "            perm = rng.permutation(n)\n",
    "            n_tr = int(train_fraction * n)\n",
    "            tr_idx, cal_idx = perm[:n_tr], perm[n_tr:]\n",
    "            D_tr, D_cal = D.iloc[tr_idx].reset_index(drop=True), D.iloc[cal_idx].reset_index(drop=True)\n",
    "\n",
    "            \n",
    "            Z_tr = D_tr[[\"Y\",\"X1\",\"X2\"]].to_numpy()\n",
    "            if standardize:\n",
    "                S = fit_standardizer(Z_tr)\n",
    "                Z_tr_s = transform_Z(Z_tr, S)\n",
    "            else:\n",
    "                S = None\n",
    "                Z_tr_s = Z_tr\n",
    "\n",
    "            \n",
    "            bgmm = BayesianGaussianMixture(\n",
    "                n_components=K_cap, covariance_type=\"full\",\n",
    "                weight_concentration_prior_type=\"dirichlet_process\",\n",
    "                weight_concentration_prior=1.0,\n",
    "                init_params=\"kmeans\",\n",
    "                n_init=n_init,\n",
    "                max_iter=max_iter,\n",
    "                tol=tol,\n",
    "                reg_covar=reg_covar,\n",
    "                random_state=seed+2,\n",
    "            ).fit(Z_tr_s)\n",
    "\n",
    "            \n",
    "            mask = bgmm.weights_ > 1e-6\n",
    "            eff_K = int(np.sum(mask)); eff_Ks.append(eff_K)\n",
    "            cache = precompute_cache(bgmm.weights_[mask], bgmm.means_[mask], bgmm.covariances_[mask])\n",
    "\n",
    "            \n",
    "            Xc = D_cal[[\"X1\",\"X2\"]].to_numpy(); Yc = D_cal[\"Y\"].to_numpy()\n",
    "            s_cal_log = np.empty(len(Yc), float)\n",
    "            if standardize:\n",
    "                for i in range(len(Yc)):\n",
    "                    y_s, x_s = transform_yx(Yc[i], Xc[i], S)\n",
    "                    s_cal_log[i] = log_dens_point_y_given_x(y_s, x_s, cache)\n",
    "            else:\n",
    "                for i in range(len(Yc)):\n",
    "                    s_cal_log[i] = log_dens_point_y_given_x(Yc[i], Xc[i], cache)\n",
    "\n",
    "            \n",
    "            m_cal = len(s_cal_log)\n",
    "            k = int(np.floor(alpha * (m_cal + 1)))  \n",
    "            if k <= 0:\n",
    "                q_log = -np.inf  \n",
    "            else:\n",
    "                q_log = np.partition(s_cal_log, k - 1)[k - 1]  \n",
    "\n",
    "          \n",
    "            D_test = sample_from_generator(test_m, params, seed+3)\n",
    "            Xt = D_test[[\"X1\",\"X2\"]].to_numpy(); Yt = D_test[\"Y\"].to_numpy()\n",
    "            y_grid_s = transform_ygrid(y_grid, S) if standardize else y_grid\n",
    "\n",
    "            cov_flags = np.empty(test_m, float)\n",
    "            set_lengths = np.empty(test_m, float)\n",
    "\n",
    "            for i in range(test_m):\n",
    "                x = Xt[i]; y = Yt[i]\n",
    "                if standardize:\n",
    "                    y_s, x_s = transform_yx(y, x, S)\n",
    "                    dens_grid_log = log_dens_grid_y_given_x(y_grid_s, x_s, cache)\n",
    "                    s_true_log = log_dens_point_y_given_x(y_s, x_s, cache)\n",
    "                else:\n",
    "                    dens_grid_log = log_dens_grid_y_given_x(y_grid, x, cache)\n",
    "                    s_true_log = log_dens_point_y_given_x(y, x, cache)\n",
    "\n",
    "                \n",
    "                mask_set = (dens_grid_log >= q_log)\n",
    "                set_lengths[i] = float(np.sum(mask_set) * dy)\n",
    "                cov_flags[i] = 1.0 if s_true_log >= q_log else 0.0\n",
    "\n",
    "            coverages.append(float(np.mean(cov_flags)))\n",
    "            lengths.append(float(np.mean(set_lengths)))\n",
    "\n",
    "        cov = np.array(coverages)\n",
    "        L   = np.array(lengths)\n",
    "        per_n_times[n] = time.time() - t_n_start\n",
    "\n",
    "        rows.append({\n",
    "            \"n\": n,\n",
    "            \"coverage_average\": cov.mean(),\n",
    "            \"coverage_std\":  cov.std(ddof=1),\n",
    "            \"coverage_se\":   cov.std(ddof=1)/np.sqrt(E),\n",
    "            \"length_average\":   L.mean(),\n",
    "            \"length_std\":    L.std(ddof=1),\n",
    "            \"length_se\":     L.std(ddof=1)/np.sqrt(E),\n",
    "            \"K\": float(np.mean(eff_Ks)),\n",
    "            \"time\": round(per_n_times[n], 2),\n",
    "        })\n",
    "\n",
    "    results = pd.DataFrame(rows)\n",
    "    total_elapsed = time.time() - t0_total\n",
    "    return results, y_grid, total_elapsed, per_n_times\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "if __name__ == \"__main__\":\n",
    "    results, y_grid, total_elapsed, per_n_times = run_experiment_from_scratch(\n",
    "        n_list=(100, 300,600,1000),\n",
    "        E=10, alpha=0.20, K_cap=30,\n",
    "        max_iter=2000, tol=2e-3, n_init=5, reg_covar=1e-5,\n",
    "        grid_size=600, seed0=random.randint(1, 10000000),\n",
    "        standardize=True\n",
    "    )\n",
    "\n",
    "\n",
    "\n",
    "    out_csv = Path(\"CNB split.csv\")\n",
    "    results.to_csv(out_csv, index=False)\n",
    "  \n",
    "\n",
    "    \n",
    "    alpha = 0.2\n",
    "    plt.figure()\n",
    "    plt.plot(results[\"n\"], results[\"coverage_average\"], marker=\"o\")\n",
    "    plt.axhline(1 - alpha, linestyle=\"--\")\n",
    "    plt.title(\"CNB Split on Simulation Data Average Coverage vs n\")\n",
    "    plt.xlabel(\"n\"); plt.ylabel(\"Coverage Rate\")\n",
    "    plt.show()\n",
    "\n",
    "    plt.figure()\n",
    "    plt.errorbar(results[\"n\"], results[\"length_average\"], yerr=results[\"length_se\"], marker=\"o\")\n",
    "    plt.title(\"CNB Split on Simulation Data Average conformal set length vs n\")\n",
    "    plt.xlabel(\"n\"); plt.ylabel(\"Average Set Length\")\n",
    "    plt.show()\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a5c434cb-1d5c-4fa3-ba1a-34938c18315b",
   "metadata": {},
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f9adb2c7-2d2c-4013-9018-92f7273df6f0",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
