{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "13d26a5c",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2025-06-11T17:42:24.743119Z",
     "iopub.status.busy": "2025-06-11T17:42:24.742944Z",
     "iopub.status.idle": "2025-06-11T17:42:25.407302Z",
     "shell.execute_reply": "2025-06-11T17:42:25.406786Z"
    }
   },
   "outputs": [],
   "source": [
    "import random\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "import matplotlib_inline.backend_inline\n",
    "from scipy.stats import ks_1samp, uniform, norm, gaussian_kde\n",
    "import torch\n",
    "from torchvision.datasets import MNIST\n",
    "from torchvision import transforms\n",
    "from torchvision.models import ResNet18_Weights\n",
    "from transformers import AutoModelForSequenceClassification, AutoTokenizer\n",
    "from datasets import load_dataset\n",
    "from joblib import Parallel, delayed\n",
    "\n",
    "from tqdm import tqdm\n",
    "\n",
    "import timm\n",
    "\n",
    "np.random.seed(42)\n",
    "\n",
    "matplotlib_inline.backend_inline.set_matplotlib_formats(\"svg\")\n",
    "plt.style.use(\"math.mplstyle\")\n",
    "\n",
    "plt.rcParams.update({\"axes.labelsize\": 14})\n",
    "plt.rcParams.update({\"xtick.labelsize\": 14})\n",
    "plt.rcParams.update({\"ytick.labelsize\": 14})\n",
    "plt.rcParams.update({\"legend.fontsize\": 14})\n",
    "plt.rcParams.update({\"axes.titlesize\": 16})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "335f7100",
   "metadata": {},
   "outputs": [],
   "source": [
    "n = 1000\n",
    "xi = 400"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cd116ff8",
   "metadata": {},
   "outputs": [],
   "source": [
    "from typing import Callable\n",
    "\n",
    "def compute_changepoint_pvalue(n: int, t: int,\n",
    "                               left_row: Callable[[int], np.ndarray],\n",
    "                               right_row: Callable[[int], np.ndarray]) -> float:\n",
    "    left_pvals = []\n",
    "    if t >= 0:\n",
    "        for r in range(0, t + 1):\n",
    "            row = left_row(r)\n",
    "            if row is None or len(row) == 0:\n",
    "                continue\n",
    "            val = row[-1]\n",
    "            lt = np.sum(row < val)\n",
    "            eq = np.sum(row == val)\n",
    "            p = (lt + np.random.uniform(0.0, 1.0) * eq) / len(row)\n",
    "            left_pvals.append(p)\n",
    "\n",
    "    right_pvals = []\n",
    "    if t + 1 < n:\n",
    "        for r in range(t + 1, n):\n",
    "            row = right_row(r)\n",
    "            if row is None or len(row) == 0:\n",
    "                continue\n",
    "            val = row[0]\n",
    "            lt = np.sum(row < val)\n",
    "            eq = np.sum(row == val)\n",
    "            p = (lt + np.random.uniform(0.0, 1.0) * eq) / len(row)\n",
    "            right_pvals.append(p)\n",
    "\n",
    "    left_method = \"exact\" if len(left_pvals) <= 300 else \"asymp\"\n",
    "    right_method = \"exact\" if len(right_pvals) <= 300 else \"asymp\"\n",
    "    try:\n",
    "        p_left = ks_1samp(left_pvals, uniform.cdf, method=left_method)[1] if len(left_pvals) > 0 else 1.0\n",
    "    except Exception:\n",
    "        p_left = 1.0\n",
    "    try:\n",
    "        p_right = ks_1samp(right_pvals, uniform.cdf, method=right_method)[1] if len(right_pvals) > 0 else 1.0\n",
    "    except Exception:\n",
    "        p_right = 1.0\n",
    "\n",
    "    return 1 - (1 - min(p_left, p_right)) ** 2\n",
    "\n",
    "\n",
    "def compute_changepoint_pvalue_counts(n: int, t: int,\n",
    "                                      left_counts: Callable[[int], tuple],\n",
    "                                      right_counts: Callable[[int], tuple]) -> float:\n",
    "    left_pvals = []\n",
    "    if t >= 0:\n",
    "        for r in range(0, t + 1):\n",
    "            lt, eq, total = left_counts(r)\n",
    "            if total <= 0:\n",
    "                continue\n",
    "            p = (lt + 0.5 * eq) / total\n",
    "            left_pvals.append(p)\n",
    "\n",
    "    right_pvals = []\n",
    "    if t + 1 < n:\n",
    "        for r in range(t + 1, n):\n",
    "            lt, eq, total = right_counts(r)\n",
    "            if total <= 0:\n",
    "                continue\n",
    "            p = (lt + 0.5 * eq) / total\n",
    "            right_pvals.append(p)\n",
    "\n",
    "    left_method = \"exact\" if len(left_pvals) <= 300 else \"asymp\"\n",
    "    right_method = \"exact\" if len(right_pvals) <= 300 else \"asymp\"\n",
    "    try:\n",
    "        p_left = ks_1samp(left_pvals, uniform.cdf, method=left_method)[1] if len(left_pvals) > 0 else 1.0\n",
    "    except Exception:\n",
    "        p_left = 1.0\n",
    "    try:\n",
    "        p_right = ks_1samp(right_pvals, uniform.cdf, method=right_method)[1] if len(right_pvals) > 0 else 1.0\n",
    "    except Exception:\n",
    "        p_right = 1.0\n",
    "\n",
    "    return 1 - (1 - min(p_left, p_right)) ** 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1fbdd363",
   "metadata": {},
   "outputs": [],
   "source": [
    "def compute_prefix_suffix_modes(predictions, num_classes):\n",
    "    preds = np.asarray(predictions, dtype=int)\n",
    "    n = preds.shape[0]\n",
    "    onehot = np.zeros((n, num_classes), dtype=int)\n",
    "    onehot[np.arange(n), preds] = 1\n",
    "    prefix_counts = np.cumsum(onehot, axis=0)\n",
    "    prefix_mode = prefix_counts.argmax(axis=1)\n",
    "    suffix_counts_rev = np.cumsum(onehot[::-1], axis=0)\n",
    "    suffix_counts = suffix_counts_rev[::-1]\n",
    "    suffix_mode = suffix_counts.argmax(axis=1)\n",
    "    return prefix_mode, suffix_mode\n",
    "\n",
    "\n",
    "def run_classification_simulation(predictions, probabilities, num_classes, n_jobs=-1, verbose=1):\n",
    "    if isinstance(probabilities, torch.Tensor):\n",
    "        probs = probabilities.detach().cpu().numpy()\n",
    "    else:\n",
    "        probs = np.asarray(probabilities)\n",
    "    preds = np.asarray(predictions, dtype=int)\n",
    "    n = len(preds)\n",
    "\n",
    "    prefix_mode, suffix_mode = compute_prefix_suffix_modes(preds, num_classes)\n",
    "    eps = 1e-12\n",
    "\n",
    "    def pvalue_for_t(t: int) -> float:\n",
    "        ref_right = prefix_mode[t] if t >= 0 else prefix_mode[0]\n",
    "        ref_left = suffix_mode[t + 1] if (t + 1 < n) else None\n",
    "\n",
    "        def left_row(r: int) -> np.ndarray:\n",
    "            if ref_left is None:\n",
    "                return np.array([1.0])\n",
    "            idx = slice(0, r + 1)\n",
    "            baseline_labels = prefix_mode[: r + 1]\n",
    "            num = probs[idx, baseline_labels]\n",
    "            den = probs[idx, ref_left] + eps\n",
    "            return num / den\n",
    "\n",
    "        def right_row(r: int) -> np.ndarray:\n",
    "            idx = slice(r, n)\n",
    "            baseline_labels = suffix_mode[r:]\n",
    "            num = probs[idx, baseline_labels]\n",
    "            den = probs[idx, ref_right] + eps\n",
    "            return num / den\n",
    "\n",
    "        return compute_changepoint_pvalue(n, t, left_row, right_row)\n",
    "\n",
    "    p_values = Parallel(n_jobs=n_jobs, verbose=verbose)(delayed(pvalue_for_t)(t) for t in range(n - 1))\n",
    "    return np.array(p_values)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e2d30e7b",
   "metadata": {},
   "outputs": [],
   "source": [
    "def pvalues_from_kappa(left_kappa, right_kappa, n_jobs=-1, verbose=1):\n",
    "    n_minus_1 = len(left_kappa)\n",
    "    n = n_minus_1 + 1\n",
    "\n",
    "    def pvalue_for_t(t: int) -> float:\n",
    "        def left_row(r: int):\n",
    "            return left_kappa[t][r]\n",
    "\n",
    "        def right_row(r: int):\n",
    "            return right_kappa[t][r]\n",
    "\n",
    "        return compute_changepoint_pvalue(n, t, left_row, right_row)\n",
    "\n",
    "    p_values = Parallel(n_jobs=n_jobs, verbose=verbose)(delayed(pvalue_for_t)(t) for t in range(n_minus_1))\n",
    "    return np.array(p_values)\n",
    "\n",
    "\n",
    "def compute_kappa_classification(predictions, probabilities, num_classes, eps=1e-12):\n",
    "    if isinstance(probabilities, torch.Tensor):\n",
    "        probs = probabilities.detach().cpu().numpy()\n",
    "    else:\n",
    "        probs = np.asarray(probabilities)\n",
    "    preds = np.asarray(predictions, dtype=int)\n",
    "    n = len(preds)\n",
    "\n",
    "    prefix_mode, suffix_mode = compute_prefix_suffix_modes(preds, num_classes)\n",
    "\n",
    "    left_kappa = []\n",
    "    right_kappa = []\n",
    "\n",
    "    for t in range(n - 1):\n",
    "        ref_right = prefix_mode[t]\n",
    "        ref_left = suffix_mode[t + 1] if (t + 1 < n) else None\n",
    "\n",
    "        left_rows = {}\n",
    "        if ref_left is None:\n",
    "            for r in range(0, t + 1):\n",
    "                left_rows[r] = np.array([1.0])\n",
    "        else:\n",
    "            for r in range(0, t + 1):\n",
    "                j_idx = np.arange(0, r + 1)\n",
    "                baseline_labels = prefix_mode[: r + 1]\n",
    "                num = probs[j_idx, baseline_labels]\n",
    "                den = probs[j_idx, ref_left] + eps\n",
    "                left_rows[r] = num / den\n",
    "\n",
    "        right_rows = {}\n",
    "        for r in range(t + 1, n):\n",
    "            j_idx = np.arange(r, n)\n",
    "            baseline_labels = suffix_mode[r:]\n",
    "            num = probs[j_idx, baseline_labels]\n",
    "            den = probs[j_idx, ref_right] + eps\n",
    "            right_rows[r] = num / den\n",
    "\n",
    "        left_kappa.append(left_rows)\n",
    "        right_kappa.append(right_rows)\n",
    "\n",
    "    return left_kappa, right_kappa\n",
    "\n",
    "\n",
    "def compute_kappa_oracle(timeseries, f0=None, f1=None):\n",
    "    if f0 is None:\n",
    "        f0 = norm(loc=-1, scale=1)\n",
    "    if f1 is None:\n",
    "        f1 = norm(loc=1, scale=1)\n",
    "\n",
    "    x = np.asarray(timeseries).reshape(-1)\n",
    "    n = len(x)\n",
    "    eps = 1e-12\n",
    "\n",
    "    left_scores = (f1.pdf(x) + eps) / (f0.pdf(x) + eps)\n",
    "    right_scores = (f0.pdf(x) + eps) / (f1.pdf(x) + eps)\n",
    "\n",
    "    left_kappa = []\n",
    "    right_kappa = []\n",
    "    for t in range(n - 1):\n",
    "        left_rows = {r: left_scores[: r + 1] for r in range(0, t + 1)}\n",
    "        right_rows = {r: right_scores[r:] for r in range(t + 1, n)}\n",
    "        left_kappa.append(left_rows)\n",
    "        right_kappa.append(right_rows)\n",
    "    return left_kappa, right_kappa\n",
    "\n",
    "\n",
    "def compute_kappa_parametric(timeseries):\n",
    "    x = np.asarray(timeseries).reshape(-1)\n",
    "    n = len(x)\n",
    "    eps = 1e-12\n",
    "\n",
    "    cumsum = np.cumsum(x)\n",
    "    r_idx = np.arange(n)\n",
    "    mu0 = cumsum / (r_idx + 1)\n",
    "\n",
    "    rcumsum = np.cumsum(x[::-1])\n",
    "\n",
    "    def suffix_mean(start: int) -> float:\n",
    "        length = n - start\n",
    "        total = rcumsum[n - 1 - start]\n",
    "        return total / length\n",
    "\n",
    "    left_kappa = []\n",
    "    right_kappa = []\n",
    "    for t in range(n - 1):\n",
    "        mu1_suffix = suffix_mean(t + 1) if (t + 1 < n) else None\n",
    "        mu0_t = mu0[t]\n",
    "\n",
    "        left_rows = {}\n",
    "        if mu1_suffix is None:\n",
    "            for r in range(0, t + 1):\n",
    "                left_rows[r] = np.array([1.0])\n",
    "        else:\n",
    "            for r in range(0, t + 1):\n",
    "                vals = x[: r + 1]\n",
    "                m0 = mu0[r]\n",
    "                num = np.exp(-0.5 * (vals - mu1_suffix) ** 2)\n",
    "                den = np.exp(-0.5 * (vals - m0) ** 2) + eps\n",
    "                left_rows[r] = num / den\n",
    "\n",
    "        right_rows = {}\n",
    "        for r in range(t + 1, n):\n",
    "            vals = x[r:]\n",
    "            mu1_r = suffix_mean(r)\n",
    "            num = np.exp(-0.5 * (vals - mu0_t) ** 2)\n",
    "            den = np.exp(-0.5 * (vals - mu1_r) ** 2) + eps\n",
    "            right_rows[r] = num / den\n",
    "\n",
    "        left_kappa.append(left_rows)\n",
    "        right_kappa.append(right_rows)\n",
    "\n",
    "    return left_kappa, right_kappa\n",
    "\n",
    "\n",
    "def compute_kde(data):\n",
    "    data = np.asarray(data).reshape(-1)\n",
    "    if len(data) <= 1:\n",
    "        mean = data[0] if len(data) == 1 else 0.0\n",
    "        return lambda x: norm.pdf(x, loc=mean, scale=0.1)\n",
    "    return gaussian_kde(data, bw_method=\"scott\")\n",
    "\n",
    "\n",
    "def compute_kappa_kde(timeseries):\n",
    "    x = np.asarray(timeseries).reshape(-1)\n",
    "    n = len(x)\n",
    "    eps = 1e-12\n",
    "\n",
    "    left_kappa = []\n",
    "    right_kappa = []\n",
    "\n",
    "    for t in range(n - 1):\n",
    "        f1_suffix = compute_kde(x[t + 1 :]) if (t + 1 < n) else None\n",
    "        f0_t = compute_kde(x[: t + 1]) if t >= 0 else compute_kde(x[:1])\n",
    "\n",
    "        left_rows = {}\n",
    "        if f1_suffix is None:\n",
    "            for r in range(0, t + 1):\n",
    "                left_rows[r] = np.array([1.0])\n",
    "        else:\n",
    "            for r in range(0, t + 1):\n",
    "                f0_r = compute_kde(x[: r + 1])\n",
    "                vals = x[: r + 1]\n",
    "                num = f1_suffix(vals)\n",
    "                den = f0_r(vals) + eps\n",
    "                left_rows[r] = num / den\n",
    "\n",
    "        right_rows = {}\n",
    "        for r in range(t + 1, n):\n",
    "            f1_r = compute_kde(x[r:])\n",
    "            vals = x[r:]\n",
    "            num = f0_t(vals)\n",
    "            den = f1_r(vals) + eps\n",
    "            right_rows[r] = num / den\n",
    "\n",
    "        left_kappa.append(left_rows)\n",
    "        right_kappa.append(right_rows)\n",
    "\n",
    "    return left_kappa, right_kappa"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e5846042",
   "metadata": {},
   "outputs": [],
   "source": [
    "def pvalues_from_row_functions(n, left_row_fn, right_row_fn, n_jobs=1, verbose=0):\n",
    "    def pvalue_for_t(t: int) -> float:\n",
    "        def left_row(r: int):\n",
    "            return left_row_fn(t, r)\n",
    "        def right_row(r: int):\n",
    "            return right_row_fn(t, r)\n",
    "        return compute_changepoint_pvalue(n, t, left_row, right_row)\n",
    "\n",
    "    if n_jobs == 1:\n",
    "        return np.array([pvalue_for_t(t) for t in range(n - 1)])\n",
    "    else:\n",
    "        return np.array(Parallel(n_jobs=n_jobs, verbose=verbose)(delayed(pvalue_for_t)(t) for t in range(n - 1)))\n",
    "\n",
    "\n",
    "def make_classification_kappa_accessors(predictions, probabilities, num_classes, eps=1e-12):\n",
    "    if isinstance(probabilities, torch.Tensor):\n",
    "        probs = probabilities.detach().cpu().numpy()\n",
    "    else:\n",
    "        probs = np.asarray(probabilities)\n",
    "    preds = np.asarray(predictions, dtype=int)\n",
    "    n = len(preds)\n",
    "\n",
    "    prefix_mode, suffix_mode = compute_prefix_suffix_modes(preds, num_classes)\n",
    "\n",
    "    def left_row_fn(t: int, r: int) -> np.ndarray:\n",
    "        if t + 1 >= n:\n",
    "            return np.array([1.0])\n",
    "        ref_left = suffix_mode[t + 1]\n",
    "        j_idx = np.arange(0, r + 1)\n",
    "        baseline_labels = prefix_mode[: r + 1]\n",
    "        num = probs[j_idx, baseline_labels]\n",
    "        den = probs[j_idx, ref_left] + eps\n",
    "        return num / den\n",
    "\n",
    "    def right_row_fn(t: int, r: int) -> np.ndarray:\n",
    "        ref_right = prefix_mode[t]\n",
    "        j_idx = np.arange(r, n)\n",
    "        baseline_labels = suffix_mode[r:]\n",
    "        num = probs[j_idx, baseline_labels]\n",
    "        den = probs[j_idx, ref_right] + eps\n",
    "        return num / den\n",
    "\n",
    "    return left_row_fn, right_row_fn\n",
    "\n",
    "\n",
    "def make_oracle_row_functions(timeseries):\n",
    "    x = np.asarray(timeseries).reshape(-1)\n",
    "    n = len(x)\n",
    "    f0 = norm(loc=-1, scale=1)\n",
    "    f1 = norm(loc=1, scale=1)\n",
    "    eps = 1e-12\n",
    "    left_scores = (f1.pdf(x) + eps) / (f0.pdf(x) + eps)\n",
    "    right_scores = (f0.pdf(x) + eps) / (f1.pdf(x) + eps)\n",
    "\n",
    "    def left_row_fn(t: int, r: int) -> np.ndarray:\n",
    "        return left_scores[: r + 1]\n",
    "\n",
    "    def right_row_fn(t: int, r: int) -> np.ndarray:\n",
    "        return right_scores[r:]\n",
    "\n",
    "    return left_row_fn, right_row_fn\n",
    "\n",
    "\n",
    "def make_parametric_row_functions(timeseries):\n",
    "    x = np.asarray(timeseries).reshape(-1)\n",
    "    n = len(x)\n",
    "    eps = 1e-12\n",
    "\n",
    "    cumsum = np.cumsum(x)\n",
    "    r_idx = np.arange(n)\n",
    "    mu0 = cumsum / (r_idx + 1)\n",
    "\n",
    "    rcumsum = np.cumsum(x[::-1])\n",
    "    def suffix_mean(start: int) -> float:\n",
    "        length = n - start\n",
    "        total = rcumsum[n - 1 - start]\n",
    "        return total / length\n",
    "\n",
    "    def left_row_fn(t: int, r: int) -> np.ndarray:\n",
    "        if t + 1 >= n:\n",
    "            return np.array([1.0])\n",
    "        mu1_suffix = suffix_mean(t + 1)\n",
    "        vals = x[: r + 1]\n",
    "        m0 = mu0[r]\n",
    "        num = np.exp(-0.5 * (vals - mu1_suffix) ** 2)\n",
    "        den = np.exp(-0.5 * (vals - m0) ** 2) + eps\n",
    "        return num / den\n",
    "\n",
    "    def right_row_fn(t: int, r: int) -> np.ndarray:\n",
    "        mu0_t = mu0[t]\n",
    "        vals = x[r:]\n",
    "        mu1_r = suffix_mean(r)\n",
    "        num = np.exp(-0.5 * (vals - mu0_t) ** 2)\n",
    "        den = np.exp(-0.5 * (vals - mu1_r) ** 2) + eps\n",
    "        return num / den\n",
    "\n",
    "    return left_row_fn, right_row_fn\n",
    "\n",
    "\n",
    "def make_kde_row_functions(timeseries):\n",
    "    x = np.asarray(timeseries).reshape(-1)\n",
    "    n = len(x)\n",
    "    eps = 1e-12\n",
    "\n",
    "    prefix_kdes = []\n",
    "    for r in range(n):\n",
    "        prefix_kdes.append(compute_kde(x[: r + 1]))\n",
    "    suffix_kdes = []\n",
    "    for r in range(n):\n",
    "        suffix_kdes.append(compute_kde(x[r:]))\n",
    "\n",
    "    def left_row_fn(t: int, r: int) -> np.ndarray:\n",
    "        if t + 1 >= n:\n",
    "            return np.array([1.0])\n",
    "        f1_suffix = suffix_kdes[t + 1]\n",
    "        f0_r = prefix_kdes[r]\n",
    "        vals = x[: r + 1]\n",
    "        num = f1_suffix(vals)\n",
    "        den = f0_r(vals) + eps\n",
    "        return num / den\n",
    "\n",
    "    def right_row_fn(t: int, r: int) -> np.ndarray:\n",
    "        f0_t = prefix_kdes[t]\n",
    "        f1_r = suffix_kdes[r]\n",
    "        vals = x[r:]\n",
    "        num = f0_t(vals)\n",
    "        den = f1_r(vals) + eps\n",
    "        return num / den\n",
    "\n",
    "    return left_row_fn, right_row_fn"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0ee52a93",
   "metadata": {},
   "source": [
    "# Gaussian mean change"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a59ed127",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2025-06-11T17:42:25.408905Z",
     "iopub.status.busy": "2025-06-11T17:42:25.408776Z",
     "iopub.status.idle": "2025-06-11T17:42:25.444479Z",
     "shell.execute_reply": "2025-06-11T17:42:25.444189Z"
    }
   },
   "outputs": [],
   "source": [
    "timeseries = np.concatenate([np.random.normal(-1, 1, xi), np.random.normal(1, 1, n - xi)])\n",
    "\n",
    "time_indices = np.arange(1, n + 1)\n",
    "\n",
    "plt.plot(time_indices, timeseries)\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9347fcc1",
   "metadata": {},
   "source": [
    "## Oracle likelihood ratio score"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d5b46420",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2025-06-11T17:42:25.446270Z",
     "iopub.status.busy": "2025-06-11T17:42:25.446153Z",
     "iopub.status.idle": "2025-06-11T17:42:27.372805Z",
     "shell.execute_reply": "2025-06-11T17:42:27.372417Z"
    }
   },
   "outputs": [],
   "source": [
    "def run_single_simulation_oracle(n, xi, timeseries):\n",
    "    f_0 = norm(-1, 1)\n",
    "    f_1 = norm(1, 1)\n",
    "\n",
    "    left_scores = f_1.pdf(timeseries) / f_0.pdf(timeseries)\n",
    "    right_scores = f_0.pdf(timeseries) / f_1.pdf(timeseries)\n",
    "\n",
    "    def pvalue_for_t(t: int) -> float:\n",
    "        def left_row(r: int) -> np.ndarray:\n",
    "            return left_scores[: r + 1]\n",
    "\n",
    "        def right_row(r: int) -> np.ndarray:\n",
    "            return right_scores[r:]\n",
    "\n",
    "        return compute_changepoint_pvalue(n, t, left_row, right_row)\n",
    "\n",
    "    p_values = Parallel(n_jobs=-1, verbose=1)(delayed(pvalue_for_t)(t) for t in range(n - 1))\n",
    "    return np.array(p_values)\n",
    "\n",
    "p_values = run_single_simulation_oracle(n, xi, timeseries)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "afffabdc",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2025-06-11T17:42:27.374524Z",
     "iopub.status.busy": "2025-06-11T17:42:27.374395Z",
     "iopub.status.idle": "2025-06-11T17:42:27.473730Z",
     "shell.execute_reply": "2025-06-11T17:42:27.473042Z"
    }
   },
   "outputs": [],
   "source": [
    "def plot_and_report(p_values, xi, title, outpath):\n",
    "    time_indices_p = np.arange(1, len(p_values) + 1)\n",
    "    plt.plot(time_indices_p, p_values)\n",
    "    plt.axvline(x=xi, color=\"red\", linestyle=\"--\", label=\"Changepoint ($\\\\xi = 400$)\")\n",
    "    plt.axhline(0.05, color=\"green\", linestyle=\":\", label=\"Threshold ($\\\\alpha = 0.05$)\")\n",
    "    plt.xlabel(\"$t$\")\n",
    "    plt.title(title)\n",
    "    plt.legend()\n",
    "    plt.savefig(outpath)\n",
    "    plt.show()\n",
    "\n",
    "    detected_changepoint = np.argmax(p_values) + 1\n",
    "    print(f\"\\nTrue change point: {xi}\")\n",
    "    print(f\"Detected change point: {detected_changepoint}\")\n",
    "    print(f\"Detection error: {abs(detected_changepoint - xi)}\")\n",
    "    print(f\"Size of confidence set: {np.sum(p_values > 0.05)}\")\n",
    "    print(f\"Changepoint in confidence set: {p_values[xi-1] > 0.05}\")\n",
    "    print(f\"CI: {np.where(p_values > 0.05)[0] + 1}\")\n",
    "\n",
    "plot_and_report(p_values, xi, \"p-values for Gaussian mean change (oracle score)\", \"images/oracle.pdf\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "bf08d461",
   "metadata": {},
   "source": [
    "## Parametric learned score"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "19b3bc0b",
   "metadata": {},
   "outputs": [],
   "source": [
    "def run_single_simulation_parametric(n, xi, timeseries):\n",
    "    left_row_fn, right_row_fn = make_parametric_row_functions(timeseries)\n",
    "    return pvalues_from_row_functions(len(timeseries), left_row_fn, right_row_fn, n_jobs=-1, verbose=1)\n",
    "\n",
    "p_values = run_single_simulation_parametric(n, xi, timeseries)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "13ada403",
   "metadata": {},
   "outputs": [],
   "source": [
    "plot_and_report(p_values, xi, \"p-values for Gaussian mean change (parametric learned score)\", \"images/parametric.pdf\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1115fe7e",
   "metadata": {},
   "source": [
    "## Kernel density estimator (KDE) learned score"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9db173ba",
   "metadata": {},
   "outputs": [],
   "source": [
    "def run_single_simulation_kde(n, xi, timeseries, n_jobs=-1, verbose=1):\n",
    "    left_row_fn, right_row_fn = make_kde_row_functions(timeseries)\n",
    "    return pvalues_from_row_functions(len(timeseries), left_row_fn, right_row_fn, n_jobs=n_jobs, verbose=verbose)\n",
    "\n",
    "p_values = run_single_simulation_kde(n, xi, timeseries)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "05d831e7",
   "metadata": {},
   "outputs": [],
   "source": [
    "plot_and_report(p_values, xi, \"p-values for Gaussian mean change (KDE learned score)\", \"images/kde.pdf\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "76883e3c",
   "metadata": {},
   "source": [
    "## Coverage and width simulations"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8fbc2b44",
   "metadata": {},
   "outputs": [],
   "source": [
    "def _oracle_p_value_for_t(t, timeseries):\n",
    "    n = len(timeseries)\n",
    "    f_0 = norm(-1, 1)\n",
    "    f_1 = norm(1, 1)\n",
    "    left_scores = f_1.pdf(timeseries) / f_0.pdf(timeseries)\n",
    "    right_scores = f_0.pdf(timeseries) / f_1.pdf(timeseries)\n",
    "\n",
    "    def left_row(r: int) -> np.ndarray:\n",
    "        return left_scores[: r + 1]\n",
    "\n",
    "    def right_row(r: int) -> np.ndarray:\n",
    "        return right_scores[r:]\n",
    "\n",
    "    return compute_changepoint_pvalue(n, t, left_row, right_row)\n",
    "\n",
    "\n",
    "def run_single_simulation(n, xi):\n",
    "    timeseries = np.concatenate([np.random.normal(-1, 1, xi), np.random.normal(1, 1, n - xi)])\n",
    "    p_values = Parallel(n_jobs=-1, verbose=1)(delayed(_oracle_p_value_for_t)(t, timeseries) for t in range(n - 1))\n",
    "    return np.array(p_values)\n",
    "\n",
    "\n",
    "def run_simulation_study(n_simulations=1000, n=500, xi=200):\n",
    "    all_p_values = np.zeros((n_simulations, n - 1))\n",
    "    coverages_95 = []\n",
    "    coverages_50 = []\n",
    "    widths_95 = []\n",
    "    widths_50 = []\n",
    "    detected_cps = []\n",
    "    pbar = tqdm(range(n_simulations))\n",
    "    for i in pbar:\n",
    "        p_values = run_single_simulation(n, xi)\n",
    "        all_p_values[i] = p_values\n",
    "        coverage_95 = p_values[xi - 1] > 0.05\n",
    "        coverage_50 = p_values[xi - 1] > 0.50\n",
    "        width_95 = np.sum(p_values > 0.05)\n",
    "        width_50 = np.sum(p_values > 0.50)\n",
    "        detected_cp = np.argmax(p_values) + 1\n",
    "        coverages_95.append(coverage_95)\n",
    "        coverages_50.append(coverage_50)\n",
    "        widths_95.append(width_95)\n",
    "        widths_50.append(width_50)\n",
    "        detected_cps.append(detected_cp)\n",
    "        if i > 0:\n",
    "            running_cov_95 = np.mean(coverages_95)\n",
    "            running_cov_50 = np.mean(coverages_50)\n",
    "            running_width_95 = np.mean(widths_95)\n",
    "            running_error = np.mean([abs(cp - xi) for cp in detected_cps])\n",
    "            pbar.set_description(f\"Cov95:{running_cov_95:.3f} Cov50:{running_cov_50:.3f} W95:{running_width_95:.1f} Err:{running_error:.1f}\")\n",
    "    return all_p_values"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "42b8f9e4",
   "metadata": {},
   "outputs": [],
   "source": [
    "all_p_values = run_simulation_study(n_simulations=1000, n=500, xi=200)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e1906929",
   "metadata": {},
   "outputs": [],
   "source": [
    "xi = 200\n",
    "n_simulations, n_minus_1 = all_p_values.shape\n",
    "\n",
    "coverage_95 = all_p_values[:, xi-1] > 0.05\n",
    "coverage_50 = all_p_values[:, xi-1] > 0.50\n",
    "\n",
    "widths_95 = np.sum(all_p_values > 0.05, axis=1)\n",
    "widths_50 = np.sum(all_p_values > 0.50, axis=1)\n",
    "\n",
    "detected_cps = np.argmax(all_p_values, axis=1) + 1\n",
    "detection_errors = np.abs(detected_cps - xi)\n",
    "\n",
    "print(f\"95% Coverage: {np.mean(coverage_95):.3f}\")\n",
    "print(f\"50% Coverage: {np.mean(coverage_50):.3f}\")\n",
    "print(f\"Average Width (95%): {np.mean(widths_95):.1f}\")\n",
    "print(f\"Average Width (50%): {np.mean(widths_50):.1f}\")\n",
    "print(f\"Average Detection Error: {np.mean(detection_errors):.1f}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9b6fe4f1",
   "metadata": {},
   "outputs": [],
   "source": [
    "np.save('all_p_values.npy', all_p_values)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8b650cf3",
   "metadata": {},
   "source": [
    "# MNIST digit change"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "07b9ced6",
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_mnist_trained_model(device=\"cpu\"):\n",
    "    class MNISTModel(torch.nn.Module):\n",
    "        def __init__(self):\n",
    "            super(MNISTModel, self).__init__()\n",
    "            self.conv1 = torch.nn.Conv2d(1, 32, 3, 1)\n",
    "            self.conv2 = torch.nn.Conv2d(32, 64, 3, 1)\n",
    "            self.dropout1 = torch.nn.Dropout(0.25)\n",
    "            self.dropout2 = torch.nn.Dropout(0.5)\n",
    "            self.fc1 = torch.nn.Linear(9216, 128)\n",
    "            self.fc2 = torch.nn.Linear(128, 10)\n",
    "\n",
    "        def forward(self, x):\n",
    "            x = self.conv1(x)\n",
    "            x = torch.nn.functional.relu(x)\n",
    "            x = self.conv2(x)\n",
    "            x = torch.nn.functional.relu(x)\n",
    "            x = torch.nn.functional.max_pool2d(x, 2)\n",
    "            x = self.dropout1(x)\n",
    "            x = torch.flatten(x, 1)\n",
    "            x = self.fc1(x)\n",
    "            x = torch.nn.functional.relu(x)\n",
    "            x = self.dropout2(x)\n",
    "            x = self.fc2(x)\n",
    "            return x\n",
    "\n",
    "    model = MNISTModel().to(device)\n",
    "\n",
    "    transform = transforms.Compose(\n",
    "        [transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]\n",
    "    )\n",
    "\n",
    "    train_dataset = MNIST(root=\"./data\", train=True, download=True, transform=transform)\n",
    "    train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=64)\n",
    "\n",
    "    optimizer = torch.optim.Adam(model.parameters())\n",
    "    criterion = torch.nn.CrossEntropyLoss()\n",
    "\n",
    "    print(\"Training MNIST model...\")\n",
    "    model.train()\n",
    "    for epoch in range(1):\n",
    "        for batch_idx, (data, target) in enumerate(tqdm(train_loader)):\n",
    "            data, target = data.to(device), target.to(device)\n",
    "            optimizer.zero_grad()\n",
    "            output = model(data)\n",
    "            loss = criterion(output, target)\n",
    "            loss.backward()\n",
    "            optimizer.step()\n",
    "\n",
    "            if batch_idx % 100 == 0:\n",
    "                print(\n",
    "                    f\"Epoch: {epoch} [{batch_idx*len(data)}/{len(train_loader.dataset)} \"\n",
    "                    f\"({100. * batch_idx / len(train_loader):.0f}%)]\\\\tLoss: {loss.item():.6f}\"\n",
    "                )\n",
    "\n",
    "    model.eval()\n",
    "    return model\n",
    "\n",
    "\n",
    "def predict_digit(model, image, device=\"cpu\"):\n",
    "    image = image.reshape(1, 1, 28, 28)\n",
    "    image_tensor = torch.tensor(image, device=device)\n",
    "\n",
    "    with torch.no_grad():\n",
    "        outputs = torch.softmax(model(image_tensor), dim=1).cpu()\n",
    "        predicted = outputs.argmax(dim=1).item()\n",
    "    return (predicted, outputs)\n",
    "\n",
    "\n",
    "def generate_mnist_dataset(length, changepoint, digit1=3, digit2=7):\n",
    "    transform = transforms.ToTensor()\n",
    "    mnist_data = MNIST(root=\"./data\", train=True, download=True, transform=transform)\n",
    "    data = mnist_data.data.numpy()\n",
    "    targets = mnist_data.targets.numpy()\n",
    "\n",
    "    images_digit1 = data[targets == digit1]\n",
    "    images_digit2 = data[targets == digit2]\n",
    "    np.random.shuffle(images_digit1)\n",
    "    np.random.shuffle(images_digit2)\n",
    "\n",
    "    n1 = changepoint + 1\n",
    "    n2 = length - n1\n",
    "    if n1 > len(images_digit1) or n2 > len(images_digit2):\n",
    "        raise ValueError(\"Insufficient images for the specified digits and length.\")\n",
    "\n",
    "    data1 = images_digit1[:n1]\n",
    "    data2 = images_digit2[:n2]\n",
    "    x = np.concatenate([data1, data2], axis=0)\n",
    "\n",
    "    x = x.reshape(length, -1).astype(np.float32) / 255.0\n",
    "    return x"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b70fc254",
   "metadata": {},
   "outputs": [],
   "source": [
    "def compute_sequential_scores(x, model, device=\"cpu\", batch_size=128):\n",
    "    model.to(device)\n",
    "    model.eval()\n",
    "    n = len(x)\n",
    "    preds = []\n",
    "    probs_list = []\n",
    "    x_tensor = torch.tensor(x, dtype=torch.float32, device=device).view(n, 1, 28, 28)\n",
    "    x_tensor = (x_tensor - 0.1307) / 0.3081\n",
    "    with torch.no_grad():\n",
    "        for start in tqdm(range(0, n, batch_size)):\n",
    "            end = min(start + batch_size, n)\n",
    "            batch = x_tensor[start:end]\n",
    "            logits = model(batch)\n",
    "            batch_probs = torch.softmax(logits, dim=1).detach().cpu()\n",
    "            probs_list.append(batch_probs)\n",
    "            preds.extend(batch_probs.argmax(dim=1).tolist())\n",
    "    probabilities = torch.cat(probs_list, dim=0)\n",
    "    return preds, probabilities"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bf80c99b",
   "metadata": {},
   "outputs": [],
   "source": [
    "model = get_mnist_trained_model()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "601689df",
   "metadata": {},
   "outputs": [],
   "source": [
    "xi = 400\n",
    "\n",
    "def _mnist_p_values(predictions, probabilities):\n",
    "    num_classes = 10\n",
    "    left_row_fn, right_row_fn = make_classification_kappa_accessors(predictions, probabilities, num_classes)\n",
    "    n = len(predictions)\n",
    "    return pvalues_from_row_functions(n, left_row_fn, right_row_fn, n_jobs=-1, verbose=1)\n",
    "\n",
    "\n",
    "def run_mnist_simulation(n, xi):\n",
    "    x = generate_mnist_dataset(n, xi)\n",
    "    predictions, probabilities = compute_sequential_scores(x, model)\n",
    "    return _mnist_p_values(predictions, probabilities)\n",
    "\n",
    "p_values = run_mnist_simulation(n, xi)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e2147a73",
   "metadata": {},
   "outputs": [],
   "source": [
    "plot_and_report(p_values, xi, \"p-values for MNIST digit change (digit classifier)\", \"images/mnist-pvalues.pdf\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f6862e3c",
   "metadata": {},
   "source": [
    "# SST-2 sentiment change (LLM)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "50e2545b",
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_pretrained_sentiment_model(device=\"cpu\"):\n",
    "    model_name = \"distilbert-base-uncased-finetuned-sst-2-english\"\n",
    "    tokenizer = AutoTokenizer.from_pretrained(model_name)\n",
    "    model = AutoModelForSequenceClassification.from_pretrained(model_name)\n",
    "\n",
    "    model.to(device)\n",
    "    model.eval()\n",
    "    return model, tokenizer\n",
    "\n",
    "\n",
    "def generate_sentiment_dataset(length, changepoint, dataset_name=\"sst2\"):\n",
    "    dataset = load_dataset(dataset_name)\n",
    "\n",
    "    train_data = dataset[\"train\"]\n",
    "    positive_texts = [item[\"sentence\"] for item in train_data if item[\"label\"] == 1]\n",
    "    negative_texts = [item[\"sentence\"] for item in train_data if item[\"label\"] == 0]\n",
    "\n",
    "    random.shuffle(positive_texts)\n",
    "    random.shuffle(negative_texts)\n",
    "\n",
    "    n1 = changepoint + 1\n",
    "    n2 = length - n1\n",
    "\n",
    "    if n1 > len(positive_texts) or n2 > len(negative_texts):\n",
    "        raise ValueError(\"Insufficient texts for the specified length and changepoint.\")\n",
    "\n",
    "    texts_before = positive_texts[:n1]\n",
    "    texts_after = negative_texts[:n2]\n",
    "\n",
    "    texts = texts_before + texts_after\n",
    "    true_labels = [1] * n1 + [0] * n2\n",
    "\n",
    "    return texts, true_labels\n",
    "\n",
    "\n",
    "def generate_mixed_sentiment_dataset(length, changepoint, dataset_name=\"sst2\"):\n",
    "    dataset = load_dataset(dataset_name)\n",
    "\n",
    "    train_data = dataset[\"train\"]\n",
    "    positive_texts = [item[\"sentence\"] for item in train_data if item[\"label\"] == 1]\n",
    "    negative_texts = [item[\"sentence\"] for item in train_data if item[\"label\"] == 0]\n",
    "\n",
    "    random.shuffle(positive_texts)\n",
    "    random.shuffle(negative_texts)\n",
    "\n",
    "    n_pre = changepoint + 1\n",
    "    n_post = length - n_pre\n",
    "\n",
    "    n_pos_pre = int(n_pre * 0.6)\n",
    "    n_neg_pre = n_pre - n_pos_pre\n",
    "\n",
    "    n_pos_post = int(n_post * 0.4)\n",
    "    n_neg_post = n_post - n_pos_post\n",
    "\n",
    "    if n_pos_pre + n_pos_post > len(positive_texts) or n_neg_pre + n_neg_post > len(\n",
    "        negative_texts\n",
    "    ):\n",
    "        raise ValueError(\n",
    "            \"Insufficient texts for the specified distribution and length.\"\n",
    "        )\n",
    "\n",
    "    pre_pos_texts = positive_texts[:n_pos_pre]\n",
    "    pre_neg_texts = negative_texts[:n_neg_pre]\n",
    "    pre_texts = pre_pos_texts + pre_neg_texts\n",
    "    pre_labels = [1] * n_pos_pre + [0] * n_neg_pre\n",
    "\n",
    "    pre_combined = list(zip(pre_texts, pre_labels))\n",
    "    random.shuffle(pre_combined)\n",
    "    pre_texts, pre_labels = zip(*pre_combined)\n",
    "\n",
    "    post_pos_texts = positive_texts[n_pos_pre : n_pos_pre + n_pos_post]\n",
    "    post_neg_texts = negative_texts[n_neg_pre : n_neg_pre + n_neg_post]\n",
    "    post_texts = post_pos_texts + post_neg_texts\n",
    "    post_labels = [1] * n_pos_post + [0] * n_neg_post\n",
    "\n",
    "    post_combined = list(zip(post_texts, post_labels))\n",
    "    random.shuffle(post_combined)\n",
    "    post_texts, post_labels = zip(*post_combined)\n",
    "\n",
    "    texts = list(pre_texts) + list(post_texts)\n",
    "    true_labels = list(pre_labels) + list(post_labels)\n",
    "\n",
    "    return texts, true_labels\n",
    "\n",
    "\n",
    "def predict_sentiment(model, tokenizer, text, device=\"cpu\"):\n",
    "    inputs = tokenizer(\n",
    "        text, return_tensors=\"pt\", padding=True, truncation=True, max_length=512\n",
    "    ).to(device)\n",
    "\n",
    "    with torch.no_grad():\n",
    "        outputs = model(**inputs)\n",
    "        probs = torch.softmax(outputs.logits, dim=1).cpu()\n",
    "        predicted = probs.argmax(dim=1).item()\n",
    "\n",
    "    return (predicted, probs.squeeze())"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "cc830d8c",
   "metadata": {},
   "source": [
    "## Plot examples"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "43494506",
   "metadata": {},
   "outputs": [],
   "source": [
    "def compute_sequential_sentiment_scores(texts, device=\"cpu\", batch_size=16, max_length=128):\n",
    "    model_name = \"distilbert-base-uncased-finetuned-sst-2-english\"\n",
    "    tokenizer = AutoTokenizer.from_pretrained(model_name)\n",
    "    model = AutoModelForSequenceClassification.from_pretrained(model_name)\n",
    "    model.to(device)\n",
    "    model.eval()\n",
    "\n",
    "    n = len(texts)\n",
    "    preds = []\n",
    "    probs_list = []\n",
    "    with torch.no_grad():\n",
    "        for start in tqdm(range(0, n, batch_size)):\n",
    "            end = min(start + batch_size, n)\n",
    "            batch_texts = texts[start:end]\n",
    "            inputs = tokenizer(batch_texts, return_tensors=\"pt\", padding=True, truncation=True, max_length=max_length).to(device)\n",
    "            outputs = model(**inputs)\n",
    "            batch_probs = torch.softmax(outputs.logits, dim=1).detach().cpu()\n",
    "            probs_list.append(batch_probs)\n",
    "            preds.extend(batch_probs.argmax(dim=1).tolist())\n",
    "    probabilities = torch.cat(probs_list, dim=0)\n",
    "    return preds, probabilities"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1ed4e30b",
   "metadata": {},
   "source": [
    "## Full sentiment change (pos. to neg.)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cd6b8a7a",
   "metadata": {},
   "outputs": [],
   "source": [
    "def _sentiment_p_values(predictions, probabilities):\n",
    "    num_classes = 2\n",
    "    left_row_fn, right_row_fn = make_classification_kappa_accessors(predictions, probabilities, num_classes)\n",
    "    n_local = len(predictions)\n",
    "    return pvalues_from_row_functions(n_local, left_row_fn, right_row_fn, n_jobs=-1, verbose=1)\n",
    "\n",
    "\n",
    "def run_sentiment_simulation(n, xi, device=\"cpu\"):\n",
    "    texts, true_labels = generate_sentiment_dataset(n, xi)\n",
    "    predictions, probabilities = compute_sequential_sentiment_scores(texts, device=device, batch_size=16, max_length=128)\n",
    "    return _sentiment_p_values(predictions, probabilities)\n",
    "\n",
    "p_values = run_sentiment_simulation(n, xi)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e6eefefc",
   "metadata": {},
   "outputs": [],
   "source": [
    "plot_and_report(p_values, xi, \"p-values for SST-2 sentiment change\", \"images/sentiment-pvalues.pdf\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7790001f",
   "metadata": {},
   "source": [
    "## Mixed sentiment change (60% pos. to 60% neg.)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4c276ef4",
   "metadata": {},
   "outputs": [],
   "source": [
    "def run_mixed_sentiment_simulation(n, xi, device=\"cpu\"):\n",
    "    texts, true_labels = generate_mixed_sentiment_dataset(n, xi)\n",
    "    predictions, probabilities = compute_sequential_sentiment_scores(texts, device=device, batch_size=16, max_length=128)\n",
    "    return _sentiment_p_values(predictions, probabilities)\n",
    "\n",
    "p_values = run_mixed_sentiment_simulation(n, xi)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2c52d074",
   "metadata": {},
   "outputs": [],
   "source": [
    "plot_and_report(p_values, xi, \"p-values for SST-2 mixed sentiment change\", \"images/sentiment-pvalues-mixed.pdf\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ba53a9f9",
   "metadata": {},
   "source": [
    "# Human Activity Dataset (HAD) change"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e0e7f999",
   "metadata": {},
   "outputs": [],
   "source": [
    "def load_had_dataset():\n",
    "    import urllib.request\n",
    "    import zipfile\n",
    "    import os\n",
    "    \n",
    "    url = \"https://archive.ics.uci.edu/ml/machine-learning-databases/00240/UCI%20HAR%20Dataset.zip\"\n",
    "    \n",
    "    os.makedirs(\"./data\", exist_ok=True)\n",
    "    \n",
    "    if not os.path.exists(\"./data/UCI_HAR_Dataset.zip\"):\n",
    "        print(\"Downloading Human Activity Dataset...\")\n",
    "        urllib.request.urlretrieve(url, \"./data/UCI_HAR_Dataset.zip\")\n",
    "    \n",
    "    if not os.path.exists(\"./data/UCI HAR Dataset\"):\n",
    "        print(\"Extracting dataset...\")\n",
    "        with zipfile.ZipFile(\"./data/UCI_HAR_Dataset.zip\", 'r') as zip_ref:\n",
    "            zip_ref.extractall(\"./data\")\n",
    "    \n",
    "    train_file = \"./data/UCI HAR Dataset/train/X_train.txt\"\n",
    "    train_labels_file = \"./data/UCI HAR Dataset/train/y_train.txt\"\n",
    "    \n",
    "    if os.path.exists(train_file):\n",
    "        X_train = np.loadtxt(train_file)\n",
    "        y_train = np.loadtxt(train_labels_file)\n",
    "        \n",
    "        walking_indices = np.where(y_train == 1)[0]\n",
    "        sitting_indices = np.where(y_train == 4)[0]\n",
    "        \n",
    "        n_walking = min(400, len(walking_indices))\n",
    "        n_sitting = min(600, len(sitting_indices))\n",
    "        \n",
    "        if n_walking > 0 and n_sitting > 0:\n",
    "            selected_walking = walking_indices[100:100+n_walking]\n",
    "            selected_sitting = sitting_indices[100:100+n_sitting]\n",
    "            combined_indices = np.concatenate([selected_walking, selected_sitting])\n",
    "        \n",
    "            timeseries = X_train[combined_indices, 0]\n",
    "            true_changepoint = n_walking\n",
    "            \n",
    "            return timeseries, true_changepoint\n",
    "\n",
    "\n",
    "print(\"Loading Human Activity Dataset...\")\n",
    "timeseries, true_changepoint = load_had_dataset()\n",
    "\n",
    "n_had = len(timeseries)\n",
    "xi_had = true_changepoint\n",
    "\n",
    "print(f\"HAD Dataset loaded:\")\n",
    "print(f\"Total length: {n_had} samples\")\n",
    "print(f\"True changepoint: {xi_had} (activity transition)\")\n",
    "\n",
    "time_indices = np.arange(1, n_had + 1)\n",
    "plt.plot(time_indices, timeseries)\n",
    "plt.axvline(x=xi_had + 1, color=\"red\", linestyle=\"--\", label=f\"Changepoint ($\\\\xi = {xi_had}$)\")\n",
    "plt.xlabel(\"$t$\")\n",
    "plt.ylabel(\"Accelerometer reading\")\n",
    "plt.title(\"Human Activity Recognition accelerometer data\")\n",
    "plt.legend()\n",
    "plt.savefig(\"images/had-examples.pdf\")\n",
    "plt.show()\n",
    "\n",
    "pre_change = timeseries[:xi_had]\n",
    "post_change = timeseries[xi_had:]\n",
    "\n",
    "print(f\"\\nDataset Statistics:\")\n",
    "print(f\"Pre-change activity mean: {np.mean(pre_change):.3f}, std: {np.std(pre_change):.3f}\")\n",
    "print(f\"Post-change activity mean: {np.mean(post_change):.3f}, std: {np.std(post_change):.3f}\")\n",
    "print(f\"Mean difference: {np.mean(pre_change) - np.mean(post_change):.3f}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a6f9002e",
   "metadata": {},
   "outputs": [],
   "source": [
    "def run_had_simulation_kde():\n",
    "    left_row_fn, right_row_fn = make_kde_row_functions(timeseries)\n",
    "    return pvalues_from_row_functions(len(timeseries), left_row_fn, right_row_fn, n_jobs=-1, verbose=1)\n",
    "\n",
    "p_values = run_had_simulation_kde()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "01c9cb20",
   "metadata": {},
   "outputs": [],
   "source": [
    "time_indices_p = np.arange(1, n_had)\n",
    "plot_and_report(p_values, xi_had + 1, \"p-values for HAR activity change\", \"images/had-pvalues.pdf\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "94e5fd56",
   "metadata": {},
   "source": [
    "# CIFAR-100 class change"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "30ee247d",
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_cifar100_pretrained_model(device=\"cpu\"):\n",
    "    import detectors\n",
    "    import timm\n",
    "    print(\"Loading pretrained ResNet18 model from timm and adapting for CIFAR-100...\")\n",
    "    \n",
    "    model = timm.create_model(\"resnet18_cifar100\", pretrained=True, num_classes=100)\n",
    "    model.to(device)\n",
    "    model.eval()\n",
    "    return model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6e2a2a7f",
   "metadata": {},
   "outputs": [],
   "source": [
    "def predict_cifar100_class(model, image, device=\"cpu\"):\n",
    "    if len(image.shape) == 3:\n",
    "        image = image.unsqueeze(0)\n",
    "    elif len(image.shape) == 4 and image.shape[0] != 1:\n",
    "        image = image[:1]\n",
    "    \n",
    "    image_tensor = image.to(device)\n",
    "    \n",
    "    with torch.no_grad():\n",
    "        outputs = torch.softmax(model(image_tensor), dim=1).cpu()\n",
    "        predicted = outputs.argmax(dim=1).item()\n",
    "    \n",
    "    return (predicted, outputs.squeeze())\n",
    "\n",
    "\n",
    "def generate_cifar100_dataset(length, changepoint, class1=15, class2=47):\n",
    "    from torchvision.datasets import CIFAR100\n",
    "    \n",
    "    transform = transforms.Compose([\n",
    "        transforms.ToTensor(),\n",
    "        transforms.Normalize(mean=[0.5071, 0.4867, 0.4408], std=[0.2675, 0.2565, 0.2761])\n",
    "    ])\n",
    "    \n",
    "    train_data = CIFAR100(root=\"./data\", train=True, download=True, transform=transform)\n",
    "    test_data = CIFAR100(root=\"./data\", train=False, download=True, transform=transform)\n",
    "    \n",
    "    all_images = []\n",
    "    all_labels = []\n",
    "    \n",
    "    for i in range(len(train_data)):\n",
    "        image, label = train_data[i]\n",
    "        all_images.append(image)\n",
    "        all_labels.append(label)\n",
    "    \n",
    "    for i in range(len(test_data)):\n",
    "        image, label = test_data[i]\n",
    "        all_images.append(image)\n",
    "        all_labels.append(label)\n",
    "    \n",
    "    class1_indices = [i for i, label in enumerate(all_labels) if label == class1]\n",
    "    class2_indices = [i for i, label in enumerate(all_labels) if label == class2]\n",
    "    \n",
    "    np.random.shuffle(class1_indices)\n",
    "    np.random.shuffle(class2_indices)\n",
    "    \n",
    "    n1 = changepoint + 1\n",
    "    n2 = length - n1\n",
    "    \n",
    "    selected_indices = class1_indices[:n1] + class2_indices[:n2]\n",
    "    selected_images = [all_images[i] for i in selected_indices]\n",
    "    \n",
    "    return torch.stack(selected_images)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "97fb7d06",
   "metadata": {},
   "outputs": [],
   "source": [
    "def compute_sequential_cifar100_scores(x, model, device=\"cpu\", batch_size=64):\n",
    "    model.to(device)\n",
    "    model.eval()\n",
    "    n = len(x)\n",
    "    preds = []\n",
    "    probs_list = []\n",
    "    with torch.no_grad():\n",
    "        for start in tqdm(range(0, n, batch_size)):\n",
    "            end = min(start + batch_size, n)\n",
    "            batch = x[start:end].to(device)\n",
    "            logits = model(batch)\n",
    "            batch_probs = torch.softmax(logits, dim=1).detach().cpu()\n",
    "            probs_list.append(batch_probs)\n",
    "            preds.extend(batch_probs.argmax(dim=1).tolist())\n",
    "    probabilities = torch.cat(probs_list, dim=0)\n",
    "    return preds, probabilities"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8451c24e",
   "metadata": {},
   "outputs": [],
   "source": [
    "from torchvision.datasets import CIFAR100\n",
    "\n",
    "cifar100_dataset = CIFAR100(root=\"./data\", train=False, download=True)\n",
    "class_names = cifar100_dataset.classes\n",
    "\n",
    "print(\"CIFAR-100 class names:\")\n",
    "print(f\"Class 3: {class_names[3]}\")\n",
    "print(f\"Class 4: {class_names[4]}\")\n",
    "\n",
    "transform = transforms.Compose(\n",
    "    [\n",
    "        transforms.ToTensor(),\n",
    "        transforms.Normalize(\n",
    "            mean=[0.5071, 0.4867, 0.4408], std=[0.2675, 0.2565, 0.2761]\n",
    "        ),\n",
    "    ]\n",
    ")\n",
    "\n",
    "train_data = CIFAR100(root=\"./data\", train=True, download=True, transform=transform)\n",
    "test_data = CIFAR100(root=\"./data\", train=False, download=True, transform=transform)\n",
    "\n",
    "class3_count_train = sum(1 for _, label in train_data if label == 3)\n",
    "class4_count_train = sum(1 for _, label in train_data if label == 4)\n",
    "class3_count_test = sum(1 for _, label in test_data if label == 3)\n",
    "class4_count_test = sum(1 for _, label in test_data if label == 4)\n",
    "\n",
    "print(\n",
    "    f\"\\nClass 3 ({class_names[3]}) - Train: {class3_count_train}, Test: {class4_count_test}, Total: {class3_count_train + class3_count_test}\"\n",
    ")\n",
    "print(\n",
    "    f\"Class 4 ({class_names[4]}) - Train: {class4_count_train}, Test: {class4_count_test}, Total: {class4_count_train + class4_count_test}\"\n",
    ")\n",
    "\n",
    "\n",
    "def generate_cifar100_dataset_fixed(length, changepoint, class1=3, class2=4):\n",
    "    from torchvision.datasets import CIFAR100\n",
    "\n",
    "    transform = transforms.Compose(\n",
    "        [\n",
    "            transforms.ToTensor(),\n",
    "            transforms.Normalize(\n",
    "                mean=[0.5071, 0.4867, 0.4408], std=[0.2675, 0.2565, 0.2761]\n",
    "            ),\n",
    "        ]\n",
    "    )\n",
    "\n",
    "    train_data = CIFAR100(root=\"./data\", train=True, download=True, transform=transform)\n",
    "    test_data = CIFAR100(root=\"./data\", train=False, download=True, transform=transform)\n",
    "\n",
    "    class1_images = []\n",
    "    class1_labels = []\n",
    "\n",
    "    for _, (image, label) in enumerate(train_data):\n",
    "        if label == class1:\n",
    "            class1_images.append(image)\n",
    "            class1_labels.append(label)\n",
    "\n",
    "    for _, (image, label) in enumerate(test_data):\n",
    "        if label == class1:\n",
    "            class1_images.append(image)\n",
    "            class1_labels.append(label)\n",
    "\n",
    "    class2_images = []\n",
    "    class2_labels = []\n",
    "\n",
    "    for _, (image, label) in enumerate(train_data):\n",
    "        if label == class2:\n",
    "            class2_images.append(image)\n",
    "            class2_labels.append(label)\n",
    "\n",
    "    for _, (image, label) in enumerate(test_data):\n",
    "        if label == class2:\n",
    "            class2_images.append(image)\n",
    "            class2_labels.append(label)\n",
    "\n",
    "    np.random.shuffle(class1_images)\n",
    "    np.random.shuffle(class2_images)\n",
    "\n",
    "    n1 = changepoint + 1\n",
    "    n2 = length - n1\n",
    "\n",
    "    selected_class1 = class1_images[:n1]\n",
    "    selected_class2 = class2_images[:n2]\n",
    "\n",
    "    all_images = selected_class1 + selected_class2\n",
    "    all_labels = [class1] * n1 + [class2] * n2\n",
    "\n",
    "    print(f\"Generated dataset: {len(all_images)} images\")\n",
    "    print(f\"First {n1} images are class {class1} ({class_names[class1]})\")\n",
    "    print(f\"Last {n2} images are class {class2} ({class_names[class2]})\")\n",
    "\n",
    "    actual_labels_start = all_labels[:5]\n",
    "    actual_labels_end = all_labels[-5:]\n",
    "    print(f\"First 5 labels: {actual_labels_start}\")\n",
    "    print(f\"Last 5 labels: {actual_labels_end}\")\n",
    "\n",
    "    return torch.stack(all_images), all_labels\n",
    "\n",
    "\n",
    "x_fixed, labels_fixed = generate_cifar100_dataset_fixed(n, xi, class1=3, class2=4)\n",
    "print(f\"\\nSuccessfully generated dataset with shape: {x_fixed.shape}\")\n",
    "\n",
    "def show_cifar100_examples_with_labels(x, labels, changepoint, n_examples=3):\n",
    "    mean = torch.tensor([0.5071, 0.4867, 0.4408])\n",
    "    std = torch.tensor([0.2675, 0.2565, 0.2761])\n",
    "\n",
    "    def denormalize_image(tensor):\n",
    "        denorm = tensor * std[:, None, None] + mean[:, None, None]\n",
    "        return torch.clamp(denorm, 0, 1)\n",
    "\n",
    "    fig, axes = plt.subplots(1, 5, figsize=(15, 3))\n",
    "\n",
    "    class_names = CIFAR100(root=\"./data\", train=False, download=True).classes\n",
    "\n",
    "    time_points = [398, 399, 400]\n",
    "    titles_before = [r\"$t = 398$\", r\"$t = 399$\", r\"$t = \\xi = 400$\"]\n",
    "\n",
    "    for i, (t, title) in enumerate(zip(time_points, titles_before)):\n",
    "        idx = t - 1\n",
    "        img_denorm = denormalize_image(x[idx])\n",
    "        actual_class = labels[idx]\n",
    "        class_name = class_names[actual_class]\n",
    "\n",
    "        axes[i].imshow(img_denorm.permute(1, 2, 0).numpy())\n",
    "        axes[i].set_title(f\"{title}\\n{class_name}\")\n",
    "        axes[i].axis(\"off\")\n",
    "\n",
    "    post_change_points = [401, 402]\n",
    "    titles_after = [r\"$t = 401$\", r\"$t = 402$\"]\n",
    "\n",
    "    for i, (t, title) in enumerate(zip(post_change_points, titles_after)):\n",
    "        idx = t - 1\n",
    "        img_denorm = denormalize_image(x[idx])\n",
    "        actual_class = labels[idx]\n",
    "        class_name = class_names[actual_class]\n",
    "\n",
    "        axes[i + 3].imshow(img_denorm.permute(1, 2, 0).numpy())\n",
    "        axes[i + 3].set_title(f\"{title}\\n{class_name}\")\n",
    "        axes[i + 3].axis(\"off\")\n",
    "\n",
    "    plt.tight_layout()\n",
    "    plt.savefig(\"images/cifar100-examples.pdf\")\n",
    "    plt.show()\n",
    "\n",
    "show_cifar100_examples_with_labels(x_fixed, labels_fixed, xi)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3585ef6c",
   "metadata": {},
   "outputs": [],
   "source": [
    "device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
    "print(f\"Using device: {device}\")\n",
    "\n",
    "cifar100_model = get_cifar100_pretrained_model(device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1c34984f",
   "metadata": {},
   "outputs": [],
   "source": [
    "n_cifar = 800\n",
    "xi_cifar = 300\n",
    "\n",
    "print(f\"CIFAR-100 simulation parameters:\")\n",
    "print(f\"Total length: {n_cifar}\")\n",
    "print(f\"Changepoint: {xi_cifar}\")\n",
    "print(f\"Pre-change images needed: {xi_cifar + 1}\")\n",
    "print(f\"Post-change images needed: {n_cifar - xi_cifar - 1}\")\n",
    "print(f\"Available per class: ~600 (500 train + 100 test)\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cbf17f06",
   "metadata": {},
   "outputs": [],
   "source": [
    "def _cifar100_p_values(predictions, probabilities):\n",
    "    num_classes = 100\n",
    "    left_row_fn, right_row_fn = make_classification_kappa_accessors(predictions, probabilities, num_classes)\n",
    "    n_local = len(predictions)\n",
    "    return pvalues_from_row_functions(n_local, left_row_fn, right_row_fn, n_jobs=-1, verbose=1)\n",
    "\n",
    "\n",
    "def run_cifar100_simulation(n, xi, class1=15, class2=47):\n",
    "    x = generate_cifar100_dataset(n, xi, class1, class2)\n",
    "    predictions, probabilities = compute_sequential_cifar100_scores(x, cifar100_model, device)\n",
    "    return _cifar100_p_values(predictions, probabilities)\n",
    "\n",
    "p_values = run_cifar100_simulation(n_cifar, xi_cifar, class1=3, class2=4)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "05b4cac3",
   "metadata": {},
   "outputs": [],
   "source": [
    "plot_and_report(p_values, xi_cifar, \"p-values for CIFAR-100 class change (pretrained model)\", \"images/cifar100-pvalues.pdf\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2b15807f",
   "metadata": {},
   "source": [
    "# Multiple changepoints: kernel changepoint detection (KCPD)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4f40852d",
   "metadata": {},
   "outputs": [],
   "source": [
    "from typing import Literal, Tuple, List\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "import random\n",
    "from scipy.stats import norm, gaussian_kde, ks_1samp, uniform\n",
    "\n",
    "SEED = 400\n",
    "random.seed(SEED)\n",
    "np.random.seed(SEED)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8158919e",
   "metadata": {},
   "outputs": [],
   "source": [
    "def kcpd(\n",
    "    x: np.ndarray,\n",
    "    sigma: float | None = None,\n",
    "    C: float = 0.06,\n",
    "    beta: float | None = None,\n",
    ") -> Tuple[List[int], float]:\n",
    "    x = np.asarray(x, dtype=float).ravel()\n",
    "    T = x.size\n",
    "    if T == 0:\n",
    "        return [], 0.0\n",
    "\n",
    "    diff2 = (x[:, None] - x[None, :]) ** 2\n",
    "    if sigma is None:\n",
    "        iu = np.triu_indices(T, 1)\n",
    "        med = np.median(diff2[iu]) if iu[0].size else 1.0\n",
    "        sigma = np.sqrt(max(med, 1e-12))\n",
    "    K = np.exp(-diff2 / (2.0 * sigma * sigma))\n",
    "    \n",
    "    S = np.zeros((T + 1, T + 1), float)\n",
    "    S[1:, 1:] = np.cumsum(np.cumsum(K, axis=0), axis=1)\n",
    "    D = np.zeros(T + 1, float)\n",
    "    D[1:] = np.cumsum(np.diag(K))\n",
    "\n",
    "    if beta is None:\n",
    "        beta = C * np.sqrt(T * max(np.log(max(T, 2)), 1.0))\n",
    "\n",
    "    DP = np.full(T + 1, np.inf)\n",
    "    prev = np.full(T + 1, -1, int)\n",
    "    DP[0] = -beta\n",
    "\n",
    "    for j in range(1, T + 1):\n",
    "        best_val, best_t = np.inf, -1\n",
    "        for t in range(0, j):\n",
    "            i = t + 1\n",
    "            nseg = j - t\n",
    "            diag_sum = D[j] - D[i - 1]\n",
    "            block = S[j, j] - S[i - 1, j] - S[j, i - 1] + S[i - 1, i - 1]\n",
    "            cost = diag_sum - block / nseg\n",
    "            val = DP[t] + cost + beta\n",
    "            if val < best_val:\n",
    "                best_val, best_t = val, t\n",
    "        DP[j], prev[j] = best_val, best_t\n",
    "\n",
    "    cps, cur = [], T\n",
    "    while cur > 0:\n",
    "        t = prev[cur]\n",
    "        if t < 0:\n",
    "            break\n",
    "        if t > 0:\n",
    "            cps.append(t)\n",
    "        cur = t\n",
    "    cps.sort()\n",
    "    return cps, float(DP[T])\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "88661357",
   "metadata": {},
   "outputs": [],
   "source": [
    "def generate_meanshift_1d(n: int, cps: list[int], means: list[float], sigma: float = 1.0, seed: int | None = None):\n",
    "    rng = np.random.default_rng(seed)\n",
    "    cps = sorted(int(t) for t in cps if 1 <= t < n)\n",
    "    assert len(means) == len(cps) + 1, \"len(means) must be #segments = len(cps)+1\"\n",
    "    seg_lengths = np.diff([0] + cps + [n])\n",
    "    y = np.concatenate([rng.normal(m, sigma, L) for m, L in zip(means, seg_lengths)])\n",
    "    return y, cps\n",
    "\n",
    "n = 1500\n",
    "true_cps = [150, 500, 820, 1100]\n",
    "means = [-1.0, 0.5, 1.5, -2.0,-1.0]\n",
    "X, cps_true = generate_meanshift_1d(n, true_cps, means, sigma=1.0, seed=400)\n",
    "\n",
    "t = np.arange(1, n + 1)\n",
    "plt.figure(figsize=(8, 4))\n",
    "plt.plot(t, X, linewidth=1)\n",
    "plt.axvline(cps_true[0], linestyle=\"--\", color=\"red\", label=\"Changepoints\")\n",
    "for cp in cps_true:\n",
    "    plt.axvline(cp, linestyle=\"--\", color=\"red\")\n",
    "plt.title(f\"Multiple Gaussian mean changes\")\n",
    "plt.xlabel(\"$t$\")\n",
    "plt.ylabel(\"value\")\n",
    "plt.tight_layout()\n",
    "plt.legend()\n",
    "plt.savefig(\"images/kcpd-example.pdf\")\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9d68e148",
   "metadata": {},
   "outputs": [],
   "source": [
    "X = X.reshape(-1, 1)\n",
    "\n",
    "cps_hat, obj = kcpd(X, C=0.02)  \n",
    "print(\"Estimated CPs:\", cps_hat, \"| objective:\", obj)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d26fa9ef",
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "\n",
    "plt.figure(figsize=(10, 3.6))\n",
    "plt.plot(t, X, lw=1.0, label=\"series\")\n",
    "for j, cp in enumerate(cps_true):\n",
    "    plt.axvline(cp, ls=\"--\", alpha=0.9, color=\"tab:green\", label=\"true CP\" if j == 0 else None)\n",
    "for j, cp in enumerate(sorted(cps_hat)):\n",
    "    plt.axvline(cp, ls=\"--\", alpha=0.9, color=\"tab:red\",   label=\"est CP\"  if j == 0 else None)\n",
    "plt.title(\"True vs. estimated changepoints\")\n",
    "plt.xlabel(\"time index\"); plt.ylabel(\"value\")\n",
    "plt.legend(ncol=3)\n",
    "plt.tight_layout()\n",
    "plt.show()\n",
    "\n",
    "cps = sorted([c for c in cps_hat if 1 <= c < len(X)])\n",
    "segments = list(zip([1] + [c+1 for c in cps], cps + [len(X)]))\n",
    "print(\"Estimated segments (1-based):\", segments)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a79e309a",
   "metadata": {},
   "outputs": [],
   "source": [
    "def compute_p_value_for_t(t, n, timeseries, left_scores, right_scores):\n",
    "    p_values_curr = np.zeros(n)\n",
    "    \n",
    "    for r in range(t+1):\n",
    "        left_segment_scores = left_scores[:r+1]\n",
    "        rank = np.sum(left_scores[r] < left_segment_scores) + np.random.uniform(0, 1) * np.sum(left_scores[r] == left_segment_scores)\n",
    "        p_values_curr[r] = rank / (r + 1)\n",
    "    \n",
    "    for r in range(n-1, t, -1):\n",
    "        right_segment_scores = right_scores[r:]\n",
    "        rank = np.sum(right_scores[r] < right_segment_scores) + np.random.uniform(0, 1) * np.sum(right_scores[r] == right_segment_scores)\n",
    "        p_values_curr[r] = rank / (n - r + 1)\n",
    "    \n",
    "    p_left = ks_1samp(p_values_curr[:t+1], uniform.cdf, method=\"exact\")[1]\n",
    "    p_right = ks_1samp(p_values_curr[t+1:], uniform.cdf, method=\"exact\")[1]\n",
    "\n",
    "    D_left = ks_1samp(p_values_curr[:t+1], uniform.cdf, method=\"exact\").statistic\n",
    "    D_right = ks_1samp(p_values_curr[t+1:], uniform.cdf, method=\"exact\").statistic\n",
    "\n",
    "    \n",
    "    return 1 - (1 - min(p_left, p_right)) ** 2\n",
    "\n",
    "def run_single_simulation(data, mu0, mu1):\n",
    "    f_0 = norm(mu0, 1)\n",
    "    f_1 = norm(mu1, 1)\n",
    "\n",
    "    n=len(data)\n",
    "\n",
    "    left_scores = f_1.pdf(data) / f_0.pdf(data)\n",
    "    right_scores = f_0.pdf(data) / f_1.pdf(data)\n",
    "\n",
    "    p_values = Parallel(n_jobs=-1, verbose=1)(\n",
    "        delayed(compute_p_value_for_t)(t, n, data, left_scores, right_scores) \n",
    "        for t in range(n-1)\n",
    "    )\n",
    "    \n",
    "    return np.array(p_values)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d5c409dd",
   "metadata": {},
   "outputs": [],
   "source": [
    "y = np.asarray(X, float).ravel()\n",
    "n = y.size\n",
    "cps = sorted(int(t) for t in cps_hat if 1 <= t < n)\n",
    "\n",
    "mids = [0]\n",
    "if len(cps) >= 2:\n",
    "    mids += [cps[i] + (cps[i+1] - cps[i]) // 2 for i in range(len(cps) - 1)]\n",
    "elif len(cps) == 1:\n",
    "    mids += [cps[0] // 2]\n",
    "mids += [n]\n",
    "mids = np.clip(np.unique(mids), 0, n).tolist()\n",
    "\n",
    "wins = [(a, b) for a, b in zip(mids[:-1], mids[1:]) if b - a >= 2]\n",
    "\n",
    "p_concat = np.full(n - 1, np.nan, float)\n",
    "ci_union = set()\n",
    "\n",
    "for k, (a, b) in enumerate(wins):\n",
    "    mu0 = means[k]         \n",
    "    mu1 = means[k + 1]\n",
    "    pvals = run_single_simulation(y[a:b], mu0, mu1)\n",
    "    ci_local = [t for t in range(1, b - a) if pvals[t - 1] >= 0.05]\n",
    "    ci_union.update(a + int(t) for t in ci_local)                \n",
    "    p_concat[a : b - 1] = pvals\n",
    "\n",
    "ci_global_sorted_MCP = sorted(ci_union)\n",
    "\n",
    "print(\"Windows [a,b):\", wins)\n",
    "print(\"Global CI (concatenated):\", ci_global_sorted_MCP)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f85615e1",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.figure(figsize=(10, 4))\n",
    "time_indices_p = np.arange(1, len(p_concat) + 1)\n",
    "plt.plot(time_indices_p, p_concat)\n",
    "\n",
    "plt.axvline(x=cps_true[0], color=\"red\", linestyle=\"--\", label=\"True changepoints\")\n",
    "for xi in cps_true[1:]:\n",
    "    plt.axvline(x=xi, color=\"red\", linestyle=\"--\")\n",
    "\n",
    "plt.axvline(x=cps[0], color=\"orange\", linestyle=\"--\", label=\"Estimated changepoints\")\n",
    "for xi in cps[1:]:\n",
    "    plt.axvline(x=xi, color=\"orange\", linestyle=\"--\")\n",
    "\n",
    "plt.axhline(0.05, color=\"green\", linestyle=\":\", label=\"Threshold ($\\\\alpha = 0.05$)\")\n",
    "plt.xlabel(\"$t$\")\n",
    "plt.title(\"p-values for multiple Gaussian mean changes (KCPD)\")\n",
    "plt.legend()\n",
    "plt.savefig(\"images/kcpd-pvalues.pdf\")\n",
    "plt.show()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "deep-learning",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
