{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import pickle\n",
    "\n",
    "import matplotlib as mpl\n",
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "import seaborn as sns\n",
    "import torch.nn as nn\n",
    "import torch.optim\n",
    "from tqdm.auto import tqdm\n",
    "\n",
    "import multitask.dataset as dataset\n",
    "from multitask.models.individual import get_individual_model, calculate_rdm, plot_rdm\n",
    "import multitask.models.individual.utils as utils\n",
    "import multitask.models.individual.hooks as hooks\n",
    "\n",
    "from train.utils.training import get_device\n",
    "from train.utils.argparse import check_runs\n",
    "\n",
    "sns.set_theme(style='ticks', palette='pastel')\n",
    "mpl.rcParams['font.family'] = 'Liberation Sans'\n",
    "\n",
    "model_path = os.path.join('out', 'individual')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "num_runs = 10\n",
    "initial_seed = 6789\n",
    "max_seed = 10e5\n",
    "num_epochs = 50\n",
    "num_hidden = 5 * [100]\n",
    "batch_size = 100\n",
    "num_train = 50000\n",
    "num_test = 10000\n",
    "tasks_names = ['parity', 'value']\n",
    "# tasks_names = ['parity', 'value', 'prime', 'fibonacci', 'multiples_3']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "parameters = {\n",
    "    'num_runs': num_runs,\n",
    "    'initial_seed': initial_seed,\n",
    "    'max_seed': max_seed,\n",
    "    'num_epochs': num_epochs,\n",
    "    'num_hidden': num_hidden,\n",
    "    'batch_size': batch_size,\n",
    "    'num_train': num_train,\n",
    "    'num_test': num_test,\n",
    "    'tasks': tasks_names,\n",
    "    'idxs_contexts': None\n",
    "}\n",
    "\n",
    "data_folder = check_runs(model_path, parameters)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "pickle_data = os.path.join(data_folder, 'data.pickle')\n",
    "with open(pickle_data, 'rb') as handle:\n",
    "    results_individual = pickle.load(handle)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "seeds = sorted(list(results_individual.keys()))\n",
    "num_seeds = len(seeds)\n",
    "num_tasks = len(tasks_names)\n",
    "\n",
    "print(seeds)\n",
    "print(tasks_names)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "tasks_datasets = dataset.get_tasks_dict(tasks_names, root='data')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Plot All RDMs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "tasks_datasets"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "device = get_device()\n",
    "criterion = nn.CrossEntropyLoss()\n",
    "num_layers = len(num_hidden)\n",
    "\n",
    "fig, ax = plt.subplots(num_seeds, num_layers, figsize=(2 * num_layers, 2 * num_runs))\n",
    "list_rdm = []\n",
    "\n",
    "for i_seed, seed in tqdm(enumerate(seeds), total=num_seeds):\n",
    "    task_activations = []\n",
    "    task_numbers = []\n",
    "    indices = results_individual[seed]['indices']\n",
    "\n",
    "    for i_task, task_name in enumerate(tasks_names):\n",
    "        state_dict = results_individual[seed][task_name]['model']\n",
    "\n",
    "        model = get_individual_model(num_hidden,\n",
    "                                     device)\n",
    "        model.load_state_dict(state_dict)\n",
    "        model.to(device)\n",
    "\n",
    "        task_dataset = tasks_datasets[task_name]\n",
    "        test_sampler = dataset.SequentialSampler(indices['test'])\n",
    "        testloader = torch.utils.data.DataLoader(task_dataset,\n",
    "                                                 sampler=test_sampler,\n",
    "                                                 batch_size=batch_size)\n",
    "\n",
    "        numbers = testloader.dataset.numbers.numpy()\n",
    "        numbers = numbers[indices['test']]\n",
    "\n",
    "        _, activations = hooks.get_layer_activations(model,\n",
    "                                                     testloader,\n",
    "                                                     criterion,\n",
    "                                                     device=device,\n",
    "                                                     disable=True)\n",
    "        \n",
    "        task_activations.append(activations)\n",
    "        task_numbers.append(numbers)\n",
    "\n",
    "    rdm_dict = calculate_rdm(task_activations,\n",
    "                             tasks_names,\n",
    "                             num_hidden,\n",
    "                             task_numbers)\n",
    "\n",
    "    plot_rdm(ax[i_seed], rdm_dict, num_hidden, cmap='coolwarm', vmin=0, vmax=1)\n",
    "\n",
    "    list_rdm.append(rdm_dict)\n",
    "\n",
    "fig.tight_layout()\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Average RDM"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "mean_rdm = {}\n",
    "\n",
    "for layer in range(num_layers):\n",
    "    mean_rdm[layer+1] = np.zeros((num_tasks * 10, num_tasks * 10))\n",
    "    for rdm in list_rdm:\n",
    "        mean_rdm[layer+1] += rdm[layer+1]\n",
    "    mean_rdm[layer+1] /= num_runs\n",
    "\n",
    "fig, ax = plt.subplots(1, num_layers, figsize=(2 * num_layers, 2))\n",
    "plot_rdm(ax, mean_rdm, num_hidden, cmap='coolwarm', vmin=0, vmax=1)\n",
    "fig.tight_layout()\n",
    "fig.savefig(f'figures/figure02/fig02_rdm_individual_mean_{num_tasks}.svg')\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Plot Single RDM"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "idx_seed = 0\n",
    "seed = seeds[idx_seed]\n",
    "\n",
    "state_dict = results_individual[seed][task_name]['model']\n",
    "indices = results_individual[seed]['indices']\n",
    "\n",
    "model = get_individual_model(num_hidden, device)\n",
    "model.load_state_dict(state_dict)\n",
    "\n",
    "task_dataset = tasks_datasets[task_name]\n",
    "test_sampler = dataset.SequentialSampler(indices['test'])\n",
    "testloader = torch.utils.data.DataLoader(task_dataset,\n",
    "                                            sampler=test_sampler,\n",
    "                                            batch_size=batch_size)\n",
    "\n",
    "numbers = testloader.dataset.numbers.numpy()\n",
    "numbers = numbers[indices['test']]\n",
    "\n",
    "_, activations = hooks.get_layer_activations(model,\n",
    "                                             testloader,\n",
    "                                             criterion,\n",
    "                                             device=device,\n",
    "                                             disable=True)\n",
    "\n",
    "task_activations.append(activations)\n",
    "task_numbers.append(numbers)\n",
    "\n",
    "rdm_dict = calculate_rdm(task_activations,\n",
    "                        tasks_names,\n",
    "                        num_hidden,\n",
    "                        task_numbers)\n",
    "\n",
    "fig, ax = plt.subplots(1, num_layers, figsize=(2 * num_layers, 2))\n",
    "plot_rdm(ax, rdm_dict, num_hidden, cmap='coolwarm', vmin=0, vmax=1)\n",
    "\n",
    "fig.tight_layout()\n",
    "# fig.savefig('figures/figure02/fig02_rdm_individual_single.svg')\n",
    "plt.show()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3.8.10",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.10"
  },
  "orig_nbformat": 4,
  "vscode": {
   "interpreter": {
    "hash": "5de0b3d16828453b801d3a971a2e845298ac67ea708b1fd16f0d1197d2abd69f"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
