{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ad59dd92",
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "import numpy as np\n",
    "\n",
    "import json\n",
    "import os\n",
    "\n",
    "import matplotlib.pyplot as plt\n",
    "from scipy.stats import wasserstein_distance\n",
    "\n",
    "from tqdm.notebook import tqdm\n",
    "\n",
    "import HALL_lib"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8bf87ecc",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.rcParams['text.usetex'] = True\n",
    "plt.rcParams['text.latex.preamble'] = r'''\n",
    "\\usepackage{mathtools}\n",
    "% more packages here\n",
    "'''"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9b600274",
   "metadata": {},
   "outputs": [],
   "source": [
    "# ============== NEWDATASET ==============\n",
    "\n",
    "# cache_dir = \"cache-newdataset\"\n",
    "# maxResponsesPerPrompt = 500\n",
    "\n",
    "# model_names = {\n",
    "#     0 : \"Llama-8B\",\n",
    "#     1 : \"Phi-14B\",\n",
    "#     2 : \"Mistral-7B\",\n",
    "#     3 : \"Solar-11B\",\n",
    "#     4 : \"Gemma-9B\",\n",
    "#     5 : \"Qwen-15B\",\n",
    "#     6 : \"DeepSeek-7B\",\n",
    "#     7 : \"Yi-9B\"\n",
    "# }\n",
    "\n",
    "# ============== OLDDATASET ==============\n",
    "\n",
    "# cache_dir = \"cache-olddataset-20-22\"\n",
    "# maxResponsesPerPrompt=100\n",
    "\n",
    "# model_names = {\n",
    "#     0: 'Mistral-7B',\n",
    "#     1: 'Gemma-9B',\n",
    "#     2: 'Solar-11B',\n",
    "#     3: 'Phi-14B',\n",
    "#     4: 'Qwen-32B',\n",
    "#     5: 'Gemma-27B',\n",
    "#     6: 'Qwen-15B'\n",
    "# }\n",
    "\n",
    "# best_reg_lambda = 1.9\n",
    "\n",
    "# ============== OLDDATASET ==============\n",
    "\n",
    "embedding_label=\"response_embeddings\"\n",
    "cache_dir = \"cache-icml\"\n",
    "maxResponsesPerPrompt=150\n",
    "\n",
    "model_names = {\n",
    "    0: 'Mistral-7B',\n",
    "    1: 'Gemma-9B',\n",
    "    2: 'Solar-11B',\n",
    "    3: 'Phi-14B',\n",
    "    4: 'Qwen-15B',\n",
    "    5: 'Gemma-27B',\n",
    "    6: 'Qwen-32B',\n",
    "    7: \"DeepSeek-7B\",\n",
    "    8: \"Llama-8B\",\n",
    "    9: \"Yi-9B\"\n",
    "}\n",
    "\n",
    "best_reg_lambda = 1.2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a810ae8c",
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset_fn = \"icml26_dataset.parquet\"\n",
    "\n",
    "df = HALL_lib.loadParquet(dataset_fn, unifyYears=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d7f99912",
   "metadata": {},
   "outputs": [],
   "source": [
    "model_order, model_rank = HALL_lib.build_model_size_order(model_names)\n",
    "\n",
    "false_premise = {pid : fp for pid, fp in df[['prompt_id', \"false_premise\"]].value_counts().keys()} if \"false_premise\" in df.columns else None"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6b67494f",
   "metadata": {},
   "outputs": [],
   "source": [
    "results_df, geometry_store, null_store = HALL_lib.run_structural_analysis(\n",
    "    df,\n",
    "    lambda_reg=best_reg_lambda,\n",
    "    n_permutations=100,\n",
    "    random_state=42,\n",
    "    min_per_class_plot=5,\n",
    "    use_cache=True,\n",
    "    cache_dir=cache_dir,\n",
    "    overwrite_cache=False,\n",
    "    embedding_label=embedding_label\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a9f2200b",
   "metadata": {},
   "outputs": [],
   "source": [
    "def dispersion_ratio(D_GG, D_HH):\n",
    "    \"\"\"\n",
    "    Ratio of intra-class dispersions (HH / GG).\n",
    "    > 1 means hallucinations are more spread.\n",
    "    \"\"\"\n",
    "    if len(D_GG) == 0 or len(D_HH) == 0:\n",
    "        return np.nan\n",
    "    return np.median(D_HH) / np.median(D_GG)\n",
    "\n",
    "def separation_ratio(D_GG, D_HH, D_GH):\n",
    "    \"\"\"\n",
    "    Inter-class distance relative to average intra-class distance.\n",
    "    Lower values indicate degradation.\n",
    "    \"\"\"\n",
    "    if len(D_GG) == 0 or len(D_HH) == 0 or len(D_GH) == 0:\n",
    "        return np.nan\n",
    "    intra = 0.5 * (np.median(D_GG) + np.median(D_HH))\n",
    "    return np.median(D_GH) / intra\n",
    "\n",
    "def intra_shape_divergence(D_GG, D_HH):\n",
    "    \"\"\"\n",
    "    Wasserstein distance between intra-class distance distributions,\n",
    "    normalised by GG scale.\n",
    "    \"\"\"\n",
    "    if len(D_GG) == 0 or len(D_HH) == 0:\n",
    "        return np.nan\n",
    "    W = wasserstein_distance(D_GG, D_HH)\n",
    "    return W / np.median(D_GG)\n",
    "\n",
    "def covariance_trace(X):\n",
    "    if len(X) < 2:\n",
    "        return np.nan\n",
    "    S = np.cov(X, rowvar=False, bias=True)\n",
    "    return np.trace(S)\n",
    "\n",
    "def effective_rank(X, eps=1e-12):\n",
    "    \"\"\"\n",
    "    exp( Shannon entropy of normalised eigenvalues )\n",
    "    \"\"\"\n",
    "    if len(X) < 2:\n",
    "        return np.nan\n",
    "    S = np.cov(X, rowvar=False, bias=True)\n",
    "    w = np.linalg.eigvalsh(S)\n",
    "    w = np.maximum(w, eps)\n",
    "    p = w / w.sum()\n",
    "    return np.exp(-(p * np.log(p)).sum())\n",
    "\n",
    "def normalised_dispersion(X):\n",
    "    if len(X) < 2:\n",
    "        return np.nan\n",
    "    S = np.cov(X, rowvar=False, bias=True)\n",
    "    d = S.shape[0]\n",
    "    return np.trace(S) / d\n",
    "\n",
    "def dispersion_inflation(X_G, X_H):\n",
    "    return normalised_dispersion(X_H) / normalised_dispersion(X_G)\n",
    "\n",
    "def shape_degradation(X_G, X_H):\n",
    "    return effective_rank(X_H) / effective_rank(X_G)\n",
    "\n",
    "def covariance_spectral_divergence(X_G, X_H, eps=1e-12):\n",
    "    if len(X_G) < 2 or len(X_H) < 2:\n",
    "        return np.nan\n",
    "\n",
    "    SG = np.cov(X_G, rowvar=False, bias=True)\n",
    "    SH = np.cov(X_H, rowvar=False, bias=True)\n",
    "\n",
    "    wG = np.maximum(np.linalg.eigvalsh(SG), eps)\n",
    "    wH = np.maximum(np.linalg.eigvalsh(SH), eps)\n",
    "\n",
    "    pG = wG / wG.sum()\n",
    "    pH = wH / wH.sum()\n",
    "\n",
    "    return np.sum(pG * np.log(pG / pH))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d44d2167",
   "metadata": {},
   "outputs": [],
   "source": [
    "def balanced_subsample_prompt(X, y, n_total, rng):\n",
    "    \"\"\"\n",
    "    Balanced subsampling preserving original GG/HH ratio.\n",
    "    \"\"\"\n",
    "    idx_G = np.where(y == 0)[0]  # assuming 0 = GG, 1 = HH\n",
    "    idx_H = np.where(y == 1)[0]\n",
    "\n",
    "    nG_full = len(idx_G)\n",
    "    nH_full = len(idx_H)\n",
    "    n_full = nG_full + nH_full\n",
    "\n",
    "    if n_total > n_full or nG_full < 3 or nH_full < 3:\n",
    "        return None\n",
    "\n",
    "    frac_G = nG_full / n_full\n",
    "    nG = int(round(frac_G * n_total))\n",
    "    nH = n_total - nG\n",
    "\n",
    "    # enforce feasibility\n",
    "    if nG < 1 or nH < 1 or nG > nG_full or nH > nH_full:\n",
    "        return None\n",
    "\n",
    "    sel_G = rng.choice(idx_G, size=nG, replace=False)\n",
    "    sel_H = rng.choice(idx_H, size=nH, replace=False)\n",
    "\n",
    "    sel = np.concatenate([sel_G, sel_H])\n",
    "    return X[sel], y[sel]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d69890b1",
   "metadata": {},
   "outputs": [],
   "source": [
    "def evaluate_prompt_degradation_subsampled(\n",
    "    df,\n",
    "    model_id,\n",
    "    prompt_id,\n",
    "    n_total,\n",
    "    rng,\n",
    "):\n",
    "    X, y = HALL_lib.extract_prompt_data(df, model_id, prompt_id, embedding_label=\"response_embeddings\")\n",
    "    out = balanced_subsample_prompt(X, y, n_total, rng)\n",
    "\n",
    "    if out is None:\n",
    "        return None\n",
    "\n",
    "    Xs, ys = out\n",
    "    X_G, X_H = HALL_lib.split_by_label(Xs, ys)\n",
    "\n",
    "    if len(X_G) < 2 or len(X_H) < 2:\n",
    "        return None\n",
    "\n",
    "    res = {\n",
    "        \"model_id\": model_id,\n",
    "        \"prompt_id\": prompt_id,\n",
    "        \"n_total\": n_total,\n",
    "        \"n_G\": len(X_G),\n",
    "        \"n_H\": len(X_H),\n",
    "        \"frac_G\": len(X_G) / n_total,\n",
    "    }\n",
    "\n",
    "    # absolute\n",
    "    res[\"trace_G\"] = covariance_trace(X_G)\n",
    "    res[\"trace_H\"] = covariance_trace(X_H)\n",
    "    res[\"effrank_G\"] = effective_rank(X_G)\n",
    "    res[\"effrank_H\"] = effective_rank(X_H)\n",
    "\n",
    "    # relative\n",
    "    res[\"dispersion_inflation\"] = dispersion_inflation(X_G, X_H)\n",
    "    res[\"shape_degradation\"] = shape_degradation(X_G, X_H)\n",
    "    res[\"spectral_divergence\"] = covariance_spectral_divergence(X_G, X_H)\n",
    "\n",
    "    # other\n",
    "    D_GG, D_HH, D_GH = HALL_lib.compute_distance_distributions(X_G, X_H)\n",
    "    res[\"dispersion_ratio\"] = dispersion_ratio(D_GG, D_HH)\n",
    "    res[\"separation_ratio\"] = separation_ratio(D_GG, D_HH, D_GH)\n",
    "    res[\"intra_shape_divergence\"] = intra_shape_divergence(D_GG, D_HH)\n",
    "\n",
    "    return res"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f08d807c",
   "metadata": {},
   "outputs": [],
   "source": [
    "def run_subsampling_protocol(\n",
    "    df,\n",
    "    model_ids,\n",
    "    prompt_ids,\n",
    "    n_sizes,\n",
    "    n_boot=10,\n",
    "    seed=0,\n",
    "    use_cache=False,\n",
    "    cache_dir=\"cache\",\n",
    "    overwrite_cache=False,\n",
    "):\n",
    "    \"\"\"\n",
    "    Run subsampling protocol with filesystem caching.\n",
    "\n",
    "    Cache files\n",
    "    ----------\n",
    "    subsampling_results.parquet\n",
    "    meta.json\n",
    "    \"\"\"\n",
    "\n",
    "    # ---- cache paths ----\n",
    "    if use_cache:\n",
    "        if cache_dir is None:\n",
    "            raise ValueError(\"cache_dir must be provided when use_cache=True\")\n",
    "\n",
    "        os.makedirs(cache_dir, exist_ok=True)\n",
    "\n",
    "        results_path = os.path.join(cache_dir, \"subsampling_results.parquet\")\n",
    "        meta_path = os.path.join(cache_dir, \"meta.json\")\n",
    "\n",
    "        cache_exists = (\n",
    "            os.path.exists(results_path)\n",
    "            and os.path.exists(meta_path)\n",
    "        )\n",
    "\n",
    "        if cache_exists and not overwrite_cache:\n",
    "            return pd.read_parquet(results_path)\n",
    "\n",
    "    # ---- computation ----\n",
    "    rng = np.random.default_rng(seed)\n",
    "    rows = []\n",
    "\n",
    "    for model_id in tqdm(model_ids, desc=\"Model\"):\n",
    "        for prompt_id in tqdm(prompt_ids, desc=\"Prompt\"):\n",
    "            for n_total in n_sizes:\n",
    "                for b in range(n_boot):\n",
    "\n",
    "                    res = evaluate_prompt_degradation_subsampled(\n",
    "                        df, model_id, prompt_id, n_total, rng\n",
    "                    )\n",
    "\n",
    "                    if res is not None:\n",
    "                        res = res.copy()\n",
    "                        res[\"boot\"] = b\n",
    "                        rows.append(res)\n",
    "\n",
    "    results_df = pd.DataFrame(rows)\n",
    "\n",
    "    # ---- save cache ----\n",
    "    if use_cache:\n",
    "        results_df.to_parquet(results_path, index=False)\n",
    "\n",
    "        meta = {\n",
    "            \"model_ids\": list(model_ids),\n",
    "            \"prompt_ids\": list(prompt_ids),\n",
    "            \"n_sizes\": list(n_sizes),\n",
    "            \"n_boot\": n_boot,\n",
    "            \"seed\": seed,\n",
    "            \"n_rows\": len(results_df),\n",
    "        }\n",
    "\n",
    "        with open(meta_path, \"w\") as f:\n",
    "            json.dump(meta, f, indent=2)\n",
    "\n",
    "    return results_df"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8673fe09",
   "metadata": {},
   "outputs": [],
   "source": [
    "n_sizes = range(10, 101, 10)\n",
    "\n",
    "df_deg = run_subsampling_protocol(\n",
    "    df=df,\n",
    "    model_ids=model_names.keys(),\n",
    "    prompt_ids=results_df[\"prompt_id\"].unique(),\n",
    "    n_sizes=n_sizes,\n",
    "    n_boot=10,\n",
    "    use_cache=True,\n",
    "    cache_dir=cache_dir,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7d30bcc9",
   "metadata": {},
   "outputs": [],
   "source": [
    "def plot_degradation_indices(df_deg, indices=None, model_names=None, figsize=(16, 10)):\n",
    "    \"\"\"\n",
    "    Plot degradation indices vs dataset size for each model.\n",
    "\n",
    "    Parameters\n",
    "    ----------\n",
    "    df_deg : pd.DataFrame\n",
    "        Output of run_subsampling_protocol\n",
    "    indices : list of str, optional\n",
    "        Column names of degradation indices to plot\n",
    "    model_names : dict, optional\n",
    "        model_id -> model name mapping\n",
    "    figsize : tuple\n",
    "        Figure size\n",
    "    \"\"\"\n",
    "    if indices is None:\n",
    "        indices = [\"trace_H\", \"effrank_H\", \"dispersion_inflation\", \n",
    "                   \"shape_degradation\", \"spectral_divergence\"]\n",
    "\n",
    "    n_indices = len(indices)\n",
    "    fig, axes = plt.subplots(n_indices, 1, figsize=figsize, sharex=True)\n",
    "\n",
    "    if n_indices == 1:\n",
    "        axes = [axes]\n",
    "\n",
    "    for ax, idx in zip(axes, indices):\n",
    "        # compute median and IQR per model & n_total\n",
    "        df_plot = df_deg.groupby([\"model_id\", \"n_total\"])[idx].agg(\n",
    "            median=\"median\",\n",
    "            q25=lambda x: np.percentile(x, 25),\n",
    "            q75=lambda x: np.percentile(x, 75)\n",
    "        ).reset_index()\n",
    "\n",
    "        for mid in df_plot[\"model_id\"].unique():\n",
    "            df_m = df_plot[df_plot[\"model_id\"] == mid]\n",
    "            label = model_names[mid] if model_names else str(mid)\n",
    "\n",
    "            ax.plot(df_m[\"n_total\"], df_m[\"median\"], label=label, marker=\"o\")\n",
    "            ax.fill_between(df_m[\"n_total\"], df_m[\"q25\"], df_m[\"q75\"], alpha=0.2)\n",
    "\n",
    "        ax.set_ylabel(idx)\n",
    "        ax.grid(True, axis=\"y\", alpha=0.3)\n",
    "\n",
    "    axes[-1].set_xlabel(\"Subsample size (n_total)\")\n",
    "    axes[0].set_title(\"Degradation indices vs dataset size\")\n",
    "    axes[0].legend(loc=\"upper left\")\n",
    "    fig.tight_layout()\n",
    "    return fig, axes\n",
    "\n",
    "if df_deg is not None:\n",
    "    fig, axes = plot_degradation_indices(df_deg, indices=[\"dispersion_inflation\", \"shape_degradation\", \"spectral_divergence\"], model_names=model_names)\n",
    "    fig, axes = plot_degradation_indices(df_deg, indices=[\"dispersion_ratio\", \"separation_ratio\", \"intra_shape_divergence\"], model_names=model_names)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "hall_env",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.14.0"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
