{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import pickle\n",
    "\n",
    "import matplotlib as mpl\n",
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "import seaborn as sns\n",
    "import torch.nn as nn\n",
    "import torch.optim\n",
    "from tqdm.auto import tqdm\n",
    "from sklearn.model_selection import train_test_split\n",
    "from sklearn.linear_model import LogisticRegression\n",
    "\n",
    "import multitask.dataset as dataset\n",
    "from multitask.models.task_switching import get_task_model, calculate_rdm, plot_rdm\n",
    "import multitask.models.task_switching.utils as utils\n",
    "import multitask.models.task_switching.hooks as hooks\n",
    "\n",
    "from train.utils.training import get_device\n",
    "from train.utils.argparse import check_runs\n",
    "\n",
    "sns.set_theme(style='ticks', palette='pastel')\n",
    "mpl.rcParams['font.family'] = 'Liberation Sans'\n",
    "mpl.rcParams['axes.spines.right'] = False\n",
    "mpl.rcParams['axes.spines.top'] = False\n",
    "\n",
    "model_path = os.path.join('out', 'task_switching')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "num_runs = 10\n",
    "initial_seed = 6789\n",
    "max_seed = 10e5\n",
    "num_epochs = 50\n",
    "num_hidden = 10 * [100]\n",
    "batch_size = 100\n",
    "num_train = 50000\n",
    "num_test = 10000\n",
    "tasks_names = ['parity', 'value']\n",
    "idxs_contexts = list(range(len(num_hidden)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "parameters = {\n",
    "    'num_runs': num_runs,\n",
    "    'initial_seed': initial_seed,\n",
    "    'max_seed': max_seed,\n",
    "    'num_epochs': num_epochs,\n",
    "    'num_hidden': num_hidden,\n",
    "    'batch_size': batch_size,\n",
    "    'num_train': num_train,\n",
    "    'num_test': num_test,\n",
    "    'tasks': tasks_names,\n",
    "    'idxs_contexts': idxs_contexts\n",
    "}\n",
    "\n",
    "data_folder = check_runs(model_path, parameters)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "pickle_data = os.path.join(data_folder, 'data.pickle')\n",
    "with open(pickle_data, 'rb') as handle:\n",
    "    results_task_switching = pickle.load(handle)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "seeds = sorted(list(results_task_switching.keys()))\n",
    "num_seeds = len(seeds)\n",
    "num_tasks = len(tasks_names)\n",
    "\n",
    "print(seeds)\n",
    "print(tasks_names)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "tasks_datasets = dataset.get_tasks_dict(tasks_names, root='data')\n",
    "\n",
    "task_switching_tasks = {}\n",
    "num_tasks = len(tasks_names)\n",
    "\n",
    "for i_context, task_name in enumerate(tasks_names):\n",
    "    task_switching_tasks[task_name] = {}\n",
    "    task_switching_tasks[task_name]['data'] = tasks_datasets[task_name]\n",
    "    task_switching_tasks[task_name]['activations'] = num_tasks * [0]\n",
    "    task_switching_tasks[task_name]['activations'][i_context] = 1  # Set to 0 for Removed\n",
    "\n",
    "for key, value in task_switching_tasks.items():\n",
    "    print(f'{key}: {value[\"activations\"]}')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "device = get_device()\n",
    "criterion = nn.CrossEntropyLoss()\n",
    "\n",
    "seeds_task_swithing  = sorted(list(results_task_switching.keys()))\n",
    "list_activations = []\n",
    "list_numbers = []\n",
    "\n",
    "for i_seed, seed in tqdm(enumerate(seeds_task_swithing), total=num_runs):\n",
    "    state_dict = results_task_switching[seed]['model']\n",
    "    model = get_task_model(task_switching_tasks,\n",
    "                           num_hidden,\n",
    "                           idxs_contexts,\n",
    "                           device)\n",
    "    model.load_state_dict(state_dict)\n",
    "    \n",
    "    indices = results_task_switching[seed]['indices']\n",
    "\n",
    "    test_sampler = dataset.SequentialSampler(indices['test'])\n",
    "    _, test_dataloaders = dataset.create_dict_dataloaders(task_switching_tasks,\n",
    "                                                          indices,\n",
    "                                                          batch_size=batch_size)\n",
    "    tasks_testloader = dataset.SequentialTaskDataloader(test_dataloaders)\n",
    "\n",
    "    numbers = test_dataloaders[tasks_names[0]].dataset.numbers.numpy()\n",
    "    numbers = numbers[indices['test']]\n",
    "\n",
    "    _, activations = hooks.get_layer_activations(model,\n",
    "                                                tasks_testloader,\n",
    "                                                criterion,\n",
    "                                                device=device,\n",
    "                                                disable=True)\n",
    "    \n",
    "    list_activations.append(activations)\n",
    "    list_numbers.append(numbers)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "num_layers = len(num_hidden)\n",
    "max_iter = 8000\n",
    "\n",
    "acc_numbers_all = np.zeros((num_seeds, num_layers))\n",
    "acc_tasks_all = np.zeros((num_seeds, num_layers))\n",
    "acc_congruency_all = np.zeros((num_seeds, num_layers))\n",
    "\n",
    "for i_seed, seed in enumerate(seeds):\n",
    "    activations = list_activations[i_seed]\n",
    "    numbers = list_numbers[i_seed]\n",
    "\n",
    "    labels_numbers = np.hstack((numbers, numbers))\n",
    "    labels_task = np.concatenate((np.zeros_like(numbers), np.ones_like(numbers)))\n",
    "    labels_congruency = np.array([1 if number in [0, 2, 4, 5, 7, 9] else 0 for number in labels_numbers])\n",
    "\n",
    "    for j_layer in tqdm(range(num_layers), desc=f'{i_seed}'):\n",
    "        activations_decoder = None\n",
    "        for task in tasks_names:\n",
    "            activations_task = activations[task][f'layer{j_layer+1}']\n",
    "            if activations_decoder is None:\n",
    "                activations_decoder = activations_task\n",
    "            else:\n",
    "                activations_decoder = np.vstack((activations_decoder, \n",
    "                                                activations_task))\n",
    "        assert activations_decoder.shape[0] == labels_numbers.shape[0]\n",
    "\n",
    "        activations_decoder = (activations_decoder - activations_decoder.mean()) / activations_decoder.std()\n",
    "\n",
    "        # Numbers task\n",
    "        seed = np.random.randint(0, 1e8, 1)[0]\n",
    "        X_train, X_test, y_train, y_test = train_test_split(activations_decoder,\n",
    "                                                            labels_numbers,\n",
    "                                                            test_size=0.1,\n",
    "                                                            random_state=seed)\n",
    "        clf = LogisticRegression(random_state=seed,\n",
    "                                max_iter=max_iter,\n",
    "                                tol=1e-3).fit(X_train, y_train)\n",
    "        acc_numbers_all[i_seed, j_layer] = clf.score(X_test, y_test)\n",
    "\n",
    "        # Labels task\n",
    "        seed = np.random.randint(0, 1e8, 1)[0]\n",
    "        X_train, X_test, y_train, y_test = train_test_split(activations_decoder,\n",
    "                                                            labels_task,\n",
    "                                                            test_size=0.1,\n",
    "                                                            random_state=seed)\n",
    "        clf = LogisticRegression(random_state=seed,\n",
    "                                max_iter=max_iter,\n",
    "                                tol=1e-3).fit(X_train, y_train)\n",
    "        acc_tasks_all[i_seed, j_layer] = clf.score(X_test, y_test)\n",
    "\n",
    "        # Congruency task\n",
    "        seed = np.random.randint(0, 1e8, 1)[0]\n",
    "        X_train, X_test, y_train, y_test = train_test_split(activations_decoder,\n",
    "                                                            labels_congruency,\n",
    "                                                            test_size=0.1,\n",
    "                                                            random_state=seed)\n",
    "        clf = LogisticRegression(random_state=seed,\n",
    "                                max_iter=max_iter,\n",
    "                                tol=1e-3).fit(X_train, y_train)\n",
    "        acc_congruency_all[i_seed, j_layer] = clf.score(X_test, y_test)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "mean_numbers_all = acc_numbers_all.mean(axis=0)\n",
    "std_numbers_all = acc_numbers_all.std(axis=0)\n",
    "\n",
    "mean_tasks_all = acc_tasks_all.mean(axis=0)\n",
    "std_tasks_all = acc_tasks_all.std(axis=0)\n",
    "\n",
    "mean_congruecy_all = acc_congruency_all.mean(axis=0)\n",
    "std_congruency_all = acc_congruency_all.std(axis=0)\n",
    "\n",
    "layers = range(1, num_layers + 1)\n",
    "\n",
    "fig = plt.figure(figsize=(6, 4))\n",
    "plt.plot(layers,\n",
    "         mean_numbers_all)\n",
    "plt.fill_between(layers,\n",
    "                 mean_numbers_all-std_numbers_all,\n",
    "                 mean_numbers_all+std_numbers_all, alpha=0.5)\n",
    "plt.plot(layers,\n",
    "         mean_tasks_all)\n",
    "plt.fill_between(layers,\n",
    "                 mean_tasks_all-std_tasks_all,\n",
    "                 mean_tasks_all+std_tasks_all, alpha=0.5)\n",
    "plt.plot(layers,\n",
    "         mean_congruecy_all)\n",
    "plt.fill_between(layers,\n",
    "                 mean_congruecy_all-std_congruency_all,\n",
    "                 mean_congruecy_all+std_congruency_all, alpha=0.5)\n",
    "\n",
    "plt.ylim(0, 1)\n",
    "\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "results = {}\n",
    "results['numbers'] = acc_numbers_all\n",
    "results['tasks'] = acc_tasks_all\n",
    "results['congruency'] = acc_congruency_all"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "with open('pickle/results_linear_decoder_all.pickle', 'wb') as f:\n",
    "        pickle.dump(results, f, protocol=pickle.HIGHEST_PROTOCOL)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "with open(f'pickle/results_linear_decoder_first.pickle', 'rb') as handle:\n",
    "    results_first = pickle.load(handle)\n",
    "\n",
    "with open(f'pickle/results_linear_decoder_all.pickle', 'rb') as handle:\n",
    "    results_all = pickle.load(handle)\n",
    "\n",
    "with open(f'pickle/results_linear_decoder_removed.pickle', 'rb') as handle:\n",
    "    results_removed = pickle.load(handle)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "mean_results_first_numbers = results_first['numbers'].mean(axis=0)\n",
    "std_results_first_numbers = results_first['numbers'].std(axis=0)\n",
    "\n",
    "mean_results_first_tasks = results_first['tasks'].mean(axis=0)\n",
    "std_results_first_tasks = results_first['tasks'].std(axis=0)\n",
    "\n",
    "mean_results_first_congruency = results_first['congruency'].mean(axis=0)\n",
    "std_results_first_congruency = results_first['congruency'].std(axis=0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "mean_results_all_numbers = results_all['numbers'].mean(axis=0)\n",
    "std_results_all_numbers = results_all['numbers'].std(axis=0)\n",
    "\n",
    "mean_results_all_tasks = results_all['tasks'].mean(axis=0)\n",
    "std_results_all_tasks = results_all['tasks'].std(axis=0)\n",
    "\n",
    "mean_results_all_congruency = results_all['congruency'].mean(axis=0)\n",
    "std_results_all_congruency = results_all['congruency'].std(axis=0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "mean_results_removed_numbers = results_removed['numbers'].mean(axis=0)\n",
    "std_results_removed_numbers = results_removed['numbers'].std(axis=0)\n",
    "\n",
    "mean_results_removed_tasks = results_removed['tasks'].mean(axis=0)\n",
    "std_results_removed_tasks = results_removed['tasks'].std(axis=0)\n",
    "\n",
    "mean_results_removed_congruency = results_removed['congruency'].mean(axis=0)\n",
    "std_results_removed_congruency = results_removed['congruency'].std(axis=0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fig, ax = plt.subplots(1, 3, figsize=(14, 3))\n",
    "\n",
    "ax[0].plot(layers,\n",
    "           mean_results_removed_numbers)\n",
    "ax[0].fill_between(layers,\n",
    "                 mean_results_removed_numbers-std_results_removed_numbers,\n",
    "                 mean_results_removed_numbers+std_results_removed_numbers,\n",
    "                 alpha=0.5)\n",
    "\n",
    "ax[0].plot(layers,\n",
    "         mean_results_first_numbers)\n",
    "ax[0].fill_between(layers,\n",
    "                 mean_results_first_numbers-std_results_first_numbers,\n",
    "                 mean_results_first_numbers+std_results_first_numbers,\n",
    "                 alpha=0.5)\n",
    "\n",
    "ax[0].plot(layers,\n",
    "         mean_results_all_numbers)\n",
    "ax[0].fill_between(layers,\n",
    "                 mean_results_all_numbers-std_results_all_numbers,\n",
    "                 mean_results_all_numbers+std_results_all_numbers,\n",
    "                 alpha=0.5)\n",
    "ax[0].plot(layers, num_layers * [0.1], 'k--')\n",
    "\n",
    "\n",
    "ax[0].set_ylim(-0.05, 1.05)\n",
    "ax[0].set_title('Number', fontsize=18)\n",
    "ax[0].set_ylabel('Accuracy', fontsize=14)\n",
    "ax[0].set_xlabel('Layer', fontsize=14)\n",
    "ax[0].legend(['Context Removed', 'Context First', 'Context All', 'Random'], prop={'size':12})\n",
    "ax[0].set_xticklabels([0, 2, 4, 6, 8, 10], fontsize=12)\n",
    "ax[0].set_yticklabels([0, 0.0, 0.2, 0.4, 0.6, 0.8, 1.0])\n",
    "\n",
    "\n",
    "\n",
    "ax[1].plot(layers,\n",
    "           mean_results_removed_tasks)\n",
    "ax[1].fill_between(layers,\n",
    "                   mean_results_removed_tasks-std_results_removed_numbers,\n",
    "                   mean_results_removed_tasks+std_results_removed_numbers,\n",
    "                   alpha=0.5)\n",
    "\n",
    "ax[1].plot(layers,\n",
    "           mean_results_first_tasks)\n",
    "ax[1].fill_between(layers,\n",
    "                 mean_results_first_tasks-std_results_first_tasks,\n",
    "                 mean_results_first_tasks+std_results_first_tasks,\n",
    "                 alpha=0.5)\n",
    "ax[1].plot(layers,\n",
    "         mean_results_all_tasks)\n",
    "ax[1].fill_between(layers,\n",
    "                 mean_results_all_tasks-std_results_all_tasks,\n",
    "                 mean_results_all_tasks+std_results_all_tasks,\n",
    "                 alpha=0.5)\n",
    "\n",
    "ax[1].plot(layers, num_layers * [0.5], 'k--')\n",
    "\n",
    "ax[1].set_ylim(-0.05, 1.05)\n",
    "ax[1].set_title('Task', fontsize=18)\n",
    "ax[1].set_ylabel('Accuracy', fontsize=14)\n",
    "ax[1].set_xlabel('Layer', fontsize=14)\n",
    "ax[1].set_xticklabels([0, 2, 4, 6, 8, 10], fontsize=12)\n",
    "ax[1].set_yticklabels([0, 0.0, 0.2, 0.4, 0.6, 0.8, 1.0], fontsize=12)\n",
    "\n",
    "\n",
    "ax[2].plot(layers,\n",
    "           mean_results_removed_congruency)\n",
    "ax[2].fill_between(layers,\n",
    "                 mean_results_removed_congruency-std_results_removed_congruency,\n",
    "                 mean_results_removed_congruency+std_results_removed_congruency,\n",
    "                 alpha=0.5)\n",
    "\n",
    "ax[2].plot(layers,\n",
    "           mean_results_first_congruency)\n",
    "ax[2].fill_between(layers,\n",
    "                 mean_results_first_congruency-std_results_first_congruency,\n",
    "                 mean_results_first_congruency+std_results_first_congruency,\n",
    "                 alpha=0.5)\n",
    "\n",
    "ax[2].plot(layers,\n",
    "           mean_results_all_congruency)\n",
    "ax[2].fill_between(layers,\n",
    "                 mean_results_all_congruency-std_results_all_congruency,\n",
    "                 mean_results_all_congruency+std_results_all_congruency,\n",
    "                 alpha=0.5)\n",
    "\n",
    "ax[2].plot(layers, num_layers * [0.5], 'k--')\n",
    "\n",
    "\n",
    "ax[2].set_ylim(-0.05, 1.05)\n",
    "ax[2].set_ylabel('Accuracy', fontsize=14)\n",
    "ax[2].set_title('Congruency', fontsize=18)\n",
    "ax[2].set_xlabel('Layer', fontsize=14)\n",
    "ax[2].set_xticklabels([0, 2, 4, 6, 8, 10], fontsize=12)\n",
    "ax[2].set_yticklabels([0, 0.0, 0.2, 0.4, 0.6, 0.8, 1.0], fontsize=12)\n",
    "\n",
    "\n",
    "fig.savefig('figures/figure03/fig03d_decoder_layers.svg')\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3.8.10",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.10"
  },
  "orig_nbformat": 4,
  "vscode": {
   "interpreter": {
    "hash": "5de0b3d16828453b801d3a971a2e845298ac67ea708b1fd16f0d1197d2abd69f"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
