{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b7fff77a",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "from regression.load_meg_targets import load_meg_targets\n",
    "from regression.session_story_configs import subject_test_configs\n",
    "from regression.lm_embeddings.embeddings_store import MEGFeatureMapStore\n",
    "from regression.helmet_plot import HelmetPlot\n",
    "from regression.ceilings import sp\n",
    "import matplotlib as mpl\n",
    "mpl.rcParams[\"font.size\"]        = 20   # global default for text\n",
    "mpl.rcParams[\"axes.titlesize\"]   = 20   # specifically for axes titles\n",
    "mpl.rcParams[\"axes.labelsize\"]   = 20   # for x/y axis labels\n",
    "mpl.rcParams[\"xtick.labelsize\"]  = 12\n",
    "mpl.rcParams[\"ytick.labelsize\"]  = 12\n",
    "mpl.rcParams[\"legend.fontsize\"]  = 12\n",
    "mpl.rcParams[\"figure.titlesize\"] = 20   # for `plt.suptitle`\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c9154cb4",
   "metadata": {},
   "outputs": [],
   "source": [
    "subject = \"D\"\n",
    "embeddings_loc = \"./embeddings/embeddings_sweep/llama2/layer_3_context_20\"\n",
    "dataset_loc = \"./data\"\n",
    "\n",
    "test_configs = subject_test_configs(subject,dataset_loc=dataset_loc)\n",
    "meg = np.stack(load_meg_targets(test_configs),  axis=0)\n",
    "meg_test_target = np.mean(meg, axis=0)\n",
    "meg_store = MEGFeatureMapStore(embeddings_loc + \"/meg_store\")\n",
    "helmet_plotter = HelmetPlot(f\"{dataset_loc}/locations.txt\")\n",
    "\n",
    "def cc_max(trial_repeats, sp_cutoff = 0.01):\n",
    "    s = sp(trial_repeats, sp_cutoff)\n",
    "    y_bar = np.mean(trial_repeats, axis=0)\n",
    "    y_bar_std = np.std(y_bar, axis=0)\n",
    "    return np.sqrt(s)/y_bar_std"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "947f2494",
   "metadata": {},
   "outputs": [],
   "source": [
    "#spe, cc_norm, cc_max, _ = spe_and_cc_norm(meg, meg_test_target, max_flooring = 0.8, data_norm=False)\n",
    "#helmet_plotter.plot(cc_max, \"Channel $CC_{max}$\", cmap=None, vlim=(0, 1.0), colorbar_title=\"$CC_{max}$\")\n",
    "#plt.figure()\n",
    "#plt.scatter(np.arange(306), cc_max)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "24f77ff3",
   "metadata": {},
   "outputs": [],
   "source": [
    "total_cc_max = cc_max(meg, 0.01)\n",
    "\n",
    "plt.figure()\n",
    "plt.scatter(np.arange(306), total_cc_max)\n",
    "helmet_plotter.plot(total_cc_max, \"Channel $CC_{max}$\", cmap=None, vlim=(0, 1.0), colorbar_title=\"\")\n",
    "print()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4b36aab8",
   "metadata": {},
   "outputs": [],
   "source": [
    "meg_map = np.array(meg_store.load_meg_map(test_configs[0]))\n",
    "\n",
    "def get_delay_after_word_onset_indices(meg_map, delay=0):\n",
    "    indices_with_word = (meg_map != -1)\n",
    "    out_indices = []\n",
    "    for i in range(delay,len(meg_map)):\n",
    "        if indices_with_word[i - delay]:\n",
    "            out_indices.append(i)\n",
    "    return out_indices\n",
    "\n",
    "def meg_after_word_onset(meg_map, meg, delay=0):\n",
    "    out_indices = get_delay_after_word_onset_indices(meg_map, delay)\n",
    "    return meg[:,out_indices,:]\n",
    "\n",
    "def make_ceilings_timeseries(cc_max_divisions = 8, sp_cutoff = 0.01):\n",
    "    ceilings = []\n",
    "    shifts_per = 40//cc_max_divisions\n",
    "    for cc_max_bin in range(cc_max_divisions):\n",
    "        meg_in_bin = []\n",
    "        for shift_in_bin in range(shifts_per):\n",
    "            time_bin_start = cc_max_bin*cc_max_divisions\n",
    "            sifted_meg = meg_after_word_onset(meg_map, meg, delay=time_bin_start + shift_in_bin)\n",
    "            meg_in_bin.append(sifted_meg)\n",
    "        selected_meg = np.concat(meg_in_bin, axis=1)\n",
    "        ceiling = cc_max(selected_meg, sp_cutoff)\n",
    "        ceilings.append(ceiling)\n",
    "    return np.array(ceilings)\n",
    "    #meg_means.append(np.mean(np.mean(sifted_meg, axis=0), axis=0))\n",
    "baseline_cc_max = cc_max(meg, 0.01)\n",
    "cc_max_divisions = 8\n",
    "shifts_per = 40//cc_max_divisions\n",
    "ceilings_timeseries = make_ceilings_timeseries(cc_max_divisions)\n",
    "time_locs = [(i*shifts_per*20 + (i+1)*shifts_per*20)/2 for i in range(cc_max_divisions)]\n",
    "ceilings = ceilings_timeseries - baseline_cc_max\n",
    "#for i in range(30):\n",
    "#    plt.plot([i*shifts_per*20 for i in range(cc_max_divisions)], ceilings[:,i], )\n",
    "print(ceilings.shape)\n",
    "plt.plot(time_locs, np.mean(ceilings, axis=1), color=\"red\")\n",
    "plt.title(\"Average Excess $CC_{max}$\")\n",
    "plt.xlabel(\"Time After Word Onset (ms)\")\n",
    "plt.ylabel(\"Excess $CC_{max}$\")\n",
    "plt.xlim(0, 800)\n",
    "#plt.plot(np.array(meg_means)[:,100])\n",
    "#plt.plot(ceilings[:,14])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c1b01688",
   "metadata": {},
   "outputs": [],
   "source": [
    "max_divisions = 8\n",
    "ceilings_for_plot = make_ceilings_timeseries(max_divisions, 0.01)\n",
    "plot_shifts_per = 40//max_divisions\n",
    "\n",
    "print(ceilings_for_plot.shape)\n",
    "for i in range(len(ceilings_for_plot)):\n",
    "    helmet_plotter.plot(ceilings_for_plot[i,:] - baseline_cc_max, cmap=None, vlim=(-0.1, 0.1),\n",
    "                        title = f\"{i*plot_shifts_per*20}ms - {(i+1)*plot_shifts_per*20}ms Post-Word Onset \\n\"+ \"Baseline Excess $CC_{max}$\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "99c96317",
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "import mne\n",
    "\n",
    "# assume helmet_plotter is already instantiated, and:\n",
    "#   helmet_plotter.positions   ← your sensor (x,y) coords\n",
    "#   helmet_plotter.sphere_size ← your sphere size\n",
    "\n",
    "\n",
    "max_divisions = 8\n",
    "ceilings_for_plot = make_ceilings_timeseries(max_divisions, 0.01)\n",
    "plot_shifts_per  = 40 // max_divisions\n",
    "\n",
    "fig, axes = plt.subplots(\n",
    "    nrows=1, ncols=max_divisions,\n",
    "    figsize=(4*max_divisions, 4),\n",
    "    constrained_layout=False\n",
    ")\n",
    "\n",
    "# leave room for the suptitle\n",
    "fig.subplots_adjust(top=0.88)\n",
    "fig.suptitle(\"Excess $CC_{max}$ Over Time\", fontsize=20)\n",
    "\n",
    "# draw each topo into its own axis\n",
    "im = None\n",
    "for i, ax in enumerate(axes):\n",
    "    data = ceilings_for_plot[i, :] - baseline_cc_max\n",
    "    im, _ = mne.viz.plot_topomap(\n",
    "        data,\n",
    "        helmet_plotter.positions,\n",
    "        axes=ax,\n",
    "        show=False,\n",
    "        cmap=\"RdBu_r\",\n",
    "        vlim=(-0.1, 0.1),\n",
    "        outlines=\"head\",\n",
    "        sphere=helmet_plotter.sphere_size\n",
    "    )\n",
    "    start_ms = i * plot_shifts_per * 20\n",
    "    end_ms   = (i + 1) * plot_shifts_per * 20\n",
    "    ax.set_title(f\"{start_ms}–{end_ms} ms\", fontsize=20)\n",
    "    ax.set_xticks([]); ax.set_yticks([])\n",
    "\n",
    "# shared colorbar on the left\n",
    "# requires Matplotlib ≥3.3 for `location=\"left\"`\n",
    "cbar = fig.colorbar(\n",
    "    im,\n",
    "    ax=axes.tolist(),\n",
    "    orientation=\"vertical\",\n",
    "    fraction=0.02,\n",
    "    pad=0.04,\n",
    "    location=\"left\"\n",
    ")\n",
    "cbar.ax.yaxis.set_label_position('left')\n",
    "cbar.ax.yaxis.set_ticks_position('left')\n",
    "cbar.set_label(\"Excess $CC_{max}$\", fontsize=20)\n",
    "\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "400779b1",
   "metadata": {},
   "outputs": [],
   "source": [
    "U, sigma, V_T = np.linalg.svd(ceilings, full_matrices=False)\n",
    "plt.figure()\n",
    "plt.plot(np.cumsum(sigma**2)/sum(sigma**2))\n",
    "plt.figure()\n",
    "\n",
    "for i in range(8):\n",
    "    U_sign = np.sign(sum(U[:,i]))\n",
    "    #print((sigma[i]**2)/max(sigma)**2)\n",
    "    plt.plot(time_locs, U_sign*U[:,i], alpha=(sigma[i]**2)/max(sigma)**2, color=\"red\")\n",
    "    \n",
    "plt.figure()\n",
    "plt.title(\"Excess $CC_{max}$ PC Projections\")\n",
    "U_sign = np.sign(sum(U[:,0]))\n",
    "plt.plot(time_locs, sigma[0]*U_sign*U[:,0],color=\"purple\", label= \"PC 1\")\n",
    "U_sign = np.sign(sum(U[:,1]))\n",
    "plt.plot(time_locs, sigma[1]*U_sign*U[:,1], color=\"green\", label = \"PC 2\")\n",
    "plt.xlabel(\"Time After Word Onset (ms)\")\n",
    "plt.ylabel(\"Projection\")\n",
    "plt.xlim(0, 800)\n",
    "plt.legend()\n",
    "\n",
    "helmet_plotter.plot(U_sign*V_T[0,:], title=\"PC 1\", vlim=(-0.3, 0.3), colorbar_title=\"Weight\", cmap=\"BrBG_r\")\n",
    "helmet_plotter.plot(U_sign*V_T[1,:], title=\"PC 2\", vlim=(-0.3, 0.3), colorbar_title=\"Weight\", cmap=\"BrBG_r\")\n",
    "\n",
    "print()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fd9a9430",
   "metadata": {},
   "outputs": [],
   "source": [
    "component_i = 0\n",
    "plt.figure()\n",
    "plt.title(\"Excess $CC_{max}$\" + f\"Post Word Onset\\n Component {component_i+1}\")\n",
    "plt.plot(time_locs, U[:,component_i])\n",
    "plt.xlim(0,800)\n",
    "plt.xlabel(\"Time After Word Onset (ms)\")\n",
    "plt.ylabel(\"Component Weight\")\n",
    "for t in range(cc_max_divisions):\n",
    "    U_sign = np.sign(sum(U[:,0]))\n",
    "    helmet_plotter.plot(V_T[component_i,:]*U[t,component_i], vlim=(-0.3, 0.3),\n",
    "                        title = f\"{t*shifts_per*20}ms - {(t+1)*shifts_per*20}ms Post-Word Onset \\n\"+ \"Baseline Excess $CC_{max}$\" +f\" Component {component_i+1}\")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
