{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 42,
   "id": "6762e780-5229-47d9-8de1-bdc5ff3fb712",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "import matplotlib\n",
    "import mne\n",
    "import pickle\n",
    "from scipy import signal\n",
    "from scipy.fft import fft, fftfreq, ifft\n",
    "import seaborn as sns\n",
    "import pandas as pd\n",
    "from scipy.io import loadmat\n",
    "from matplotlib import animation\n",
    "from IPython.display import clear_output\n",
    "\n",
    "%matplotlib widget\n",
    "plt.rcParams['figure.figsize'] = [9, 8]\n",
    "plt.rcParams.update({'font.size': 12})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 43,
   "id": "56d72379-5155-436d-a9f9-a6efe90df15c",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Opening raw data file cichy_data/subj0/MEG2_subj01_sess01_tsss_mc-3.fif...\n",
      "    Range : 4930000 ... 5162999 =   4930.000 ...  5162.999 secs\n",
      "Ready.\n",
      "Reading 0 ... 232999  =      0.000 ...   232.999 secs...\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/var/folders/nv/wjmf18wd5_j38vg9v0cthl5h0000gn/T/ipykernel_3213/388132765.py:3: RuntimeWarning: This filename (cichy_data/subj0/MEG2_subj01_sess01_tsss_mc-3.fif) does not conform to MNE naming conventions. All raw files should end with raw.fif, raw_sss.fif, raw_tsss.fif, _meg.fif, _eeg.fif, _ieeg.fif, raw.fif.gz, raw_sss.fif.gz, raw_tsss.fif.gz, _meg.fif.gz, _eeg.fif.gz or _ieeg.fif.gz\n",
      "  raw = mne.io.read_raw_fif(dataset_path, preload=True)\n"
     ]
    }
   ],
   "source": [
    "# Load data\n",
    "dataset_path = os.path.join('cichy_data', 'subj0', 'MEG2_subj01_sess01_tsss_mc-3.fif')\n",
    "raw = mne.io.read_raw_fif(dataset_path, preload=True)\n",
    "chn_type = 'mag'\n",
    "raw = raw.pick(chn_type)\n",
    "raw.info.update({'sfreq':250})"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6c8bc014-5c3e-4274-880a-7893c79d382a",
   "metadata": {},
   "source": [
    "# kernel network FIR"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 44,
   "id": "68719456-2387-4682-a919-14ef8f3a94c8",
   "metadata": {},
   "outputs": [],
   "source": [
    "outputs = []\n",
    "num_kernels = 5\n",
    "layer_ind = [0, 1, 2, 3, 4, 5]\n",
    "\n",
    "for i in layer_ind:\n",
    "    path = 'results/cichy_epoched/all_noshuffle_wavenetclass_semb10_drop0.4/kernels_network_FIR/'\n",
    "    path += 'conv' + str(i) +'.mat'\n",
    "    data = loadmat(open(path, 'rb'))\n",
    "    data = data['X'][:, :124000]\n",
    "\n",
    "    for j in range(num_kernels):\n",
    "        f, pxx = signal.welch(data[j], fs=250, nperseg=8*250)\n",
    "\n",
    "        pxx = pxx/np.std(pxx)\n",
    "        outputs.append(pxx)\n",
    "outputs = np.array(outputs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 45,
   "id": "d2187b1a-2d68-4772-81ad-0366661fe7da",
   "metadata": {},
   "outputs": [],
   "source": [
    "freqs = np.tile(f, (outputs.shape[0], 1))\n",
    "freqs = freqs.reshape(-1)\n",
    "\n",
    "# kernels\n",
    "kernels = np.tile(np.arange(outputs.shape[0]), (outputs.shape[1], 1)).T\n",
    "kernels = kernels.reshape(-1)\n",
    "\n",
    "# layers\n",
    "layers = np.tile(np.repeat(layer_ind, num_kernels), (outputs.shape[1], 1)).T\n",
    "layers = layers.reshape(-1)\n",
    "\n",
    "hue = np.tile(np.arange(num_kernels), (outputs.shape[1], len(layer_ind))).T\n",
    "hue = hue.reshape(-1)\n",
    "\n",
    "pd_dict = {'Power': outputs.reshape(-1), 'Frequency (Hz)': freqs, 'kernels': kernels, 'Layer': layers, 'K': hue}\n",
    "pfi_pd = pd.DataFrame(pd_dict)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 46,
   "id": "0029b53e-444e-4263-9e47-36a08c64cfc0",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "f3a1a96882c047e6b2e3193a0723c420",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "%matplotlib widget\n",
    "pfi_plot = sns.relplot(\n",
    "    data=pfi_pd, kind=\"line\", facet_kws={'sharey': False, 'sharex': True},\n",
    "    x=\"Frequency (Hz)\", y=\"Power\", hue='kernels', n_boot=0, legend=None, estimator=None,\n",
    "    row='Layer', col='K', aspect=1.8, height=1.5\n",
    ")\n",
    "plt.savefig('kernel_network_fir_30.pdf', format='pdf')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f77ffd26-1651-41ac-9d0b-ad17c8861dd0",
   "metadata": {},
   "source": [
    "# Spectral"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 47,
   "id": "3beb1a1c-bbf7-4a10-92e2-b67118de21ed",
   "metadata": {},
   "outputs": [],
   "source": [
    "path = 'results/cichy_epoched/all_noshuffle_wavenetclass_semb10_drop0.4/kernelPFI/val_loss_PFIfreqs4.npy'\n",
    "pfi = np.load(open(path, 'rb'))\n",
    "pfi = pfi[:, 2:, :]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 48,
   "id": "b06dc2c4-f1dc-48f3-b568-1c6db3ad85c1",
   "metadata": {},
   "outputs": [],
   "source": [
    "inds = [True if int(i/5) in layer_ind else False for i in range(30) ]\n",
    "pfi = pfi[:, :, inds]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 49,
   "id": "c45c8c6a-074f-49e2-817f-6d94ff5f7730",
   "metadata": {},
   "outputs": [],
   "source": [
    "xf = fftfreq(256, 1/250)[:256//2]\n",
    "xf = xf[3:126]\n",
    "xf = np.round(xf)\n",
    "\n",
    "times = np.array([xf for _ in range(pfi.shape[0])])\n",
    "times = np.array([times.reshape(-1) for _ in range(pfi.shape[2])]).T\n",
    "times = times.reshape(-1)\n",
    "\n",
    "# kernels\n",
    "kernels = np.repeat(np.arange(pfi.shape[2]).reshape(1, -1), [pfi.shape[1]], axis=0)\n",
    "kernels = np.repeat(kernels.reshape(1, kernels.shape[0], -1), [pfi.shape[0]], axis=0)\n",
    "kernels = kernels.reshape(-1)\n",
    "\n",
    "# layers\n",
    "layers = np.repeat(np.repeat(layer_ind, 5).reshape(1, -1), [pfi.shape[1]], axis=0)\n",
    "layers = np.repeat(layers.reshape(1, layers.shape[0], -1), [pfi.shape[0]], axis=0)\n",
    "layers = layers.reshape(-1)\n",
    "\n",
    "pd_dict = {'Output deviation': pfi.reshape(-1), 'Frequency (Hz)': times, 'kernels': kernels, 'Layer': layers}\n",
    "pfi_pd = pd.DataFrame(pd_dict)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 50,
   "id": "68ee0a52-3d29-4480-9e4a-5ee4a146b287",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "045f90a24185416ea87f8c2d5f44f462",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "%matplotlib widget\n",
    "pfi_plot = sns.relplot(\n",
    "    data=pfi_pd, kind=\"line\", facet_kws={'sharey': False, 'sharex': True},\n",
    "    x=\"Frequency (Hz)\", y=\"Output deviation\", hue='kernels', n_boot=10, legend=None,\n",
    "    row='Layer', height=2, aspect=2\n",
    ")\n",
    "plt.savefig('kernel_spectral_PFI_30.pdf', format='pdf')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3c7fd99c-68c9-4ef8-8385-a44fbd61988b",
   "metadata": {},
   "source": [
    "# SpatioSpectral"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 51,
   "id": "1e61492c-17ef-47ea-aa1c-f5856bd488df",
   "metadata": {},
   "outputs": [],
   "source": [
    "# PFI_freq\n",
    "path = os.path.join('results', 'cichy_epoched', 'all_noshuffle_wavenetclass_semb10_drop0.4', 'kernelPFI',\n",
    "                    'val_loss_PFIfreqs_ch4.npy')\n",
    "pfi = np.load(open(path, 'rb'))\n",
    "pfi = pfi[0, :, 1:, :]\n",
    "pfi = pfi.transpose(2, 0, 1)\n",
    "pfi = pfi[layer_ind, :, :]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 52,
   "id": "2919a5d1-31e4-474c-a36a-3cdad5f593e4",
   "metadata": {},
   "outputs": [],
   "source": [
    "# times array\n",
    "xf = fftfreq(256, 1/250)[:256//2]\n",
    "xf = xf[3:126]\n",
    "xf = np.round(xf)\n",
    "xf = np.tile(xf, (pfi.shape[0], pfi.shape[2], 1))\n",
    "xf = xf.transpose(0, 2, 1).reshape(-1)\n",
    "\n",
    "# magnitudes for color hues\n",
    "mags = np.abs(np.mean(pfi, axis=1))\n",
    "mags_max = mags.max(axis=1)[:, np.newaxis]\n",
    "mags = np.array([mags/mags_max for _ in range(pfi.shape[1])])\n",
    "mags = mags.transpose(1, 0, 2).reshape(-1)\n",
    "\n",
    "layers = np.tile(layer_ind, (pfi.shape[1], pfi.shape[2], 1))\n",
    "layers = layers.transpose(2, 0, 1).reshape(-1)\n",
    "\n",
    "pfi_pd = pfi.reshape(-1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 53,
   "id": "476021b3-9afa-4ab4-a13b-f2a378016c28",
   "metadata": {},
   "outputs": [],
   "source": [
    "# put everything in a pd dataframe\n",
    "pd_dict = {'Output deviation': pfi_pd, 'Frequency (Hz)': xf, 'relative magnitude': mags, 'Layer':layers}\n",
    "pfi_pd = pd.DataFrame(pd_dict)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 54,
   "id": "ee31a83d-1a06-408b-890f-99e0baef8f7a",
   "metadata": {},
   "outputs": [],
   "source": [
    "mags = np.abs(np.mean(pfi, axis=1))\n",
    "evokeds = []\n",
    "for i in range(mags.shape[0]):\n",
    "    evokeds.append(mne.EvokedArray(mags[i:i+1, :].T, raw.info, tmin=0))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 55,
   "id": "b7f00d75-32e5-491b-8e50-da35f3cd89d2",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "c7728f5e61a44966aa29a8e4b0570533",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "fig, axes = plt.subplots(6)\n",
    "for i in range(6):\n",
    "    evokeds[i].plot_topomap(axes=axes[i], times=[0], ch_type='mag', time_unit='ms', scalings=1, units='Output deviation', vmin=0, time_format='', colorbar=False)\n",
    "\n",
    "plt.savefig('kernel_spatialspectral_topo_30.pdf', format='pdf', transparent=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 56,
   "id": "fd65a954-3ffa-4ce7-86b1-f8902532d45e",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "58b0404959824210b32ccd4b382b04fb",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "%matplotlib widget\n",
    "pfi_plot = sns.relplot(\n",
    "    data=pfi_pd, kind=\"line\", n_boot=10, height=2, facet_kws={'sharey': False, 'sharex': True},\n",
    "    x=\"Frequency (Hz)\", y=\"Output deviation\", hue='relative magnitude', legend=None, palette='Reds', aspect=2, row='Layer'\n",
    ")\n",
    "\n",
    "plt.savefig('kernel_spatiospectral_chn_30.pdf', format='pdf')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d5c9c3be-a9ad-42a5-a757-24af40010485",
   "metadata": {},
   "source": [
    "# Temporal"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 57,
   "id": "3c57a7ee-b401-4ba2-aeed-ef6e863f81c8",
   "metadata": {},
   "outputs": [],
   "source": [
    "path = os.path.join('results', 'cichy_epoched', 'all_noshuffle_wavenetclass_semb10_drop0.4', 'kernelPFI',\n",
    "                    'val_loss_PFIts.npy')\n",
    "pfi = np.load(open(path, 'rb'))\n",
    "#pfi = np.mean(pfi, axis=0)\n",
    "pfi = pfi[:, 2:, :]\n",
    "pfi = (pfi - np.mean(pfi, axis=1, keepdims=True))/np.std(pfi, axis=1, keepdims=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 58,
   "id": "8da3286a-b1f6-4cf5-835b-164401811c75",
   "metadata": {},
   "outputs": [],
   "source": [
    "inds = [True if int(i/5) in layer_ind else False for i in range(30) ]\n",
    "pfi = pfi[:, :, inds]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 59,
   "id": "37787fc4-6933-40df-a2c0-14646235d059",
   "metadata": {},
   "outputs": [],
   "source": [
    "times = np.array([np.arange(-48, 868, 4) for _ in range(pfi.shape[0])])\n",
    "times = np.array([times.reshape(-1) for _ in range(pfi.shape[2])]).T\n",
    "times = times.reshape(-1)\n",
    "\n",
    "# kernels\n",
    "kernels = np.repeat(np.arange(pfi.shape[2]).reshape(1, -1), [pfi.shape[1]], axis=0)\n",
    "kernels = np.repeat(kernels.reshape(1, kernels.shape[0], -1), [pfi.shape[0]], axis=0)\n",
    "kernels = kernels.reshape(-1)\n",
    "\n",
    "# layers\n",
    "layers = np.repeat(np.repeat(layer_ind, 5).reshape(1, -1), [pfi.shape[1]], axis=0)\n",
    "layers = np.repeat(layers.reshape(1, layers.shape[0], -1), [pfi.shape[0]], axis=0)\n",
    "layers = layers.reshape(-1)\n",
    "\n",
    "pd_dict = {'Output deviation': pfi.reshape(-1), 'Time (ms)': times, 'kernels': kernels, 'Layer': layers}\n",
    "pfi_pd = pd.DataFrame(pd_dict)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 60,
   "id": "75573a47-91f4-4be0-ba6f-0521981680da",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "a0814c76bbb940ccae15e9db76c0f1ff",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "plt.rcParams.update({'font.size': 14})\n",
    "%matplotlib widget\n",
    "pfi_plot = sns.relplot(\n",
    "    data=pfi_pd, kind=\"line\", facet_kws={'sharey': False, 'sharex': True},\n",
    "    x=\"Time (ms)\", y=\"Output deviation\", hue='kernels', n_boot=10, legend=None,\n",
    "    row='Layer', aspect=2.2, height=2\n",
    ")\n",
    "plt.savefig('kernel_temporal_PFI_30.pdf', format='pdf')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "406d98ad-6f9b-4689-9362-3e4d39c87416",
   "metadata": {},
   "source": [
    "# SpatioTemporal"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 88,
   "id": "09e7f2ba-d596-484f-b761-ed3749dd0b09",
   "metadata": {},
   "outputs": [],
   "source": [
    "path = os.path.join('results', 'cichy_epoched', 'all_noshuffle_wavenetclass_semb10_drop0.4', 'kernelPFI',\n",
    "                    'val_loss_PFIch4.npy')\n",
    "pfi = np.load(open(path, 'rb'))\n",
    "pfi = pfi[:, :, 1:, layer_ind]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 192,
   "id": "2aa9d2c6-bf36-439e-a110-386aa5f3c094",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(4, 230, 102, 6)"
      ]
     },
     "execution_count": 192,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "pfi.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 89,
   "id": "feaf32c7-5f16-4320-bf5b-b6ab4758825a",
   "metadata": {},
   "outputs": [],
   "source": [
    "# times array\n",
    "times = np.tile(np.arange(-48, 872, 4), (pfi.shape[0], pfi.shape[2], pfi.shape[3], 1))\n",
    "times = times.transpose(0, 3, 1, 2).reshape(-1)\n",
    "\n",
    "# magnitudes for color hues\n",
    "mags = np.abs(np.mean(pfi, axis=1))\n",
    "mags_max = mags.max(axis=1, keepdims=True)\n",
    "mags = np.array([mags/mags_max for _ in range(pfi.shape[1])])\n",
    "mags = mags.transpose(1, 0, 2, 3).reshape(-1)\n",
    "\n",
    "layers = np.tile(layer_ind, (pfi.shape[0], pfi.shape[1], pfi.shape[2], 1))\n",
    "layers = layers.reshape(-1)\n",
    "\n",
    "pfi_pd = pfi.reshape(-1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 90,
   "id": "31da6755-5edb-4fb0-975c-3d34f490977e",
   "metadata": {},
   "outputs": [],
   "source": [
    "# put everything in a pd dataframe\n",
    "pd_dict = {'Output deviation': pfi_pd, 'Time (ms)': times, 'relative magnitude': mags, 'Kernel = 4 | Layer': layers}\n",
    "pfi_pd = pd.DataFrame(pd_dict)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 64,
   "id": "b164c779-b80d-48bb-9cb0-8ee37413edc6",
   "metadata": {},
   "outputs": [],
   "source": [
    "mags = np.abs(np.mean(pfi, axis=(0,1)))\n",
    "evokeds = []\n",
    "for i in range(mags.shape[1]):\n",
    "    evokeds.append(mne.EvokedArray(mags[:, i:i+1], raw.info, tmin=0))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 81,
   "id": "b9ac9438-2a70-4be0-b759-cf99b2960b3b",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "3b3ae3de191144329fa51649c117bd89",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "fig, axes = plt.subplots(6)\n",
    "for i in range(6):\n",
    "    evokeds[i].plot_topomap(axes=axes[i], times=[0], ch_type='mag', time_unit='ms', scalings=1, units='Output deviation', vmin=0, time_format='', colorbar=False)\n",
    "\n",
    "plt.savefig('kernel_spatiotemporal_head.pdf', format='pdf', transparent=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 91,
   "id": "75564a36-b27f-4071-bc01-cab6ba79470b",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "8eca67408da2472ab3e4d44a45cb5bf2",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "%matplotlib widget\n",
    "pfi_plot = sns.relplot(\n",
    "    data=pfi_pd, kind=\"line\", n_boot=10, facet_kws={'sharey': False, 'sharex': True},\n",
    "    x='Time (ms)', y=\"Output deviation\", hue='relative magnitude', legend=None, palette='Reds', aspect=2.2, row='Kernel = 4 | Layer', height=2\n",
    ")\n",
    "\n",
    "plt.savefig('kernel_spatiotemporal_chn_30.pdf', format='pdf')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b1b4daa3-3a5f-4b5e-8f84-81ba4a1236e8",
   "metadata": {},
   "source": [
    "# Spatial"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 78,
   "id": "688043a6-30b8-4ff2-9e1a-2011facc4e13",
   "metadata": {},
   "outputs": [],
   "source": [
    "path = os.path.join('results', 'cichy_epoched', 'all_noshuffle_wavenetclass_semb10_drop0.4', 'kernelPFI',\n",
    "                    'val_loss_PFIch1.npy')\n",
    "pfi = np.load(open(path, 'rb'))\n",
    "pfi = np.mean(pfi[:, 0, 1:, :], axis=0, keepdims=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 79,
   "id": "43b776e7-184c-4a6e-8a1a-5817af539bd0",
   "metadata": {},
   "outputs": [],
   "source": [
    "inds = [True if int(i/5) in layer_ind else False for i in range(30) ]\n",
    "pfi = pfi[:, :, inds]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 80,
   "id": "66738bdc-b401-48db-86fc-5aeb71cae1fc",
   "metadata": {},
   "outputs": [],
   "source": [
    "inds = [i for i in range(pfi.shape[2]) if i%5 > -1]\n",
    "pfi = pfi[:, :, inds]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 81,
   "id": "af50ea7e-9733-4e59-835a-e5710ac54179",
   "metadata": {},
   "outputs": [],
   "source": [
    "evokeds = []\n",
    "for i in range(pfi.shape[2]):\n",
    "    evokeds.append(mne.EvokedArray(pfi[:, :, i].T, raw.info, tmin=0))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 85,
   "id": "f8b30bda-1eaa-4db9-a2d7-c54298249572",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "79fc12047adc4a03aaa3edd148c500b6",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "plt.rcParams['figure.figsize'] = [12, 8]\n",
    "plt.rcParams.update({'font.size': 14})\n",
    "nk = 5\n",
    "nl = 6\n",
    "fig, axes = plt.subplots(nl, nk)\n",
    "for i in range(pfi.shape[2]):\n",
    "    ax = axes[int(i/nk)][i%nk]\n",
    "    evokeds[i].plot_topomap(axes=ax, times=[0], ch_type='mag', time_unit='ms', scalings=1, units='Output deviation', vmin=0, time_format='', colorbar=False)\n",
    "    if i%nk == 0:\n",
    "        ax.set_ylabel('Layer ' + str(layer_ind[int(i/nk)]))\n",
    "        \n",
    "        \n",
    "    if i<nk:\n",
    "        ax.set_xlabel('Kernel ' + str(i))\n",
    "        ax.xaxis.set_label_position('top') \n",
    "        \n",
    "\n",
    "plt.savefig('kernel_spatial_PFI.pdf', format='pdf')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "de8913d9-3a3c-4cad-b38a-65dc421ed0e6",
   "metadata": {},
   "source": [
    "# Spatio-Temporal"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 123,
   "id": "ab82f337-423d-4c0e-a6ce-d1fbc42195e3",
   "metadata": {},
   "outputs": [],
   "source": [
    "path = os.path.join('results', 'cichy_epoched', 'wavenetclasslinear', 'kernel_PFI',\n",
    "                    'val_loss_PFIch4.npy')\n",
    "pfi = np.load(open(path, 'rb'))\n",
    "pfi = np.mean(pfi, axis=0)\n",
    "pfi = pfi[:, 1:]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 124,
   "id": "b7e299bf-ea9d-4b95-bc55-622720382237",
   "metadata": {},
   "outputs": [],
   "source": [
    "times = list(range(-50, 866, 4))\n",
    "ticks = list(range(len(times)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 125,
   "id": "ba6d49dc-f227-4494-a019-fab2b39fc7f9",
   "metadata": {},
   "outputs": [],
   "source": [
    "evokeds = []\n",
    "for i in range(pfi.shape[2]):\n",
    "    evokeds.append(mne.EvokedArray(pfi[:, :, i].T, raw.info, tmin=-0.048))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "id": "26b9db77-95c9-422f-b76c-8d533dde2a61",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "969b5fe4dfcd4700aa2efaff13bda139",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "times = np.arange(-0.048, 0.866, 0.004)\n",
    "fig, anim = evokeds[11].animate_topomap(\n",
    "    times=times, ch_type='mag', frame_rate=20, time_unit='ms', blit=False, show=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "656d70c5-e7d7-433e-9b84-b124d2b6e3b8",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "6c7e364650e245fa9c3f3c4420564752",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "bc82decfd5334fabae56243a41d8ccf0",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "611677db98e74d81bbd2888ad26273e7",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "0cd059f2a1d9425f97d35fb426ea833a",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "d42db017e2ea44adbfecd8d503627822",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "de86816b9cf444c59f437b0ca8f01fca",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "0b3d20bd4f2d4fa4a5f3abdbdc725fec",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "a769f14a534845f4a22d74ebef8b4efe",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "d19d5657202f4f2f9134fceb8adb2906",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "754822a83a2241df8d7b0f232b72d659",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "0cc22874a0e346fabc3cf1fa46900c7a",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "223b7506b6464e9589a1de811175c48b",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "27db5b7157664bf896165c11db52fff7",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "002db6b149204022937d96a04ddb7a9f",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "f81b7906b7cf420fbc7ebb505eeba23a",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "times = [-0.04, 0.13, 0.24, 0.5, 0.63, 0.8]\n",
    "for i in range(pfi.shape[2]):\n",
    "    evokeds[i].plot_topomap(times=times, ch_type='mag', time_unit='ms', scalings=1, units='Accuracy loss')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 126,
   "id": "d5423b0a-5bca-489a-a4c6-9c3e7968b78e",
   "metadata": {},
   "outputs": [],
   "source": [
    "if len(pfi.shape) == 3:\n",
    "    pfi_all = pfi.reshape(1, pfi.shape[0], pfi.shape[1], -1)\n",
    "\n",
    "pfi_pds = []\n",
    "for i in range(pfi_all.shape[3]):\n",
    "    pfi = pfi_all[:, :, :, i]\n",
    "    # times array\n",
    "    times = np.array([np.arange(-48, 872, 4) for _ in range(pfi.shape[0])])\n",
    "    times = np.array([times.reshape(-1) for _ in range(pfi.shape[2])]).T\n",
    "    times = times.reshape(-1)\n",
    "\n",
    "    # channels array\n",
    "    pfi_pd = pfi.reshape(-1, pfi.shape[2])\n",
    "\n",
    "    # magnitudes for color hues\n",
    "    mags = np.abs(np.mean(pfi, axis=(0, 1)))\n",
    "    mags = np.array([mags/np.max(mags[:-1]) for _ in range(pfi_pd.shape[0])])\n",
    "    mags = mags.reshape(-1)\n",
    "\n",
    "    pfi_pd = pfi_pd.reshape(-1)\n",
    "\n",
    "    # put everything in a pd dataframe\n",
    "    pd_dict = {'Output deviation': pfi_pd, 'Time (ms)': times, 'relative magnitude': mags, 'channels': ['individual']*mags.shape[0]}\n",
    "    pfi_pds.append(pd.DataFrame(pd_dict))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7dd01d0a-fdaf-4e5c-9fb6-3bcc4770bbcd",
   "metadata": {},
   "outputs": [],
   "source": [
    "%matplotlib widget\n",
    "for i in range(pfi_all.shape[3]):\n",
    "    pfi_plot = sns.relplot(\n",
    "        data=pfi_pds[i], kind=\"line\", style='channels',\n",
    "        x=\"Time (ms)\", y=\"Output deviation\", hue=\"relative magnitude\", n_boot=10, legend='brief', palette='Reds', aspect=1.5\n",
    "    )"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
