{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2a0b41f8",
   "metadata": {},
   "outputs": [],
   "source": [
    "import json\n",
    "import os\n",
    "from tqdm import tqdm\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "\n",
    "folder = \"INSERT YOURS\"\n",
    "files = os.listdir(folder)\n",
    "\n",
    "\n",
    "def process_feature_info(v):\n",
    "    kl_divs = []\n",
    "    for example in v:\n",
    "        kl_div = [pos_dict.get(\"kl_divergence\") for pos_dict in example]\n",
    "        kl_divs.append(np.array(kl_div))\n",
    "    \n",
    "    kl_divs = np.stack(kl_divs)\n",
    "    avg_shape = kl_divs.mean(axis=0)\n",
    "    avg_total_kl_div = kl_divs.sum(axis=1).mean()\n",
    "    sum_after_1 = kl_divs[:, 1:].sum(axis=1)\n",
    "    kl_1 = kl_divs[:, 0]\n",
    "    avg_pc_agg_after = (sum_after_1 / (kl_1 + 1e-8)).mean()\n",
    "\n",
    "    return {\n",
    "        \"avg_shape\": avg_shape,\n",
    "        \"avg_total_kl_div\": avg_total_kl_div,\n",
    "        \"avg_pc_agg_after\": avg_pc_agg_after,\n",
    "    }\n",
    "\n",
    "\n",
    "results = []\n",
    "\n",
    "for filename in tqdm(files):\n",
    "    with open(os.path.join(folder, filename)) as f:\n",
    "        data = json.load(f)\n",
    "        for k, v in data.items():\n",
    "            metrics = process_feature_info(v[1])\n",
    "            metrics[\"feature_idx\"] = k\n",
    "            results.append(metrics)\n",
    "\n",
    "data = pd.DataFrame(results).set_index(\"feature_idx\")           "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3e8331e6",
   "metadata": {},
   "outputs": [],
   "source": [
    "KL_CUTOFF = 0.05\n",
    "\n",
    "data = data[data[\"avg_total_kl_div\"] > KL_CUTOFF]\n",
    "data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6c66d5d1",
   "metadata": {},
   "outputs": [],
   "source": [
    "data[\"pc_index\"] = data[\"avg_pc_agg_after\"]\n",
    "data.head()\n",
    "\n",
    "feature_tags = pd.read_csv(\"INSERT YOURS.csv\")\n",
    "\n",
    "data.index = data.index.astype(int)\n",
    "data = data.join(feature_tags.set_index(\"index\"), on=data.index)\n",
    "\n",
    "data[\"is_union\"] = 1 - (1 - data[\"is_code\"]) * (1 - data[\"is_syntax\"])\n",
    "data.head()\n",
    "\n",
    "data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e4ae23ce",
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "import seaborn as sns\n",
    "\n",
    "plt.rcParams.update({\n",
    "    'font.family': 'serif',\n",
    "    'font.size': 6,\n",
    "    'axes.labelsize': 7,\n",
    "    'axes.titlesize': 7,\n",
    "    'xtick.labelsize': 6,\n",
    "    'ytick.labelsize': 6,\n",
    "})\n",
    "\n",
    "fig, ax = plt.subplots(figsize=(7, 2), dpi=300)\n",
    "\n",
    "sns.histplot(data, x=\"pc_index\", log_scale=True, hue=\"is_union\", legend=True, bins=50, ax=ax)\n",
    "legend = ax.get_legend()\n",
    "if legend is not None:\n",
    "    legend.set_title('\"formal reasoning\"')\n",
    "\n",
    "plt.grid(True, which='major', axis='x', linestyle='--', alpha=0.3, linewidth=0.3)\n",
    "plt.grid(True, which='minor', axis='y', linestyle=':', alpha=0.3, linewidth=0.3)\n",
    "plt.grid(True, which='major', axis='y', linestyle='--', alpha=0.3, linewidth=0.3)\n",
    "\n",
    "# Remove all spines for no frame\n",
    "for spine in ax.spines.values():\n",
    "    spine.set_visible(False)\n",
    "\n",
    "# Remove y axis\n",
    "ax.yaxis.set_visible(False)\n",
    "\n",
    "plt.xlabel(\"Pre-Caching Degree\")\n",
    "\n",
    "plt.tight_layout(pad=0.1)\n",
    "plt.savefig(\"INSERT YOURS.pdf\", dpi=300, bbox_inches=\"tight\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2b3e6d38",
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "import seaborn as sns\n",
    "\n",
    "plt.rcParams.update({\n",
    "    'font.family': 'serif',\n",
    "    'font.size': 6,\n",
    "    'axes.labelsize': 7,\n",
    "    'axes.titlesize': 7,\n",
    "    'xtick.labelsize': 6,\n",
    "    'ytick.labelsize': 6,\n",
    "})\n",
    "\n",
    "fig, ax = plt.subplots(figsize=(7, 2), dpi=300)\n",
    "\n",
    "sns.histplot(data, x=\"pc_index\", log_scale=True, legend=True, bins=50, ax=ax)\n",
    "legend = ax.get_legend()\n",
    "if legend is not None:\n",
    "    legend.set_title('\"formal reasoning\"')\n",
    "\n",
    "plt.grid(True, which='major', axis='x', linestyle='--', alpha=0.3, linewidth=0.3)\n",
    "plt.grid(True, which='minor', axis='y', linestyle=':', alpha=0.3, linewidth=0.3)\n",
    "plt.grid(True, which='major', axis='y', linestyle='--', alpha=0.3, linewidth=0.3)\n",
    "\n",
    "# Remove all spines for no frame\n",
    "for spine in ax.spines.values():\n",
    "    spine.set_visible(False)\n",
    "\n",
    "# Remove y axis\n",
    "ax.yaxis.set_visible(False)\n",
    "\n",
    "plt.xlabel(\"Pre-Caching Degree\")\n",
    "\n",
    "plt.tight_layout(pad=0.1)\n",
    "plt.savefig(\"INSERT YOURS.pdf\", dpi=300, bbox_inches=\"tight\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2848307a",
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "import seaborn as sns\n",
    "\n",
    "plt.rcParams.update({\n",
    "    'font.family': 'serif',\n",
    "    'font.size': 6,\n",
    "    'axes.labelsize': 7,\n",
    "    'axes.titlesize': 7,\n",
    "    'xtick.labelsize': 6,\n",
    "    'ytick.labelsize': 6,\n",
    "})\n",
    "\n",
    "for differentiator in [\"is_union\", \"is_code\", \"is_syntax\"]:\n",
    "    fig, ax = plt.subplots(figsize=(2.1, 1.4), dpi=300)\n",
    "\n",
    "    sns.histplot(data, x=\"pc_index\", log_scale=True, legend=True, hue=differentiator, bins=20, ax=ax)\n",
    "    legend = ax.get_legend()\n",
    "    if legend is not None:\n",
    "        legend.set_title(differentiator)\n",
    "\n",
    "    plt.grid(True, which='major', axis='x', linestyle='--', alpha=0.3, linewidth=0.3)\n",
    "    plt.grid(True, which='minor', axis='y', linestyle=':', alpha=0.3, linewidth=0.3)\n",
    "    plt.grid(True, which='major', axis='y', linestyle='--', alpha=0.3, linewidth=0.3)\n",
    "\n",
    "    # Remove all spines for no frame\n",
    "    for spine in ax.spines.values():\n",
    "        spine.set_visible(False)\n",
    "\n",
    "    # Remove y axis\n",
    "    ax.yaxis.set_visible(False)\n",
    "\n",
    "    plt.xlabel(\"Pre-Caching Degree\")\n",
    "\n",
    "    plt.tight_layout(pad=0.1)\n",
    "    plt.savefig(f\"INSERT YOURS_{differentiator}_.pdf\", dpi=300, bbox_inches=\"tight\")\n",
    "    plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "74670cdf",
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "import seaborn as sns\n",
    "\n",
    "plt.rcParams.update({\n",
    "    'font.family': 'serif',\n",
    "    'font.size': 6,\n",
    "    'axes.labelsize': 7,\n",
    "    'axes.titlesize': 7,\n",
    "    'xtick.labelsize': 6,\n",
    "    'ytick.labelsize': 6,\n",
    "})\n",
    "\n",
    "for differentiator in [\"is_union\", \"is_code\", \"is_syntax\"]:\n",
    "    fig, ax = plt.subplots(figsize=(2.1, 1.4), dpi=300)\n",
    "\n",
    "    sns.histplot(data[data[\"pc_index\"] > 10], x=\"pc_index\", log_scale=True, legend=True, hue=differentiator, bins=20, ax=ax)\n",
    "    legend = ax.get_legend()\n",
    "    if legend is not None:\n",
    "        legend.set_title(differentiator)\n",
    "\n",
    "    plt.grid(True, which='major', axis='x', linestyle='--', alpha=0.3, linewidth=0.3)\n",
    "    plt.grid(True, which='minor', axis='y', linestyle=':', alpha=0.3, linewidth=0.3)\n",
    "    plt.grid(True, which='major', axis='y', linestyle='--', alpha=0.3, linewidth=0.3)\n",
    "\n",
    "    # Remove all spines for no frame\n",
    "    for spine in ax.spines.values():\n",
    "        spine.set_visible(False)\n",
    "\n",
    "    # Remove y axis\n",
    "    ax.yaxis.set_visible(False)\n",
    "\n",
    "    plt.xlabel(\"Pre-Caching Degree\")\n",
    "    plt.xlim(left=10)\n",
    "\n",
    "    plt.tight_layout(pad=0.1)\n",
    "    plt.savefig(f\"INSERT YOURS_{differentiator}_right.pdf\", dpi=300, bbox_inches=\"tight\")\n",
    "    plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b90bb537",
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "import seaborn as sns\n",
    "\n",
    "plt.rcParams.update({\n",
    "    'font.family': 'serif',\n",
    "    'font.size': 6,\n",
    "    'axes.labelsize': 7,\n",
    "    'axes.titlesize': 7,\n",
    "    'xtick.labelsize': 6,\n",
    "    'ytick.labelsize': 6,\n",
    "})\n",
    "\n",
    "for differentiator in [\"is_union\", \"is_code\", \"is_syntax\"]:\n",
    "    fig, ax = plt.subplots(figsize=(2.1, 1.4), dpi=300)\n",
    "\n",
    "    sns.histplot(data[data[\"pc_index\"] < 0.1], x=\"pc_index\", log_scale=True, legend=True, hue=differentiator, bins=20, ax=ax)\n",
    "    legend = ax.get_legend()\n",
    "    if legend is not None:\n",
    "        legend.set_title(differentiator)\n",
    "\n",
    "    plt.grid(True, which='major', axis='x', linestyle='--', alpha=0.3, linewidth=0.3)\n",
    "    plt.grid(True, which='minor', axis='y', linestyle=':', alpha=0.3, linewidth=0.3)\n",
    "    plt.grid(True, which='major', axis='y', linestyle='--', alpha=0.3, linewidth=0.3)\n",
    "\n",
    "    # Remove all spines for no frame\n",
    "    for spine in ax.spines.values():\n",
    "        spine.set_visible(False)\n",
    "\n",
    "    # Remove y axis\n",
    "    ax.yaxis.set_visible(False)\n",
    "\n",
    "    plt.xlabel(\"Pre-Caching Degree\")\n",
    "    plt.xlim(right=0.1)\n",
    "\n",
    "    plt.tight_layout(pad=0.1)\n",
    "    plt.savefig(f\"INSERT YOURS_{differentiator}_left.pdf\", dpi=300, bbox_inches=\"tight\")\n",
    "    plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c61c17d2",
   "metadata": {},
   "outputs": [],
   "source": [
    "data_exp = data.sort_values(by=\"pc_index\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b9ef2b4a",
   "metadata": {},
   "outputs": [],
   "source": [
    "data_exp[\"pc_index_str\"] = data_exp[\"pc_index\"].apply(lambda x: f\"{x:.3f}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "449eabc5",
   "metadata": {},
   "outputs": [],
   "source": [
    "print(data_exp[:10][[\"pc_index_str\", \"desc\"]].to_latex())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d71405f0",
   "metadata": {},
   "outputs": [],
   "source": [
    "data_exp[\"pc_index_str2\"] = data_exp[\"pc_index\"].apply(lambda x: f\"{x:.1f}\")\n",
    "\n",
    "print(data_exp[-10:][::-1][[\"pc_index_str2\", \"desc\"]].to_latex())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "99ad024b",
   "metadata": {},
   "outputs": [],
   "source": [
    "small_data_exp = data_exp[(data_exp[\"pc_index\"] > 0.1) & (data_exp[\"pc_index\"] < 1.1)].sample(10, random_state=0)\n",
    "print(small_data_exp[[\"pc_index_str\", \"desc\"]].to_latex())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6ab55bbe",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "from scipy import stats\n",
    "\n",
    "def sigma_ci_normal(x, alpha=0.05):\n",
    "    x = np.asarray(x)\n",
    "    n = x.size\n",
    "    s2 = x.var(ddof=1)\n",
    "    df = n - 1\n",
    "    lower = np.sqrt(df * s2 / stats.chi2.ppf(1 - alpha/2, df))\n",
    "    upper = np.sqrt(df * s2 / stats.chi2.ppf(alpha/2, df))\n",
    "    return (lower + upper) / 2, (upper - lower) / 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "90b3b128",
   "metadata": {},
   "outputs": [],
   "source": [
    "sigma_ci_normal(np.log(data[data[\"is_union\"] == 1][\"pc_index\"]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a7b20da4",
   "metadata": {},
   "outputs": [],
   "source": [
    "sigma_ci_normal(np.log(data[data[\"is_union\"] == 0][\"pc_index\"]))"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "dlenv",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
