{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "76483d81",
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys\n",
    "import os\n",
    "\n",
    "parent_dir = os.path.abspath(os.path.join(os.getcwd(), '..'))\n",
    "if parent_dir not in sys.path:\n",
    "    sys.path.append(parent_dir)\n",
    "\n",
    "import os \n",
    "import torch\n",
    "from torch.utils.data import Dataset, DataLoader\n",
    "from torch.nn.utils.rnn import pad_sequence\n",
    "import pickle\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "import tqdm\n",
    "import torch.nn as nn\n",
    "import pytorch_lightning as pl\n",
    "import torch.nn.functional as F\n",
    "from transformers import BertTokenizer, BertModel\n",
    "from os.path import join as opj\n",
    "from himalaya.ridge import RidgeCV\n",
    "from himalaya.backend import set_backend\n",
    "from config import DATASET_FULL_TRIALS_ZSCORE\n",
    "from dataset import getDatasetLoaders_V3\n",
    "from encoding_utils import plot_channels_grid_fdr\n",
    "set_backend(\"torch_cuda\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1a9274e5",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "device  = \"cuda:1\"\n",
    "ROI = \"sm\""
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d53c5e7f",
   "metadata": {},
   "source": [
    "## Load the data\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bbc17fd1",
   "metadata": {},
   "outputs": [],
   "source": [
    "# #load the data\n",
    "\n",
    "train_loader, test_loader, _, loadedData = getDatasetLoaders_V3(DATASET_FULL_TRIALS_ZSCORE, 128, include_prego=True, roi=ROI)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d5fabd32",
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset = train_loader.dataset\n",
    "test_dataset = test_loader.dataset"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7cba56ed",
   "metadata": {},
   "source": [
    "## Create time-windows of data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "db636a4f",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "def create_fixed_time_windows(\n",
    "    neural_feats,\n",
    "    go_onsets,\n",
    "    n_before=50,\n",
    "    n_after=150,\n",
    "    window_size=4\n",
    "):\n",
    "    \"\"\"\n",
    "    For each trial, create 50 windows before go_onset and 150 windows after.\n",
    "    Each window is an average of 4 consecutive samples in the time dimension.\n",
    "    \n",
    "    Out-of-bounds samples are effectively zero-padded.\n",
    "    \n",
    "    Arguments:\n",
    "    ----------\n",
    "    neural_feats : list (or array) of length N\n",
    "        Each entry is a 2D array of shape (T_i, D).\n",
    "        T_i can vary, D is number of features/channels.\n",
    "    go_onsets : array-like of length N\n",
    "        The go_onset time index for each trial i.\n",
    "    n_before : int\n",
    "        Number of windows before the go onset.\n",
    "    n_after : int\n",
    "        Number of windows after the go onset.\n",
    "    window_size : int\n",
    "        Number of samples in each window to average.\n",
    "        \n",
    "    Returns:\n",
    "    --------\n",
    "    windowed_array : torch.Tensor of shape (N, n_before + n_after, D)\n",
    "        For each trial i, a (200, D) array (50 + 150 = 200 windows),\n",
    "        where each row is the average of 4 samples in that window.\n",
    "    \"\"\"\n",
    "    # Number of total windows\n",
    "    n_windows = n_before + n_after\n",
    "    N = len(neural_feats)\n",
    "    \n",
    "    # Determine feature dimensionality from the first trial\n",
    "    # (assuming they all have the same #channels, D)\n",
    "    example_feat = neural_feats[0]\n",
    "    _, D = example_feat.shape\n",
    "    \n",
    "    # Prepare output: (N, n_windows, D)\n",
    "    windowed_array = np.zeros((N, n_windows, D), dtype=np.float32)\n",
    "    \n",
    "    for i in tqdm.trange(N):\n",
    "        feat = neural_feats[i]        # shape (T_i, D)\n",
    "        T_i = feat.shape[0]\n",
    "        onset = go_onsets[i]         # an integer time index\n",
    "        \n",
    "        for w in range(n_windows):\n",
    "            # Where does this window start and end (in the time dimension)?\n",
    "            window_start = onset - (n_before * window_size) + (w * window_size)\n",
    "            window_end   = window_start + window_size  # exclusive\n",
    "            \n",
    "            # Clip to valid bounds [0, T_i]\n",
    "            # We'll gather the portion of data that is within the trial\n",
    "            valid_start = max(0, window_start)\n",
    "            valid_end   = min(T_i, window_end)\n",
    "\n",
    "            # If valid_start < valid_end, there's at least one valid timepoint\n",
    "            if valid_end > valid_start:\n",
    "                chunk = feat[valid_start:valid_end]  # shape (some_count, D)\n",
    "                # Average along the time axis\n",
    "                mean_chunk = chunk.mean(axis=0)  # shape (D,)\n",
    "            else:\n",
    "                # Entire window is out-of-bounds\n",
    "                # So we just keep zeros\n",
    "                mean_chunk = np.zeros(D, dtype=np.float32)\n",
    "            \n",
    "            windowed_array[i, w, :] = mean_chunk\n",
    "\n",
    "    return torch.from_numpy(windowed_array)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "811ecb7c",
   "metadata": {},
   "outputs": [],
   "source": [
    "# For training set:\n",
    "padded_neural_train = create_fixed_time_windows(\n",
    "    dataset.neural_feats,\n",
    "    dataset.go_onset,\n",
    "    n_before=50,\n",
    "    n_after=150,\n",
    "    window_size=4\n",
    ")\n",
    "print(\"padded_neural_train shape:\", padded_neural_train.shape)\n",
    "# -> (N_train, 200, D)\n",
    "\n",
    "# For testing set:\n",
    "padded_neural_test = create_fixed_time_windows(\n",
    "    test_dataset.neural_feats,\n",
    "    test_dataset.go_onset,\n",
    "    n_before=50,\n",
    "    n_after=150,\n",
    "    window_size=4\n",
    ")\n",
    "print(\"padded_neural_test shape:\", padded_neural_test.shape)\n",
    "# -> (N_test, 200, D)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9851a73c",
   "metadata": {},
   "source": [
    "## Extract semantic embeddings\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "57dc4659",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "train_senteces = dataset.sentences\n",
    "test_sentences = test_dataset.sentences"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f88d7b8d",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "processor = BertTokenizer.from_pretrained(\"bert-base-uncased\")\n",
    "embedding_model = BertModel.from_pretrained(\"bert-base-uncased\").to(device)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ef4a125f",
   "metadata": {},
   "outputs": [],
   "source": [
    "## extract embeddings for the training set\n",
    "train_embeddings = []\n",
    "test_embeddings = []\n",
    "\n",
    "\n",
    "BS = 64\n",
    "\n",
    "with torch.no_grad():\n",
    "    for i in tqdm.trange(0, len(train_senteces), BS):\n",
    "        batch = train_senteces[i:i+BS]\n",
    "        inputs = processor(batch, return_tensors=\"pt\", padding=True).to(device)\n",
    "        with torch.no_grad():\n",
    "            inputs = {k: v.to(device) for k, v in inputs.items()}\n",
    "\n",
    "            outputs = embedding_model(**inputs).pooler_output\n",
    "        train_embeddings.append(outputs.cpu())\n",
    "\n",
    "    train_embeddings = torch.cat(train_embeddings, dim=0)\n",
    "\n",
    "\n",
    "    ## extract embeddings for the test set\n",
    "    for i in tqdm.trange(0, len(test_sentences), BS):\n",
    "        batch = test_sentences[i:i+BS]\n",
    "        inputs = processor(batch, return_tensors=\"pt\", padding=True).to(device)\n",
    "        with torch.no_grad():\n",
    "            inputs = {k: v.to(device) for k, v in inputs.items()}\n",
    "\n",
    "            outputs = embedding_model(**inputs).pooler_output\n",
    "        test_embeddings.append(outputs.cpu())\n",
    "\n",
    "    test_embeddings = torch.cat(test_embeddings, dim=0)\n",
    "\n",
    "print(\"train_embeddings shape:\", train_embeddings.shape)\n",
    "# -> (N_train, 768)\n",
    "print(\"test_embeddings shape:\", test_embeddings.shape)\n",
    "# -> (N_test, 768)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bcbc80b7",
   "metadata": {},
   "outputs": [],
   "source": [
    "base_output_dir = \"encoding_semantic\"\n",
    "os.makedirs(base_output_dir,exist_ok=True)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b13c3312",
   "metadata": {},
   "source": [
    "## Train the encoding model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4c57cbbd",
   "metadata": {},
   "outputs": [],
   "source": [
    "TRAIN_SM_ENCODING = False\n",
    "\n",
    "\n",
    "if TRAIN_SM_ENCODING:\n",
    "\n",
    "    time_windows_sm_models=[]\n",
    "    time_windows_sm_corrs = []\n",
    "\n",
    "    for time_window in tqdm.trange(200):\n",
    "        encoding = RidgeCV(alphas = [1,10,1e2,1e3], ).fit(train_embeddings, padded_neural_train[:, time_window, :])\n",
    "        time_windows_sm_models.append(encoding)\n",
    "        pred = encoding.predict(test_embeddings)\n",
    "\n",
    "            ## measure channel-wise correlation\n",
    "        corrs = np.zeros(256)\n",
    "        for i in range(256):\n",
    "            corrs[i] = np.corrcoef(pred[:, i], padded_neural_test[:, time_window,i])[0, 1]\n",
    "        time_windows_sm_corrs.append(corrs)\n",
    "\n",
    "    time_windows_SM_corrs_array = np.array(time_windows_sm_corrs)\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7f1d4b77",
   "metadata": {},
   "outputs": [],
   "source": [
    "## for each time-window compute a null distribution\n",
    "if TRAIN_SM_ENCODING:\n",
    "    time_windows_SM_null_dist=[]\n",
    "    N = 100\n",
    "\n",
    "    for time_window in tqdm.trange(200):\n",
    "        null_dist = []\n",
    "\n",
    "        for i in range(N):\n",
    "            null_encoding = RidgeCV(alphas = [1,10,1e2,1e3], ).fit(train_embeddings, np.random.permutation(padded_SM_neural_train[:, time_window, :]))\n",
    "            null_pred = null_encoding.predict(test_embeddings)\n",
    "\n",
    "            ## measure channel-wise correlation\n",
    "            null_corrs = np.zeros(256)\n",
    "            for i in range(256):\n",
    "                null_corrs[i] = np.corrcoef(null_pred[:, i], padded_SM_neural_test[:, time_window,i])[0, 1]\n",
    "\n",
    "            null_dist.append(null_corrs)\n",
    "        null_dist = np.array(null_dist)\n",
    "        time_windows_SM_null_dist.append(null_dist)\n",
    "    time_windows_SM_null_dist_array = np.array(time_windows_SM_null_dist)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "add6e6c8",
   "metadata": {},
   "outputs": [],
   "source": [
    "# base_output_dir"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4af5e3c5",
   "metadata": {},
   "outputs": [],
   "source": [
    "## save all the models, the time_windows_null_dist and the time_windows_corrs\n",
    "import pickle\n",
    "\n",
    "if TRAIN_SM_ENCODING:\n",
    "\n",
    "    with open(opj(base_output_dir,\"time_windows_SM_models.pkl\"), \"wb\") as f:\n",
    "        pickle.dump(time_windows_sm_models, f)\n",
    "    with open(opj(base_output_dir,\"time_windows_SM_corrs.pkl\"), \"wb\") as f:\n",
    "        pickle.dump(time_windows_SM_corrs_array, f)\n",
    "\n",
    "    with open(opj(base_output_dir,\"time_windows_SM_null_dist.pkl\"), \"wb\") as f:\n",
    "        pickle.dump(time_windows_SM_null_dist, f)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2b3b0b17",
   "metadata": {},
   "outputs": [],
   "source": [
    "if not TRAIN_SM_ENCODING:\n",
    "\n",
    "    time_windows_sm_models= pickle.load(open(opj(base_output_dir,\"time_windows_SM_models.pkl\"), \"rb\"))\n",
    "    time_windows_SM_corrs_array = pickle.load(open(opj(base_output_dir,\"time_windows_SM_corrs.pkl\"), \"rb\"))\n",
    "    time_windows_SM_null_dist = pickle.load(open(opj(base_output_dir,\"time_windows_SM_null_dist.pkl\"), \"rb\"))\n",
    "    time_windows_SM_null_dist_array = np.array(time_windows_SM_null_dist)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7c505c06",
   "metadata": {},
   "source": [
    "### Visualization"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4447b185",
   "metadata": {},
   "outputs": [],
   "source": [
    "channel_idx = 130\n",
    "\n",
    "plt.plot(time_windows_SM_corrs_array[:, channel_idx],  color=\"tab:orange\", label=\"encoding\")\n",
    "plt.ylim(-0.3,0.3)\n",
    "plt.plot(time_windows_SM_null_dist_array[:,:,channel_idx].mean(-1), color=\"tab:blue\",label=\"null distribution\")\n",
    "\n",
    "#fill between the null distribution with std\n",
    "plt.fill_between(range(200),\n",
    "                  time_windows_SM_null_dist_array[:,:,channel_idx].mean(-1) - time_windows_SM_null_dist_array[:,:,240].std(-1), \n",
    "                  time_windows_SM_null_dist_array[:,:,channel_idx].mean(-1) + time_windows_SM_null_dist_array[:,:,240].std(-1),\n",
    "                  alpha=0.2, color=\"tab:blue\")\n",
    "\n",
    "plt.axvline(x=50, color='r', linestyle='--')\n",
    "plt.title(f\"Channel {channel_idx} correlation with the target\")\n",
    "plt.xlabel(\"Time window\")\n",
    "plt.ylabel(\"Correlation\")\n",
    "plt.legend()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "14065ffa",
   "metadata": {},
   "outputs": [],
   "source": [
    "for i in range(16):\n",
    "    plot_channels_grid_fdr(channels_range=range(i*16,(i+1)*16), figure_title=f\"Encoding - SM Channels {i*16} to {(i+1)*16} with FDR\", alpha_level=0.05,\n",
    "                           time_windows_corrs_array=time_windows_SM_corrs_array,\n",
    "                           time_windows_null_dist_array=time_windows_SM_null_dist_array, prefix_title = \"Semantic_SM\",\n",
    "                           out_folder=base_output_dir)\n",
    "\n",
    "\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "evo",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.18"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
