{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import pickle\n",
    "\n",
    "import matplotlib as mpl\n",
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "import seaborn as sns\n",
    "import torch.nn as nn\n",
    "import torch.optim\n",
    "from tqdm.auto import tqdm\n",
    "\n",
    "from train.utils.argparse import check_runs\n",
    "\n",
    "sns.set_theme(style='ticks', palette='pastel')\n",
    "mpl.rcParams['font.family'] = 'Liberation Sans'\n",
    "mpl.rcParams['axes.spines.right'] = False\n",
    "mpl.rcParams['axes.spines.top'] = False\n",
    "\n",
    "model_path_individual = os.path.join('out', 'individual')\n",
    "model_path_parallel = os.path.join('out', 'parallel')\n",
    "model_path_task_switching = os.path.join('out', 'task_switching')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "num_runs = 10\n",
    "initial_seed = 6789\n",
    "max_seed = 10e5\n",
    "num_epochs = 100\n",
    "num_hidden = 5 * [100]\n",
    "batch_size = 100\n",
    "num_train = 50000\n",
    "num_test = 10000\n",
    "tasks_names = ['parity', 'value']\n",
    "idxs_contexts = list(range(5))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "parameters = {\n",
    "    'num_runs': num_runs,\n",
    "    'initial_seed': initial_seed,\n",
    "    'max_seed': max_seed,\n",
    "    'num_epochs': num_epochs,\n",
    "    'num_hidden': num_hidden,\n",
    "    'batch_size': batch_size,\n",
    "    'num_train': num_train,\n",
    "    'num_test': num_test,\n",
    "    'tasks': tasks_names,\n",
    "    'idxs_contexts': idxs_contexts\n",
    "}\n",
    "data_folder_task_switching = check_runs(model_path_task_switching, parameters)\n",
    "\n",
    "parameters['idxs_contexts'] = None\n",
    "data_folder_individual = check_runs(model_path_individual, parameters)\n",
    "data_folder_parallel = check_runs(model_path_parallel, parameters)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "pickle_data_individual = os.path.join(data_folder_individual, 'data.pickle')\n",
    "with open(pickle_data_individual, 'rb') as handle:\n",
    "    results_individual = pickle.load(handle)\n",
    "\n",
    "pickle_data_parallel = os.path.join(data_folder_parallel, 'data.pickle')\n",
    "with open(pickle_data_parallel, 'rb') as handle:\n",
    "    results_parallel = pickle.load(handle)\n",
    "\n",
    "pickle_data_task_switching = os.path.join(data_folder_task_switching, 'data.pickle')\n",
    "with open(pickle_data_task_switching, 'rb') as handle:\n",
    "    results_task_switching = pickle.load(handle)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "seeds_individual = sorted(list(results_individual.keys()))\n",
    "seeds_parallel = sorted(list(results_parallel.keys()))\n",
    "seeds_task_switching = sorted(list(results_task_switching.keys()))\n",
    "assert seeds_individual == seeds_parallel == seeds_task_switching"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "results_models = {}\n",
    "results_models['individual'] = {}\n",
    "results_models['parallel'] = {}\n",
    "results_models['task_switching'] = {}\n",
    "list_models = ['individual', 'parallel', 'task_switching']\n",
    "\n",
    "for model in list_models:\n",
    "    for task_name in tasks_names:\n",
    "        results_models[model][task_name] = {}\n",
    "        results_models[model][task_name]['train_loss'] = np.zeros((num_runs, num_epochs))\n",
    "        results_models[model][task_name]['train_acc'] = np.zeros((num_runs, num_epochs))\n",
    "        results_models[model][task_name]['valid_loss'] = np.zeros((num_runs, num_epochs))\n",
    "        results_models[model][task_name]['valid_acc'] = np.zeros((num_runs, num_epochs))\n",
    "\n",
    "num_tasks = len(tasks_names)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for i_seed, seed in enumerate(seeds_individual):\n",
    "    for task_name in tasks_names:\n",
    "        results_models['individual'][task_name]['train_loss'][i_seed, :] = results_individual[seed][task_name]['results']['train_loss']\n",
    "        results_models['individual'][task_name]['train_acc'][i_seed, :] = results_individual[seed][task_name]['results']['train_acc']\n",
    "        results_models['individual'][task_name]['valid_loss'][i_seed, :] = results_individual[seed][task_name]['results']['valid_loss']\n",
    "        results_models['individual'][task_name]['valid_acc'][i_seed, :] = results_individual[seed][task_name]['results']['valid_acc']\n",
    "\n",
    "        results_models['parallel'][task_name]['train_loss'][i_seed, :] = results_parallel[seed]['results']['train_loss'][task_name]\n",
    "        results_models['parallel'][task_name]['train_acc'][i_seed, :] = results_parallel[seed]['results']['train_acc'][task_name]\n",
    "        results_models['parallel'][task_name]['valid_loss'][i_seed, :] = results_parallel[seed]['results']['valid_loss'][task_name]\n",
    "        results_models['parallel'][task_name]['valid_acc'][i_seed, :] = results_parallel[seed]['results']['valid_acc'][task_name]\n",
    "\n",
    "        results_models['task_switching'][task_name]['train_loss'][i_seed, :] = results_task_switching[seed]['results']['train_loss'][task_name]\n",
    "        results_models['task_switching'][task_name]['train_acc'][i_seed, :] = results_task_switching[seed]['results']['train_acc'][task_name]\n",
    "        results_models['task_switching'][task_name]['valid_loss'][i_seed, :] = results_task_switching[seed]['results']['valid_loss'][task_name]\n",
    "        results_models['task_switching'][task_name]['valid_acc'][i_seed, :] = results_task_switching[seed]['results']['valid_acc'][task_name]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from tkinter.tix import Y_REGION\n",
    "\n",
    "\n",
    "epochs = range(num_epochs)\n",
    "metrics = ['train_loss', 'train_acc', 'valid_loss', 'valid_acc']\n",
    "\n",
    "fig, ax = plt.subplots(num_tasks, 4, figsize=(12, 4))\n",
    "\n",
    "for i_task, task_name in enumerate(tasks_names):\n",
    "    for j_model, model in enumerate(list_models):\n",
    "        for k_metric, metric in enumerate(metrics):\n",
    "            mean_model = results_models[model][task_name][metric].mean(axis=0)\n",
    "            std_model = results_models[model][task_name][metric].std(axis=0)\n",
    "\n",
    "            ax[i_task, k_metric].plot(epochs, mean_model)\n",
    "            ax[i_task, k_metric].fill_between(epochs,\n",
    "                                              mean_model-std_model,\n",
    "                                              mean_model+std_model,\n",
    "                                              alpha=0.5)\n",
    "            if i_task == num_tasks - 1:\n",
    "                ax[i_task, k_metric].set_xlabel('Epochs')\n",
    "            \n",
    "            ylabel = ' '.join(metric.split('_')).capitalize()\n",
    "            ax[i_task, k_metric].set_ylabel(ylabel)\n",
    "\n",
    "            # if metric == 'train_loss' or metric == 'valid_loss':\n",
    "            #     ax[i_task, k_metric].set_ylim(0, 0.2)\n",
    "\n",
    "            # if metric == 'train_acc' or metric == 'valid_acc':\n",
    "            #     ax[i_task, k_metric].set_ylim(0.9, 1)\n",
    "\n",
    "\n",
    "fig.tight_layout()\n",
    "plt.show()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3.8.10",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.10"
  },
  "orig_nbformat": 4,
  "vscode": {
   "interpreter": {
    "hash": "5de0b3d16828453b801d3a971a2e845298ac67ea708b1fd16f0d1197d2abd69f"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
