{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 39,
   "id": "kl_rowwise_9549ade1",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>row</th>\n",
       "      <th>KL(A||B)</th>\n",
       "      <th>KL(A||C)</th>\n",
       "      <th>KL(B||C)</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>0</td>\n",
       "      <td>1.329944</td>\n",
       "      <td>2.145773</td>\n",
       "      <td>4.982303</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>1</td>\n",
       "      <td>0.263995</td>\n",
       "      <td>0.813363</td>\n",
       "      <td>0.086874</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>2</td>\n",
       "      <td>9.868894</td>\n",
       "      <td>6.576357</td>\n",
       "      <td>0.000495</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>3</td>\n",
       "      <td>1.601833</td>\n",
       "      <td>4.639836</td>\n",
       "      <td>0.218614</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>4</td>\n",
       "      <td>0.913462</td>\n",
       "      <td>2.841591</td>\n",
       "      <td>2.363610</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "   row  KL(A||B)  KL(A||C)  KL(B||C)\n",
       "0    0  1.329944  2.145773  4.982303\n",
       "1    1  0.263995  0.813363  0.086874\n",
       "2    2  9.868894  6.576357  0.000495\n",
       "3    3  1.601833  4.639836  0.218614\n",
       "4    4  0.913462  2.841591  2.363610"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Means: {'mean KL(A||B)': 1.9119363137255536, 'mean KL(A||C)': 3.6501410131554355, 'mean KL(B||C)': 1.3546289544872547}\n",
      "Saved per-row KL to kl_rowwise.csv\n"
     ]
    }
   ],
   "source": [
    "\n",
    "# KL divergence per row between the three files (A,B,C)\n",
    "import torch, numpy as np, pandas as pd\n",
    "\n",
    "EPS = 1e-8\n",
    "\n",
    "def load_2d_tensor(path: str) -> torch.Tensor:\n",
    "    obj = torch.load(path, map_location='cpu')\n",
    "    if torch.is_tensor(obj):\n",
    "        t = obj\n",
    "    elif isinstance(obj, (list, tuple)):\n",
    "        t = None\n",
    "        for item in obj:\n",
    "            if torch.is_tensor(item) and item.ndim == 2:\n",
    "                t = item\n",
    "                break\n",
    "        if t is None:\n",
    "            raise ValueError(f\"{path}: Could not find a 2D tensor inside the saved object.\")\n",
    "    else:\n",
    "        raise TypeError(f\"{path}: Expected a tensor or list/tuple, got {type(obj)}\")\n",
    "    if t.ndim != 2:\n",
    "        raise ValueError(f\"{path}: Expected a 2D tensor [N,K], got shape {tuple(t.shape)}\")\n",
    "    return t.float().cpu()\n",
    "\n",
    "def to_probs(x: torch.Tensor, assume_logits: bool) -> torch.Tensor:\n",
    "    if assume_logits:\n",
    "        return torch.softmax(x, dim=1)\n",
    "    s = x.sum(dim=1, keepdim=True).clamp_min(1e-12)\n",
    "    return (x / s).clamp_min(0.0)\n",
    "\n",
    "def kl_divergence(p: np.ndarray, q: np.ndarray, eps: float = EPS) -> float:\n",
    "    p = p.astype(np.float64) + eps\n",
    "    q = q.astype(np.float64) + eps\n",
    "    p /= p.sum(); q /= q.sum()\n",
    "    return np.sum(p * (np.log(p) - np.log(q)))\n",
    "\n",
    "# File paths (A,B,C)\n",
    "path_A = 'forget_lacuna5_allcnn_final_all_outputs_tensor.pt'\n",
    "path_B = 'scrub_forget_lacuna5_allcnn_final_all_outputs_tensor.pt'\n",
    "path_C = 'forget_set_all_probs_tensor.pt'\n",
    "\n",
    "A = to_probs(load_2d_tensor(path_A), True).numpy()\n",
    "B = to_probs(load_2d_tensor(path_B), True).numpy()\n",
    "C = to_probs(load_2d_tensor(path_C), False).numpy()\n",
    "\n",
    "n = min(len(A), len(B), len(C))\n",
    "if len(A) != len(B) or len(A) != len(C):\n",
    "    print(f'Warning: different lengths A={len(A)} B={len(B)} C={len(C)}; using n={n}')\n",
    "A, B, C = A[:n], B[:n], C[:n]\n",
    "\n",
    "kl_A_B = np.array([kl_divergence(A[i], B[i]) for i in range(n)])\n",
    "kl_A_C = np.array([kl_divergence(A[i], C[i]) for i in range(n)])\n",
    "kl_B_C = np.array([kl_divergence(B[i], C[i]) for i in range(n)])\n",
    "\n",
    "import pandas as pd\n",
    "rows = np.arange(n)\n",
    "df = pd.DataFrame({'row': rows, 'KL(A||B)': kl_A_B, 'KL(A||C)': kl_A_C, 'KL(B||C)': kl_B_C})\n",
    "try:\n",
    "    display(df.head())\n",
    "except Exception:\n",
    "    print(df.head().to_string(index=False))\n",
    "\n",
    "means = {\n",
    "    'mean KL(A||B)': float(kl_A_B.mean()),\n",
    "    'mean KL(A||C)': float(kl_A_C.mean()),\n",
    "    'mean KL(B||C)': float(kl_B_C.mean()),\n",
    "}\n",
    "print('Means:', means)\n",
    "\n",
    "df.to_csv('kl_rowwise.csv', index=False)\n",
    "print('Saved per-row KL to kl_rowwise.csv')\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "nlp",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
