{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import seaborn as sns\n",
    "import matplotlib.pyplot as plt\n",
    "import scipy.stats as stats\n",
    "plt.rcParams['font.sans-serif'] = \"Arial\"\n",
    "sns.set_style(\"ticks\")\n",
    "import lstnn.transformer_main\n",
    "from sklearn.metrics import pairwise_distances\n",
    "%matplotlib inline\n",
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 74,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAKUAAACdCAYAAADVArgaAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjguMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8g+/7EAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAbUElEQVR4nO2de1xU17XHf/MCH2hQU+0nNfHRD8QYveZT62CxCmgwikXE4DgQBxWN+BpfDRFvJGrQqwWTKD4RvFqNpmJAaRpp1EBaq1FjFTGRPIpSwJgqIpeHwAwz6/7Bhykwe4ZzxpnhYPeXz/6DPft1hsU+e+211t4yIiJwOBJC3tED4HDawoWSIzm4UHIkBxdKjuTgQsmRHFwoOZKDCyVHcnCh5EgOLpQcyeEWoSwrK8OwYcMQFhaGadOmYerUqYiKisJ3331nt9727dvx2Wefie7vs88+w/bt2x0dLpMdO3Zgx44dAICwsDC7ZdesWYM7d+44tX9HuXHjBt566y3B5YkISUlJmDJlCqZMmYLk5GQArf+GYWFhmDx5MvR6PUpKSpw/aHIDpaWlFBQU1Crv6NGjpNFo3NG9U0hJSaGUlBRBZYOCgqi0tNTFI3INp06dopiYGDIajWQwGGjWrFl0+vRp5t/wyJEjFBwcTAaDwalj6LDXt1qttsyUly9fRmRkJKZPn47g4GCcOXMGABAfH4+srCyUlZXhlVdeQVRUFLRaLUaPHo2qqioAQHR0NLZt2wYA+Mtf/oIVK1YgKysL8fHxAICdO3ciNDQUISEhePvttwEAZrMZycnJCA8PR2hoKHbt2sUcY3p6OoKDgzFz5kwUFBRY8p9//nkAQH5+PiIiIhAaGgqtVovbt29jz549uHfvHhYsWIAHDx4gJycHM2fORFhYGEJCQnDlyhUAgE6nw9atW6HVajF+/HhkZWUBAKqrq7FixQpMnjwZU6dOxblz5wAAFy5cgEajQXh4OBYuXIh79+7ZfL6WXLp0CTqdzm6fLXnuueewatUqKJVKqFQq+Pj44IcffmB+P1FRUVAoFJYxOg2nirgN2v6XmUwm2rp1K8XExBARkV6vp3/84x9ERHTp0iX6zW9+Q0REq1evpszMTCotLSVfX18qLi4mIqKVK1fSmTNnqK6ujgICAkir1RIRUWJiIp08eZIyMzNp9erVVFVVRaNHj6bGxkZqbGyk+Ph4+vHHHykjI4M2bdpEREQGg4Hmz59PeXl5rcZcUFBAEydOpOrqanr06BFNnTrVMlP6+voSEdGSJUvo7NmzRESUmZlJJ06cIKJ/z5Qmk4l0Oh09ePCAiIiysrIoNjaWiIhmzZpF77zzDhERff3116RWq4mIaOPGjbRlyxYiIiouLqawsDCqqKigsLAwqqysJCKi7Oxs0uv1Np+vJRcvXqRZs2bZ7dMWRUVF5OfnR0VFRcyZkoho2bJltG/fPrvtiEXpXBG3zb179yxrMYPBAB8fH7zzzjsAgKSkJOTm5iInJwfXr19HbW2tVf3evXtjwIABAICAgAB88cUX6Nq1KyZMmIC8vDzU1dXhiy++wJIlS5CXlwcA6NGjB3x9fTFjxgwEBgYiJiYG/fr1w7lz53Dz5k1cunQJAFBXV4fvvvsOgYGBlv4uX76MwMBAeHl5AQAmTpwIs9ncakwTJkxAQkIC8vLyEBgYiKCgoFafy+Vy7Ny5E7m5uSguLsbly5ehUCgsnwcEBAAAXnjhBVRWVgIALl68aFnHDRgwACdPnkReXh7u3LmD6OhoAE0zvUKhsPl89mD1yeLmzZtYvHgx4uPjMXjwYJSVlTHLyWQydO3a1W6fYnGbUPbt2xfZ2dnMzyIjI+Hv7w+1Wg1/f3+88cYbVmVaPvi4ceOQlpaGbt26wc/PD5WVlfj444/h7e2NXr16tap38OBBXL16FefOncO8efOwdetWmEwmxMXF4ZVXXgEAPHz4EF26dGlVTyaTgVp49SmVShgMhlZlwsPD4efnh88//xwHDx7E559/jo0bN1o+r62tRUREBMLDw6FWqzFkyBAcOXLE8rmnp6elr2YUCkWr34uKimAymTBy5Ejs3bsXQNM/dfPyhfV8arWa+T3b6rMtFy5cwG9/+1usX7/e8h3ZorCwEFqt1m4ZsXT4llBlZSVKSkqg1+sxbtw45OXlwWQy2a3Tq1cveHl54fTp0xg1ahRGjx6NHTt2tJrpgCaNcerUqRg6dChWrFiBMWPG4Ntvv8Xo0aORkZEBo9GIuro6zJkzB+fPn29V91e/+hVyc3NRVVWFhoYGyzq3Ja+//jpu3bqFqKgoLF++HDdv3gTQJFgmkwnFxcWQy+WIjY2Fn58fzpw50+6zqdVqfPLJJwCAkpISzJ07F8OGDcO1a9dQVFQEANi/fz/WrVtn8/keh6KiIqxcuRI7d+5sVyAPHToET09P+Pn5PVafbXHbTGkLb29vREREIDg4GN27d4efnx/q6+uZr/CWBAYGIicnB7169cLo0aOxdu1aq9dn//79ERISgvDwcHTr1g3PPPMMwsPD4enpiX/+85+YNm0aGhsbMWXKFLz88sut6r7wwguYO3cuIiIi8NRTT+GZZ56xGsPSpUuxYcMGvPvuu1AqlRblasKECViwYAH27NmDoUOHYvz48ejSpQt+/etf48qVK61m4Lbo9Xq8/fbbmDp1KhQKBZKTk/HTn/4UmzdvxqpVq2A2m9GnTx8kJSWhb9++zOd7HHbv3g2TyWRZWgGARqNBQEBAqyWY2WzGoEGDkJaWZnfWdQQZ2fuGOJwOoMNf3xxOW7hQciQHF0qO5OBCyZEcXCg5koMLJUdycKHkSI4O3zy3h0ymYub79GZvEA+XPc/M76myfszzhq+ZZUN7DGPm/+ppo1Vevy4NzLLbvvFk5kcOZGZj4vB/WuV1/y92G/OT+jPzY35ubWzwf/Uhu8PlrzGzFfJAqzy53HocZjP7uZ0Fnyk5dpHJVFZJDDt27GC6yNmDCyXHLjKZp1USSkVFBTIyMkT3KenXN6fjkcscF5H09HRMmzZNfJ8O9ygAvV5v8cNrbGzExIkTXdkdxwXI5R5WyWAwoKamplVq69Z39+5d1NbWYtCgQeL7dNbgW2IwGLBkyRJcv37dkvfRRx/h/v37ruiO40IUcg+rlJqaipEjR7ZKqampreqlpqbi9ddfd6hPl3gJ1dTUoLCwEJmZmVi6dCmefvpprFmzBuXl5Th8+LDgdnz7aJj5YrRsAKgyNlrlDfRiL9hZWjbA1rT/VS98fQWwtWyArWnXFrA13Ovf/5SZz9S0bWjZ2H6Ema1YmWaV17P7UKu88of5VjOjh4cHPDw8LL+/+uqr6Natm2UiSk1NtUQOtIdL1pReXl4YNWoUMjMzAQAffPABIiMjLSGqnM6DgrGmbCuALJr/9s2at1CBBNyg6NTU1ODmzZuYP3++3XIGg8Hqv4/T8SgY+5RimD59uug6Lt8SKigowO3bt6HT6VBYWGgzMJ61TuF0PAq5p1VyNS6fKf39/aHRNK0NdTodNm3axCwXGxuLuXPntsr7xYAYVw+P0w5KEfuSTuvTlY1v2bKl1e/2lBwh6xSO+1GItOA4A0lvnjtDywaAG2Qd4Rf5NLttW/Zslqb9RTn7D7Yu6B/MfFv2bJamffoGWzEIf+1HZj5T07ahZV/I7MXMH7vSOk+JJ2ym5HR+nrjXN6fzoyIulByJoYD715RusX1nZ2dDo9FAq9WisLDQlV1ynIwHeVglV+OSmdJgMGDlypW4ceMGAODo0aP48MMPcffuXSQnJws+0NQZCg0AfF9xwiqv3wuMVT1smw5ZSs3H1V8xyyb911PMfFumQ5ZS82ExsyjCRZgObSk0/1vUnZk/lpGnelJmSoPBgDlz5sDf3x8AsGfPHigUCjQ2NkKlcv9DchxHRSqr5GpcIpTNtu9mevfuDYPBgA0bNmDOnDnMOix3KE7Ho2D8uBq3KDpGoxHLli1DaGgohg1jx8CkpqZi586drfL8EeWO4XHsoJK5Xgjb4hah3LJlC8aNG2fXOM8yM+p93nX10DjtoOqAiBmXC+WDBw9w/PhxjBgxAjk5OfD19UVCQoJVOW5mlCYqufuFUtJHAdpy8hUTBguwTYdSCYMFxDnoavqyr3BhjVuMUzHAdvJ99WnrCSSzPJFZvy01NTVYvnw5qqurodPpEBoaKqgej2bk2EUll1slofzpT3/ClClTcOzYMVFRjdyiw7GL6jFO6dVqtTCZTKitrbV7enFbuFBy7KKUWwslK0rAlk5QXl4OrVbb7vnpLXGLmfHatWuYMWMGdDodSktLXdklx8ko5dZJSDRjM/369UNubi7Kyspw69YtYX068wGaaWtm3Lt3L9LS0lBeXo60tLRWh7xzpA1rpmRt37FmycOHD2Pw4MEYM2YMPD09BR/Y7zKhnDNnjiWirb6+Ht7e3vD29kZxcbHgdpyhZQNse7YYLRsQFwYrSssGRDnoRg5kh9iK2R2wNe6ejDwVQ46Ebt9NmjQJcXFx2LVrF/z8/AQfTOCWENuWN3VJeAeKw0D1GAu8n/zkJzh48KDoem5RdFpO23IbWwo8xFaaKDpg09AtQtmlSxdUVFSgoqICP/vZz5hlWLbv38CGmxbHbSidfHGToD7d0cnixYuxcOFCyGQyJCUlMcuwFs/rhr3njuFx7PA4r29HESWUf/vb3/Dee++hqqoKRAQigkwmw2efsU1fzSG2/fv3b3dHn9u+pYlK7n4dQJRQJiYmIi4uDkOGDHH6fXwsnKFlA2yvccmEwQKivMadsTtga9wRjDyF+9/e4oTS29vb6mJNzpON5F/fY8aMwcGDBxEQEGC5NxoA84ZXzpMBa5/S1YgSypMnTwJouue5GXtrSk7nRyn1NWVubq5DndTX12PJkiWorq6GRqNBRARr9cKRIkqpz5Tl5eVITExEaWkp0tPTsW7dOqxfvx59+vSxW+/cuXP45S9/idjYWERFRQkWSmcoNAA7FFYqYbAAW6mxFQab/urjK2K2xs36qyhl7p8pRS1jExIS8PLLL8NkMqFnz54YOnQo4uLi2q03ePBgGI1GHmLbCVHJySq5GlFC+a9//QuhoaGQyWRQKpVYtGgRysvL262nUqnwySefYNKkSRgzZgyzDA+xlSYqGVklVyNKKFUqFaqrqy17lLdu3bJpy27JBx98gFWrVuHs2bMoKCjAnTt3rMrwk3ylCcuf0uV9iim8fPlyREdH486dO1i0aBHy8/OxefPmdut17doV3bt3h1wuh5eXFx49emRVhmVmvDaJ+112NEqZuf1Czu5TTGF/f38cOHAA+fn5MJlMSExMRK9ebMtDS3Q6HeLi4rB7924MGzYMPj4+VmW4mVGaPM4a8u7du3jzzTdhNBoxfvx4LFiwQFA9USG2Bw4caDWbFRQUYO3atfjjH/8ofsQCYIV3As5x0JVKGCwgzkH3hY3/x8xnOUSLNdOO/Zu1s8yno//bKu+Vi//DrN+W3/3udwgKCoJarcbs2bOxa9cueHl5tVtP9D4lESEyMhLvvvsuTp8+jfj4eDFNcDoZj7N5Hhsbix49egBocvRWKoWJm6hla3p6Oi5evGjZFjp16hRCQkLEj5bTaVDKzFZJyN2MQJOvhEKhwLFjxzB06FB06dJFWJ9CCjWbFwEgODgYN2/ehEqlwtmzZwHAoZtKOZ0Dldxa0WE5ZC9duhR6vd6qbHZ2Ns6ePYtdu3YJ7lOQUF66dKnV72PHjkV1dbUlvz2hrK+vx5o1a/Djjz9iyJAhWLduneABcjoWllAKjWa8fv06Tp48iT179ohSYgUJZcttH6PRiNu3b6OxsRE+Pj6CLDQZGRkICAjAtGnTcPjwYdTX1wueyjkdi4KxphS6U7Jnzx6Ul5dbbrPdunUr+vXr1249UYrOV199hWXLlsHb2xtmsxn3799HSkpKuxvdV65cga+vL3Q6HcLDwwULpDO0bIBtF5ZKGCwgzkFXTNixWN8BFqyZUih79+51qJ4oody4cSPef/99jBgxAgCQn5+PxMRESyitLaqqqtC3b18cOHAA8+bNQ2BgIHr37u3QgDnuRakwub9PMYUfPXpkEUgAeOmllwSFxfbs2RNqtRpKpRLDhw9HWVmZlVDyEFtp8jgzpaMI2hI6evQoAKBXr144ffq0Jf/TTz+Ft7d3u/VffPFFfPnllwCAb775hnlSArd9SxOF3GyVXI0goTx+/DgAYMOGDUhPT4darYZarUZaWho2bNjQbv3IyEicPXsWERERGDVqlGVDtSWxsbH4+9//3ipxOh6l0myVXN6nmMIDBw5ERkYGHj16BLPZLMhkBDS9vm2dytUMt31LE4VCog4Z33//PSZMmGDzc1fF6DhDywbYHthSCYMFxHmNi7Fni/XQZ3qeu2FmtOpTSKEBAwZg3759rh4LR4LIpTpTqlQqm2cAcZ5sFFKdKW1dyMR58lGoJBo4tnHjRqd0tmPHDmRlZTmlLY57kKvIKrkatx3EX1FRgYyMDKxcyb49loUzFBqAHVIqlTBYQFz47oohwk2HYkKOAYB1v5usA65qcFuX6enp3MWtEyLvgIhotwjl3bt3UVtb28pE2RaWmbGrqwfGaRd5B2wdu+VMrdTUVIv7kr0y3MwoPWQqmVVyNW6ZKW/cuIE1a9bg/v37AICRI0diwIDW6yiW4yjShK8/Oa5B5oTDhPR6PVavXo3+/dnBem1xi1A2u7Y1a95tBRJgmxnd7zTFaYvMw3GhbHufklDcqlvZu++bhdgwWFumQ5ambTsMVriDri0tW0wYLMA2HdrSssedf5+Z79M73CpvuOx5ZtkxHi8y81nIPBxf4bW9T0ko/BZbjn1UCqskNJqx+T4lsXCh5NhFppRbJTF3MzoCv8WWYx8PhVVW7Fxh0YyO4rZ9SkfOlOF0PDKVtVC62vfVLa/vQ4cOQa/X4w9/+APOnz/Pz57sTHgorJNItmzZIng7CHDTTOnomTJOCYMFmPZsqYTBAuIcdFlaNsDWtHuq2N9zlbGRmc9EKV4IHxe3CGVzcJnYM2U4EsCGYLsSt/XY3pkyPMRWong8oUIp5EwZ1qFJ6T8Jc8fwOPZ4UmdKIWfK8OOlJcqTuqYUcqYMa5vBGQoNwHbQlUrEISDOQdeW6ZCl1NhSaG7Qt8x8Jk/q65vTiemAe4+4UHLs48GFkiM1FE/ompLTiemAmdItZkaDwYBFixZBq9UiIyPDHV1ynASpVFbJ1bhlpjx16hTGjh0LrVaL+fPnIywsrNUl9jZxgpYNsENhpRIGC7A1bVthsLYcdFmati0t+/uKE8x8Jk/qTPn1119j5MiRkMvl8PX1RVFRkTu65TgDhdI6uRi3zJS1tbXo1q0bgKZ7Gll3MzJDbLu5Y3Qcuzyp2ne3bt1QV1cHAKirq2Oea8kyM94sdJ43M8cxSOW436TBYMDy5cvx8OFDTJ8+HRqNRlA9t7y+m4+XJiIUFhYyj5fmJ/lKFA8P6ySQZl3i6NGj+POf/4yGBvZauy1uEcqQkBBcuHABERERmDhxIlPJ8fDwgJeXV6vEkQAKhXUSiMO6BHUCGhoaKCUlhRoaGjq0DSmNxV1tNJryrFJDQwNVV1e3Sqw21qxZQyUlJUREtG3bNvryyy8FjatTCGV1dTX5+vpSdXV1h7YhpbF0ZBspKSnk6+vbKqWkpFiVS0xMpG+//ZaIiDZv3kyFhYWC2ucWHY5ohN7N2KxL+Pj4oLCwUPAxkDzumyMa1vqfJZRCdAkWfKbkuAxPT09RVyo30ylmSg8PDyxduvSxYo2d0YaUxiKVNlyBjIjcf9I6h2OHTjFTcv6z4ELJkRxcKDmSQ/JC6QwH4ZqaGsybNw8ajQYff/yxw2Opr6/HypUrERkZKej2Xntj0Wq1+OijjxxqQ6/Xo6ysDNnZ2dBoNNBqtSgsLHSojUOHDiE8PBw6nQ7p6ekOjcfpiNrK7wBOnDhBR44cIZPJRHPnzqX6+nrRbXz44YeUmZlJZrOZZs2a5fBYfv/739OJEyeIiOjQoUNUV1cnuo309HTKyMggIqLFixfTo0ePBNdtaGigxYsX09ixY6m0tJQ0Gg01NjZSaWkpLVu2zKE2EhIS6IcffhD9HK5E8jOlMxyEtVotwsLCUFtbC3qMzYYrV66grKwMOp0O3bt3d+hMpJKSEvziF78AAAwaNEjU8zQf1+zv7w+g6ZAHhUKBxsZGqASGKbRt4/bt20hMTERMTAxKS0tFPo1rkLxQCnEQFkJ5eTlCQ0Mf657Jqqoq9O3bFwcOHEB2djYqKipEtzF48GBcvHgRJpMJV69eFezOBVgf19y7d28YDAZs2LABc+bMcaiN8ePHIzk5GW+++SaSk5MFj8WVSF4ohTgIC6Ffv37Izc1FWVkZbt265VAbPXv2hFqthlKpxPDhw1FWVia6DY1Gg6tXr2LJkiV49tln8dRTTzk0FgAwGo1YtmzZY/2zzZgxA927d8eQIUPw8KGNE0ncjOSFUoiDcHscPnwY58+fh0wmg6enJ2Qyx67haB4LAHzzzTcOjaWgoACzZ8/G7t27UVNTg4EDBzo0FqDpMNJx48aJvnWjGSKCTqeD0WjE7du30adPH4fH4kwkb/sOCQnBqlWrkJWVhenTpws26rdk0qRJiIuLw65du+Dn5+eQMAFAZGQk4uLicOzYMQQHB1sOghXDc889hxUrVgAAZs+eLfgA2bY8ePAAx48fx4gRI5CTkwNfX18kJCSIakMmkyEmJgZRUVHo0qWL024rfly4mZEjOST/+ub85yH513dH8tZbb+Grr76C0WhESUkJfv7znwNoUlZee631QQnx8fFQq9UOr+84/4YLpR02bdoEACgrK0N0dDSys7M7eET/GfDXt0iKi4sRHR2N0NBQzJw5EwUFBa0+r6+vh06nw/79+wEAFy5cgEajQXh4OBYuXIh79+4BaNof3L59OzQaDSZOnIi//vWvbn8WydKh9qROQmlpKQUFBRERUUREBJ06dYqIiK5du0ZBQUHU0NBAq1evpmPHjtH8+fNp3759RERUUVFBYWFhVFlZSURE2dnZpNfriYgoKCiI9u/fT0REp0+fpvDwcHc/lmThr28R1NbWori4GJMnTwYAvPTSS+jZs6dlM3779u0wm83Ytm0bACA/Px937txBdHQ0gKY7hBQt4qYDAgIAAEOGDEFlZaX7HkTicKEUATF2z4gIjY1NJ55NnjwZRqMR77//PtauXQuTyYSRI0daznw3GAyoqqqy1G3ec3V0M/9Jha8pReDl5YVnn30WOTk5AJpmwnv37sHX1xcAMHToULzxxhv49NNPkZ+fjxEjRuDatWsWp4v9+/dj3bp1HTb+zgKfKUWSnJyM9evXY/fu3VCpVEhJSWkVeNWjRw/Ex8cjISEBWVlZ2Lx5M1atWgWz2Yw+ffogKSmpA0ffOeAWHY7k4K9vjuTgQsmRHFwoOZKDCyVHcnCh5EgOLpQcycGFkiM5uFByJAcXSo7k4ELJkRxcKDmS4/8BX/0m+fF1Fx8AAAAASUVORK5CYII=",
      "text/plain": [
       "<Figure size 175x175 with 2 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "ground_truth = np.zeros((16,2))\n",
    "count = 0\n",
    "for i in range(1,5):\n",
    "    for j in range(1,5):\n",
    "        ground_truth[count,0] = i\n",
    "        ground_truth[count,1] = j\n",
    "        count += 1\n",
    "\n",
    "euclidean = pairwise_distances(ground_truth)\n",
    "\n",
    "plt.figure(figsize=(1.75,1.75))\n",
    "ax = sns.heatmap(euclidean,square=True,cmap='magma_r',cbar_kws={'fraction':0.046})\n",
    "# plt.title(\"Ground truth distance in 2D\",fontsize=8)\n",
    "plt.title(\"Pairwise distances in 2D\",fontsize=8)\n",
    "plt.ylabel('Token',fontsize=8,labelpad=0.3)\n",
    "plt.xlabel('Token',fontsize=8,labelpad=0.3)\n",
    "plt.xticks(fontsize=6)\n",
    "plt.yticks(fontsize=6)\n",
    "cbar = plt.gca().collections[0].colorbar  # Access the color bar\n",
    "cbar.ax.tick_params(labelsize=6) \n",
    "# cbar.ax.set_ylabel('Euclidean distance', fontsize=8)  # Shrink the label font size\n",
    "ax.invert_yaxis()\n",
    "plt.tight_layout()\n",
    "plt.savefig('../figures/manuscript_figures_v3/importance_of_pe_lst/vis_groundtruth.pdf',transparent=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 38,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Correlation between ground truth and 2dpe PearsonRResult(statistic=0.9924654473870471, pvalue=1.2920657515111678e-233)\n",
      "Correlation between ground truth and 1dpe PearsonRResult(statistic=0.633063465875991, pvalue=4.462599150139038e-30)\n"
     ]
    }
   ],
   "source": [
    "\n",
    "\n",
    "pe_1d = lstnn.transformer_main.PositionalEncoding(160,max_len=16)\n",
    "pe_2d = lstnn.transformer_main.PositionalEncoding2D(160)\n",
    "print('Correlation between ground truth and 2dpe', stats.pearsonr(euclidean.reshape(-1),euclidean_2dpe.reshape(-1)))\n",
    "print('Correlation between ground truth and 1dpe', stats.pearsonr(euclidean.reshape(-1),euclidean_1dpe.reshape(-1)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 75,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAKUAAACdCAYAAADVArgaAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjguMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8g+/7EAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAaaklEQVR4nO2deVRUR/bHv70BIiJKor8kLphEwigeTRAYcZRFIYIiwg97QG0XXBEhQUPEIUQ5aPAnmVExiiLGGXdR0Y7GXTKZTKK4gUTFiQuGhhgR0JFF6Iau3x8cOrT9Gl43vO4HqY+nzrHr1auFvn2rbt1aBIQQAgqFRwjNXQEK5WWoUFJ4BxVKCu+gQknhHVQoKbyDCiWFd1ChpPAOKpQU3kGFksI72i2U1dXVCAwMRElJiSbunXfeQVBQEIKCghAQEICIiAjcunWL8X2ZTAZfX19N+okTJ2Lv3r0AgNzcXLz77ruaZ83h2LFjjHnt3bsXp0+fxrRp0yCXy7WeNTQ0wMPDA8XFxe1tcrtISEjAjz/+yDr9ihUrUFpayjp9fHw8srOzWadXKBSYNWsWJk+ejODgYFy8eBEAoFQqERcXB39/fwQHB+P+/fta7929excTJ07UfN64caPWdzRixAhs2LABv/zyC1asWMG6PgAA0g7y8/PJ5MmTydChQ4lCodDEOzo6aqXLyckh7u7upLKyUiePGTNmkEuXLmk+l5WVkZEjR5J79+6RS5cukRkzZrCqS0VFBZk+fTohhJDDhw+TuXPnaj0/f/48kclkrNvGF7y9vbX+tm2xfPlycuTIEdbp58+fT44ePUoIIeTevXtk1KhRpKGhgWRmZpKEhARCCCGXL18mwcHBmneOHj1Kxo4dS7y9vRnzvHLlCvH39ydVVVWEEELWrl1LvvnmG9Z1apemPHDgABITE9GnT59W03l7e8PZ2RknTpxoM89XX30VDg4OuHfvnkF12bNnD95//30AgL+/PwoKClBRUaF5LpfLMXXqVFy+fBnh4eEICQmBr68vzp07BwC4cOECgoODMWnSJERERKCiogKLFi1CTk4OAOCLL77AtGnTADT1Dl5eXiCEYP369ZBKpZgwYQJkMhmePHkCABg1ahQ+++wzhISEYNKkSbh58yaApp4hNzdXp/4nTpxASEgIgoKCsGzZMlRXVyM9PR1lZWVYsGCBVlsA6G2HoUyePFnzd3NwcIBSqURtbS3++c9/IigoCADg6uqK58+fQ6FQ4NmzZ8jJycHnn3/OmF9DQwM+/fRTrFq1CjY2NgCAoKAgbNu2jXWd2iWUKSkpGDlyJKu0jo6OOl0AE7du3UJRURGGDh0KALh586ZO991yqNDMhQsX4ObmBgCwtraGr68vTp48CQD473//i+vXr8PPzw979uzB6tWrkZ2djTVr1iAtLQ1Ak9ClpKTgxIkTGDVqFG7fvg0vLy9Nd5abm4uSkhK8ePECFy9exOjRo1FcXIx79+7hwIEDOH36NPr164fjx48DACorK+Hi4oLs7GyEhoa2+qXcv38f+/btw/79+yGXy/H2229j8+bNiIyMRJ8+fZCRkQF7e3utd/S1w1AmTZqEbt26AQB27NgBJycn9OjRA2VlZVrK5tVXX8WTJ09gZ2eHtLQ0vPbaa4z5HT9+HAMGDNB8FwDg5OSEBw8e4OnTp6zqJDaqJUYgEAg0jX+ZTz75BNbW1lCr1ejWrRuSk5PRr18/lJaWwtnZGbt3724z/59//hmvv/665nNoaChSUlIgk8nw9ddf4/3334elpSXWrVuHnJwcnDp1Cjdu3EBNTQ0AYPz48Vi0aBHGjRsHHx8fjB49Gr/++iv27t2L2tpa1NTUwNXVFXl5efjuu+/g7e2NgQMHYvny5cjKykJRURGuXbuGN954Q1MHLy8vAE1fyoULF/TW/eLFi3j48CGkUimAJm3Tv3//Vturrx3GkpGRgaysLOzZswcAQAiBQCDQPCeEQChsW4ft378fy5Yt04l/7bXXoFAo0KtXrzbzMJlQFhYWag2MW7J69Wq4u7u3K3+BQACx+LfmvPvuu6ipqUFxcTG++uorrFy5EgAQHh4ODw8PuLm5wcPDAx999BEAICoqCv7+/vj222+RmpqKgoICREZGQigU4syZM3jvvffg6OiI3Nxc5ObmIj4+HgUFBVi2bBnmzp2LCRMmQCQSgbRYCWhpaampW2s0NjbC398fiYmJAIDa2loolcpW39HXDkMhhCApKQl5eXnYt2+fRjv27dsXZWVlGDBgAACgvLy8zWHa48eP8fjxYy0t2YxIJIJIJGJVJ5NMCZ07dw53796Fv78/Z2U4ODjoWKkhISHYsWMHVCoV/vCHP+DZs2coLi5GdHQ0xo4di2+++QaNjY0AmroxQgjmzJmD2bNn4/bt2wCatF16ejrc3d3xxz/+EdnZ2RgwYACsra1x7do1uLu7IywsDA4ODvj22281+RmCu7s7zp07pxmPpqSkYMuWLQCavsyX82ytHYayadMm/PTTT9i7d6+W0Hl6euLo0aMAgKtXr8LS0lKrJ2IiPz8fI0eOZPwRPn78GP369WNVJ840ZfMgGQBeeeUVfPnll7C2tjY4n+YxZUu8vLwQGxurFefj44NLly7h7bff1sRNmTIFnp6e+Mtf/gIAsLOzQ2hoKHx9fdG9e3e4u7ujrq4ONTU1WLZsGT788ENIJBJYWVlh1apVAJq+nG3btsHV1RV2dnaQSCSabjkgIABRUVHw8/ODpaUlnJ2doVAoDG6jk5MToqOjMWfOHBBC8NZbbyE+Ph4AMG7cOCxYsAAZGRkYOHBgm+1oyfz58xETE4Nhw4YhISEBPj4+GDdunOZ5TU0Ntm/fjj59+mD69Oma+K1bt0Imk+HTTz/FxIkTIZFIsG7dujbbUVxczDjW/OmnnzBo0CD07NmT1d9DQEjXWHleUVGBJUuWYP/+/eauCi85c+YMrKys4OnpafKy16xZgz/96U+sy+4yHh17e3v4+/trLG6KNmq1GqNGjTJ5uaWlpaisrDTox9BlNCWl69BlNCWl60CFksI7qFBSeAcVSgrvoEJJ4R0mczMag0AgYYwXCi31pNcXz/63JxRYMMZbWej6bG0kfZnTCmz15M3sZhMw6IZeamaXnnO3Vxnj/V9v0Imb7PszY1rRmgjmeKGXThzT31qtrmd8v6PgtVBSzI8+xcAlVCgpraKv9+ESKpSUVhEKTC8inBo60dHRmgW5DQ0N8PPz47I4CgcIhRY6gWs4+RkolUrExsZqbZA6fPiwZmkWpfMgMoEQvgxnQjl79mwcOXIEAFBXV4fc3Fw4OzsblA9frGyA2dLuCCsbYLa0DbGyAWZLW5+V3ZjwJWO8KMVLN07Pd8AlnHTfNjY2cHV11Xzes2cPwsPDuSiKwjEigVgncA3nJVRXV+P27duYN29eq+mUSmWbWwAopqfLaMqWFBQUoKioCDKZDIWFhUhISGBMt23bNri4uGgFivkRCS11Atdwrik9PDw0u/RkMhnWrFnDmG7hwoWYM2eOVlzPnv/DdfUobSDuavOUa9eu1frc2lZZCwsLWFiY3tKjtI6IenS06QgrG2C2tA2xsgFmS7sjrGyA2dI2xMoGmC1tfVb2V+cGMsaHpujGidHFNCWl89Plum9K50dCqFBSeIYIph9TmsT3LZfLIZVKERYWhsLCQi6LpHQwFsRCJ7Clrq4OsbGxCA8PR1JSEuv3TOL7bj5R7NGjR0hNTcXGjRtZ5dMRBg3A3QLdjjBogI5ZoMtk1OgzaE79wvy1hzLESdqhKbOysuDp6YkpU6Zg9+7dqKurg5WVVZvvcaIpm33fHh4eAID09HSIRCI0NDRAIjF9d0AxHgmR6ASlUonq6mqtwOSNu3r1KkpKSiCTydC9e3dWAgmYyPfdu3dvKJVKJCUlYfbs2YzvMDWUYn5EDP+YvG9M528+f/4cffr0wc6dOyGXy1FZWcmqTJMYOiqVCjExMQgMDNS7Umjbtm344osvTFEdigFIGIYtTN43JseHra0t3NzcIBaLMWzYMJSUlKB3795tlmkSoVy7di3Gjh2LkJAQvWmYGtqr1yCuq0ZpAwlDZ8rW+zZ06FBcuXIFDg4OuHPnDhYuXMiqTM4XZFRUVODQoUM4deoUZDIZkpOTGdNZWFjAxsZGK1DMj0Qo1AlsCQ8Px/nz5xEaGgpXV1f06NGD1Xu8PuBKImG2VvmyQLcjrGygYxboMlna+qzsmy+YdwDkPtUdPv3vK4k6cUfKmRVLR0EnzymtYohm7CioUFJaRdLGee1cQIWS0ipioemF0iRuxry8PEydOhUymcyoM8Ep5kMs1A1cw5lHJyoqCjdu3ADQdLD79u3bsXLlSmzfvp2LIikcIRYKdALnZXKRKdMWWzs7O9jZ2eHhw4es8+kIKxvgboFuR1jZQMcs0GWytPVZ2U+FZYzxTEhM33ubxs2oVqs1/+fxDBSFAYlQN3CNSQydlpf96LtKjW6x5SciM5xgahKhtLKyQmVlJSorK7XuLmwJ9X3zE3FXnRJavHgxFi1aBIFAoPfmKibft33vd0xRPUormKK7fhmDhPLf//43/va3v+H58+cghGhuOtV3Q2vzFtt+/fohKyur1bzpFlt+IhGa3gYwSCiTk5MRFxcHJyenNm9m7Qg6wsoGuFs13hFWNtAxq8aZLG19VjaBmjGeCZEZrG+DhNLOzg7jx4/nqi4UHsL77nv06NH4+9//Dk9PT81d1gDavHKX0nkxxzylQUJ57NgxAMCuXbs0ca2NKSmdHzHfx5Q5OTlGFVJXV4eoqChUVVVBKpUiNJRp3xyFj4j5rinLy8uRnJwMhUKBzMxMrFy5EqtWrYK9vX2r73333XcYOXIkFi5ciGnTprEWyq52Tw2XC3SZjBp9Bo2aNDLGMyEWtE9Tbtq0CW+88UarW2FexqBhbGJiIsaPH4/GxkbY2tpiyJAhiIuLa/O9N998EyqVim6x7YRIhEQnsKWysrLNqUAmDBLKx48fIzAwEAKBAGKxGJGRkSgvL2/zPYlEgq+//hoTJkzA6NGjGdPQLbb8RCIgOoEtmZmZmDJlisFlGtR9SyQSVFVVaeYoHzx4oNeX3ZI9e/Zg6dKl8PX1xZIlS1BaWqrjbmR2M/YzpHoUDmBaP8m0TuFl58ejR49QU1OD4cOHG16mIYk/+OADzJw5E6WlpYiMjER+fj5SUhgONXyJbt26oXv37hAKhbCxsUFtba1OGiY34zv9gwypHoUDxALdcSmTAlmyZAmio6O10syfPx+XL182vExDEnt4eGDnzp3Iz89HY2MjkpOT0asXs9elJTKZDHFxcdiyZQucnZ0xePBgnTTUzchPmMaQbA4j+PHHH7FixQrN3UkuLi4YOJDZS/UyBgnlzp07MWfOHHh5eQFoOmQ/IiICX331VavvvfLKK9i5c6chRQHoevfUcLlAl8nS1mdl15HnjPFMMI0h2SiQ5gXe2dnZAMBaIAEj5ikJIQgPD8df//pXnD17FvHx8YZkQelktHfy3JCpoGYMsr4zMzNx6dIlzbTQyZMnERAQYHChlM6DWKDWCZyXySZRs3sRAHx9fXH79m1IJBKcP38eAIwy+ymdA4mQeyF8GVZCmZubq/V5zJgxqKqq0sS3JZR1dXVYsWIFfv31Vzg5OWHlypXG1ZZicngrlC2nfVQqFYqKitDQ0IDBgwez8tAYe6IrxfyI+L4g4+bNm4iJiYGdnR3UajWePHmCtLS0Nq+su3r1KhwdHSGTyRAcHMz+RNcudk8Nlwt0mSxtfVZ2teoxYzwTvNWUzaxevRrr16/XzNLn5+cjOTlZY/7ro+WJrnPnzoWXlxerwzMp5kcsYr94o8PKNCRxbW2tlttoxIgRrLbFsjnRlW6x5Sfm0JSspoT27dsHAOjVqxfOnj2riT9z5gzs7OzafL/5RFcAuHPnDgYN0j2hl95iy09EQrVO4BpWQnno0CEAQFJSEjIzM+Hm5gY3Nzds376d1f0obE50XbhwIa5du6YVKOZHLFbrBM7LNCSxg4MDsrKyUFtbC7VazfoIaFtbW8bbA1pCfd/8RCTiqaFz9+5djBs3Tu9zrvbodLXLk7hcNc5kaeuzsuuUTxnjmTCFZtQpk02igQMHIiMjg+u6UHiIkK+aUiKR6D0DiNK1EfFVU+q7kInS9RFJTO/RYWV9r169ukMK27Rpk2Z9HaVzIJQQncA1JjuIv3lnW2xsLOt3uto9NVwu0GUyavQZNGrC3kkhMMNVDSYr0tidbRTzIjTDjmiTCCWbnW3UzchPhO2YOn706BE+/vhjqFQq+Pj4YMGCBezKNL5I9jTvbGsrDXUz8g+BRKAT2LJr1y5ER0fjwIED+P7771nv5TeJpmSzs41ph9y4/nT/j7kRMBwmxGbfN9D0nTa7lNVqNcRiduJmEqFks7ONuhn5icBCVyi3stj3DUCzWOfgwYMYMmQI63W0vL7Fdu7rzIs9+LJAtyOsbKBjFugyWdqGWNkAoFLptqdxHcPf7sOtrDQlAMjlcpw4cQKbN29mrXTo3YyU1pHorv5n26vduHEDx44dQ3p6ukG9oBkOD6Z0JgRioU5gS3p6OsrLyzF//nzIZDI8fsxuGwbVlJTWsWDeJ8WGrVu3GvWeyeYpjZmvopgfAUP3zTUm6b6Nna+i8AALkW7gGJNoSmPnq7raPTVcLtA1xNImxIDlaGLTa0qTCKWx81UUHiAxvdlhshLlcjnOnz+PzZs3Mz6nvm+eYtFFhZLNfBXT6bBrekhNUT1Ka3RVTdlyvgoAPv/8c/Ttq30dCZPv+7TX/5miepTW6KpjSjbzVUxegq52T42pF+jqM2gIqWedR5ftvimdGDPce0SFktI6FlQoKXxD1EXHlJROjBk0pUncjEqlEpGRkQgLCzPqrj6K+SASiU7gGpNoypMnT2LMmDEICwvDvHnzEBQUpHWJvT662j01XC7QZbK09VnZarUh1ncX1ZS3bt2Ci4sLhEIhHB0dcf/+fVMUS+kIRGLdwDEm0ZQ1NTWwtrYG0HRPI9PdjExuxm7WpqgdpVW6qvVtbW2NFy9eAABevHjBeK4lk5vxdmHrZ1pSuIdIjN/Mp1Qq8cEHH+Dp06cICQmBVMrObWyS7rv5eGlCCAoLCxmPl6Yn+fIUCwvdwJJmW2Lfvn04ffo06uvZjWVNIpQBAQH44YcfEBoaCj8/P0Yjx8LCAjY2NlqBwgNEIt3AEqNtCdIJqK+vJ2lpaaS+vt6sefCpLqbKo6HxG51QX19PqqqqtAJTHitWrCDFxcWEEEI2bNhArly5wqpenUIoq6qqiKOjI6mqqjJrHnyqiznzSEtLI46OjlohLS1NJ11ycjL5z3/+QwghJCUlhRQWFrLKn3p0KAbD5hJ64DdbYvDgwSgsLGR9DCTd900xGKbxP5NQsrElmKCaksIZlpaWere/tEan0JQWFhZYsmRJuw7A6og8+FQXvuTBBbw+4Iry+6RTaErK7wsqlBTeQYWSwjt4L5QdsUC4uroac+fOhVQqxfHjx42uS11dHWJjYxEeHs7q9t7W6hIWFobDhw8blUd0dDRKSkogl8shlUoRFhaGwsJCo/LYtWsXgoODIZPJkJmZaVR9OhyDpvLNwNGjR8nevXtJY2MjmTNnDqmrqzM4j/3795MjR44QtVpNZsyYYXRd/vGPf5CjR48SQgjZtWsXefHihcF5ZGZmkqysLEIIIYsXLya1tbWs362vryeLFy8mY8aMIQqFgkilUtLQ0EAUCgWJiYkxKo/ExETyyy+/GNwOLuG9puyIBcJhYWEICgpCTU0NSDsmG65evYqSkhLIZDJ0797dqDORiouL8d577wEABg0aZFB7lEolZs+eDQ8PDwBNhzyIRCI0NDRAwnKbwst5FBUVITk5GREREVAoFAa2hht4L5RsFgizoby8HIGBge26Z/L58+fo06cPdu7cCblcjsrKSoPzePPNN3Hp0iU0Njbi+vXrrJdzAYCNjQ1cXV01n3v37g2lUomkpCTMnj3bqDx8fHyQmpqKjz/+GKmpqazrwiW8F0o2C4TZ0LdvX+Tk5KCkpAQPHjwwKg9bW1u4ublBLBZj2LBhKCkpMTgPqVSK69evIyoqCv3790fPnj2NqgsAqFQqxMTEtOvHNnXqVHTv3h1OTk54+pT9PeBcwnuhZLNAuC12796N77//HgKBAJaWlhAI2F9QxFQXALhz545RdSkoKMCsWbOwZcsWVFdXw8HBwai6AMDatWsxduxYhISEGPU+IQQymQwqlQpFRUWwt7c3ui4dCe993wEBAVi6dCmys7MREhLC2qnfkgkTJiAuLg6bN2+Gu7u7UcIEAOHh4YiLi8PBgwfh6+urOQjWEAYMGIAPP/wQADBr1izWB8i+TEVFBQ4dOoThw4fj1KlTcHR0RGJiokF5CAQCREREYNq0abCysuqw24rbC3UzUngH77tvyu8P3nff5iQhIQE3b96ESqVCcXEx3nrrLQBNxsr06dO10sbHx8PNzc3o8R3lN6hQtsKaNWsAACUlJZg5cybkcrmZa/T7gHbfBvLw4UPMnDkTgYGB+POf/4yCggKt53V1dZDJZNixYwcA4IcffoBUKkVwcDAWLVqEsrKmY2F8fHywceNGSKVS+Pn54V//+pfJ28JbzOpP6iQoFAri7e1NCCEkNDSUnDx5khBCSF5eHvH29ib19fVk+fLl5ODBg2TevHkkIyODEEJIZWUlCQoKIs+ePSOEECKXy0l0dDQhhBBvb2+yY8cOQgghZ8+eJcHBwaZuFm+h3bcB1NTU4OHDh/D39wcAjBgxAra2tprJ+I0bN0KtVmPDhg0AgPz8fJSWlmLmzJkAmu4QErXYN+3p6QkAcHJywrNnz0zXEJ5DhdIACMPsGSEEDQ1Nl1D5+/tDpVJh/fr1+OSTT9DY2AgXFxfNme9KpRLPn/92IlvznKuxk/ldFTqmNAAbGxv0798fp06dAtCkCcvKyuDo6AgAGDJkCD766COcOXMG+fn5GD58OPLy8jSLLnbs2IGVK1earf6dBaopDSQ1NRWrVq3Cli1bIJFIkJaWprXxqkePHoiPj0diYiKys7ORkpKCpUuXQq1Ww97eHuvWrTNj7TsH1KND4R20+6bwDiqUFN5BhZLCO6hQUngHFUoK76BCSeEdVCgpvIMKJYV3UKGk8A4qlBTeQYWSwjv+H5bVStmYXy/aAAAAAElFTkSuQmCC",
      "text/plain": [
       "<Figure size 175x175 with 2 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "euclidean_1dpe = pairwise_distances(pe_1d.pe[0,:,:])\n",
    "\n",
    "plt.figure(figsize=(1.75,1.75))\n",
    "ax = sns.heatmap(euclidean_1dpe,square=True,cmap='magma_r',cbar_kws={'fraction':0.046})\n",
    "plt.title('1D PE (Vaswani et al., 2017)', fontsize=8)\n",
    "plt.ylabel('Token',fontsize=8,labelpad=0.3)\n",
    "plt.xlabel('Token',fontsize=8,labelpad=0.3)\n",
    "plt.xticks(fontsize=6)\n",
    "plt.yticks(fontsize=6)\n",
    "cbar = plt.gca().collections[0].colorbar  # Access the color bar\n",
    "cbar.ax.tick_params(labelsize=6) \n",
    "# cbar.ax.set_ylabel('Euclidean distance', fontsize=8)  # Shrink the label font size\n",
    "ax.invert_yaxis()\n",
    "plt.tight_layout()\n",
    "plt.savefig('../figures/manuscript_figures_v3/importance_of_pe_lst/vis_1dpe.pdf',transparent=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 76,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAKUAAACdCAYAAADVArgaAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjguMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8g+/7EAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAZ20lEQVR4nO2deVRUR/bHv72xiQQlUUdjJBo6jDExExQ9OoriviDqj3RQbJcEFRcwmqD4ixwloDI6ahR3cRj3iAQlk2A06OQXlwMTo4aomIkRFRAXQEeQpbf6/eGhx+ZVw3tN9+NB6pNTJ1LUq7qv+1Lv3bp1b8kIIQQMhoSQN7UADEZdmFIyJAdTSobkYErJkBxMKRmSgyklQ3IwpWRIDqaUDMnBlJIhOZRiDZSSkoK0tDTI5XL06NEDcXFxcHJywuuvvw5fX18AgF6vR4cOHfDRRx/hjTfe4PSh1Wpx7949uLm5AQAMBgMmT56MsLAw5OTkICIiAq+88orFNTNmzMD48eM5fR04cABeXl4YOXIkMjMzsXPnTphMJhBCEBQUhFmzZgEAZs6ciYSEBLRv397On4glp06dwpUrV7BgwQKL+pycHGzevBn79u1rsI/79+9j2bJl2LVrF+9xbflejEYjoqKisGbNGrRq1UrYjfKBiMBPP/1Exo4dS54+fUpMJhP5+OOPSUpKCiGEELVabdH29OnTpE+fPqSsrIzTz5QpU0h2drb55wcPHpBevXqRGzdukOzsbDJlyhRe8pSWlpKwsDBCCCH37t0jgwYNMo9XUVFBJkyYQL799ltbbtXuCLkvoTTmezl9+jRJTEx0iFyiPL49PDwQGxsLNzc3yGQy+Pr64u7du9S2gwcPRo8ePfDVV1812O9LL70Eb29v3LhxQ5A8+/fvx4gRIwAAjx49gl6vR1VVFQCgVatWSExMhI+PDwAgMDAQhYWFSE9Px4cffojw8HCMGDECUVFR0Ol0KCwsRGBgoLnvpKQkJCUlwWQyYfny5Rg3bhxGjx6NpKQkAEBVVRWio6MxduxYBAUF4dixYwCA9PR0xMTEAADOnj2LMWPGYOLEiUhNTTX3nZqainHjxmHMmDGIiopCdXW1xX09L0tMTAxWrlyJsLAwBAYGYvv27ZzPoTHfy8CBA3Hy5Ek8efJE0GfPB1GU0tvbG/7+/gCA0tJSHDhwwOKLrItarcZvv/3WYL9Xr15Ffn6++VF/5coVBAcHW5TCwkLOdadOnTLL4+vrixEjRmDYsGEICQnB2rVrYTQa0aVLF851ly5dwoYNG5CZmYnCwkKcPXvWqmz//ve/kZubiy+//BJpaWm4ceMGampqsHnzZrRu3RpfffUV9uzZg6SkJFy/ft18nU6nw5IlS7Bhwwakp6ebX1UAYN26ddi/fz++/vprdOrUCTdv3qz38ykqKsK+ffvw+eefY/v27RwFasz3olAo8Mc//hE5OTn1ymALor1TAs/+kmfPno13330Xffv2tdpOJpPB1dWV+rtly5bBzc0NJpMJrq6uiI+Px8svv4yioiL06NGD17vX7du30bFjR/PPsbGxCA8Px5kzZ3D+/HmEhoYiMTERo0aNsrjunXfeQevWrQEAPj4++M9//mN1jC5dusBgMCAsLAwBAQGIjo6Gs7MzsrOzkZCQAABo27YthgwZgn/9619wd3cHAPzyyy9o164d1Go1ACA4OBgbN24EAAwZMgQajQaBgYEYMWIEunfvXu99DhgwAHK5HO3atUObNm1QXl4ODw8PTjtbv5eOHTvi9u3b9cpgC6IpZV5eHmbNmoVZs2ZBq9U22HbMmDHU3yUkJKBPnz6NkkUmk0GpfHbr3333HSorKzF69GhoNBpoNBocOXIER48e5Sils7OzRR+EEPP/azEYDFAqlXB1dcXRo0eRk5ODs2fPIjQ0FPv27YPJZIJMJjO3J4TAYDBw+q1FoVCY/71q1Spcu3YN33//PaKjozF//nwEBwdbvU+avHVpzPeiUCgs5LMXojy+y8rKEB4ejtjY2AZv/Ntvv8Wvv/7KUQh74u3tjaKiIgCAi4sL1q9fj+LiYgDPlOTatWt4/fXXefXl4eGBx48fo6SkBDU1NThz5gwA4MKFCwgPD0ffvn2xZMkSdOvWDfn5+ejbt6/5PbGsrAxZWVno1auXuT+1Wo3S0lJcvXoVAMzvcFVVVRg6dCg6dOiAiIgIBAcHIy8vr1GfQ2O/l7t373JWO+yBKDPlnj17UFFRgS1btmDLli0AgEGDBmHhwoUAYPHX/uKLL+Jvf/ubxbsUX2rfKZ/n+XFqCQwMRHZ2Nl577TX07dsX8+bNw6xZs6DX60EIwZ///GdERkbyGrN169YIDw+HRqNB+/bt8dZbbwEA/Pz80K1bN4wdOxYuLi7o3r07Bg4cCH9/f6xYsQJjx46F0WjEzJkz8dZbb5mNNScnJ6xfvx4xMTFQKpXmR7SrqysiIiIwZcoUuLi44IUXXkBiYqLgz+h5GvO9GI1GXL16tdEy0JAR2pzewiktLcX8+fNx6NChphal2ZKVlYWLFy9i8eLFdu/7d+nR8fLywqhRo5CZmdnUojRLjEYj0tLSMHfuXIf0/7ucKRnS5nc5UzKkDVNKhuRgSsmQHEwpGZKDKSVDcojq+xaKTKai1v/Bsz+1/k/oR61v5+LMqTtR/U9q296yP1Pr+7d34tR1cTNQWgJ7b9LrAzq4UOsndrnPqevck777JmLva9T6kR2NnLrxQ/OpbZWrP6DWK+SDOHVyOfezM5lqqNfbC0krJaPpsTYxOBKmlIx6kcm4M6WjYUrJqBe5THwVcaihExkZad5kazAYMHz4cEcOx3AAcrkTpzgah/wZ6HQ6LFy4ED///LO5Li0tDQ8fPnTEcAwHohBBCeviMKWcPn06vvjiCwBAdXU1cnJy0KNHD0H92MPKBoAH1VxrUYiVDdAt7duV9I8voAO9nmZlA3RLu+An7g5xgG5lA3RL25qVbVi6m1qv+Msgbh3F+nY0Dnl8u7u7o3fv3uaf9+/fj0mTJjliKIaDUciUnOJoHD5CRUUFrl27hvDw8Hrb6XQ66HQ6R4vDEEiLmSmfJzc3F/n5+dBqtcjLy8Mnn3xCbbdjxw74+flZFEbTo5A7c4qjcfhM2a9fP2g0GgDPMlysXLmS2m727NmYMWOGRZ2681hHi8doAGVLW6esG79RX/irk5MTnJzEt/QY9aNgHh1L7GFlA8AlnOfUfdieHnRvzZ9Ns7TP3ae/A6/zf0Stt+bPplna6bfpuYsWjPuVWk+ztK1Z2ceyXqXWv0frFy1spmQ0f1rc45vR/FERppQMiaGA+O+Uovi+MzIyoNFoEBoa2uisDgxxcSJOnOJoRPF9Hzx4EIcOHUJxcTHWrl1rTtjUEPYwaACg+PE5Tl2XrgOpba25DmlGzQ+EnnWtc8/O1HprrkOaUfN/96opLYGPBLgOrRk039yl5/+hGTqqljJT1vq++/V7Zj1v27YNCoUCBoMBKpX4N8mwHRVRcYqjEcX33bZtW+h0OsTFxWH69OnUa3Q6HSoqKiwKo+lRUP5zNKIYOnq9HlFRUQgKCrK6U2jHjh3YvHmzRV0/TBZDPEY9qGSOV8K6iKKUiYmJGDhwICZOnGi1Dc3NGOmzztGiMRpA1QQBrw5XytLSUhw5cgQ9e/bE8ePHoVarERsby2nH3IzSRCVvnFImJSWhU6dO9U5IdZF0gquObQZR6+2xQVcqYbCAsA2647xOUutpcgvZVAwAzn/h7k34nxe5E8gXJfHU6+tSVlaG4OBgLFy4UJBSssVzRr00ZqZMTk6mnmHUEEwpGfWiei4/ey20Ddl1X7+Ki4vx9OlT9OzZU/CYLG0Lo16Uchmn0DZk79ixw+K6HTt2YObMmbaNaQ/BrREZGYklS5bg4cOHWLVqFVxcXLBq1Sp07kz3eDCkh5IybdFWSuoaqT///DOWLl1qjmD18/Ojnk1EHdM2Ueunrptx+/bt2LVrF0pKSrBr1y58+umnjhiW4QCUcu7jm89KSW0ka3p6OgDwVkhAxBBbT09PeHp64tatW7z7sYeVDdD92VIJgwWEbdAN6OBDrReyOmBNbtqagYqrk4IQYnXXIoqb0WQymf8t4RUoBgWVnFscjSjW9/MnbMmtLDGwEFtpomgCU1gUpXRxcUFZWRnKysrQqVMnahua79sX74ohHqMelJQlIYePKcYgc+fORUREBGQyGdasWUNtQ7Powl6x/2lWDGGI8biuiyClPHv2LNavX48nT56AEGI+MPPUqVPU9rUhti+//LLFudU0mO9bmqjk4tsAgpQyPj4e0dHR8PX1tXhPdBT2sLIB+q5xqYTBAsJ2jdtjdcCa3LQD7RTiP72FKaWnpyeGDh3qKFkYEkTyj+/+/fvj73//OwICAizOkn7+QHdGy6Kx65S2IEgpjx07BgDYu3evua6+d0pG80cp9XfK06dP2zRIdXU15s2bh/Lycmg0GoSEhNjUD0N8lFKfKUtKShAfH4+CggIkJydj+fLlWLFiBby8vOq97syZM+jVqxdmz56NyZMn81ZKexg0AD0UViphsADdqLEWBjt+aOMNMWty0wwdpUz8mVLQa2xsbCyGDh0Ko9EIDw8PdO/eHdHR0Q1e17VrV+j1ehZi2wxRyQmnOBpBSnn//n0EBQVBJpNBqVRizpw5KCkpafA6lUqFr7/+GiNHjkT//vQ85izEVpqoZIRTHI0gpVSpVCgvLzevUd68edOqL/t59u/fj0WLFiErKwu5ubkoKiritGGZfKWJUs4tDh9TSOMFCxZg6tSpKCoqwpw5c3D58mWsXr26wetcXV3RqlUryOVyuLu7o7KyktOG5mb8OoDukmSIh1JmariRvccU0rhfv35ISUnB5cuXYTQaER8fjzZt2jR4nVarRXR0NLZu3YoePXrAx4e7J5C5GaVJU7gZBYXYpqSkWMxmubm5WLZsGb788kuHCDemzf9S64WEwQJ0F5xUwmABYRt0X91ZQK2nbYgW6qZ970duRMCJvtzvYET2Kur19kLwOiUhBJMmTcK6detw8uRJxMTEOEo2hgRoisVzQa+tycnJyM7ONi8LZWZmYvTo0Y6SjSEBlDITpzh8TD6Nat2LADBs2DBcu3YNKpUKWVlZAGBTwDmjeaCSS9TQycnJsfh5wIABKC8vN9c3pJTV1dVYunQp7t27B19fXyxfvtw2aRmiI1mlfH7ZR6/XIz8/HwaDAT4+Prw8NKmpqQgICMD48eOxb98+VFdXw8WF/tLPkBYKqW/IuHLlCqKiouDp6QmTyYSHDx9i06ZNDS50X7hwAWq1GlqtFhMmTOCtkPawsgG6X1gqYbCAsA26QsKOhe4doCHZmbKWhIQEbNiwwZwf5vLly4iPjzfHd1vjyZMnaNeuHVJSUvDBBx9g0KBBaNu2re1SM0RDqaD/8Tp0TCGNKysrLRIWvf3227zCYj08PODv7w+lUok333wThYWFHKVkIbbSpClmSl5LQgcPHgQAtGnTBidP/ndh+MSJE/D09Gzw+jfeeAM//PADAOD69et49VXuVi3m+5YmCrmJUxwNL6U8cuQIACAuLg7Jycnw9/eHv78/du3ahbi4uAavnzRpErKyshASEoLevXujdevWnDazZ8/Gjz/+aFEYTY9SaeIUvhQXF0Or1SI0NBQ7d+7kP6YQAb29vZGamorKykqYTCa4u7vzus7Dw4OTKq4uzPctTRQK22fGvXv3IjIyEv7+/pg2bRomT57MS2d4KeWvv/6KIUOGWP29o2J07GFlA/Qd2FIJgwWE7RrvT6+mWtpCd+jToM2MfJKmAs+efrVPRZPJBKWS3xzIq1WXLl0ETb+MloOcMlPSUuzMnz8fkZGRFnW19sbhw4fRvXt33kuBvJRSpVJZzQHEaNkoKDMln6SptWRkZCArKwtbtmzhPSYvpbR2IBOj5aNQcT06fN//f/rpJxw7dgzbtm0TZC/wsr4TEhJ4d1gfSUlJ5syujOaBXEU4hS/btm1DSUkJZs6cCa1Wi/v36e/adRHtdIiysjKkpqZi4cKFvK+xh0ED0ENKpRIGCwgL353alf6V0YwaISHH1pA1QkO2b99u03WiKaWtZ6owmhZ5E0REi6KUfM5UoS0zsAjxpkfeBEvHouTU4nOmCnMzShOZSsYpjkaUmZLPmSq0ZQbEzxFDPEY9yJogmZAoSsnnTBXaMkON40VjNIDMqYUqZS1Cz1QRGgZrzXVIs7Sth8Hy36BrzcoWEgYL0F2H1qzs0Iv05A9/8OSmw/kT+lHbjnAZTK2nIXNqged9M5o5KvqylyNhSsmoF5kYyYPqwJSSUT9OLXSmLC4uxuLFi6HX6xEYGIhZs2aJMSzDDsia4PEtytxcu9nz888/x7lz51juyeaEk4JbHIwoM6Wtmz3tEQYL0P3ZUgmDBYRt0KVZ2QDd0m7n4kxpCTyoFrDYpmyhj29bN3syJIBKfLNDtBEb2uzJQmwlilMLVUo+mz1pW+w/dX9PDPEY9dFSZ8rnN3sCwF//+le0b2/pxmDppSVKE7xTCsrkKzY1S7TUeiEGDUDfoOv3Uim1rT0iDq0hJLePtQ26BkIPeaUZNdYMmks4T62/++g7Th35npu1VzaQnmHZXrDFc0b9NMG5R0wpGfXjxJSSITUULXSdktGMaYKZUhQ3o06nw5w5cxAaGorU1FQxhmTYCaJScYqjEWWmzMzMxIABAxAaGorw8HAEBwdbHGJvDXtY2QA9FFYqYbAA3dK2FgZrbYMuzdK2ZmUXPz5HrafSUmfKq1evws/PD3K5HGq1Gr/99psYwzLsgULJLQ5GlJny6dOncHNzA/DsnEba2Yw0N6OrmxjSMeqlpVrfbm5uqKqqAgBUVVVRcxTS3IzX8urPaclwPERle+C3TqfDggUL8OjRI0ycOBEajYbXdaI8vmvTSxNCkJeXR00vzTL5ShQnJ27hSa0tcfDgQXzzzTeoqeG3ZU4UpRw9ejTOnz+PkJAQDB8+nGrkODk5wd3d3aIwJIBCwS08sdmWIM2AmpoasmnTJlJTU9OkfUhJFrH6MBj/ySk1NTWkvLzcotD6WLp0Kblz5w4hhJDPPvuM/PDDD7zkahZKWV5eTtRqNSkvL2/SPqQkS1P2sWnTJqJWqy3Kpk2bOO3i4+PJL7/8QgghZPXq1SQvL49X/8yjwxAM30y+tbaEj48P8vLyeKeBFD+ol9Hsob3/05SSjy1Bg82UDIfh7OwsKNd5Lc1ipnRycsL8+fMbdc6OPfqQkixS6cMRSHrnOeP3SbOYKRm/L5hSMiQHU0qG5JC8Utpjg3BFRQU++OADaDQa/OMf/7BZlurqaixcuBCTJk3idXpvfbKEhoYiLS3Npj4iIyNRWFiIjIwMaDQahIaGIi8vz6Y+9u7diwkTJkCr1SI5OdkmeeyOoKX8JuDo0aPkwIEDxGg0khkzZpDq6mrBfRw6dIh88cUXxGQykSlTptgsy549e8jRo0cJIYTs3buXVFVVCe4jOTmZpKamEkIImTt3LqmsrOR9bU1NDZk7dy4ZMGAAKSgoIBqNhhgMBlJQUECioqJs6iM2NpbcvXtX8H04EsnPlPbYIBwaGorg4GA8ffoUpBGLDRcuXEBhYSG0Wi1atWplU06kO3fu4J133gEAvPrqq4LuR6fTYfr06ejX71kyq23btkGhUMBgMEDFM0yhbh/5+fmIj4/H+++/j4ICelpssZG8UvLZIMyHkpISBAUFNeqcySdPnqBdu3ZISUlBRkYGysrKBPfRtWtXZGdnw2g04uLFi7y3cwGAu7s7evfubf65bdu20Ol0iIuLw/Tp023qIzAwEGvXrsXixYuxdu1a3rI4EskrJZ8Nwnxo3749Tp8+jcLCQty8edOmPjw8PODv7w+lUok333wThYWFgvvQaDS4ePEi5s2bh86dO+OFF16wSRYA0Ov1iIqKatQf27vvvotWrVrB19cXjx49slkWeyJ5peSzQbgh9u3bh3PnzkEmk8HZ2RkymW3HcNTKAgDXr1+3SZbc3FxMmzYNW7duRUVFBby9vW2SBQASExMxcOBAwadu1EIIgVarhV6vR35+Pry8vGyWxZ5I3vc9evRoLFq0COnp6Zg4cSJvp/7zjBw5EtHR0diyZQv69OljkzIBwKRJkxAdHY3Dhw9j2LBh5kSwQnjllVfw4YcfAgCmTZvGO4FsXUpLS3HkyBH07NkTx48fh1qtRmxsrKA+ZDIZ3n//fUyePBkuLi52O624sTA3I0NySP7xzfj9IfnHd1PyySef4MqVK9Dr9bhz5w66desG4JmxEhYWZtE2JiYG/v7+Nr/fMf4LU8p6WLlyJQCgsLAQU6dORUZGRhNL9PuAPb4FcuvWLUydOhVBQUF47733kJuba/H76upqaLVa7N79LLXM+fPnodFoMGHCBERERODBgwcAnq0Pbty4ERqNBsOHD8f3338v+r1Ilib1JzUTCgoKyODBgwkhhISEhJDMzExCCCGXLl0igwcPJjU1NWTJkiXk8OHDJDw8nOzcuZMQQkhZWRkJDg4mjx8/JoQQkpGRQSIjIwkhhAwePJjs3r2bEELIyZMnyYQJE8S+LcnCHt8CePr0KW7duoVRo0YBAN5++214eHiYF+M3btwIk8mEzz77DABw+fJlFBUVYerUqQCenSGkeC5uOiAgAADg6+uLx48fi3cjEocppQAIZfWMEAKD4Vku81GjRkGv12PDhg1YtmwZjEYj/Pz8sH37dgDP/M5Pnvw321vtmquti/ktFfZOKQB3d3d07twZx48fB/BsJnzw4AHUajUAoHv37vj4449x4sQJXL58GT179sSlS5fMmy52796N5cuXN5n8zQU2Uwpk7dq1WLFiBbZu3QqVSoVNmzZZBF61bt0aMTExiI2NRXp6OlavXo1FixbBZDLBy8sLa9awY1gagnl0GJKDPb4ZkoMpJUNyMKVkSA6mlAzJwZSSITmYUjIkB1NKhuRgSsmQHEwpGZKDKSVDcjClZEiO/wchhrCN5nQbPgAAAABJRU5ErkJggg==",
      "text/plain": [
       "<Figure size 175x175 with 2 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "euclidean_2dpe = pairwise_distances(pe_2d.pe[:,:])\n",
    "\n",
    "plt.figure(figsize=(1.75,1.75))\n",
    "ax = sns.heatmap(euclidean_2dpe,square=True,cmap='magma_r',cbar_kws={'fraction':0.046})\n",
    "plt.title('2D PE (Sinusoids in 2D)', fontsize=8)\n",
    "plt.ylabel('Token Sequence',fontsize=8)\n",
    "plt.ylabel('Token',fontsize=8,labelpad=0.3)\n",
    "plt.xlabel('Token',fontsize=8,labelpad=0.3)\n",
    "plt.xticks(fontsize=6)\n",
    "plt.yticks(fontsize=6)\n",
    "cbar = plt.gca().collections[0].colorbar  # Access the color bar\n",
    "cbar.ax.tick_params(labelsize=6) \n",
    "# cbar.ax.set_ylabel('Euclidean distance', fontsize=8)  # Shrink the label font size\n",
    "ax.invert_yaxis()\n",
    "plt.tight_layout()\n",
    "plt.savefig('../figures/manuscript_figures_v3/importance_of_pe_lst/vis_2dpe.pdf',transparent=True)\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "lstnn",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.18"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
