{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import copy\n",
    "import pickle\n",
    "\n",
    "import matplotlib as mpl\n",
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "import seaborn as sns\n",
    "import torch.nn as nn\n",
    "import torch.optim\n",
    "from tqdm.auto import tqdm\n",
    "from scipy.cluster.hierarchy import linkage, dendrogram, fcluster\n",
    "from scipy.spatial.distance import squareform, pdist\n",
    "\n",
    "import multitask.dataset as dataset\n",
    "from multitask.models.task_switching import get_task_model, calculate_rdm, plot_rdm\n",
    "import multitask.models.task_switching.utils as utils\n",
    "import multitask.models.task_switching.hooks as hooks\n",
    "\n",
    "from train.utils.training import get_device\n",
    "from train.utils.argparse import check_runs\n",
    "\n",
    "sns.set_theme(style='ticks', palette='pastel')\n",
    "mpl.rcParams['font.family'] = 'Liberation Sans'\n",
    "mpl.rcParams['axes.spines.right'] = False\n",
    "mpl.rcParams['axes.spines.top'] = False\n",
    "\n",
    "model_path = os.path.join('out', 'task_switching')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "num_runs = 10\n",
    "initial_seed = 6789\n",
    "max_seed = 10e5\n",
    "num_epochs = 50\n",
    "num_hidden = 10 * [100]\n",
    "batch_size = 100\n",
    "num_train = 50000\n",
    "num_test = 10000\n",
    "tasks_names = ['parity', 'value']\n",
    "num_tasks = len(tasks_names)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "num_layers = len(num_hidden)\n",
    "list_results = []\n",
    "\n",
    "for max_contexts in range(1, num_layers+1):\n",
    "\n",
    "    idxs_contexts = list(range(max_contexts))\n",
    "    print(idxs_contexts)\n",
    "\n",
    "    parameters = {\n",
    "        'num_runs': num_runs,\n",
    "        'initial_seed': initial_seed,\n",
    "        'max_seed': max_seed,\n",
    "        'num_epochs': num_epochs,\n",
    "        'num_hidden': num_hidden,\n",
    "        'batch_size': batch_size,\n",
    "        'num_train': num_train,\n",
    "        'num_test': num_test,\n",
    "        'tasks': tasks_names,\n",
    "        'idxs_contexts': idxs_contexts\n",
    "    }\n",
    "\n",
    "    data_folder = check_runs(model_path, parameters)\n",
    "\n",
    "    pickle_data = os.path.join(data_folder, 'data.pickle')\n",
    "    with open(pickle_data, 'rb') as handle:\n",
    "        results_task_switching = pickle.load(handle)\n",
    "    list_results.append(results_task_switching)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "if num_tasks > 2:\n",
    "    raise NotImplementedError"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "tasks_datasets = dataset.get_tasks_dict(tasks_names, root='data')\n",
    "\n",
    "task_switching_tasks = {}\n",
    "num_tasks = len(tasks_names)\n",
    "\n",
    "for i_context, task_name in enumerate(tasks_names):\n",
    "    task_switching_tasks[task_name] = {}\n",
    "    task_switching_tasks[task_name]['data'] = tasks_datasets[task_name]\n",
    "    task_switching_tasks[task_name]['activations'] = num_tasks * [0]\n",
    "    task_switching_tasks[task_name]['activations'][i_context] = 1\n",
    "\n",
    "for key, value in task_switching_tasks.items():\n",
    "    print(f'{key}: {value[\"activations\"]}')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "device = get_device()\n",
    "\n",
    "norm_weights_parity = np.zeros((num_runs, num_layers, num_layers))\n",
    "norm_weights_value = np.zeros((num_runs, num_layers, num_layers))\n",
    "correlations = np.zeros((num_runs, num_layers, num_layers))\n",
    "\n",
    "for i_results, results in enumerate(list_results):\n",
    "    seeds = list(results.keys())\n",
    "    idxs_contexts = list(range(i_results + 1))\n",
    "    for j_seed, seed in enumerate(seeds):\n",
    "        state_dict = results[seed]['model']\n",
    "        model = get_task_model(task_switching_tasks,\n",
    "                               num_hidden,\n",
    "                               idxs_contexts,\n",
    "                               device)\n",
    "        model.load_state_dict(state_dict)\n",
    "        for k_context in idxs_contexts:\n",
    "            weights = model.layers[k_context].weight.detach().cpu().numpy()\n",
    "            norm_parity, norm_value = np.linalg.norm(weights[:, -2:], axis=0)\n",
    "            norm_weights_parity[j_seed, i_results, k_context] = norm_parity\n",
    "            norm_weights_value[j_seed, i_results, k_context] = norm_value\n",
    "\n",
    "            weights_parity = np.array(weights[:, -2])\n",
    "            weights_value = np.array(weights[:, -1])\n",
    "            correlations[j_seed, i_results, k_context] = np.corrcoef(weights_parity, weights_value)[0, 1]\n",
    "            "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "mean_weights_parity = norm_weights_parity.mean(axis=0)\n",
    "std_weights_parity = norm_weights_parity.std(axis=0)\n",
    "\n",
    "mean_weights_value = norm_weights_value.mean(axis=0)\n",
    "std_weights_value = norm_weights_value.std(axis=0)\n",
    "\n",
    "mean_correlations = correlations.mean(axis=0)\n",
    "std_correlations = correlations.std(axis=0)\n",
    "\n",
    "mask = np.triu(1 - mean_correlations.astype(bool))\n",
    "vmax = np.max([mean_weights_parity.max(), mean_weights_value.max()])\n",
    "\n",
    "fig, ax = plt.subplots(1, 3, figsize=(9, 2))\n",
    "sns.heatmap(mean_weights_parity, ax=ax[0], mask=mask, vmin=0, vmax=vmax)\n",
    "sns.heatmap(mean_weights_value, ax=ax[1], mask=mask, vmin=0, vmax=vmax)\n",
    "sns.heatmap(mean_correlations, ax=ax[2], cmap='icefire', mask=mask, vmin=-1, vmax=1)\n",
    "\n",
    "# ax[0].set_xticks(range(num_layers))\n",
    "# ax[0].set_yticks(range(num_layers))\n",
    "# ax[0].set_xticklabels(range(1, num_layers+1), fontsize=12)\n",
    "# ax[0].set_yticklabels(range(1, num_layers+1), fontsize=12)\n",
    "ax[0].set_facecolor('silver') \n",
    "\n",
    "# ax[1].set_xticks(range(num_layers))\n",
    "# ax[1].set_yticks(range(num_layers))\n",
    "# ax[1].set_xticklabels(range(1, num_layers+1), fontsize=12)\n",
    "# ax[1].set_yticklabels(range(1, num_layers+1), fontsize=12)\n",
    "ax[1].set_facecolor('silver') \n",
    "\n",
    "# ax[2].set_xticks(range(num_layers))\n",
    "# ax[2].set_yticks(range(num_layers))\n",
    "# ax[2].set_xticklabels(range(1, num_layers+1), fontsize=12)\n",
    "# ax[2].set_yticklabels(range(1, num_layers+1), fontsize=12)\n",
    "ax[2].set_facecolor('silver') \n",
    "\n",
    "fig.savefig('figures/figure05/fig05a_weights_contexts_increase.svg')\n",
    "\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3.8.10",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.10"
  },
  "orig_nbformat": 4,
  "vscode": {
   "interpreter": {
    "hash": "5de0b3d16828453b801d3a971a2e845298ac67ea708b1fd16f0d1197d2abd69f"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
