{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "ca571c58",
   "metadata": {},
   "source": [
    "# Getting the synth_data "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1895d85b",
   "metadata": {},
   "outputs": [],
   "source": [
    "import json\n",
    "from transformers import DistilBertTokenizer, DistilBertModel\n",
    "import torch\n",
    "from typing import List\n",
    "from tqdm import tqdm\n",
    "import numpy as np\n",
    "import random\n",
    "from scipy.stats import ks_2samp, wasserstein_distance, cramervonmises_2samp\n",
    "from collections import defaultdict\n",
    "import re\n",
    "from typing import List"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c46e9e96",
   "metadata": {},
   "outputs": [],
   "source": [
    "_PREFIX_RE = re.compile(r'^[1-5]\\.\\s+')\n",
    "\n",
    "def remove_number_prefixes(string: List[str]) -> List[str]:\n",
    "    \"\"\"\n",
    "    Given a list of strings, returns a new list where any leading\n",
    "    '1. ', '2. ', ..., '5. ' has been removed from each string.\n",
    "    \"\"\"\n",
    "    return _PREFIX_RE.sub('', string)\n",
    "\n",
    "def get_distilbert_embeddings(\n",
    "    texts: List[str],\n",
    "    batch_size: int = 32,\n",
    "    model_name: str = 'distilbert-base-uncased',\n",
    "    device: str = 'cuda'\n",
    ") -> List[List[float]]:\n",
    "    # Setup\n",
    "    device = device or ('cuda' if torch.cuda.is_available() else 'cpu')\n",
    "    tokenizer = DistilBertTokenizer.from_pretrained(model_name)\n",
    "    model = DistilBertModel.from_pretrained(model_name).to(device)\n",
    "    model.eval()\n",
    "\n",
    "    all_embeddings = []\n",
    "\n",
    "    for i in tqdm(range(0, len(texts), batch_size), desc=\"Encoding\"):\n",
    "        batch_texts = texts[i:i+batch_size]\n",
    "        inputs = tokenizer(batch_texts, padding=True, truncation=True, return_tensors=\"pt\").to(device)\n",
    "\n",
    "        with torch.no_grad():\n",
    "            outputs = model(**inputs)\n",
    "            cls_batch = outputs.last_hidden_state[:, 0, :]  # shape: (batch_size, hidden_size)\n",
    "            cls_batch = cls_batch.cpu().tolist()  # convert to list of lists\n",
    "\n",
    "        all_embeddings.extend(cls_batch)\n",
    "\n",
    "    return [list(x) for x in zip(*all_embeddings)]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bba73f1b",
   "metadata": {},
   "outputs": [],
   "source": [
    "def random_cdf_diff(data1, data2, rng, n_samples=10):\n",
    "    \"\"\"\n",
    "    Randomly sample one or more points from data1 using the provided RNG,\n",
    "    compute the squared difference between the empirical CDFs of data1 and data2\n",
    "    at those points, and return the average squared difference.\n",
    "\n",
    "    Parameters:\n",
    "    - data1: 1D array-like of numeric observations for first sample.\n",
    "    - data2: 1D array-like of numeric observations for second sample.\n",
    "    - rng: numpy.random.RandomState or numpy.random.Generator instance.\n",
    "    - n_samples: int, number of points to sample from data1 (default=1).\n",
    "\n",
    "    Returns:\n",
    "    - xs: ndarray of shape (n_samples,) containing the sampled points.\n",
    "    - avg_sqdiff: float, the average of (F1(x) - F2(x))**2 over the sampled points.\n",
    "    \"\"\"\n",
    "    # ensure arrays\n",
    "    data1 = np.asarray(data1)\n",
    "    data2 = np.asarray(data2)\n",
    "\n",
    "    # 1) sample n_samples points from data1\n",
    "    xs = rng.choice(data1, size=n_samples, replace=True)\n",
    "\n",
    "    # 2) sort once for each dataset\n",
    "    sorted1 = np.sort(data1)\n",
    "    sorted2 = np.sort(data2)\n",
    "\n",
    "    # 3) compute squared CDF‐differences\n",
    "    sq_diffs = []\n",
    "    for x in xs:\n",
    "        F1 = np.searchsorted(sorted1, x, side='right') / len(sorted1)\n",
    "        F2 = np.searchsorted(sorted2, x, side='right') / len(sorted2)\n",
    "        sq_diffs.append((F1 - F2) ** 2)\n",
    "\n",
    "    # 4) average and return\n",
    "    avg_sqdiff = np.mean(sq_diffs).item()\n",
    "    return avg_sqdiff\n",
    "\n",
    "\n",
    "def ks_statistic(x, y):\n",
    "    \"\"\"\n",
    "    Kolmogorov–Smirnov statistic between samples x and y.\n",
    "    Returns the KS D-statistic (max difference between empirical CDFs).\n",
    "    \"\"\"\n",
    "    return ks_2samp(x, y).statistic\n",
    "\n",
    "\n",
    "def wasserstein(x, y):\n",
    "    \"\"\"\n",
    "    1D Wasserstein (Earth-Mover's) distance between x and y.\n",
    "    \"\"\"\n",
    "    return wasserstein_distance(x, y)\n",
    "\n",
    "\n",
    "def js_divergence(x, y):\n",
    "    # 1) build support\n",
    "    xs, counts_x = np.unique(x, return_counts=True)\n",
    "    ys, counts_y = np.unique(y, return_counts=True)\n",
    "    support = np.union1d(xs, ys)\n",
    "\n",
    "    # 2) build P and Q on that support\n",
    "    p = np.zeros_like(support, dtype=float)\n",
    "    q = np.zeros_like(support, dtype=float)\n",
    "\n",
    "    # assign probabilities by counts\n",
    "    idx_x = np.searchsorted(support, xs)\n",
    "    p[idx_x] = counts_x / counts_x.sum()\n",
    "\n",
    "    idx_y = np.searchsorted(support, ys)\n",
    "    q[idx_y] = counts_y / counts_y.sum()\n",
    "\n",
    "    # 3) avoid zeros & compute Jensen–Shannon\n",
    "    eps = 1e-12\n",
    "    p += eps\n",
    "    q += eps\n",
    "    m = 0.5 * (p + q)\n",
    "\n",
    "    return 0.5 * (np.sum(p * np.log(p/m)) + np.sum(q * np.log(q/m)))\n",
    "\n",
    "\n",
    "def cvm_statistic(x, y):\n",
    "    \"\"\"\n",
    "    Cramér–von Mises two-sample statistic between x and y.\n",
    "    \"\"\"\n",
    "    return cramervonmises_2samp(x, y).statistic\n",
    "\n",
    "\n",
    "def mean_difference(x, y):\n",
    "    \"\"\"\n",
    "    Difference in sample means between x and y (mean(x) - mean(y)).\n",
    "    \"\"\"\n",
    "    return np.abs(np.mean(x) - np.mean(y))\n",
    "\n",
    "\n",
    "def mean_py(lst):\n",
    "    return sum(lst) / len(lst) if lst else float('nan')\n",
    "\n",
    "\n",
    "def eval_b(o, cor, clean, baseline, rng=None):\n",
    "\n",
    "    vals_cor = []\n",
    "    vals_clean = []\n",
    "    for i in tqdm(range(768)):\n",
    "        if rng: # for our methods that use specific RNG\n",
    "            vals_cor.append(baseline(o[i], cor[i], rng))\n",
    "            vals_clean.append(baseline(o[i], clean[i], rng))\n",
    "        else:\n",
    "            vals_cor.append(baseline(o[i], cor[i]))\n",
    "            vals_clean.append(baseline(o[i], clean[i])) \n",
    "\n",
    "    return [max(vals_clean), max(vals_cor), mean_py(vals_clean), mean_py(vals_cor)]\n",
    "\n",
    "with open('data/clean_data.json', 'r') as f:\n",
    "     others_data = json.load(f)\n",
    "otheres_embeds = get_distilbert_embeddings(others_data)\n",
    "\n",
    "\n",
    "def get_stats(counts=500, ftn=random_cdf_diff, rng_seed=None):\n",
    "    '''\n",
    "    Runs one experiment using specific test statistic given a specific seed\n",
    "    '''\n",
    "    if rng_seed:\n",
    "        rng = np.random.default_rng(rng_seed)\n",
    "        random.seed(rng_seed)\n",
    "    else:\n",
    "        rng = None\n",
    "\n",
    "    with open('data/synth_data.json', 'r') as f:\n",
    "        synth_data = json.load(f)\n",
    "\n",
    "    random.shuffle(synth_data)\n",
    "\n",
    "    data = []\n",
    "\n",
    "    synth_q = set()\n",
    "\n",
    "    for i in range(counts):\n",
    "\n",
    "        qs = synth_data[i]\n",
    "        data.extend(qs['questions'])\n",
    "        for q in qs['synth_questions']:\n",
    "            if q[-1] != '?':\n",
    "                continue\n",
    "            synth_q.add(remove_number_prefixes(q))\n",
    "\n",
    "    clean_subset = data\n",
    "    corrupt_subset = data.copy()\n",
    "    corrupt_subset.extend(list(synth_q))\n",
    "\n",
    "\n",
    "    corrupt_embeds = get_distilbert_embeddings(corrupt_subset)\n",
    "    clean_embeds = get_distilbert_embeddings(clean_subset)\n",
    "\n",
    "    out = eval_b(otheres_embeds, corrupt_embeds, clean_embeds, ftn, rng=rng if ftn.__name__ == 'random_cdf_diff' else None)\n",
    "    out.append(len(clean_subset))\n",
    "    out.append(len(synth_q))\n",
    "\n",
    "    return out\n",
    "\n",
    "\n",
    "def get_key_stats(d, name):\n",
    "    n = len(d)\n",
    "    count_clean = d[0][4]\n",
    "\n",
    "    max_clean, max_cor, mean_clean, mean_cor = 0, 0, 0, 0\n",
    "\n",
    "    for dd in d:\n",
    "        max_clean += dd[0]\n",
    "        max_cor += dd[1]\n",
    "        mean_clean += dd[2]\n",
    "        mean_cor += dd[3]\n",
    "\n",
    "    print(f'{name} {max_clean/n} {max_cor/n} {mean_clean/n} {mean_cor/n} {count_clean}')\n",
    "\n",
    "\n",
    "def get_key_stats_var(d, name):\n",
    "    \"\"\"\n",
    "    Compute empirical (population) variance of each numeric column in `d`\n",
    "    and print them in the same order the original function used.\n",
    "    \n",
    "    Parameters\n",
    "    ----------\n",
    "    d : list[tuple | list]\n",
    "        Sequence where each element is index-able and the first four\n",
    "        positions are numeric.\n",
    "    name : str\n",
    "        Label printed at the start of the output line.\n",
    "    \"\"\"\n",
    "    n = len(d)\n",
    "    count_clean = d[0][4]          # unchanged\n",
    "\n",
    "    # ---------- first pass: means ----------\n",
    "    sums = [0.0, 0.0, 0.0, 0.0]\n",
    "    for row in d:\n",
    "        for i in range(4):\n",
    "            sums[i] += row[i]\n",
    "    means = [s / n for s in sums]\n",
    "\n",
    "    # ---------- second pass: squared deviations ----------\n",
    "    sq_devs = [0.0, 0.0, 0.0, 0.0]\n",
    "    for row in d:\n",
    "        for i in range(4):\n",
    "            diff = row[i] - means[i]\n",
    "            sq_devs[i] += diff * diff\n",
    "\n",
    "    # population variance (divide by n).  Use (n-1) for sample variance.\n",
    "    variances = [sd / (n-1) for sd in sq_devs]\n",
    "\n",
    "    print(f\"{name} \"\n",
    "          f\"{variances[0]} {variances[1]} {variances[2]} {variances[3]} \"\n",
    "          f\"{count_clean}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c0d083c6",
   "metadata": {},
   "source": [
    "# Run the experiments"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "608b78fb",
   "metadata": {},
   "outputs": [],
   "source": [
    "counts = [10,50, 100, 500, 1000]\n",
    "rng_seeds = [4, 123653, 434, 54, 87910]\n",
    "ftns = [random_cdf_diff, ks_statistic, wasserstein, js_divergence, cvm_statistic, mean_difference]\n",
    "\n",
    "data_dict = defaultdict(list)\n",
    "\n",
    "for count in tqdm(counts):\n",
    "    for r_s in rng_seeds:\n",
    "        for ftn in ftns:\n",
    "\n",
    "            out = get_stats(counts=count, ftn=ftn, rng_seed=r_s)\n",
    "\n",
    "            data_dict[f'{count}_{ftn.__name__}'].append(out)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "efae356c",
   "metadata": {},
   "outputs": [],
   "source": [
    "for key in data_dict:\n",
    "    get_key_stats(data_dict[key], key)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ab7b4b6d",
   "metadata": {},
   "outputs": [],
   "source": [
    "for key in data_dict:\n",
    "    get_key_stats_var(data_dict[key], key)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "dai",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
