{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "405a29ec-670c-4906-a5ac-8128819b03f4",
   "metadata": {},
   "source": [
    "## Here I'm using word level embeddings\n",
    "\n",
    "I just want to embed the test story"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "9ad16ea6-d2a9-45d4-9eef-152aad000f00",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2024-08-21 10:09:12.275108: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.\n",
      "To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.\n",
      "2024-08-21 10:09:12.880020: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT\n"
     ]
    }
   ],
   "source": [
    "from SemanticModel import SemanticSentenceModel\n",
    "from matplotlib.pyplot import figure, cm\n",
    "import numpy as np\n",
    "import logging\n",
    "import tqdm\n",
    "from DataSequence import DataSequence\n",
    "logging.basicConfig(level=logging.DEBUG)\n",
    "from stimulus_utils import load_grids_for_stories\n",
    "from stimulus_utils import load_generic_trfiles\n",
    "from stimulus_utils import load_simulated_trfiles\n",
    "\n",
    "from dsutils import make_word_ds, make_phoneme_ds\n",
    "\n",
    "from joblib import Parallel, delayed\n",
    "from sklearn.linear_model import RidgeCV\n",
    "import numpy as np\n",
    "from os.path import join as opj\n",
    "import os\n",
    "import tables\n",
    "import json\n",
    "import h5py\n",
    "\n",
    "from os.path import join \n",
    "import transformers\n",
    "import torch\n",
    "from huggingface_hub import notebook_login, login\n",
    "import os\n",
    "import seaborn as sns\n",
    "import pickle\n",
    "\n",
    "import nibabel as nib\n",
    "from npp import zscore\n",
    "from transformers import AutoTokenizer, AutoModel\n",
    "\n",
    "os.environ[\"HF_TOKEN\"]=\"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\"\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "033897c4-b372-4203-a444-70633ae18533",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "sub S3\n"
     ]
    }
   ],
   "source": [
    "\n",
    "\n",
    "# sub=\"sub-03\"\n",
    "sub = \"S3\"\n",
    "\n",
    "encode_stories = True\n",
    "\n",
    "DATA_DIR = \"/home/matteo/tutorial_language_fmri/semantic-decoding/data_train\"\n",
    "EM_DATA_DIR=\"../deep-fMRI-dataset-master/em_data\"\n",
    "\n",
    "TEST_DATA_DIR = \"/home/matteo/tutorial_language_fmri/semantic-decoding/data_test\"\n",
    "\n",
    "context_window = 5\n",
    "\n",
    "print(\"sub\", sub)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "0dd0b5d0-aba4-491f-bc59-7a911d75c81d",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "\n",
    "def fit_and_predict(voxel_idx, X_train, z_train, X_test, z_test):\n",
    "    # Instantiate and fit the model\n",
    "    model = RidgeCV(alphas=[1e-3, 1e-2, 1e-1, 1, 10, 100,1e3], cv=5)\n",
    "    \n",
    "    # z_train = np.nan_to_num(z_train)\n",
    "    # X_train = np.nan_to_num(z_train)\n",
    "    \n",
    "    model.fit(np.nan_to_num(z_train), np.nan_to_num(X_train[:, voxel_idx]))\n",
    "    \n",
    "    # Predict on test data\n",
    "    y_pred = model.predict(np.nan_to_num(z_test))\n",
    "    corr = np.corrcoef(y_pred, X_test[:, voxel_idx])[0, 1]\n",
    "    # wandb.log({\"corr\":corr,\"voxel_idx\":voxel_idx})\n",
    "    \n",
    "    return model, corr\n",
    "\n",
    "def get_response(stories, subject, ):\n",
    "\t\"\"\"Get the subject\"s fMRI response for stories.\"\"\"\n",
    "\tsubject_dir = join(DATA_DIR, \"train_response/%s\" % subject)\n",
    "\t\n",
    "\t# main_path = pathlib.Path(__file__).parent.parent.resolve()\n",
    "\n",
    "\t# base = os.path.join(main_path, subject_dir)\n",
    "\tresp = []\n",
    "\tfor story in stories:\n",
    "\t\tresp_path = os.path.join(subject_dir, \"%s.hf5\" % story)\n",
    "\t\thf = h5py.File(resp_path, \"r\")\n",
    "\t\tresp.extend(hf[\"data\"][:])\n",
    "\t\thf.close()\n",
    "\treturn np.array(resp)\n",
    "\n",
    "def get_val_response(stories, subject):\n",
    "    \"\"\"Get the subject\"s fMRI response for stories.\"\"\"\n",
    "\n",
    "    subject_dir = opj(DATA_DIR.replace(\"train\",\"test\"),  \"test_response\",subject,\"perceived_speech\")\n",
    "\n",
    "    resp = []\n",
    "    for story in stories:\n",
    "        resp_path = os.path.join(subject_dir, \"%s.hf5\" % story)\n",
    "        hf = h5py.File(resp_path, \"r\")\n",
    "        resp.extend(hf[\"data\"][:])\n",
    "        hf.close()\n",
    "    return np.array(resp)\n",
    "\n",
    "\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "2c8439e6-e6a8-4910-b188-d47d638900f4",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "\n",
    "# semantic_model = SemanticSentenceModel(device=\"cuda:5\")\n",
    "\n",
    "\n",
    "test_stories = [\"wheretheressmoke\"]\n",
    "allstories = test_stories #for legacy\n",
    "Pstories = test_stories # for legacy\n",
    "# \n",
    "grids = load_grids_for_stories([\"wheretheressmoke\"],grid_dir=opj(DATA_DIR.replace(\"train\",\"test\"), \"test_stimulus\",\"perceived_speech\"))\n",
    "\n",
    "# Load TextGrids\n",
    "\n",
    "\n",
    "# Load TRfiles\n",
    "# trfiles = load_generic_trfiles(allstories,grid_dir=opj(DATA_DIR, \"ds003020/derivative/TextGrids\"))\n",
    "with open(join(DATA_DIR, \"ds003020/derivative/respdict.json\"), \"r\") as f:\n",
    "    respdict = json.load(f)\n",
    "trfiles = load_simulated_trfiles(respdict)\n",
    "# Make word and phoneme datasequences\n",
    "wordseqs = make_word_ds(grids, trfiles) # dictionary of {storyname : word DataSequence}\n",
    "phonseqs = make_phoneme_ds(grids, trfiles) # dictionary of {storyname : phoneme DataSequence}\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "f1a8c576",
   "metadata": {},
   "outputs": [],
   "source": [
    "# # print(\"Downsampling\")\n",
    "\n",
    "# # Downsample stimuli\n",
    "# interptype = \"lanczos\" # filter type\n",
    "# window = 3 # number of lobes in Lanczos filter\n",
    "\n",
    "# downsampled_semanticseqs = dict() # dictionary to hold downsampled stimuli\n",
    "# for story in tqdm.tqdm(allstories):\n",
    "#     downsampled_semanticseqs[story] = semanticseqs[story].chunksums(interptype, window=window)\n",
    "\n",
    "\n",
    "# print(\"Stack stimuli\")\n",
    "\n",
    "# # Combine stimuli\n",
    "# trim = 5\n",
    "# # Rstim = np.vstack([zscore(downsampled_semanticseqs[story][5+trim:-trim]) for story in Rstories])\n",
    "# Pstim = np.vstack([zscore(downsampled_semanticseqs[story][5+trim:-trim]) for story in Pstories])\n",
    "\n",
    "# # Print the sizes of these matrices\n",
    "# print (\"Pstim shape: \", Pstim.shape)\n",
    "\n",
    "# def make_delayed(stim, delays, circpad=False,stack=False):\n",
    "#     \"\"\"Creates non-interpolated concatenated delayed versions of [stim] with the given [delays] \n",
    "#     (in samples).\n",
    "    \n",
    "#     If [circpad], instead of being padded with zeros, [stim] will be circularly shifted.\n",
    "#     \"\"\"\n",
    "#     nt,ndim = stim.shape\n",
    "#     dstims = []\n",
    "#     for di,d in enumerate(delays):\n",
    "#         dstim = np.zeros((nt, ndim))\n",
    "#         if d<0: ## negative delay\n",
    "#             dstim[:d,:] = stim[-d:,:]\n",
    "#             if circpad:\n",
    "#                 dstim[d:,:] = stim[:-d,:]\n",
    "#         elif d>0:\n",
    "#             dstim[d:,:] = stim[:-d,:]\n",
    "#             if circpad:\n",
    "#                 dstim[:d,:] = stim[-d:,:]\n",
    "#         else: ## d==0\n",
    "#             dstim = stim.copy()\n",
    "#         dstims.append(dstim)\n",
    "#     if stack:\n",
    "#         return np.hstack(dstims)\n",
    "#     else:\n",
    "#         return np.array(dstims)\n",
    "\n",
    "\n",
    "#     # Delay stimuli\n",
    "# ndelays = 4\n",
    "# delays = range(1, ndelays+1)\n",
    "\n",
    "# print (\"FIR model delays: \", delays)\n",
    "\n",
    "# delPstim = make_delayed(Pstim, delays)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "6b999b58-e50a-40f6-853e-02bea9046359",
   "metadata": {},
   "outputs": [],
   "source": [
    "class ContextualSemanticModel:\n",
    "    \"\"\"This class defines a semantic vector-space model using a pre-trained language model\n",
    "    to obtain contextual word embeddings.\n",
    "\n",
    "    It contains two important variables: vocab and data.\n",
    "    vocab is a 1D list (or array) of words.\n",
    "    data is a 2D array (features by words) of word-feature values.\n",
    "    \"\"\"\n",
    "    def __init__(self, model_name: str,  context_window: int = 5, device =\"cuda:0\",hook_idx=None):\n",
    "        \"\"\"Initializes a ContextualSemanticModel with the given model name and vocabulary.\"\"\"\n",
    "        \n",
    "        self.context_window = context_window\n",
    "        self.device = device\n",
    "        self.tokenizer = AutoTokenizer.from_pretrained(model_name)\n",
    "        self.model = AutoModel.from_pretrained(model_name).to(device)\n",
    "        self.hook_idx = hook_idx\n",
    "        \n",
    "        if self.hook_idx is not None:\n",
    "            self.outputs={}\n",
    "            \n",
    "            print(f\"Hooking the model with output at layer {hook_idx}\")\n",
    "        # Hook to capture the intermediate layer output\n",
    "        def hook(module, input, output):\n",
    "            self.outputs[\"layer_output\"] = output\n",
    "\n",
    "        # Register hook on the specific layer\n",
    "        layer = self.model.layers[self.hook_idx]\n",
    "        layer.register_forward_hook(hook)\n",
    "\n",
    "\n",
    "\n",
    "    def get_word_embedding(self, word: str, context: list[str] = None, method: str = 'last') -> np.ndarray:\n",
    "        \n",
    "        with torch.no_grad():\n",
    "            \"\"\"Returns the embedding of the given word within the optional context.\"\"\"\n",
    "            if context is None:\n",
    "                context = []\n",
    "\n",
    "            # Construct the input text\n",
    "            context_text = ' '.join(context[-self.context_window:]) + ' ' + word\n",
    "            inputs = self.tokenizer(context_text, return_tensors='pt')\n",
    "            input_ids = inputs.input_ids.to(self.device)\n",
    "            outputs = self.model(input_ids,output_attentions=True)\n",
    "\n",
    "            # Get the embeddings for all tokens\n",
    "            word_embedding = outputs.last_hidden_state[0].detach().cpu().numpy()\n",
    "            \n",
    "            if self.hook_idx is not None:\n",
    "                # print(\"DBG, was\", word_embedding.shape, word_embedding)\n",
    "                word_embedding  = self.outputs[\"layer_output\"][0].squeeze().detach().cpu().numpy()\n",
    "                # print(\"DBG, now\", word_embedding.shape, word_embedding)\n",
    "\n",
    "\n",
    "        # word_embedding = np.array(token_embeddings)\n",
    "        if method == 'mean':\n",
    "            word_embedding = word_embedding.mean(axis=0)\n",
    "        elif method == 'sum':\n",
    "            word_embedding = word_embedding.sum(axis=0)\n",
    "        elif method == 'concat':\n",
    "            word_embedding = word_embedding.flatten()\n",
    "        elif method == 'weighted_sum':\n",
    "\n",
    "\n",
    "            # Use attention weights for weighted sum\n",
    "            attention_weights = outputs.attentions[-1][0, :, :, :].detach().cpu().numpy()\n",
    "            # Average over heads\n",
    "            word_weights = attention_weights.mean(axis=0)\n",
    "            # Average over tokens\n",
    "            word_weights = word_weights.mean(axis=0) \n",
    "            word_embedding = np.average(word_embedding, axis=0, weights=attention_weights)\n",
    "        elif method == 'last':\n",
    "            word_embedding = word_embedding[-1]\n",
    "        else:\n",
    "            raise ValueError(f\"Unknown aggregation method: {method}\")\n",
    "\n",
    "        \n",
    "        return word_embedding\n",
    "\n",
    "\n",
    "    def __getitem__(self, word: str) -> np.ndarray:\n",
    "        \"\"\"Returns the vector corresponding to the given [word].\"\"\"\n",
    "        return self.data[:, self.vindex[word]]\n",
    "\n",
    "    def similarity(self, word1: str, context1: list[str], word2: str, context2: list[str], method: str = 'last') -> float:\n",
    "        \"\"\"Returns the cosine similarity between the vectors for [word1] and [word2] given their contexts.\"\"\"\n",
    "        vec1 = self.get_word_embedding(word1, context1, method)\n",
    "        vec2 = self.get_word_embedding(word2, context2, method)\n",
    "        return np.dot(vec1, vec2) / (np.linalg.norm(vec1) * np.linalg.norm(vec2))\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "0d1b3344-4596-4ef5-9544-31b047f32dc2",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "DEBUG:urllib3.connectionpool:Starting new HTTPS connection (1): huggingface.co:443\n",
      "DEBUG:urllib3.connectionpool:https://huggingface.co:443 \"HEAD /meta-llama/Meta-Llama-3-8B/resolve/main/tokenizer_config.json HTTP/1.1\" 403 0\n",
      "Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.\n",
      "DEBUG:urllib3.connectionpool:https://huggingface.co:443 \"HEAD /meta-llama/Meta-Llama-3-8B/resolve/main/config.json HTTP/1.1\" 403 0\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "e66e292a78c64496a5e7c9369d0dd177",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Hooking the model with output at layer 13\n"
     ]
    }
   ],
   "source": [
    "import logging\n",
    "\n",
    "\n",
    "model_name= \"meta-llama/Meta-Llama-3-8B\"\n",
    "# Initialize the ContextualSemanticModel\n",
    "semantic_model = ContextualSemanticModel(model_name=model_name, device = \"cuda:2\", context_window = context_window,hook_idx=13)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "88e121df-f5a1-499e-91ef-c09ce5a20c32",
   "metadata": {},
   "source": [
    "## Load stories and stimuli"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "2fce3469-8abd-4c82-812c-d04a63c664d1",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "['wheretheressmoke']"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "##here only focus on the first test story\n",
    "\n",
    "Pstories = test_stories\n",
    "\n",
    "# valstim = np.array(valstim)\n",
    "Pstories"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9f8066f1-eca4-4e54-b0f8-985ecd9d4510",
   "metadata": {},
   "source": [
    "### Load textgrids"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "686a865a-c79c-4ede-8dfa-8dc437abc9a4",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "# Load TRfiles\n",
    "# trfiles = load_generic_trfiles(allstories,grid_dir=opj(DATA_DIR, \"ds003020/derivative/TextGrids\"))\n",
    "with open(join(DATA_DIR, \"ds003020/derivative/respdict.json\"), \"r\") as f:\n",
    "    respdict = json.load(f)\n",
    "trfiles = load_simulated_trfiles(respdict)\n",
    "# Make word and phoneme datasequences\n",
    "wordseqs = make_word_ds(grids, trfiles) # dictionary of {storyname : word DataSequence}\n",
    "phonseqs = make_phoneme_ds(grids, trfiles) # dictionary of {storyname : phoneme DataSequence}\n",
    "\n",
    "## Create the semantic represenataion of the stories\n",
    "\n",
    "# Project stimuli\n",
    "semanticseqs = dict() # dictionary to hold projected stimuli {story name : projected DataSequence}\n",
    "sentence_semanticseqs = dict()\n",
    "sentences= dict()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "1c0a1bf2-1248-4e7b-ba1e-b235fe73cf6d",
   "metadata": {},
   "outputs": [],
   "source": [
    "def make_semantic_model(dataseq_story,semantic_model):\n",
    "    \n",
    "    embeddings = []\n",
    "    sentences=[]\n",
    "\n",
    "    for i in tqdm.trange(len(dataseq_story.data)):\n",
    "        word = dataseq_story.data[i]\n",
    "        context = dataseq_story.data[i-context_window:i]\n",
    "        sentences.append(\" \".join(context)+\" \"+word)\n",
    "        \n",
    "        #actually compute embeddings\n",
    "        embedding = semantic_model.get_word_embedding(word,context)\n",
    "        embeddings.append(embedding)\n",
    "    \n",
    "        \n",
    "    return DataSequence(np.stack(embeddings), dataseq_story.split_inds, dataseq_story.data_times, dataseq_story.tr_times), sentences"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "180e7bec-e545-4bcb-ba12-4199f82d6e21",
   "metadata": {},
   "outputs": [],
   "source": [
    "# story= allstories[0]\n",
    "# wordseqs[story].data\n",
    "\n",
    "# ds, sent= make_semantic_model(wordseqs[story],semantic_model)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "4924d30a-d060-4c8d-8213-05f7a9d046ab",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Running wheretheressmoke, 1/1\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████████████████████████████████| 1839/1839 [01:10<00:00, 26.05it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Done encoding stories, saving to disk\n",
      "Dictionaries have been saved in data_encoded/S3\n"
     ]
    }
   ],
   "source": [
    "tgt_dir=f\"data_encoded/{sub}\"\n",
    "os.makedirs(tgt_dir,exist_ok=True)\n",
    "\n",
    "if encode_stories:\n",
    "\n",
    "    # Project stimuli\n",
    "    semanticseqs = dict() # dictionary to hold projected stimuli {story name : projected DataSequence}\n",
    "    for i,story in enumerate(test_stories):\n",
    "        print(f\"Running {story}, {i+1}/{len(test_stories)}\")\n",
    "        semanticseqs[story],sent = make_semantic_model(wordseqs[story],semantic_model)\n",
    "        sentences[story] = sent\n",
    "\n",
    "    print(\"Done encoding stories, saving to disk\")\n",
    "    # Save the dictionaries as pickle files\n",
    "    with open(os.path.join(tgt_dir, 'semanticseqs_heldout.pkl'), 'wb') as f:\n",
    "        pickle.dump(semanticseqs, f)\n",
    "\n",
    "    with open(os.path.join(tgt_dir, 'sentences_heldout.pkl'), 'wb') as f:\n",
    "        pickle.dump(sentences, f)\n",
    "\n",
    "    print(f'Dictionaries have been saved in {tgt_dir}')\n",
    "else:\n",
    "    with open(os.path.join(tgt_dir, 'semanticseqs_heldout.pkl'), 'rb') as f:\n",
    "        semanticseqs = pickle.load(f)\n",
    "\n",
    "    with open(os.path.join(tgt_dir, 'sentences_heldout.pkl'), 'rb') as f:\n",
    "        sentences = pickle.load(f)\n",
    "\n",
    "    print(f'Dictionaries have been loaded from {tgt_dir}')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2b1fa865-4b03-469a-879f-96c098fd94f1",
   "metadata": {},
   "source": [
    "## STOP HERE"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "huggingface",
   "language": "python",
   "name": "huggingface"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
