{
 "cells": [
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Load data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from mne.datasets.sleep_physionet.age import fetch_data\n",
    "from dataset import load_sleep_physionet_raw\n",
    "import mne\n",
    "\n",
    "mne.set_log_level('ERROR')\n",
    "\n",
    "subjects = range(5)\n",
    "recordings = [1]\n",
    "\n",
    "fnames = fetch_data(subjects=subjects, recording=recordings, on_missing='warn')\n",
    "# Load recordings\n",
    "raws = [load_sleep_physionet_raw(f[0], f[1]) for f in fnames]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "l_freq, h_freq = None, 30\n",
    "\n",
    "for raw in raws:\n",
    "    raw.load_data().filter(l_freq, h_freq)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from dataset import  extract_epochs\n",
    "import numpy as np\n",
    "X_all, y_all = [], []\n",
    "for raw in raws:\n",
    "    data, event = extract_epochs(raw)\n",
    "    data -= np.mean(data, axis=2, keepdims=True)\n",
    "    data /= np.std(data, axis=2, keepdims=True)\n",
    "\n",
    "    X_all.append(data)\n",
    "    y_all.append(event)\n",
    "\n",
    "X, y = X_all[:4], y_all[:4]\n",
    "X_target, y_target = X_all[4], y_all[4]"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Compute PSD, barycenter and transform data \n",
    "Compute barycenter from source data (i.e. X) and mapped on the barycenter X and X_target"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from barycenter import monge_mapping_transform\n",
    "from barycenter import get_psd\n",
    "filter_size = 128\n",
    "\n",
    "# transform data\n",
    "X_transform = monge_mapping_transform(X, X, filter_size=filter_size)\n",
    "X_target_transform = monge_mapping_transform([X_target], X, filter_size=filter_size)[0]\n",
    "\n",
    "# Compute psd for source raw and transform data for one channels\n",
    "chan = 0\n",
    "psd_raw = [\n",
    "    np.array(\n",
    "        [get_psd(X[i][j, chan], filter_size=filter_size) for j in range(len(X[i]))]\n",
    "    )\n",
    "    for i in range(len(X))\n",
    "]\n",
    "psd_transform = [\n",
    "    np.array(\n",
    "        [\n",
    "            get_psd(X_transform[i][j, chan], filter_size=filter_size)\n",
    "            for j in range(len(X[i]))\n",
    "        ]\n",
    "    )\n",
    "    for i in range(len(X))\n",
    "]\n",
    "\n",
    "# Compute psd for target raw and transform data\n",
    "psd_target_raw = np.array(\n",
    "        [\n",
    "            get_psd(X_target[j, chan], filter_size=filter_size)\n",
    "            for j in range(len(X_target))\n",
    "        ]\n",
    "    )\n",
    "psd_target_transform = np.array(\n",
    "        [\n",
    "            get_psd(X_target_transform[j, chan], filter_size=filter_size)\n",
    "            for j in range(len(X_target))\n",
    "        ]\n",
    "    )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# compute barycenter\n",
    "from barycenter import get_barycenters\n",
    "X_ = [np.concatenate(X[i][:, chan]) for i in range(len(X))]\n",
    "bary = get_barycenters(X_, filter_size)"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Plot psd for source"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def set_axe(axe):\n",
    "    axe.set_yscale(\"log\")\n",
    "    axe.set_xlim(0, 20)\n",
    "    axe.set_ylim(1e0, 1e4)\n",
    "    axe.set_xlabel('Freq. (Hz)')\n",
    "    axe.set_ylabel('PSD')\n",
    "    axe.grid(axis='y')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Plot psd for one subject\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "subj = 0\n",
    "trange = np.linspace(0, 50, len(psd_raw[subj][0]))\n",
    "\n",
    "fig, axes = plt.subplots(1, 3, figsize=(10, 4))\n",
    "nb_samples = 50\n",
    "index = np.random.choice(len(psd_raw[subj]), nb_samples, replace=False)\n",
    "for psd in psd_raw[subj][index]:\n",
    "    axes[0].plot(trange, psd, color=\"grey\", alpha=0.3)\n",
    "axes[0].plot(trange, np.mean(psd_raw[subj], axis=0), linewidth=2, color=\"r\")\n",
    "set_axe(axes[0])\n",
    "axes[0].set_title(\"Raw data\")\n",
    "\n",
    "for psd in psd_raw:\n",
    "    axes[1].plot(trange, np.mean(psd, axis=0), color=\"grey\", alpha=0.3)\n",
    "axes[1].plot(trange, bary, linewidth=2, color=\"black\", linestyle='--')\n",
    "set_axe(axes[1])\n",
    "axes[1].set_title(\"Barycenter\")\n",
    "\n",
    "for psd in psd_transform[subj][index]:\n",
    "    axes[2].plot(trange, psd, color=\"grey\", alpha=0.3)\n",
    "axes[2].plot(trange, np.mean(psd_raw[subj], axis=0), linewidth=2, color=\"r\", label='Raw data')\n",
    "axes[2].plot(trange, np.mean(psd_transform[subj], axis=0), linewidth=2, color=\"b\", label='Mapped data')\n",
    "axes[2].plot(trange, bary, linewidth=2, color=\"black\", linestyle='--', label='Barycenter')\n",
    "set_axe(axes[2])\n",
    "axes[2].legend()\n",
    "axes[2].set_title(\"Mapped data with CMMN\")\n",
    "fig.suptitle(\"CMMN for one source data\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Plot psd for one subject\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "trange = np.linspace(0, 50, len(psd_target_raw[0]))\n",
    "\n",
    "fig, axes = plt.subplots(1, 3, figsize=(10, 4))\n",
    "nb_samples = 50\n",
    "index = np.random.choice(len(psd_target_raw), nb_samples, replace=False)\n",
    "for psd in psd_target_raw[index]:\n",
    "    axes[0].plot(trange, psd, color=\"grey\", alpha=0.3)\n",
    "axes[0].plot(trange, np.mean(psd_target_raw, axis=0), linewidth=2, color=\"r\")\n",
    "set_axe(axes[0])\n",
    "axes[0].set_title(\"Raw data\")\n",
    "\n",
    "axes[1].plot(trange, bary, linewidth=2, color=\"black\", linestyle='--')\n",
    "set_axe(axes[1])\n",
    "axes[1].set_title(\"Barycenter\")\n",
    "\n",
    "for psd in psd_target_transform[index]:\n",
    "    axes[2].plot(trange, psd, color=\"grey\", alpha=0.3)\n",
    "axes[2].plot(trange, np.mean(psd_target_raw, axis=0), linewidth=2, color=\"r\", label=\"Raw data\")\n",
    "axes[2].plot(trange, np.mean(psd_target_transform, axis=0), linewidth=2, color=\"b\", label=\"Mapped data\")\n",
    "axes[2].plot(trange, bary, linewidth=2, color=\"black\", linestyle='--', label=\"Barycenter\")\n",
    "set_axe(axes[2])\n",
    "axes[2].legend()\n",
    "axes[2].set_title(\"Mapped data with CMMN\")\n",
    "\n",
    "fig.suptitle(\"CMMN for target data\")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "phd_expe",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.8"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
