{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import pickle\n",
    "\n",
    "import matplotlib.pyplot as plt\n",
    "import seaborn as sns"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "performance_dict = pickle.load(\n",
    "    open(os.path.join(\"synthetic_output\", \"0.003_performance_dict.pkl\"), \"rb\")\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "performance_dict"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def plot_test_performance_vs_dim(\n",
    "    num_layers: int,\n",
    "    autoregressive: bool,\n",
    "    add_bos: bool = None,\n",
    "    n_runs: int = 10,\n",
    "    output_dir: str = \"synthetic_output\",\n",
    "):\n",
    "    # num_layers = 2\n",
    "    # autoregressive = True\n",
    "    # add_bos = False\n",
    "    if not autoregressive:\n",
    "        assert add_bos is None\n",
    "\n",
    "    plot_dict = dict()\n",
    "    for k, v in performance_dict.items():\n",
    "        if isinstance(k, tuple):\n",
    "            if num_layers != int(k[0].split(\":\")[1]) or autoregressive != (\n",
    "                k[3].split(\":\")[1] == \"True\"\n",
    "            ):\n",
    "                continue\n",
    "            # print(k)\n",
    "            if autoregressive:\n",
    "                if add_bos != (k[4].split(\":\")[1] == \"True\"):\n",
    "                    continue\n",
    "\n",
    "            time_feat_dim = k[1].split(\":\")[1]\n",
    "            time_encoding_method = k[2].split(\":\")[1]\n",
    "            # autoregressive = k[3].split(\":\")[1]\n",
    "            # if autoregressive == \"True\":\n",
    "            #     add_bos = k[4].split(\":\")[1]\n",
    "            # else:\n",
    "            #     add_bos = None\n",
    "\n",
    "            avg_test_performance = v[2][0]\n",
    "            std_test_performance = v[2][1]\n",
    "            if time_encoding_method not in plot_dict:\n",
    "                plot_dict[time_encoding_method] = (\n",
    "                    [time_feat_dim],\n",
    "                    [avg_test_performance],\n",
    "                    [std_test_performance],\n",
    "                )\n",
    "            else:\n",
    "                plot_dict[time_encoding_method][0].append(time_feat_dim)\n",
    "                plot_dict[time_encoding_method][1].append(avg_test_performance)\n",
    "                plot_dict[time_encoding_method][2].append(std_test_performance)\n",
    "\n",
    "    sns.set_theme()\n",
    "    plt.figure(figsize=(14, 8))\n",
    "    plt.tight_layout()\n",
    "\n",
    "    # avg_list = [2,4,6,8,10]\n",
    "    # stderr_list = [0.1, 0.15, 0.03, 0.2, 0.13]\n",
    "    # upper = [avg_list[j] + 1.96 * stderr_list[j] for j in range(len(stderr_list))]\n",
    "    # lower = [avg_list[j] - 1.96 * stderr_list[j] for j in range(len(stderr_list))]\n",
    "    # sns.lineplot(x=range(len(avg_list)), y=avg_list, markersize=18, legend=None)\n",
    "    # plt.fill_between(range(len(avg_list)), lower, upper, alpha=0.07)\n",
    "    for k, v in plot_dict.items():\n",
    "        # plt.errorbar(v[0], v[1], yerr=v[2], label=k)\n",
    "        # plt.plot(v[0], v[1], label=k)\n",
    "        avg_list = v[1]\n",
    "        stddev_list = v[2]\n",
    "        stderr_list = [stddev / n_runs**0.5 for stddev in stddev_list]\n",
    "        upper = [avg_list[j] + 1.96 * stderr_list[j] for j in range(len(stderr_list))]\n",
    "        lower = [avg_list[j] - 1.96 * stderr_list[j] for j in range(len(stderr_list))]\n",
    "        # upper = [avg_list[j] + stddev_list[j] for j in range(len(stddev_list))]\n",
    "        # lower = [avg_list[j] - stddev_list[j] for j in range(len(stddev_list))]\n",
    "        sns.lineplot(x=v[0], y=avg_list, markersize=18, label=k)\n",
    "        plt.fill_between(v[0], lower, upper, alpha=0.07)\n",
    "\n",
    "    avg_oracle_test_acc = performance_dict[\"oracle\"][2][0]\n",
    "    std_oracle_test_acc = performance_dict[\"oracle\"][2][1]\n",
    "    stderr_oracle_test_acc = std_oracle_test_acc / n_runs**0.5\n",
    "    sns.lineplot(x=v[0], y=[avg_oracle_test_acc] * len(v[0]), label=\"oracle\", linestyle=\"--\")\n",
    "    plt.fill_between(\n",
    "        v[0],\n",
    "        [avg_oracle_test_acc - 1.96 * stderr_oracle_test_acc] * len(v[0]),\n",
    "        [avg_oracle_test_acc + 1.96 * stderr_oracle_test_acc] * len(v[0]),\n",
    "        alpha=0.07,\n",
    "    )\n",
    "\n",
    "    if autoregressive:\n",
    "        if add_bos:\n",
    "            title = f\"{num_layers}_layer-autoregressive-add_bos\"\n",
    "        else:\n",
    "            title = f\"{num_layers}_layer-autoregressive-wo_bos\"\n",
    "    else:\n",
    "        title = f\"{num_layers}_layer-nonautoregressive\"\n",
    "    plt.title(title, fontsize=24)\n",
    "    plt.ylabel(\"Test Accuracy\", fontsize=24)\n",
    "    plt.ylim(0.75, 1.0)\n",
    "    plt.xlabel(\"Time Feature Dimension\", fontsize=24)\n",
    "    plt.legend()\n",
    "    plt.savefig(os.path.join(output_dir, f\"{title}.png\"))\n",
    "    plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plot_test_performance_vs_dim(num_layers=1, autoregressive=False, add_bos=None)\n",
    "plot_test_performance_vs_dim(num_layers=1, autoregressive=True, add_bos=False)\n",
    "plot_test_performance_vs_dim(num_layers=1, autoregressive=True, add_bos=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plot_test_performance_vs_dim(num_layers=2, autoregressive=False, add_bos=None)\n",
    "plot_test_performance_vs_dim(num_layers=2, autoregressive=True, add_bos=False)\n",
    "plot_test_performance_vs_dim(num_layers=2, autoregressive=True, add_bos=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "dg-shift",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.16"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
