{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "os.chdir(\"../\")\n",
    "\n",
    "\n",
    "import matplotlib.pyplot as plt\n",
    "from src.sfm import define_style, save_plot\n",
    "\n",
    "\n",
    "define_style()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def plot_kl_results():\n",
    "    \"\"\"Plot KL against `K`, the number of categories.\"\"\"\n",
    "    ks = [5, 10, 20, 40, 60, 80, 100, 120, 140, 160]\n",
    "    dfm = [0.04784, 0.1304, 0.02398, 0.05646, 0.07041, 0.09034, 0.1034, 0.1152, 0.12, 0.1235]\n",
    "    sfm = [1e-3, 2.5*1e-3, 4*1e-3, 0.02, 0.026, 0.025, 0.05, 0.08, 0.14, 0.235]\n",
    "    plt.plot(ks, dfm, marker=\"s\", label=\"Dirichlet FM\")\n",
    "    plt.plot(ks, sfm, marker=\"x\", label=\"Simplex FM\")\n",
    "    plt.xlabel(\"Number of categories, K\")\n",
    "    plt.ylabel(\"KL divergence\")\n",
    "    plt.legend()\n",
    "    save_plot(\"./out/toy_kl.pdf\")\n",
    "    plt.show()\n",
    "\n",
    "\n",
    "plot_kl_results()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "sfm",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.14"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
