{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "from continualworld.results_processing.utils import get_data_for_runs, METHODS_ORDER\n",
    "from continualworld.results_processing.tables import calculate_metrics\n",
    "from continualworld.results_processing.plots import visualize_sequence"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Replace with your own directories\n",
    "\n",
    "cl_logs = 'examples/logs/cl'\n",
    "mtl_logs = 'examples/logs/mtl'\n",
    "baseline_logs = 'examples/logs/baseline'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "cl_data = get_data_for_runs(cl_logs, kind=\"cl\")\n",
    "mtl_data = get_data_for_runs(mtl_logs, kind=\"mtl\")\n",
    "baseline_data = get_data_for_runs(baseline_logs, kind=\"single\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 299
    },
    "id": "8bp60d0mRUCn",
    "outputId": "ef74b0b6-e79b-4b6b-e1fe-6c8023bd6050"
   },
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>performance</th>\n",
       "      <th>lb_performance</th>\n",
       "      <th>ub_performance</th>\n",
       "      <th>forgetting</th>\n",
       "      <th>lb_forgetting</th>\n",
       "      <th>ub_forgetting</th>\n",
       "      <th>total_normalized_ft</th>\n",
       "      <th>lb_total_normalized_ft</th>\n",
       "      <th>ub_total_normalized_ft</th>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>cl_method</th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>finetuning</th>\n",
       "      <td>0.020</td>\n",
       "      <td>0.02</td>\n",
       "      <td>0.02</td>\n",
       "      <td>-0.020</td>\n",
       "      <td>-0.02</td>\n",
       "      <td>-0.02</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>ewc</th>\n",
       "      <td>0.005</td>\n",
       "      <td>0.00</td>\n",
       "      <td>0.01</td>\n",
       "      <td>-0.005</td>\n",
       "      <td>-0.01</td>\n",
       "      <td>0.00</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>mtl_popart</th>\n",
       "      <td>0.000</td>\n",
       "      <td>0.00</td>\n",
       "      <td>0.00</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "            performance  lb_performance  ub_performance  forgetting  \\\n",
       "cl_method                                                             \n",
       "finetuning        0.020            0.02            0.02      -0.020   \n",
       "ewc               0.005            0.00            0.01      -0.005   \n",
       "mtl_popart        0.000            0.00            0.00         NaN   \n",
       "\n",
       "            lb_forgetting  ub_forgetting  total_normalized_ft  \\\n",
       "cl_method                                                       \n",
       "finetuning          -0.02          -0.02                  0.0   \n",
       "ewc                 -0.01           0.00                  0.0   \n",
       "mtl_popart            NaN            NaN                  NaN   \n",
       "\n",
       "            lb_total_normalized_ft  ub_total_normalized_ft  \n",
       "cl_method                                                   \n",
       "finetuning                     0.0                     0.0  \n",
       "ewc                            0.0                     0.0  \n",
       "mtl_popart                     NaN                     NaN  "
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "table = calculate_metrics(cl_data, mtl_data, baseline_data,\n",
    "                          methods_order=METHODS_ORDER)\n",
    "table"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "visualize_sequence(\n",
    "    cl_data,\n",
    "    mtl_data,\n",
    "    baseline_data,\n",
    "    group_by=['cl_method'],\n",
    "    show_avg=True,\n",
    "    show_current=True,\n",
    "    show_individual=True,\n",
    "    show_ft=True,\n",
    "    order=('cl_method', METHODS_ORDER),\n",
    "    smoothen=False,\n",
    "    use_ci=False  # set True to show confidence intervals in all plots, will be slower though\n",
    ")"
   ]
  }
 ],
 "metadata": {
  "colab": {
   "collapsed_sections": [
    "MpDEO5Lgf4PZ",
    "E0gW5oPhgHm1",
    "KYgtu3OUQn7W",
    "I4N8R-QHRpcS",
    "OQ8ulqg3RPlE",
    "uQqgbh7LGPg2",
    "V6JMuTBQgbdJ"
   ],
   "name": "CL neptune downloader - MW - Long Seq (MZ remix). update 10.02, 18:00",
   "provenance": [],
   "toc_visible": true
  },
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
