{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "d018ce08-631a-4475-bff3-fd8b5e32f98c",
   "metadata": {},
   "outputs": [],
   "source": [
    "from matplotlib import pyplot as plt\n",
    "import seaborn as sns\n",
    "import os\n",
    "import pickle\n",
    "cur_dir = os.getcwd()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "21f77e74-87b7-49cb-8d4e-6d74cc9c03bd",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "/Users/florencer/git/EE_synthetic/figures/plots\n"
     ]
    }
   ],
   "source": [
    "result_path = os.path.join(cur_dir,'results')\n",
    "figure_path = os.path.join(cur_dir, 'plots')\n",
    "print(figure_path)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "5df936eb-fd53-41df-985a-4c03707975cb",
   "metadata": {},
   "outputs": [],
   "source": [
    "baseline = '2s_exp'\n",
    "cost = '0.03'\n",
    "prefix = baseline+str(cost)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "4a91fcc5-7ddb-4836-bab0-765ef1cf9ff9",
   "metadata": {},
   "outputs": [],
   "source": [
    "def load_experiment(filename):\n",
    "    with open(filename, \"rb\") as f:\n",
    "        training_log_dict = pickle.load(f)\n",
    "    return training_log_dict\n",
    "        \n",
    "def load_experiment_of_cost(baseline, cost):\n",
    "    filename = os.path.join(result_path, f\"{baseline}{cost}__training_result.pkl\")\n",
    "    print(filename)\n",
    "    return load_experiment(filename)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "429b0952-4304-46e3-bbc2-5a7eea6dae1d",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "/Users/florencer/git/EE_synthetic/figures/results/2s_exp0.03__training_result.pkl\n"
     ]
    }
   ],
   "source": [
    "training_log_dict = load_experiment_of_cost(baseline, cost)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "bb9de515-2aae-47c7-97fa-696eadde1976",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "dict_keys(['xzy_x', 'xzy_z', 'xzy_s', 'xzy_gt_s', 'xzy_y', 'xzy_t1', 'xzy_t2', 'param_cs', 'param_ds', 'track_t1_acc', 'track_t2_acc', 'track_01c', 'ls', 'f1ls', 'f2ls', 'l01cs', 'test_avg_l01c', 'optimal_l01c', 'track_epoch_loss', 'df_testacc', 'df_testrate', 'f1 acc', 'f2 acc'])\n"
     ]
    }
   ],
   "source": [
    "print(training_log_dict.keys())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "35b692bf-4f95-4f5a-b0ba-8047d9b0b746",
   "metadata": {},
   "outputs": [],
   "source": [
    "sns.set_theme(style=\"whitegrid\")\n",
    "plt.figure(figsize=(5, 4))\n",
    "sns.set_context(\"notebook\", font_scale=1.5)\n",
    "test_01c = training_log_dict['track_01c']\n",
    "train_loss = training_log_dict['track_epoch_loss']\n",
    "optimal_l01c = training_log_dict['optimal_l01c']\n",
    "    \n",
    "#plt.plot(train_loss, label=r'$\\hat{R}_{hinge}(f)$', marker='o')   \n",
    "plt.plot(test_01c, label=r'$\\hat{R}_{01c}(f)$', marker='o')\n",
    "plt.axhline(y=optimal_l01c, color='r', linestyle='--', linewidth=2, label=r'$\\hat{R}^*_{01c}$')\n",
    "#plt.xlim([0,100])\n",
    "#plt.ylim([0.37,0.4])\n",
    "plt.xlabel('Epoch')\n",
    "plt.ylabel(r'$R_{01c}(f)$')\n",
    "plt.legend()\n",
    "plt.tight_layout()\n",
    "#plt.show()\n",
    "plt.savefig(os.path.join(figure_path, prefix+'R01c.pdf'))\n",
    "plt.close()\n",
    "\n",
    "\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ba157f2c-436a-4479-ac16-acaa66f84a77",
   "metadata": {},
   "outputs": [],
   "source": [
    "def plot_xzy(x,z,y, prefix):\n",
    "    num_points_max = 5000\n",
    "    plt.figure(figsize=(8, 6))\n",
    "    sc = plt.scatter(x[:num_points_max,:], z[:num_points_max,:], c=y[:num_points_max], cmap=\"viridis\", edgecolor=\"k\", alpha=0.5)\n",
    "    plt.colorbar(sc)\n",
    "    plt.xlabel(\"X\")\n",
    "    plt.ylabel(\"Z\")\n",
    "    if not os.path.exists(figure_path):\n",
    "        os.makedirs(figure_path)\n",
    "    plt.savefig(os.path.join(figure_path,prefix+'xyz.pdf'))\n",
    "    plt.close()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5a560ed1-2f3e-4f2c-8328-26049f4a0e9a",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "55ec96dc-ce49-41fe-af96-9a40f16ccc13",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
