{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "ca571c58",
   "metadata": {},
   "source": [
    "# Getting the synth_data "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1895d85b",
   "metadata": {},
   "outputs": [],
   "source": [
    "from tqdm import tqdm\n",
    "import numpy as np\n",
    "import random\n",
    "from scipy.stats import ks_2samp, wasserstein_distance, cramervonmises_2samp\n",
    "from collections import defaultdict\n",
    "import pickle\n",
    "\n",
    "\n",
    "def random_cdf_diff(data1, data2, rng, n_samples=10):\n",
    "    \"\"\"\n",
    "    Randomly sample one or more points from data1 using the provided RNG,\n",
    "    compute the squared difference between the empirical CDFs of data1 and data2\n",
    "    at those points, and return the average squared difference.\n",
    "\n",
    "    Parameters:\n",
    "    - data1: 1D array-like of numeric observations for first sample.\n",
    "    - data2: 1D array-like of numeric observations for second sample.\n",
    "    - rng: numpy.random.RandomState or numpy.random.Generator instance.\n",
    "    - n_samples: int, number of points to sample from data1 (default=1).\n",
    "\n",
    "    Returns:\n",
    "    - xs: ndarray of shape (n_samples,) containing the sampled points.\n",
    "    - avg_sqdiff: float, the average of (F1(x) - F2(x))**2 over the sampled points.\n",
    "    \"\"\"\n",
    "    # ensure arrays\n",
    "    data1 = np.asarray(data1)\n",
    "    data2 = np.asarray(data2)\n",
    "\n",
    "    # 1) sample n_samples points from data1\n",
    "    xs = rng.choice(data1, size=n_samples, replace=True)\n",
    "\n",
    "    # 2) sort once for each dataset\n",
    "    sorted1 = np.sort(data1)\n",
    "    sorted2 = np.sort(data2)\n",
    "\n",
    "    # 3) compute squared CDF‐differences\n",
    "    sq_diffs = []\n",
    "    for x in xs:\n",
    "        F1 = np.searchsorted(sorted1, x, side='right') / len(sorted1)\n",
    "        F2 = np.searchsorted(sorted2, x, side='right') / len(sorted2)\n",
    "        sq_diffs.append((F1 - F2) ** 2)\n",
    "\n",
    "    # 4) average and return\n",
    "    avg_sqdiff = np.mean(sq_diffs).item()\n",
    "    return avg_sqdiff\n",
    "\n",
    "\n",
    "def ks_statistic(x, y):\n",
    "    \"\"\"\n",
    "    Kolmogorov-Smirnov statistic between samples x and y.\n",
    "    Returns the KS D-statistic (max difference between empirical CDFs).\n",
    "    \"\"\"\n",
    "    return ks_2samp(x, y).statistic\n",
    "\n",
    "\n",
    "def wasserstein(x, y):\n",
    "    \"\"\"\n",
    "    1D Wasserstein (Earth-Mover's) distance between x and y.\n",
    "    \"\"\"\n",
    "    return wasserstein_distance(x, y)\n",
    "\n",
    "\n",
    "\n",
    "def js_divergence(x, y):\n",
    "    # 1) build support\n",
    "    xs, counts_x = np.unique(x, return_counts=True)\n",
    "    ys, counts_y = np.unique(y, return_counts=True)\n",
    "    support = np.union1d(xs, ys)\n",
    "\n",
    "    # 2) build P and Q on that support\n",
    "    p = np.zeros_like(support, dtype=float)\n",
    "    q = np.zeros_like(support, dtype=float)\n",
    "\n",
    "    # assign probabilities by counts\n",
    "    idx_x = np.searchsorted(support, xs)\n",
    "    p[idx_x] = counts_x / counts_x.sum()\n",
    "\n",
    "    idx_y = np.searchsorted(support, ys)\n",
    "    q[idx_y] = counts_y / counts_y.sum()\n",
    "\n",
    "    # 3) avoid zeros & compute Jensen–Shannon\n",
    "    eps = 1e-12\n",
    "    p += eps\n",
    "    q += eps\n",
    "    m = 0.5 * (p + q)\n",
    "\n",
    "    return 0.5 * (np.sum(p * np.log(p/m)) + np.sum(q * np.log(q/m)))\n",
    "\n",
    "def cvm_statistic(x, y):\n",
    "    \"\"\"\n",
    "    Cramér-von Mises two-sample statistic between x and y.\n",
    "    \"\"\"\n",
    "    return cramervonmises_2samp(x, y).statistic\n",
    "\n",
    "def mean_difference(x, y):\n",
    "    \"\"\"\n",
    "    Difference in sample means between x and y (mean(x) - mean(y)).\n",
    "    \"\"\"\n",
    "    return np.abs(np.mean(x) - np.mean(y))\n",
    "\n",
    "\n",
    "def mean_py(lst):\n",
    "    return sum(lst) / len(lst) if lst else float('nan')\n",
    "\n",
    "def eval_b(o, cor, clean, baseline, rng=None):\n",
    "\n",
    "    vals_cor = []\n",
    "    vals_clean = []\n",
    "    for i in tqdm(range(384)):\n",
    "        if rng: # for our methods that use specific RNG\n",
    "            vals_cor.append(baseline(o[i], cor[i], rng))\n",
    "            vals_clean.append(baseline(o[i], clean[i], rng))\n",
    "        else:\n",
    "            vals_cor.append(baseline(o[i], cor[i]))\n",
    "            vals_clean.append(baseline(o[i], clean[i])) \n",
    "\n",
    "    return [max(vals_clean), max(vals_cor), mean_py(vals_clean), mean_py(vals_cor)]\n",
    "\n",
    "\n",
    "def read_pickle_file(file_path):\n",
    "    \"\"\"\n",
    "    Reads a pickle file and returns the loaded object.\n",
    "\n",
    "    Args:\n",
    "        file_path (str): The path to the pickle file.\n",
    "\n",
    "    Returns:\n",
    "        object: The object loaded from the pickle file.\n",
    "                Returns None if an error occurs during reading.\n",
    "    \"\"\"\n",
    "    try:\n",
    "        with open(file_path, 'rb') as file:\n",
    "            data = pickle.load(file)\n",
    "        return data\n",
    "    except FileNotFoundError:\n",
    "        print(f\"Error: File not found at path: {file_path}\")\n",
    "        return None\n",
    "    except Exception as e:\n",
    "         print(f\"An error occurred: {e}\")\n",
    "         return None\n",
    "    \n",
    "d_clean = read_pickle_file('./flowers_heldout_features.pkl')\n",
    "d_synth = read_pickle_file('./flowers_augmented_synth_features.pkl')\n",
    "d_real = read_pickle_file('./flowers_augmented_real_features.pkl')\n",
    "otheres_embeds = [list(x) for x in zip(*d_clean)] # heldout\n",
    "\n",
    "\n",
    "def get_stats(counts=500, ftn=random_cdf_diff, rng_seed=None):\n",
    "    if rng_seed:\n",
    "        rng = np.random.default_rng(rng_seed)\n",
    "        random.seed(rng_seed)\n",
    "    else:\n",
    "        rng = None\n",
    "\n",
    "    d_s = d_synth.copy()\n",
    "    d_r = d_real.copy()\n",
    "\n",
    "    random.shuffle(d_s)\n",
    "    random.shuffle(d_r)\n",
    "    d_s = d_s[:counts]\n",
    "    d_r = d_r[:counts]\n",
    "\n",
    "    corrupt_embeds = [list(x) for x in zip(*(d_s+d_r))] # synth + clean\n",
    "\n",
    "    clean_embeds = [list(x) for x in zip(*d_r)] # clean\n",
    "\n",
    "    out = eval_b(otheres_embeds, corrupt_embeds, clean_embeds, ftn, rng=rng if ftn.__name__ == 'random_cdf_diff' else None)\n",
    "    out.append(counts)\n",
    "    out.append(counts)\n",
    "\n",
    "    return out\n",
    "\n",
    "\n",
    "def get_key_stats(d, name):\n",
    "    n = len(d)\n",
    "    count_clean = d[0][4]\n",
    "\n",
    "    max_clean, max_cor, mean_clean, mean_cor = 0, 0, 0, 0\n",
    "\n",
    "    for dd in d:\n",
    "        max_clean += dd[0]\n",
    "        max_cor += dd[1]\n",
    "        mean_clean += dd[2]\n",
    "        mean_cor += dd[3]\n",
    "\n",
    "    print(f'{name} {max_clean/n} {max_cor/n} {mean_clean/n} {mean_cor/n} {count_clean}')\n",
    "\n",
    "\n",
    "def get_key_stats_var(d, name):\n",
    "    \"\"\"\n",
    "    Compute empirical (population) variance of each numeric column in `d`\n",
    "    and print them in the same order the original function used.\n",
    "    \n",
    "    Parameters\n",
    "    ----------\n",
    "    d : list[tuple | list]\n",
    "        Sequence where each element is index-able and the first four\n",
    "        positions are numeric.\n",
    "    name : str\n",
    "        Label printed at the start of the output line.\n",
    "    \"\"\"\n",
    "    n = len(d)\n",
    "    count_clean = d[0][4]          # unchanged\n",
    "\n",
    "    # ---------- first pass: means ----------\n",
    "    sums = [0.0, 0.0, 0.0, 0.0]\n",
    "    for row in d:\n",
    "        for i in range(4):\n",
    "            sums[i] += row[i]\n",
    "    means = [s / n for s in sums]\n",
    "\n",
    "    # ---------- second pass: squared deviations ----------\n",
    "    sq_devs = [0.0, 0.0, 0.0, 0.0]\n",
    "    for row in d:\n",
    "        for i in range(4):\n",
    "            diff = row[i] - means[i]\n",
    "            sq_devs[i] += diff * diff\n",
    "\n",
    "    # population variance (divide by n).  Use (n-1) for sample variance.\n",
    "    variances = [sd / (n-1) for sd in sq_devs]\n",
    "\n",
    "    print(f\"{name} \"\n",
    "          f\"{variances[0]} {variances[1]} {variances[2]} {variances[3]} \"\n",
    "          f\"{count_clean}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "225ab0c1",
   "metadata": {},
   "source": [
    "## Run Image Experiments"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "608b78fb",
   "metadata": {},
   "outputs": [],
   "source": [
    "counts = [10,50, 100, 500, 1000]\n",
    "rng_seeds = [4, 123653, 434, 54, 87910]\n",
    "ftns = [random_cdf_diff, ks_statistic, wasserstein, js_divergence, cvm_statistic, mean_difference]\n",
    "\n",
    "data_dict = defaultdict(list)\n",
    "\n",
    "for count in tqdm(counts):\n",
    "    for r_s in rng_seeds:\n",
    "        for ftn in ftns:\n",
    "\n",
    "            out = get_stats(counts=count, ftn=ftn, rng_seed=r_s)\n",
    "\n",
    "            data_dict[f'{count}_{ftn.__name__}'].append(out)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "efae356c",
   "metadata": {},
   "outputs": [],
   "source": [
    "for key in data_dict:\n",
    "    get_key_stats(data_dict[key], key)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "61df6179",
   "metadata": {},
   "outputs": [],
   "source": [
    "for key in data_dict:\n",
    "    get_key_stats_var(data_dict[key], key)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "dai",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
