{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def string_similarity(str1, str2):\n",
    "    # Convert strings to sets of words\n",
    "    words1 = set(str1.lower().split())\n",
    "    words2 = set(str2.lower().split())\n",
    "    \n",
    "    # Intersection of words\n",
    "    intersection = words1.intersection(words2)\n",
    "    \n",
    "    # Union of words\n",
    "    union = words1.union(words2)\n",
    "    \n",
    "    # Calculate Jaccard similarity coefficient\n",
    "    similarity = len(intersection) / len(union)\n",
    "    \n",
    "    return similarity"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "deref_txt = []\n",
    "with open('sentences_dereferenced.txt', 'r') as f:\n",
    "    for sent in f:\n",
    "        deref_txt.append(sent.strip())\n",
    "        \n",
    "ordered_txt = []\n",
    "with open('sentences_ordered.txt', 'r') as f:\n",
    "    for sent in f:\n",
    "        ordered_txt.append(sent.strip())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "glove_384 = np.loadtxt('/home3/name/what-is-brainscore/glove_data/glove_384.txt')\n",
    "glove_243 = np.loadtxt('/home3/name/what-is-brainscore/glove_data/glove_243.txt')\n",
    "\n",
    "glove_all = np.vstack((glove_243, glove_384))\n",
    "glove_all.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# let's see how well this string similarity func works \n",
    "sent = deref_txt[0]\n",
    "sim_store = []\n",
    "for idx, s in enumerate(ordered_txt):\n",
    "    sim = string_similarity(sent, s)\n",
    "    sim_store.append(sim)\n",
    "print(sent, ordered_txt[np.argmax(sim_store)])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "rearrange_idxs = []\n",
    "for osent in ordered_txt:\n",
    "    sim_store = []\n",
    "    for dsent in deref_txt:\n",
    "        sim = string_similarity(dsent, osent)\n",
    "        sim_store.append(sim)\n",
    "    # store the index of the dereferenced sentence that best matches \n",
    "    # the ordered sent \n",
    "    rearrange_idxs.append(np.argmax(sim_store))\n",
    "    # if match is not exact, print to manually make sure the match is correct.\n",
    "    if np.max(sim_store) < 1:\n",
    "        #print(osent, deref_txt[np.argmax(sim_store)])\n",
    "        pass"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "glove_correct_order = glove_all[rearrange_idxs]\n",
    "glove_old = np.load('/home3/name/what-is-brainscore/data_processed/pereira/X_glove_content.npz')['layer1']\n",
    "np.array_equal(glove_correct_order, glove_old)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "deref_txt_ordered = list(np.array(deref_txt)[rearrange_idxs])\n",
    "with open('sentences_ordered_dereferenced.txt', \"w\") as file:\n",
    "            for item in deref_txt_ordered:\n",
    "                file.write(str(item) + \"\\n\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "deref_txt = []\n",
    "with open('sentences_ordered_dereferenced.txt', 'r') as f:\n",
    "    for sent in f:\n",
    "        deref_txt.append(sent.strip())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "narrative_idxs = [[8,9,10], [19, 20, 21], [22, 23, 24, 25], [36, 37, 38, 39], \n",
    "                  [50, 51, 52], [57, 58, 59], [63, 64, 65, 66], [73, 74, 75, 76], \n",
    "                  [83, 84, 85, 86], [96, 97, 98, 99], [107, 108, 109], [113, 114, 115, 116], \n",
    "                  [123, 124, 125, 126], [140, 141, 142, 143], [151, 152, 153], [161, 162, 163], \n",
    "                  [171, 172, 173], [174, 175, 176, 177], [184, 185, 186, 187], \n",
    "                  [197, 198, 199, 200], [205, 206, 207, 208], [222, 223, 224], \n",
    "                  [231, 232, 233], [234, 235, 236, 237]]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "data_labels = np.load('/home3/name/what-is-brainscore/temp_data_all/data_labels_pereira.npy')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "narrative_indicator = np.zeros(627)\n",
    "for ni in narrative_idxs:\n",
    "    ni = [n-1 for n in ni]\n",
    "    narrative_indicator[ni] = 1\n",
    "    #print(np.array(deref_txt)[ni])\n",
    "    if len(np.unique(data_labels[ni])) != 1:\n",
    "        print(\"NOOOOOOO\")\n",
    "    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "np.savez('/home3/name/what-is-brainscore/temp_data_all/temp_data_pereira/X_narrative', \n",
    "        **{'layer1': narrative_indicator})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "llama",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.4"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
