{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import copy\n",
    "import pickle\n",
    "\n",
    "import matplotlib as mpl\n",
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import seaborn as sns\n",
    "import torch.nn as nn\n",
    "import torch.optim\n",
    "from tqdm.auto import tqdm\n",
    "from scipy.cluster.hierarchy import linkage, dendrogram, fcluster\n",
    "from scipy.spatial.distance import squareform, pdist\n",
    "\n",
    "import multitask.dataset as dataset\n",
    "from multitask.models.task_switching import get_task_model, calculate_rdm, plot_rdm\n",
    "import multitask.models.task_switching.utils as utils\n",
    "import multitask.models.task_switching.hooks as hooks\n",
    "\n",
    "from train.utils.training import get_device\n",
    "from train.utils.argparse import check_runs\n",
    "\n",
    "sns.set_theme(style='ticks', palette='pastel')\n",
    "mpl.rcParams['font.family'] = 'Liberation Sans'\n",
    "mpl.rcParams['axes.spines.right'] = False\n",
    "mpl.rcParams['axes.spines.top'] = False\n",
    "\n",
    "model_path = os.path.join('out', 'task_switching')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "num_runs = 10\n",
    "initial_seed = 6789\n",
    "max_seed = 10e5\n",
    "num_epochs = 50\n",
    "num_hidden = 10 * [100]\n",
    "batch_size = 100\n",
    "num_train = 50000\n",
    "num_test = 10000\n",
    "tasks_names = ['parity', 'value']\n",
    "num_tasks = len(tasks_names)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "num_layers = len(num_hidden)\n",
    "list_results = []\n",
    "\n",
    "for max_contexts in range(1, num_layers+1):\n",
    "\n",
    "    idxs_contexts = list(range(max_contexts))\n",
    "    num_hidden_contexts = len(idxs_contexts) * [num_hidden[0]]\n",
    "    print(idxs_contexts, num_hidden_contexts)\n",
    "\n",
    "    parameters = {\n",
    "        'num_runs': num_runs,\n",
    "        'initial_seed': initial_seed,\n",
    "        'max_seed': max_seed,\n",
    "        'num_epochs': num_epochs,\n",
    "        'num_hidden': num_hidden_contexts,\n",
    "        'batch_size': batch_size,\n",
    "        'num_train': num_train,\n",
    "        'num_test': num_test,\n",
    "        'tasks': tasks_names,\n",
    "        'idxs_contexts': idxs_contexts\n",
    "    }\n",
    "\n",
    "    data_folder = check_runs(model_path, parameters)\n",
    "\n",
    "    pickle_data = os.path.join(data_folder, 'data.pickle')\n",
    "    with open(pickle_data, 'rb') as handle:\n",
    "        results_task_switching = pickle.load(handle)\n",
    "    list_results.append(results_task_switching)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "if num_tasks > 2:\n",
    "    raise NotImplementedError"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "tasks_datasets = dataset.get_tasks_dict(tasks_names, root='data')\n",
    "\n",
    "task_switching_tasks = {}\n",
    "num_tasks = len(tasks_names)\n",
    "\n",
    "for i_context, task_name in enumerate(tasks_names):\n",
    "    task_switching_tasks[task_name] = {}\n",
    "    task_switching_tasks[task_name]['data'] = tasks_datasets[task_name]\n",
    "    task_switching_tasks[task_name]['activations'] = num_tasks * [0]\n",
    "    task_switching_tasks[task_name]['activations'][i_context] = 1\n",
    "\n",
    "for key, value in task_switching_tasks.items():\n",
    "    print(f'{key}: {value[\"activations\"]}')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "criterion = nn.CrossEntropyLoss()\n",
    "device = get_device()\n",
    "\n",
    "acc_test_parity = np.zeros((num_runs, num_layers))\n",
    "acc_test_value = np.zeros((num_runs, num_layers))\n",
    "acc_test_joint = np.zeros((num_runs, num_layers))\n",
    "\n",
    "\n",
    "for i_results, results in enumerate(list_results):\n",
    "    seeds = list(results.keys())\n",
    "    idxs_contexts = list(range(i_results+1))\n",
    "    num_hidden_contexts = len(idxs_contexts) * [num_hidden[0]]\n",
    "    print(idxs_contexts, num_hidden_contexts)\n",
    "    for j_seed, seed in enumerate(seeds):\n",
    "        state_dict = results[seed]['model']\n",
    "        model = get_task_model(task_switching_tasks,\n",
    "                               num_hidden_contexts,\n",
    "                               idxs_contexts,\n",
    "                               device)\n",
    "        model.load_state_dict(state_dict)\n",
    "        indices = results[seed]['indices']\n",
    "\n",
    "        test_sampler = dataset.SequentialSampler(indices['test'])\n",
    "        _, test_dataloaders = dataset.create_dict_dataloaders(task_switching_tasks,\n",
    "                                                            indices,\n",
    "                                                            batch_size=batch_size)\n",
    "        tasks_testloader = dataset.SequentialTaskDataloader(test_dataloaders)\n",
    "\n",
    "        numbers = test_dataloaders[tasks_names[0]].dataset.numbers.numpy()\n",
    "        numbers = numbers[indices['test']]\n",
    "\n",
    "        acc_test_increase, _ = hooks.get_layer_activations(model,\n",
    "                                                           tasks_testloader,\n",
    "                                                           criterion,\n",
    "                                                           device=device,\n",
    "                                                           disable=True)\n",
    "        \n",
    "        acc_test_parity[j_seed, i_results] = acc_test_increase['parity'].mean()\n",
    "        acc_test_value[j_seed, i_results] = acc_test_increase['value'].mean()\n",
    "        acc_test_joint[j_seed, i_results] = (acc_test_increase['parity'] * acc_test_increase['value']).mean()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "accuracies_df = pd.DataFrame(columns=['Acc', 'Idx', 'Task'])\n",
    "\n",
    "# for i_acc_parity, acc_parity in enumerate(acc_test_parity.T):\n",
    "#     parity_df = pd.DataFrame({'Acc': acc_parity, 'Idx': i_acc_parity, 'Task': 'Parity'})\n",
    "#     accuracies_df = pd.concat((accuracies_df, parity_df))\n",
    "\n",
    "# for i_acc_value, acc_value in enumerate(acc_test_value.T):\n",
    "#     value_df = pd.DataFrame({'Acc': acc_value, 'Idx': i_acc_value, 'Task': 'Value'})\n",
    "#     accuracies_df = pd.concat((accuracies_df, value_df))\n",
    "\n",
    "for i_acc_joint, acc_joint in enumerate(acc_test_joint.T):\n",
    "     joint_df = pd.DataFrame({'Acc': acc_joint, 'Idx': i_acc_joint, 'Task': 'Joint'})\n",
    "     accuracies_df = pd.concat((accuracies_df, joint_df))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fig = plt.figure()\n",
    "sns.barplot(x='Acc', y='Idx', hue='Task', data=accuracies_df, errorbar=('se'), errwidth=1.5, capsize=0.15, orient='horizontal')\n",
    "plt.xlabel('Accuracy', fontsize=16)\n",
    "plt.ylabel('Task', fontsize=16)\n",
    "plt.xlim(0.96, 0.985)\n",
    "plt.legend(loc='lower left', prop={'size':12})\n",
    "\n",
    "fig.savefig('figures/figure05/fig05c_acc_weights_contexts_layers.svg')\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3.8.10",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.10"
  },
  "orig_nbformat": 4,
  "vscode": {
   "interpreter": {
    "hash": "5de0b3d16828453b801d3a971a2e845298ac67ea708b1fd16f0d1197d2abd69f"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
