{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import pickle\n",
    "\n",
    "import matplotlib as mpl\n",
    "import matplotlib.pyplot as plt\n",
    "import pandas as pd\n",
    "import numpy as np\n",
    "import seaborn as sns\n",
    "import torch.nn as nn\n",
    "import torch.optim\n",
    "from tqdm.auto import tqdm\n",
    "\n",
    "import multitask.dataset as dataset\n",
    "from multitask.models.individual import get_individual_model\n",
    "from multitask.models.individual import train as train_individual\n",
    "from multitask.models.individual import hooks as hooks_individual\n",
    "from multitask.models.parallel import get_parallel_model\n",
    "from multitask.models.parallel import train as train_parallel\n",
    "from multitask.models.parallel import hooks as hooks_parallel\n",
    "from multitask.models.task_switching import get_task_model\n",
    "from multitask.models.task_switching import train as train_task_switching\n",
    "from multitask.models.task_switching import hooks as hooks_task_switching\n",
    "\n",
    "from train.utils.argparse import check_runs\n",
    "from train.utils.training import get_device\n",
    "\n",
    "sns.set_theme(style='ticks', palette='pastel')\n",
    "mpl.rcParams['font.family'] = 'Liberation Sans'\n",
    "mpl.rcParams['axes.spines.right'] = False\n",
    "mpl.rcParams['axes.spines.top'] = False\n",
    "\n",
    "model_path_individual = os.path.join('out', 'individual')\n",
    "model_path_parallel = os.path.join('out', 'parallel')\n",
    "model_path_task_switching = os.path.join('out', 'task_switching')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "num_runs = 10\n",
    "initial_seed = 6789\n",
    "max_seed = 10e5\n",
    "num_epochs = 100\n",
    "num_hidden = 5 * [100]\n",
    "batch_size = 100\n",
    "num_train = 50000\n",
    "num_test = 10000\n",
    "tasks_names = ['parity', 'value']\n",
    "# tasks_names = [\"parity\", \"small\", \"prime\", \"fibonacci\", \"multiples_3\"]\n",
    "# tasks_names = [\"parity\", \"imparity\", \"small\", \"large\", \"prime\", \"not_prime\", \"fibonacci\", \"not_fibonacci\", \"multiples_3\", \"not_multiples_3\"]\n",
    "idxs_contexts = [0, 1, 2, 3, 4]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "parameters = {\n",
    "    'num_runs': num_runs,\n",
    "    'initial_seed': initial_seed,\n",
    "    'max_seed': max_seed,\n",
    "    'num_epochs': num_epochs,\n",
    "    'num_hidden': num_hidden,\n",
    "    'batch_size': batch_size,\n",
    "    'num_train': num_train,\n",
    "    'num_test': num_test,\n",
    "    'tasks': tasks_names,\n",
    "    'idxs_contexts': idxs_contexts\n",
    "}\n",
    "data_folder_task_switching = check_runs(model_path_task_switching, parameters)\n",
    "\n",
    "parameters['idxs_contexts'] = None\n",
    "data_folder_individual = check_runs(model_path_individual, parameters)\n",
    "data_folder_parallel = check_runs(model_path_parallel, parameters)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "pickle_data_individual = os.path.join(data_folder_individual, 'data.pickle')\n",
    "with open(pickle_data_individual, 'rb') as handle:\n",
    "    results_individual = pickle.load(handle)\n",
    "\n",
    "pickle_data_parallel = os.path.join(data_folder_parallel, 'data.pickle')\n",
    "with open(pickle_data_parallel, 'rb') as handle:\n",
    "    results_parallel = pickle.load(handle)\n",
    "\n",
    "pickle_data_task_switching = os.path.join(data_folder_task_switching, 'data.pickle')\n",
    "with open(pickle_data_task_switching, 'rb') as handle:\n",
    "    results_task_switching = pickle.load(handle)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "seeds_individual = sorted(list(results_individual.keys()))\n",
    "seeds_parallel = sorted(list(results_parallel.keys()))\n",
    "seeds_task_switching = sorted(list(results_task_switching.keys()))\n",
    "assert seeds_individual == seeds_parallel == seeds_task_switching"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "device = get_device()\n",
    "criterion = nn.CrossEntropyLoss()\n",
    "tasks = dataset.get_tasks_dict(tasks_names, root='data')\n",
    "num_tasks = len(tasks)\n",
    "num_layers = len(num_hidden)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "sparsity_individual = np.zeros((num_tasks, num_runs, num_layers))\n",
    "dead_individual = np.zeros((num_tasks, num_runs, num_layers))\n",
    "\n",
    "for i_seed, seed in tqdm(enumerate(seeds_individual), total=len(seeds_individual)):\n",
    "    indices = results_individual[seed]['indices']\n",
    "    test_model = get_individual_model(num_hidden, device)\n",
    "    test_sampler = dataset.SequentialSampler(indices['test'])\n",
    "\n",
    "    for i_task, (task_name, task_dataset) in enumerate(tasks.items()):\n",
    "        saved_model = results_individual[seed][task_name]['model']\n",
    "        test_model.load_state_dict(saved_model)\n",
    "        test_model = test_model.to(device)\n",
    "        \n",
    "        test_sampler = dataset.SequentialSampler(indices['test'])\n",
    "        testloader = torch.utils.data.DataLoader(task_dataset,\n",
    "                                                 sampler=test_sampler,\n",
    "                                                 batch_size=100)\n",
    "\n",
    "        _, activations_individuals = hooks_individual.get_layer_activations(test_model,\n",
    "                                                             testloader,\n",
    "                                                             criterion,\n",
    "                                                             device=device,\n",
    "                                                             disable=True)\n",
    "                                                \n",
    "        for j_layer in range(num_layers):\n",
    "            layer = f'layer{j_layer+1}'\n",
    "            sparsity_individual[i_task, i_seed, j_layer] = 100 * (np.sum(activations_individuals[layer] == 0, axis=1).mean() / num_hidden[j_layer])\n",
    "            dead_individual[i_task, i_seed, j_layer] = 100 * (np.sum(activations_individuals[layer].sum(axis=0) == 0) / num_hidden[j_layer])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "parallel_datasets = {}\n",
    "for task_name in tasks_names:\n",
    "    parallel_datasets[task_name] = tasks[task_name]\n",
    "\n",
    "parallel_tasks = dataset.MultilabelTasks(parallel_datasets)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "sparsity_parallel = np.zeros((num_runs, num_layers))\n",
    "dead_parallel = np.zeros((num_runs, num_layers))\n",
    "\n",
    "for i_seed, seed in tqdm(enumerate(seeds_parallel), total=num_runs):\n",
    "    saved_model = results_parallel[seed]['model']\n",
    "    test_model = get_parallel_model(num_tasks,\n",
    "                               num_hidden,\n",
    "                               device)\n",
    "    test_model.load_state_dict(saved_model)\n",
    "    test_model = test_model.to(device)\n",
    "    \n",
    "    indices = results_parallel[seed]['indices']\n",
    "\n",
    "    test_sampler = dataset.SequentialSampler(indices['test'])\n",
    "    parallel_testloader = torch.utils.data.DataLoader(parallel_tasks,\n",
    "                                                      sampler=test_sampler,\n",
    "                                                      batch_size=batch_size)\n",
    "\n",
    "    numbers = parallel_datasets[tasks_names[0]].numbers\n",
    "    numbers = numbers[indices['test']]\n",
    "    \n",
    "    _, activations_parallel = hooks_parallel.get_layer_activations(test_model,\n",
    "                                                       parallel_testloader,\n",
    "                                                       criterion=criterion,\n",
    "                                                       device=device,\n",
    "                                                       disable=True)\n",
    "    \n",
    "    for j_layer in range(num_layers):\n",
    "        layer = f'layer{j_layer+1}'\n",
    "        sparsity_parallel[i_seed, j_layer] = 100 * (np.sum(activations_parallel[layer] == 0, axis=1).mean() / num_hidden[j_layer])\n",
    "        dead_parallel[i_seed, j_layer] = 100 * (np.sum(activations_parallel[layer].sum(axis=0) == 0) / num_hidden[j_layer])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "tasks_datasets = dataset.get_tasks_dict(tasks_names, root='data')\n",
    "\n",
    "task_switching_tasks = {}\n",
    "num_tasks = len(tasks_names)\n",
    "\n",
    "for i_context, task_name in enumerate(tasks_names):\n",
    "    task_switching_tasks[task_name] = {}\n",
    "    task_switching_tasks[task_name]['data'] = tasks_datasets[task_name]\n",
    "    task_switching_tasks[task_name]['activations'] = num_tasks * [0]\n",
    "    task_switching_tasks[task_name]['activations'][i_context] = 1\n",
    "\n",
    "for key, value in task_switching_tasks.items():\n",
    "    print(f'{key}: {value[\"activations\"]}')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "sparsity_task_switching = np.zeros((num_runs, num_layers))\n",
    "dead_task_switching = np.zeros((num_runs, num_layers))\n",
    "\n",
    "\n",
    "for i_seed, seed in tqdm(enumerate(seeds_task_switching), total=num_runs):\n",
    "    state_dict = results_task_switching[seed]['model']\n",
    "    model = get_task_model(task_switching_tasks,\n",
    "                           num_hidden,\n",
    "                           idxs_contexts,\n",
    "                           device)\n",
    "    model.load_state_dict(state_dict)\n",
    "    \n",
    "    indices = results_task_switching[seed]['indices']\n",
    "\n",
    "    test_sampler = dataset.SequentialSampler(indices['test'])\n",
    "    _, test_dataloaders = dataset.create_dict_dataloaders(task_switching_tasks,\n",
    "                                                          indices,\n",
    "                                                          batch_size=batch_size)\n",
    "    tasks_testloader = dataset.SequentialTaskDataloader(test_dataloaders)\n",
    "\n",
    "    numbers = test_dataloaders[tasks_names[0]].dataset.numbers.numpy()\n",
    "    numbers = numbers[indices['test']]\n",
    "\n",
    "    _, activations_task_switching = hooks_task_switching.get_layer_activations(model,\n",
    "                                                             tasks_testloader,\n",
    "                                                             criterion,\n",
    "                                                             device=device,\n",
    "                                                             disable=True)\n",
    "\n",
    "    for j_layer in range(num_layers):\n",
    "        layer = f'layer{j_layer+1}'\n",
    "        for i_task, task in enumerate(tasks_names):\n",
    "            if i_task == 0:\n",
    "                total_activations_layer = activations_task_switching[task][layer]\n",
    "            else:\n",
    "                total_activations_layer = np.vstack((total_activations_layer,\n",
    "                                                     activations_task_switching[task][layer]))\n",
    "        \n",
    "        sparsity_task_switching[i_seed, j_layer] = 100 * (np.sum(total_activations_layer == 0, axis=1).mean() / num_hidden[j_layer])\n",
    "        dead_task_switching[i_seed, j_layer] = 100 * (np.sum(total_activations_layer.sum(axis=0) == 0) / num_hidden[j_layer])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "mean_sparsity_individual = sparsity_individual.mean(axis=0).mean(axis=0)\n",
    "mean_sparsity_parallel = sparsity_parallel.mean(axis=0)\n",
    "mean_sparsity_task_switching = sparsity_task_switching.mean(axis=0)\n",
    "\n",
    "std_sparsity_individual = sparsity_individual.mean(axis=0).std(axis=0)\n",
    "std_sparsity_parallel = sparsity_parallel.std(axis=0)\n",
    "std_sparsity_task_switching = sparsity_task_switching.std(axis=0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "layers = range(1, num_layers + 1)\n",
    "fig = plt.figure()\n",
    "\n",
    "plt.plot(layers, mean_sparsity_individual)\n",
    "plt.plot(layers, mean_sparsity_parallel)\n",
    "plt.plot(layers, mean_sparsity_task_switching,)\n",
    "\n",
    "plt.fill_between(layers,\n",
    "                 mean_sparsity_individual-std_sparsity_individual,\n",
    "                 mean_sparsity_individual+std_sparsity_individual,\n",
    "                 alpha=0.5)\n",
    "\n",
    "plt.fill_between(layers,\n",
    "                 mean_sparsity_parallel-std_sparsity_parallel,\n",
    "                 mean_sparsity_parallel+std_sparsity_parallel,\n",
    "                 alpha=0.5)\n",
    "\n",
    "\n",
    "plt.fill_between(layers,\n",
    "                 mean_sparsity_task_switching-std_sparsity_task_switching,\n",
    "                 mean_sparsity_task_switching+std_sparsity_task_switching,\n",
    "                 alpha=0.5)\n",
    "plt.xlabel('Layer', fontsize=16)\n",
    "# plt.ylabel('Mean Squared Error', fontsize=16)\n",
    "plt.ylabel('Sparsity (%)', fontsize=16)\n",
    "plt.xticks(layers, fontsize=14, fontname='Liberation Sans')\n",
    "plt.yticks(fontsize=12,  fontname='Liberation Sans')\n",
    "plt.legend(['Individual', 'Parallel', 'Task Switching'], prop={'size':12})\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "df_sparsity_all = pd.DataFrame({}, columns=['Sparsity', 'Model', 'Layer'])\n",
    "for i_layer in range(num_layers):\n",
    "    layer = f'layer{i_layer+1}'\n",
    "    df_sparsity_individual = pd.DataFrame({'Sparsity': sparsity_individual.mean(axis=0)[:, i_layer], 'Model': 'Individual', 'Layer': layer})\n",
    "    df_sparsity_parallel = pd.DataFrame({'Sparsity': sparsity_parallel[:, i_layer], 'Model': 'Parallel', 'Layer': layer})\n",
    "    df_sparsity_task_switching = pd.DataFrame({'Sparsity': sparsity_task_switching[:, i_layer], 'Model': 'Task Switching', 'Layer': layer})\n",
    "    df_sparsity_all = pd.concat([df_sparsity_all, df_sparsity_individual, df_sparsity_parallel, df_sparsity_task_switching])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "mean_dead_individual = dead_individual.mean(axis=0).mean(axis=0)\n",
    "mean_dead_parallel = dead_parallel.mean(axis=0)\n",
    "mean_dead_task_switching = dead_task_switching.mean(axis=0)\n",
    "\n",
    "std_dead_individual = dead_individual.mean(axis=0).std(axis=0)\n",
    "std_dead_parallel = dead_parallel.std(axis=0)\n",
    "std_dead_task_switching = dead_task_switching.std(axis=0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "layers = range(1, num_layers + 1)\n",
    "fig = plt.figure()\n",
    "\n",
    "plt.plot(layers, mean_dead_individual)\n",
    "plt.plot(layers, mean_dead_parallel)\n",
    "plt.plot(layers, mean_dead_task_switching,)\n",
    "\n",
    "plt.fill_between(layers,\n",
    "                 mean_dead_individual-std_dead_individual,\n",
    "                 mean_dead_individual+std_dead_individual,\n",
    "                 alpha=0.5)\n",
    "\n",
    "plt.fill_between(layers,\n",
    "                 mean_dead_parallel-std_dead_parallel,\n",
    "                 mean_dead_parallel+std_dead_parallel,\n",
    "                 alpha=0.5)\n",
    "\n",
    "\n",
    "plt.fill_between(layers,\n",
    "                 mean_dead_task_switching-std_dead_task_switching,\n",
    "                 mean_dead_task_switching+std_dead_task_switching,\n",
    "                 alpha=0.5)\n",
    "plt.xlabel('Layer', fontsize=16)\n",
    "# plt.ylabel('Mean Squared Error', fontsize=16)\n",
    "plt.ylabel('Dead Units (%)', fontsize=16)\n",
    "plt.xticks(layers, fontsize=14, fontname='Liberation Sans')\n",
    "plt.yticks(fontsize=12,  fontname='Liberation Sans')\n",
    "plt.legend(['Individual', 'Parallel', 'Task Switching'], prop={'size':12})\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "df_sparsity_all = pd.DataFrame({}, columns=['Sparsity', 'Model', 'Layer'])\n",
    "for i_layer in range(num_layers):\n",
    "    layer = f'layer{i_layer+1}'\n",
    "    df_sparsity_individual = pd.DataFrame({'Sparsity': sparsity_individual.mean(axis=0)[:, i_layer], 'Model': 'Individual', 'Layer': layer})\n",
    "    df_sparsity_parallel = pd.DataFrame({'Sparsity': sparsity_parallel[:, i_layer], 'Model': 'Parallel', 'Layer': layer})\n",
    "    df_sparsity_task_switching = pd.DataFrame({'Sparsity': sparsity_task_switching[:, i_layer], 'Model': 'Task Switching', 'Layer': layer})\n",
    "    df_sparsity_all = pd.concat([df_sparsity_all, df_sparsity_individual, df_sparsity_parallel, df_sparsity_task_switching])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "df_dead_all = pd.DataFrame({}, columns=['Dead', 'Model', 'Layer'])\n",
    "for i_layer in range(num_layers):\n",
    "    layer = f'layer{i_layer+1}'\n",
    "    df_dead_individual = pd.DataFrame({'Dead': dead_individual.mean(axis=0)[:, i_layer], 'Model': 'Individual', 'Layer': layer})\n",
    "    df_dead_parallel = pd.DataFrame({'Dead': dead_parallel[:, i_layer], 'Model': 'Parallel', 'Layer': layer})\n",
    "    df_dead_task_switching = pd.DataFrame({'Dead': dead_task_switching[:, i_layer], 'Model': 'Task Switching', 'Layer': layer})\n",
    "    df_dead_all = pd.concat([df_dead_all, df_dead_individual, df_dead_parallel, df_dead_task_switching])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fig, ax = plt.subplots(1, 2, figsize=(12, 4))\n",
    "\n",
    "sns.barplot(data=df_sparsity_all, x=\"Layer\", y=\"Sparsity\", hue='Model', ci='sd', ax=ax[0])\n",
    "sns.barplot(data=df_dead_all, x=\"Layer\", y=\"Dead\", hue='Model', ci='sd', ax=ax[1])\n",
    "fig.suptitle(f'Num. Layers: {num_layers}   Num. Units: {num_hidden[0]}')\n",
    "\n",
    "ax[0].set_ylim(40, 100)\n",
    "ax[1].set_ylim(0, 100)\n",
    "\n",
    "fig.tight_layout()\n",
    "fig.savefig(f'figures/figS02_sparsity_{num_layers}_{num_hidden[0]}_{num_tasks}.pdf')\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3.8.10",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.10"
  },
  "orig_nbformat": 4,
  "vscode": {
   "interpreter": {
    "hash": "5de0b3d16828453b801d3a971a2e845298ac67ea708b1fd16f0d1197d2abd69f"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
