{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import copy\n",
    "import pickle\n",
    "\n",
    "import matplotlib as mpl\n",
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "import seaborn as sns\n",
    "import torch.nn as nn\n",
    "import torch.optim\n",
    "from tqdm.auto import tqdm\n",
    "from string import digits\n",
    "from scipy.cluster.hierarchy import linkage, dendrogram, fcluster\n",
    "from scipy.spatial.distance import squareform, pdist\n",
    "\n",
    "import multitask.dataset as dataset\n",
    "from multitask.models.task_switching import get_task_model, calculate_rdm, plot_rdm\n",
    "import multitask.models.task_switching.utils as utils\n",
    "import multitask.models.task_switching.hooks as hooks\n",
    "\n",
    "from train.utils.training import get_device\n",
    "from train.utils.argparse import check_runs\n",
    "\n",
    "sns.set_theme(style='ticks', palette='pastel')\n",
    "mpl.rcParams['font.family'] = 'Liberation Sans'\n",
    "\n",
    "model_path = os.path.join('out', 'task_switching')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "num_runs = 10\n",
    "initial_seed = 6789\n",
    "max_seed = 10e5\n",
    "num_epochs = 50\n",
    "num_hidden = 10 * [100]\n",
    "batch_size = 100\n",
    "num_train = 50000\n",
    "num_test = 10000\n",
    "# tasks_names = ['parity', 'value']\n",
    "tasks_names = [\"parity\", \"value\", \"prime\", \"fibonacci\", \"multiples_3\"]\n",
    "idxs_contexts = list(range(len(num_hidden)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "parameters = {\n",
    "    'num_runs': num_runs,\n",
    "    'initial_seed': initial_seed,\n",
    "    'max_seed': max_seed,\n",
    "    'num_epochs': num_epochs,\n",
    "    'num_hidden': num_hidden,\n",
    "    'batch_size': batch_size,\n",
    "    'num_train': num_train,\n",
    "    'num_test': num_test,\n",
    "    'tasks': tasks_names,\n",
    "    'idxs_contexts': idxs_contexts\n",
    "}\n",
    "\n",
    "data_folder = check_runs(model_path, parameters)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "pickle_data = os.path.join(data_folder, 'data.pickle')\n",
    "with open(pickle_data, 'rb') as handle:\n",
    "    results_task_switching = pickle.load(handle)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "seeds = sorted(list(results_task_switching.keys()))\n",
    "num_seeds = len(seeds)\n",
    "num_tasks = len(tasks_names)\n",
    "\n",
    "print(seeds)\n",
    "print(tasks_names)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "tasks_datasets = dataset.get_tasks_dict(tasks_names, root='data')\n",
    "\n",
    "task_switching_tasks = {}\n",
    "num_tasks = len(tasks_names)\n",
    "\n",
    "for i_context, task_name in enumerate(tasks_names):\n",
    "    task_switching_tasks[task_name] = {}\n",
    "    task_switching_tasks[task_name]['data'] = tasks_datasets[task_name]\n",
    "    task_switching_tasks[task_name]['activations'] = num_tasks * [0]\n",
    "    task_switching_tasks[task_name]['activations'][i_context] = 1\n",
    "\n",
    "for key, value in task_switching_tasks.items():\n",
    "    print(f'{key}: {value[\"activations\"]}')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "device = get_device()\n",
    "criterion = nn.CrossEntropyLoss()\n",
    "num_layers = len(num_hidden)\n",
    "\n",
    "num_clusters = 3\n",
    "num_digits = 10\n",
    "theta = np.arange(0, 2 * np.pi, 2 * np.pi / num_digits)\n",
    "\n",
    "fig, ax = plt.subplots(num_runs, num_tasks,\n",
    "                       figsize=(4 * num_tasks, 4 * num_runs),\n",
    "                       subplot_kw=dict(projection=\"polar\"))\n",
    "\n",
    "for i_seed, seed in tqdm(enumerate(seeds), total=num_runs):\n",
    "    state_dict = results_task_switching[seed]['model']\n",
    "    model = get_task_model(task_switching_tasks,\n",
    "                           num_hidden,\n",
    "                           idxs_contexts,\n",
    "                           device)\n",
    "    model.load_state_dict(state_dict)\n",
    "    \n",
    "    indices = results_task_switching[seed]['indices']\n",
    "\n",
    "    test_sampler = dataset.SequentialSampler(indices['test'])\n",
    "    _, test_dataloaders = dataset.create_dict_dataloaders(task_switching_tasks,\n",
    "                                                          indices,\n",
    "                                                          batch_size=batch_size)\n",
    "    tasks_testloader = dataset.SequentialTaskDataloader(test_dataloaders)\n",
    "\n",
    "    numbers = test_dataloaders[tasks_names[0]].dataset.numbers.numpy()\n",
    "    numbers = numbers[indices['test']]\n",
    "\n",
    "    _, activations = hooks.get_layer_activations(model,\n",
    "                                                tasks_testloader,\n",
    "                                                criterion,\n",
    "                                                device=device,\n",
    "                                                disable=True)\n",
    "\n",
    "\n",
    "    avg_activations_units = np.zeros((num_hidden[0], num_tasks * num_digits))\n",
    "    activations_cluster_all = np.zeros((num_tasks, num_clusters, num_digits))\n",
    "\n",
    "    for i_task, task_name in enumerate(tasks_names):\n",
    "        current_activations = activations[task_name]['layer10']\n",
    "        for i_unit in range(num_hidden[0]):\n",
    "            activation_unit = current_activations[:, i_unit]\n",
    "            for i_digit in range(num_digits):\n",
    "                avg_activations_units[i_unit, i_task * num_digits + i_digit] = np.mean(activation_unit[numbers == i_digit])\n",
    "    \n",
    "    avg_activations_units /= avg_activations_units.max()\n",
    "    \n",
    "    Z_activations = linkage(avg_activations_units, 'centroid')\n",
    "    cluster_labels = fcluster(Z_activations,  t=num_clusters, criterion='maxclust')\n",
    "\n",
    "    for i_cluster in range(1, num_clusters+1):\n",
    "        activations_cluster = avg_activations_units[cluster_labels==i_cluster, :].mean(axis=0)\n",
    "        \n",
    "        for i_task in range(num_tasks):\n",
    "            activations_cluster_all[i_task, i_cluster-1, :] = activations_cluster[num_digits * i_task: num_digits * (i_task + 1)]\n",
    "    \n",
    "    for i_task, task_name in enumerate(tasks_names):\n",
    "        for i_cluster in range(num_clusters):\n",
    "            data = activations_cluster_all[i_task, i_cluster]\n",
    "            ax[i_seed, i_task].bar(theta, data, width=0.5, alpha=0.5)\n",
    "        ax[i_seed, i_task].set_xticks(theta)\n",
    "        ax[i_seed, i_task].set_yticklabels([])\n",
    "        ax[i_seed, i_task].set_xticklabels(range(0, len(theta)), fontsize=12)\n",
    "        ax[i_seed, i_task].yaxis.grid(True, alpha=.5)\n",
    "        ax[i_seed, i_task].set_ylim(0, 1)\n",
    "        ax[i_seed, i_task].set_theta_zero_location(\"N\")\n",
    "        ax[i_seed, i_task].set_theta_direction(-1)\n",
    "        ax[i_seed, i_task].set_title(task_name.capitalize(), fontsize=16)\n",
    "\n",
    "fig.tight_layout()\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "seed = seeds[3]\n",
    "num_clusters = 3\n",
    "layer = 'layer10'\n",
    "\n",
    "fig, ax = plt.subplots(1, num_tasks, figsize=(4 * num_tasks, 4),\n",
    "                 subplot_kw=dict(projection=\"polar\"))\n",
    "\n",
    "state_dict = results_task_switching[seed]['model']\n",
    "model = get_task_model(task_switching_tasks,\n",
    "                        num_hidden,\n",
    "                        idxs_contexts,\n",
    "                        device)\n",
    "model.load_state_dict(state_dict)\n",
    "\n",
    "indices = results_task_switching[seed]['indices']\n",
    "\n",
    "test_sampler = dataset.SequentialSampler(indices['test'])\n",
    "_, test_dataloaders = dataset.create_dict_dataloaders(task_switching_tasks,\n",
    "                                                        indices,\n",
    "                                                        batch_size=batch_size)\n",
    "tasks_testloader = dataset.SequentialTaskDataloader(test_dataloaders)\n",
    "\n",
    "numbers = test_dataloaders[tasks_names[0]].dataset.numbers.numpy()\n",
    "numbers = numbers[indices['test']]\n",
    "\n",
    "_, activations = hooks.get_layer_activations(model,\n",
    "                                            tasks_testloader,\n",
    "                                            criterion,\n",
    "                                            device=device,\n",
    "                                            disable=True)\n",
    "\n",
    "\n",
    "avg_activations_units = np.zeros((num_hidden[0], num_tasks * num_digits))\n",
    "activations_cluster_all = np.zeros((num_tasks, num_clusters, num_digits))\n",
    "\n",
    "for i_task, task_name in enumerate(tasks_names):\n",
    "    current_activations = activations[task_name][layer]\n",
    "    for i_unit in range(num_hidden[0]):\n",
    "        activation_unit = current_activations[:, i_unit]\n",
    "        for i_digit in range(num_digits):\n",
    "            avg_activations_units[i_unit, i_task * num_digits + i_digit] = np.mean(activation_unit[numbers == i_digit])\n",
    "\n",
    "avg_activations_units /= avg_activations_units.max()\n",
    "\n",
    "Z_activations = linkage(avg_activations_units, 'centroid')\n",
    "cluster_labels = fcluster(Z_activations,  t=num_clusters, criterion='maxclust')\n",
    "\n",
    "for i_cluster in range(1, num_clusters+1):\n",
    "    activations_cluster = avg_activations_units[cluster_labels==i_cluster, :].mean(axis=0)\n",
    "    \n",
    "    for i_task in range(num_tasks):\n",
    "        activations_cluster_all[i_task, i_cluster-1, :] = activations_cluster[num_digits * i_task: num_digits * (i_task + 1)]\n",
    "\n",
    "for i_task, task_name in enumerate(tasks_names):\n",
    "    for i_cluster in range(num_clusters):\n",
    "        data = activations_cluster_all[i_task, i_cluster]\n",
    "        ax[i_task].bar(theta, data, width=0.5, alpha=0.5)\n",
    "    ax[i_task].set_xticks(theta)\n",
    "    ax[i_task].set_yticklabels([])\n",
    "    ax[i_task].set_xticklabels(range(0, len(theta)), fontsize=12)\n",
    "    ax[i_task].yaxis.grid(True, alpha=.5)\n",
    "    ax[i_task].set_ylim(0, 1)\n",
    "    ax[i_task].set_theta_zero_location(\"N\")\n",
    "    ax[i_task].set_theta_direction(-1)\n",
    "    ax[i_task].set_title(task_name.capitalize(), fontsize=16)\n",
    "\n",
    "fig.tight_layout()\n",
    "fig.savefig(f'figures/figure04/fig04_mixed_selectivity_tasks_{seed}_{num_clusters}.svg')\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.imshow(avg_activations_units)\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "activations_cluster_all"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fig, ax = plt.subplots(1, num_tasks, figsize=(20, 3))\n",
    "\n",
    "for i_cluster in range(1, num_clusters+1):    \n",
    "    for i_task in range(num_tasks):\n",
    "        ax[i_task].plot(activations_cluster_all[i_task, i_cluster-1, :])\n",
    "\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "task_switching_tasks = {}\n",
    "num_tasks = len(tasks_names)\n",
    "\n",
    "for i_context, task_name in enumerate(tasks_names):\n",
    "    task_switching_tasks[task_name] = {}\n",
    "    task_switching_tasks[task_name]['data'] = tasks_datasets[task_name]\n",
    "    task_switching_tasks[task_name]['activations'] = num_tasks * [0]\n",
    "    task_switching_tasks[task_name]['activations'][i_context] = 0\n",
    "\n",
    "for key, value in task_switching_tasks.items():\n",
    "    print(f'{key}: {value[\"activations\"]}')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# num_clusters = 3\n",
    "fig, ax = plt.subplots(1, 1, figsize=(4, 4),\n",
    "                 subplot_kw=dict(projection=\"polar\"))\n",
    "\n",
    "state_dict = results_task_switching[seed]['model']\n",
    "model = get_task_model(task_switching_tasks,\n",
    "                        num_hidden,\n",
    "                        idxs_contexts,\n",
    "                        device)\n",
    "model.load_state_dict(state_dict)\n",
    "\n",
    "indices = results_task_switching[seed]['indices']\n",
    "\n",
    "test_sampler = dataset.SequentialSampler(indices['test'])\n",
    "_, test_dataloaders = dataset.create_dict_dataloaders(task_switching_tasks,\n",
    "                                                        indices,\n",
    "                                                        batch_size=batch_size)\n",
    "tasks_testloader = dataset.SequentialTaskDataloader(test_dataloaders)\n",
    "\n",
    "numbers = test_dataloaders[tasks_names[0]].dataset.numbers.numpy()\n",
    "numbers = numbers[indices['test']]\n",
    "\n",
    "_, activations = hooks.get_layer_activations(model,\n",
    "                                            tasks_testloader,\n",
    "                                            criterion,\n",
    "                                            device=device,\n",
    "                                            disable=True)\n",
    "\n",
    "avg_activations_units = np.zeros((num_hidden[0], num_digits))\n",
    "activations_cluster_all = np.zeros((num_clusters, num_digits))\n",
    "\n",
    "current_activations = activations[tasks_names[0]][layer]  # All activations are the same!\n",
    "for i_unit in range(num_hidden[0]):\n",
    "    activation_unit = current_activations[:, i_unit]\n",
    "    for i_digit in range(num_digits):\n",
    "        avg_activations_units[i_unit, i_digit] = activation_unit[numbers == i_digit].mean()\n",
    "\n",
    "\n",
    "max_avg = avg_activations_units.max()\n",
    "avg_activations_units /= max_avg\n",
    "\n",
    "for i_cluster in range(1, num_clusters+1):\n",
    "    activations_cluster = avg_activations_units[cluster_labels==i_cluster, :].mean(axis=0)\n",
    "    \n",
    "    for i_task in range(num_tasks):\n",
    "        activations_cluster_all[i_cluster-1, :] = activations_cluster\n",
    "\n",
    "for i_cluster in range(num_clusters):\n",
    "    data = activations_cluster_all[i_cluster]\n",
    "    ax.bar(theta, data, width=0.5, alpha=0.5)\n",
    "ax.set_xticks(theta)\n",
    "ax.set_yticklabels([])\n",
    "ax.set_xticklabels(range(0, len(theta)), fontsize=12)\n",
    "ax.yaxis.grid(True, alpha=.5)\n",
    "ax.set_ylim(0, 1)\n",
    "ax.set_theta_zero_location(\"N\")\n",
    "ax.set_theta_direction(-1)\n",
    "ax.set_title('Removed', fontsize=16)\n",
    "\n",
    "plt.legend(['Cluster1', 'Cluster2', 'Cluster3'])\n",
    "fig.savefig(f'figures/figure04/fig04_mixed_selectivity_removed_{seed}_{num_clusters}.svg')\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.imshow(avg_activations_units)\n",
    "plt.show()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3.8.10",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.10"
  },
  "orig_nbformat": 4,
  "vscode": {
   "interpreter": {
    "hash": "5de0b3d16828453b801d3a971a2e845298ac67ea708b1fd16f0d1197d2abd69f"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
