{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "===================================BUG REPORT===================================\n",
      "Welcome to bitsandbytes. For bug reports, please run\n",
      "\n",
      "python -m bitsandbytes\n",
      "\n",
      " and submit this information together with your error trace to: https://github.com/TimDettmers/bitsandbytes/issues\n",
      "================================================================================\n",
      "bin /home//venv/lib/python3.10/site-packages/bitsandbytes/libbitsandbytes_cpu.so\n",
      "CUDA_SETUP: WARNING! libcudart.so not found in any environmental path. Searching in backup paths...\n",
      "CUDA SETUP: Highest compute capability among GPUs detected: 8.0\n",
      "CUDA SETUP: Detected CUDA version 117\n",
      "CUDA SETUP: Loading binary /home//venv/lib/python3.10/site-packages/bitsandbytes/libbitsandbytes_cpu.so...\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home//venv/lib/python3.10/site-packages/bitsandbytes/cuda_setup/main.py:145: UserWarning: WARNING: The following directories listed in your path were found to be non-existent: {PosixPath('http'), PosixPath('//xen03.iitd.ac.in'), PosixPath('3128')}\n",
      "  warn(msg)\n",
      "/home//venv/lib/python3.10/site-packages/bitsandbytes/cuda_setup/main.py:145: UserWarning: WARNING: The following directories listed in your path were found to be non-existent: {PosixPath('vs/workbench/api/node/extensionHostProcess')}\n",
      "  warn(msg)\n",
      "/home//venv/lib/python3.10/site-packages/bitsandbytes/cuda_setup/main.py:145: UserWarning: WARNING: The following directories listed in your path were found to be non-existent: {PosixPath('module'), PosixPath('//matplotlib_inline.backend_inline')}\n",
      "  warn(msg)\n",
      "/home//venv/lib/python3.10/site-packages/bitsandbytes/cuda_setup/main.py:145: UserWarning: WARNING: The following directories listed in your path were found to be non-existent: {PosixPath('/usr/local/cuda/lib64')}\n",
      "  warn(msg)\n",
      "/home//venv/lib/python3.10/site-packages/bitsandbytes/cuda_setup/main.py:145: UserWarning: WARNING: No libcudart.so found! Install CUDA or the cudatoolkit package (anaconda)!\n",
      "  warn(msg)\n"
     ]
    }
   ],
   "source": [
    "from transformers import AutoModelForCausalLM, AutoTokenizer, LlamaTokenizerFast\n",
    "import torch\n",
    "import sys\n",
    "import os\n",
    "from peft import PeftModel, PeftConfig\n",
    "from peft import LoraConfig, TaskType, get_peft_model\n",
    "\n",
    "import transformer_lens\n",
    "import transformer_lens.utils as utils\n",
    "from transformer_lens.hook_points import (\n",
    "    HookedRootModule,\n",
    "    HookPoint,\n",
    ")  # Hooking utilities\n",
    "from transformer_lens import HookedTransformer, HookedTransformerConfig, FactoredMatrix, ActivationCache\n",
    "from transformers import LlamaForCausalLM, LlamaTokenizer\n",
    "\n",
    "\n",
    "sys.path.append('../../JS')\n",
    "from honest_llama import llama"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "b92298051627489fa95eb79cf68b4a88",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home//venv/lib/python3.10/site-packages/transformers/generation/configuration_utils.py:381: UserWarning: `do_sample` is set to `False`. However, `temperature` is set to `0.9` -- this flag is only used in sample-based generation modes. You should set `do_sample=True` or unset `temperature`. This was detected when initializing the generation config instance, which means the corresponding file may hold incorrect parameterization and should be fixed.\n",
      "  warnings.warn(\n",
      "/home//venv/lib/python3.10/site-packages/transformers/generation/configuration_utils.py:386: UserWarning: `do_sample` is set to `False`. However, `top_p` is set to `0.6` -- this flag is only used in sample-based generation modes. You should set `do_sample=True` or unset `top_p`. This was detected when initializing the generation config instance, which means the corresponding file may hold incorrect parameterization and should be fixed.\n",
      "  warnings.warn(\n"
     ]
    }
   ],
   "source": [
    "path_to_model = '/home/models//Llama-2-7b-hf'  # push to runtime argument\n",
    "path_to_tokenizer = '/home/models//Llama-2-7b-hf' # push to runtime argument\n",
    "# ../frugal_lms/saved_model/base_gptj/\n",
    "# %%\n",
    "device = 'cuda:0' # push it to runtime argument\n",
    "# device = args.device\n",
    "# mt = ModelAndTokenizer(path_to_model, device=device)\n",
    "# tokenizer = LlamaTokenizerFast.from_pretrained('hf-internal-testing/llama-tokenizer', is_fast=True)\n",
    "# model = AutoModelForCausalLM.from_pretrained(path_to_model, low_cpu_mem_usage=True).to(device)\n",
    "\n",
    "tokenizer = LlamaTokenizer.from_pretrained(path_to_tokenizer)\n",
    "model = llama.LLaMAForCausalLM.from_pretrained(path_to_model, low_cpu_mem_usage=True).to(device)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "import json"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "import json\n",
    "with open('data/2shot10k.json', 'r') as f:\n",
    "    data = json.load(f)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'question': 'Rompuses are sterpuses. Rompuses are zumpuses. Rompuses are not medium. Grimpuses are rompuses. Impuses are medium. Grimpuses are tumpuses. Grimpuses are shy. Each zumpus is temperate. Each tumpus is sunny. Each numpus is a lorpus. Numpuses are liquid. Sally is a grimpus. Sally is a numpus.', 'query': 'True or false: Sally is not medium.', 'answer': 'Sally is a grimpus. Grimpuses are rompuses. Sally is a rompus. Rompuses are not medium. Sally is not medium. True', 'label': 'True', 'tree': '(sterpus\\n  (rompus properties: not medium\\n    (grimpus properties: shy\\n    )\\n  )\\n)\\n(zumpus properties: temperate\\n  (rompus properties: not medium\\n    (grimpus properties: shy\\n    )\\n  )\\n)\\n(tumpus properties: sunny\\n  (grimpus properties: shy\\n  )\\n)\\n(lorpus\\n  (numpus properties: liquid\\n  )\\n)\\n'}\n"
     ]
    }
   ],
   "source": [
    "print(data[0])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pickle\n",
    "with open('data/pluralToNouns.pickle', 'rb') as handle:\n",
    "    pluralToNouns = pickle.load(handle)\n",
    "def preprocessing(word):\n",
    "    word = word.lower()\n",
    "    # print(word)\n",
    "    if(word in pluralToNouns):\n",
    "        word = pluralToNouns[word]\n",
    "    return word"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "def tokenize_sentence_with_parent(sentence, tokenizer):\n",
    "    tokenized_sentence = tokenizer.tokenize(sentence)\n",
    "    # print(tokenized_sentence)\n",
    "    # print(tokenized_sentence)\n",
    "    parent_words = []\n",
    "\n",
    "    i = 0\n",
    "    current_word = \"\"\n",
    "\n",
    "    for i in range(len(tokenized_sentence)):\n",
    "        if(tokenized_sentence[i] == '.'):\n",
    "            parent_words.append('.')\n",
    "        # print(tokenized_sentence[i].startswith(\"▁\"))\n",
    "        elif tokenized_sentence[i].startswith(\"▁\") and (i + 1 < len(tokenized_sentence)) and tokenized_sentence[i + 1].startswith(\"▁\"):\n",
    "            parent_words.append(tokenized_sentence[i].replace(\"▁\", \"\"))\n",
    "        elif(tokenized_sentence[i].startswith(\"▁\") and (i  == (len(tokenized_sentence) - 1))):\n",
    "            parent_words.append(tokenized_sentence[i].replace(\"▁\", \"\"))\n",
    "        elif(tokenized_sentence[i].startswith(\"▁\") == True and (i + 1 < len(tokenized_sentence)) and tokenized_sentence[i + 1].startswith(\"▁\") == False):\n",
    "            end = i +  1\n",
    "            while(end < len(tokenized_sentence)):\n",
    "                if(tokenized_sentence[end].startswith(\"▁\") == True or tokenized_sentence[end] == '.'):\n",
    "                    end -=  1\n",
    "                    break\n",
    "                end += 1\n",
    "            start = i\n",
    "            notPresent = False\n",
    "            while(start > 0):\n",
    "                if(tokenized_sentence[start] == '.'):\n",
    "                    start = i\n",
    "                    break\n",
    "                if(tokenized_sentence[start] == '▁not'):\n",
    "                    notPresent = True\n",
    "                    break\n",
    "                start -= 1\n",
    "            if(notPresent == False):\n",
    "                start = i\n",
    "\n",
    "            parent_word = \"\"\n",
    "            # print(end)\n",
    "            for j in range(start, end + 1):\n",
    "                if(j < len(tokenized_sentence) - 1):\n",
    "                    parent_word += tokenizer.convert_tokens_to_string([tokenized_sentence[j]])\n",
    "            # print(\"parent word: \", parent_word)\n",
    "            # print(start, end)\n",
    "            parent_words.append(parent_word)\n",
    "\n",
    "             \n",
    "        elif(tokenized_sentence[i].startswith(\"▁\") == False):\n",
    "            temp = []\n",
    "            start = i\n",
    "            # print(tokenized_sentence[i])\n",
    "            while(start > 0):\n",
    "                if(tokenized_sentence[start].startswith(\"▁\") == True or tokenized_sentence[start] == '.'):\n",
    "                    break\n",
    "                start -= 1\n",
    "            \n",
    "            start2 = i\n",
    "            notPresent = False\n",
    "            while(start2 > 0):\n",
    "                if(tokenized_sentence[start2] == '.'):\n",
    "                    start2 = i\n",
    "                    break\n",
    "                if(tokenized_sentence[start2] == '▁not'):\n",
    "                    notPresent = True\n",
    "                    break\n",
    "                start2 -= 1\n",
    "            if(notPresent == True):\n",
    "                start = start2\n",
    "            # print(\"start: \",tokenized_sentence[start])\n",
    "            end = i\n",
    "            while(end < len(tokenized_sentence)):\n",
    "                if(tokenized_sentence[end].startswith(\"▁\") == True or tokenized_sentence[end] == '.'):\n",
    "                    end -=  1\n",
    "                    break\n",
    "                end += 1\n",
    "            # print(\"end: \",tokenized_sentence[end])\n",
    "            parent_word = \"\"\n",
    "            for j in range(start, end + 1):\n",
    "                if(j < len(tokenized_sentence) - 1):\n",
    "                # print(tokenized_sentence[j])\n",
    "                    parent_word += tokenizer.convert_tokens_to_string([tokenized_sentence[j]])\n",
    "            # print(\"parent word: \", parent_word)\n",
    "            parent_words.append(parent_word)\n",
    "    for i in range(len(parent_words)):\n",
    "        parent_words[i] = preprocessing(parent_words[i])\n",
    "    # print(tokenized_sentence)\n",
    "    # print(parent_words)\n",
    "    return tokenized_sentence, parent_words"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [],
   "source": [
    "tokenized_sentence, parent_words = tokenize_sentence_with_parent(data[0]['question'], tokenizer)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "▁Rom rompus\n",
      "p rompus\n",
      "uses rompus\n",
      "▁are are\n",
      "▁ster sterpus\n",
      "p sterpus\n",
      "uses sterpus\n",
      ". .\n",
      "▁Rom rompus\n",
      "p rompus\n",
      "uses rompus\n",
      "▁are are\n",
      "▁zum zumpus\n",
      "p zumpus\n",
      "uses zumpus\n",
      ". .\n",
      "▁Rom rompus\n",
      "p rompus\n",
      "uses rompus\n",
      "▁are are\n",
      "▁not not\n",
      "▁medium notmedium\n",
      ". .\n",
      "▁Gr grimpus\n",
      "imp grimpus\n",
      "uses grimpus\n",
      "▁are are\n",
      "▁rom rompus\n",
      "p rompus\n",
      "uses rompus\n",
      ". .\n",
      "▁Imp impus\n",
      "uses impus\n",
      "▁are are\n",
      "▁medium medium\n",
      ". .\n",
      "▁Gr grimpus\n",
      "imp grimpus\n",
      "uses grimpus\n",
      "▁are are\n",
      "▁t tumpus\n",
      "ump tumpus\n",
      "uses tumpus\n",
      ". .\n",
      "▁Gr grimpus\n",
      "imp grimpus\n",
      "uses grimpus\n",
      "▁are are\n",
      "▁sh shy\n",
      "y shy\n",
      ". .\n",
      "▁Each each\n",
      "▁zum zumpus\n",
      "pus zumpus\n",
      "▁is is\n",
      "▁temper temperate\n",
      "ate temperate\n",
      ". .\n",
      "▁Each each\n",
      "▁t tumpus\n",
      "ump tumpus\n",
      "us tumpus\n",
      "▁is is\n",
      "▁sun sunny\n",
      "ny sunny\n",
      ". .\n",
      "▁Each each\n",
      "▁num numpus\n",
      "pus numpus\n",
      "▁is is\n",
      "▁a a\n",
      "▁l lorpus\n",
      "or lorpus\n",
      "pus lorpus\n",
      ". .\n",
      "▁N numpus\n",
      "ump numpus\n",
      "uses numpus\n",
      "▁are are\n",
      "▁liquid liquid\n",
      ". .\n",
      "▁S sally\n",
      "ally sally\n",
      "▁is is\n",
      "▁a a\n",
      "▁gr grimpus\n",
      "imp grimpus\n",
      "us grimpus\n",
      ". .\n",
      "▁S sally\n",
      "ally sally\n",
      "▁is is\n",
      "▁a a\n",
      "▁num numpus\n",
      "pus numpus\n",
      ". .\n"
     ]
    }
   ],
   "source": [
    "for i in range(len(parent_words)):\n",
    "    print(tokenized_sentence[i], parent_words[i])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [],
   "source": [
    "class TreeNode:\n",
    "    def __init__(self, name, children=None):\n",
    "        self.name = name\n",
    "        self.parentName = None\n",
    "        self.children = children if children else []\n",
    "\n",
    "def parseTreeString(tree_structure):\n",
    "\n",
    "    current_node = TreeNode('root')\n",
    "    topNode = current_node\n",
    "    # print(topNode.name)\n",
    "    def extractPropertyAndNode(line):\n",
    "        property_name = None\n",
    "        if('properties:' in line):\n",
    "            property_name = line.split(':')[-1].strip().replace(' ', '')\n",
    "            node_name = line.split('properties:')[0].strip().replace('(', '').replace(')', '')\n",
    "        else:\n",
    "            node_name = line.strip().replace('(', '').replace(')', '')\n",
    "        return node_name, property_name\n",
    "    # Split the tree structure into lines\n",
    "    lines = tree_structure.split(\"\\n\")\n",
    "    # print(len(lines))\n",
    "    for line in lines:\n",
    "        line = line.strip(\"()\")\n",
    "\n",
    "        node_name, property_name = extractPropertyAndNode(line)\n",
    "        # print(node_name, property_name)\n",
    "        if(node_name == ''):\n",
    "            continue\n",
    "        if(property_name is None):\n",
    "            property_name = ''\n",
    "\n",
    "        name_node = TreeNode(node_name)\n",
    "        property_node = TreeNode(property_name)\n",
    "\n",
    "        if(current_node is None):\n",
    "            current_node = name_node\n",
    "            if(property_name != ''):\n",
    "                current_node.children.append(property_node)\n",
    "        else:\n",
    "            current_node.children.append(name_node)\n",
    "            current_node = name_node\n",
    "            if(property_name != ''):\n",
    "                current_node.children.append(property_node)\n",
    "                \n",
    "    topNode = topNode.children[0]\n",
    "\n",
    "    return topNode"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [],
   "source": [
    "entry = 0\n",
    "tree_structure = data[entry]['tree']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [],
   "source": [
    "structures = tree_structure.split(')\\n(')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {},
   "outputs": [],
   "source": [
    "root_node = parseTreeString(structures[0])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(sterpus\n",
      "  (rompus properties: not medium\n",
      "    (grimpus properties: shy\n",
      "    )\n",
      "  )\n",
      "\n"
     ]
    }
   ],
   "source": [
    "print(structures[0])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 34,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'notmedium'"
      ]
     },
     "execution_count": 34,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "root_node.children[0].children[0].name"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 35,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_nodes(root):\n",
    "    nodes = []\n",
    "    stack = [root]\n",
    "    while stack:\n",
    "        node = stack.pop()\n",
    "        nodes.append(node)\n",
    "        stack.extend(node.children)\n",
    "    return nodes\n",
    "\n",
    "\n",
    "def fill_adjacency_matrix(node, adjacency_matrix, node_indices):\n",
    "    node_index = node_indices[node]\n",
    "    for child in node.children:\n",
    "        child_index = node_indices[child]\n",
    "        adjacency_matrix[node_index][child_index] = 1\n",
    "        fill_adjacency_matrix(child, adjacency_matrix, node_indices)\n",
    "\n",
    "def create_adjacency_matrix(root):\n",
    "    nodes = get_nodes(root)\n",
    "    num_nodes = len(nodes)\n",
    "    adjacency_matrix = [[0] * num_nodes for _ in range(num_nodes)]\n",
    "    node_indices = {node: index for index, node in enumerate(nodes)}\n",
    "    fill_adjacency_matrix(root, adjacency_matrix, node_indices)\n",
    "    return adjacency_matrix"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 36,
   "metadata": {},
   "outputs": [],
   "source": [
    "adjacency_matrix = create_adjacency_matrix(root_node)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 38,
   "metadata": {},
   "outputs": [],
   "source": [
    "nodes = get_nodes(root_node)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 41,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "sterpus\n",
      "rompus\n",
      "grimpus\n",
      "shy\n",
      "notmedium\n"
     ]
    }
   ],
   "source": [
    "node_name = []\n",
    "for node in nodes:\n",
    "    print(node.name)\n",
    "    node_name.append(node.name)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 42,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "['sterpus', 'rompus', 'grimpus', 'shy', 'notmedium']"
      ]
     },
     "execution_count": 42,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "node_name"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 43,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[[0, 1, 0, 0, 0],\n",
       " [0, 0, 1, 0, 1],\n",
       " [0, 0, 0, 1, 0],\n",
       " [0, 0, 0, 0, 0],\n",
       " [0, 0, 0, 0, 0]]"
      ]
     },
     "execution_count": 43,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "adjacency_matrix"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 88,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(sterpus\n",
      "  (rompus properties: not medium\n",
      "    (grimpus properties: shy\n",
      "    )\n",
      "  )\n",
      ")\n",
      "(zumpus properties: temperate\n",
      "  (rompus properties: not medium\n",
      "    (grimpus properties: shy\n",
      "    )\n",
      "  )\n",
      ")\n",
      "(tumpus properties: sunny\n",
      "  (grimpus properties: shy\n",
      "  )\n",
      ")\n",
      "(lorpus\n",
      "  (numpus properties: liquid\n",
      "  )\n",
      ")\n",
      "\n"
     ]
    }
   ],
   "source": [
    "print(tree_structure)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 44,
   "metadata": {},
   "outputs": [],
   "source": [
    "root_nodes = []\n",
    "adjacency_matrixs = []\n",
    "allNodes = []\n",
    "structures = tree_structure.split(')\\n(')\n",
    "for structure in structures:\n",
    "    root_node = parseTreeString(structure)\n",
    "    adjacency_matrix = create_adjacency_matrix(root_node)\n",
    "    allNode = get_nodes(root_node)\n",
    "    root_nodes.append(root_node)\n",
    "    adjacency_matrixs.append(adjacency_matrix)\n",
    "    allNodes.append(allNode)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 46,
   "metadata": {},
   "outputs": [],
   "source": [
    "tokenized_sentence, parent_words = tokenize_sentence_with_parent(data[entry]['question'], tokenizer)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 47,
   "metadata": {},
   "outputs": [],
   "source": [
    "def findIndexInMatrix(nodes, word, parent_word):\n",
    "    for i in range(len(nodes)):\n",
    "        if(nodes[i].name == parent_word):\n",
    "            return i\n",
    "    return None"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 72,
   "metadata": {},
   "outputs": [],
   "source": [
    "n = len(tokenized_sentence)\n",
    "matrix = [[0 for i in range(n)] for j in range(n)]\n",
    "for i in range(len(tokenized_sentence)):\n",
    "    for j in range(i,len(tokenized_sentence)):\n",
    "        if(i == j):\n",
    "            matrix[i][j] = 0\n",
    "        for count in range(len(root_nodes)):\n",
    "            nodes = allNodes[count]\n",
    "            firstIndex = findIndexInMatrix(nodes, tokenized_sentence[i], parent_words[i])\n",
    "            secondIndex = findIndexInMatrix(nodes, tokenized_sentence[j], parent_words[j])\n",
    "            if(firstIndex != None and secondIndex != None):\n",
    "                if(adjacency_matrixs[count][firstIndex][secondIndex] == 1 or adjacency_matrixs[count][secondIndex][firstIndex] == 1):\n",
    "                    # print(parent_words[i], parent_words[j])\n",
    "                    matrix[i][j] = 1\n",
    "                    matrix[j][i] = 1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 52,
   "metadata": {},
   "outputs": [],
   "source": [
    "allNodeEntities = []\n",
    "for nodes in allNodes:\n",
    "    for node in nodes:\n",
    "        allNodeEntities.append(node.name)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 53,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "['sterpus',\n",
       " 'rompus',\n",
       " 'grimpus',\n",
       " 'shy',\n",
       " 'notmedium',\n",
       " 'zumpus',\n",
       " 'rompus',\n",
       " 'grimpus',\n",
       " 'shy',\n",
       " 'notmedium',\n",
       " 'temperate',\n",
       " 'tumpus',\n",
       " 'grimpus',\n",
       " 'shy',\n",
       " 'sunny',\n",
       " 'lorpus',\n",
       " 'numpus',\n",
       " 'liquid']"
      ]
     },
     "execution_count": 53,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "allNodeEntities"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 56,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "96"
      ]
     },
     "execution_count": 56,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "len(matrix)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 60,
   "metadata": {},
   "outputs": [],
   "source": [
    "def getPositiveAndNegativePair(tokenized_sentence, parent_words, matrix, typeA, typeB):\n",
    "    positivePair = []\n",
    "    negativePair = []\n",
    "    start = 0\n",
    "    end = len(tokenized_sentence)\n",
    "    for i in range(start, end):\n",
    "        if(typeA == 'first'):\n",
    "            if(i >= 0 and parent_words[i - 1] == parent_words[i]):\n",
    "                continue\n",
    "        elif(typeA == 'last'):\n",
    "            if(i < (len(tokenized_sentence) - 1) and parent_words[i] == parent_words[i + 1]):\n",
    "                continue\n",
    "        j = i\n",
    "        while(j < len(tokenized_sentence)):\n",
    "            if(tokenized_sentence[j] == '.'):\n",
    "                break\n",
    "            if(typeB == 'last'):\n",
    "                if(matrix[i][j] == 1 and j < (len(tokenized_sentence) - 1) and matrix[i][j + 1] != 1):\n",
    "                    if('not' in parent_words[j]):\n",
    "                        # print(tokenized_sentence[i], tokenized_sentence[j], parent_words[i], parent_words[j])\n",
    "                        negativePair.append((i, j))\n",
    "                    else:\n",
    "                        # print(tokenized_sentence[i], tokenized_sentence[j], parent_words[i], parent_words[j])\n",
    "                        positivePair.append((i, j))\n",
    "                    break\n",
    "            else:\n",
    "                if(matrix[i][j] == 1):\n",
    "                    if('not' in parent_words[j]):\n",
    "                        # print(tokenized_sentence[i], tokenized_sentence[j], parent_words[i], parent_words[j])\n",
    "                        negativePair.append((i, j))\n",
    "                    else:\n",
    "                        # print(tokenized_sentence[i], tokenized_sentence[j], parent_words[i], parent_words[j])\n",
    "                        positivePair.append((i, j))\n",
    "                    if(typeB == 'first'):\n",
    "                        break\n",
    "            j += 1\n",
    "    return positivePair, negativePair"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 66,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "['rompus', 'rompus', 'rompus', 'are', 'sterpus', 'sterpus', 'sterpus', '.', 'rompus', 'rompus', 'rompus', 'are', 'zumpus', 'zumpus', 'zumpus', '.', 'rompus', 'rompus', 'rompus', 'are', 'not', 'notmedium', '.', 'grimpus', 'grimpus', 'grimpus', 'are', 'rompus', 'rompus', 'rompus', '.', 'impus', 'impus', 'are', 'medium', '.', 'grimpus', 'grimpus', 'grimpus', 'are', 'tumpus', 'tumpus', 'tumpus', '.', 'grimpus', 'grimpus', 'grimpus', 'are', 'shy', 'shy', '.', 'each', 'zumpus', 'zumpus', 'is', 'temperate', 'temperate', '.', 'each', 'tumpus', 'tumpus', 'tumpus', 'is', 'sunny', 'sunny', '.', 'each', 'numpus', 'numpus', 'is', 'a', 'lorpus', 'lorpus', 'lorpus', '.', 'numpus', 'numpus', 'numpus', 'are', 'liquid', '.', 'sally', 'sally', 'is', 'a', 'grimpus', 'grimpus', 'grimpus', '.', 'sally', 'sally', 'is', 'a', 'numpus', 'numpus', '.']\n",
      "['▁Rom', 'p', 'uses', '▁are', '▁ster', 'p', 'uses', '.', '▁Rom', 'p', 'uses', '▁are', '▁zum', 'p', 'uses', '.', '▁Rom', 'p', 'uses', '▁are', '▁not', '▁medium', '.', '▁Gr', 'imp', 'uses', '▁are', '▁rom', 'p', 'uses', '.', '▁Imp', 'uses', '▁are', '▁medium', '.', '▁Gr', 'imp', 'uses', '▁are', '▁t', 'ump', 'uses', '.', '▁Gr', 'imp', 'uses', '▁are', '▁sh', 'y', '.', '▁Each', '▁zum', 'pus', '▁is', '▁temper', 'ate', '.', '▁Each', '▁t', 'ump', 'us', '▁is', '▁sun', 'ny', '.', '▁Each', '▁num', 'pus', '▁is', '▁a', '▁l', 'or', 'pus', '.', '▁N', 'ump', 'uses', '▁are', '▁liquid', '.', '▁S', 'ally', '▁is', '▁a', '▁gr', 'imp', 'us', '.', '▁S', 'ally', '▁is', '▁a', '▁num', 'pus', '.']\n"
     ]
    }
   ],
   "source": [
    "print(parent_words)\n",
    "print(tokenized_sentence)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 62,
   "metadata": {},
   "outputs": [],
   "source": [
    "positivePair, negativePair = getPositiveAndNegativePair(tokenized_sentence, parent_words, matrix, 'first', 'first')\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 69,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'Rompuses are sterpuses. Rompuses are zumpuses. Rompuses are not medium. Grimpuses are rompuses. Impuses are medium. Grimpuses are tumpuses. Grimpuses are shy. Each zumpus is temperate. Each tumpus is sunny. Each numpus is a lorpus. Numpuses are liquid. Sally is a grimpus. Sally is a numpus.'"
      ]
     },
     "execution_count": 69,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "data[entry]['question']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 85,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'▁medium'"
      ]
     },
     "execution_count": 85,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "tokenized_sentence[34]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 86,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0"
      ]
     },
     "execution_count": 86,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "matrix[31][34]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 74,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(0, 4)\n",
      "rompus sterpus\n",
      "(8, 12)\n",
      "rompus zumpus\n",
      "(23, 27)\n",
      "grimpus rompus\n",
      "(36, 40)\n",
      "grimpus tumpus\n",
      "(44, 48)\n",
      "grimpus shy\n",
      "(52, 55)\n",
      "zumpus temperate\n",
      "(59, 63)\n",
      "tumpus sunny\n",
      "(67, 71)\n",
      "numpus lorpus\n",
      "(75, 79)\n",
      "numpus liquid\n"
     ]
    }
   ],
   "source": [
    "for pair in positivePair:\n",
    "    print(pair)\n",
    "    print(parent_words[pair[0]], parent_words[pair[1]])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 70,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "rompus notmedium\n"
     ]
    }
   ],
   "source": [
    "for pair in negativePair:\n",
    "    print(parent_words[pair[0]], parent_words[pair[1]])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 73,
   "metadata": {},
   "outputs": [],
   "source": [
    "def getUnRelatedPair(tokenized_sentence, parent_words, adjacency_matrix, allNodeEntities, typeA, typeB):\n",
    "    unrelatedPair = []\n",
    "    start = 0\n",
    "    end = len(tokenized_sentence)\n",
    "    if(typeA == 'first'):\n",
    "        start = 1\n",
    "    if(typeA == 'last'):\n",
    "        end = len(tokenized_sentence) - 1\n",
    "    for i in range(start, end):\n",
    "        if(typeA == 'first'):\n",
    "            if(parent_words[i - 1] == parent_words[i]):\n",
    "                continue\n",
    "        elif(typeA == 'last'):\n",
    "            if(parent_words[i] == parent_words[i + 1]):\n",
    "                continue\n",
    "        for j in range(i, len(tokenized_sentence)):\n",
    "            if(typeB == 'last'):\n",
    "                if(parent_words[i] != parent_words[j] and parent_words[i] in allNodeEntities and parent_words[j] in allNodeEntities and adjacency_matrix[i][j] == float('inf')\n",
    "                   and j < (len(tokenized_sentence) - 1) and parent_words[j] != parent_words[j + 1]):\n",
    "                    # print(tokenized_sentence[i], tokenized_sentence[j], parent_words[i], parent_words[j])\n",
    "                    unrelatedPair.append((i, j))\n",
    "                    break\n",
    "\n",
    "            else:\n",
    "                if(parent_words[i] != parent_words[j] and parent_words[i] in allNodeEntities and parent_words[j] in allNodeEntities and adjacency_matrix[i][j] == float('inf')):\n",
    "                    # print(tokenized_sentence[i], tokenized_sentence[j], parent_words[i], parent_words[j])\n",
    "                    unrelatedPair.append((i, j))\n",
    "                    if(typeB == 'first'):\n",
    "                        break\n",
    "    return unrelatedPair"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 89,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "def adjacency_to_distance(adj_matrix):\n",
    "    num_vertices = len(adj_matrix)\n",
    "    distance_matrix = np.full((num_vertices, num_vertices), float('inf'))\n",
    "\n",
    "    for i in range(num_vertices):\n",
    "        distance_matrix[i][i] = 0\n",
    "\n",
    "    for i in range(num_vertices):\n",
    "        for j in range(num_vertices):\n",
    "            if adj_matrix[i][j] != 0:\n",
    "                distance_matrix[i][j] = adj_matrix[i][j]\n",
    "\n",
    "    for k in range(num_vertices):\n",
    "        distance_matrix = np.minimum(distance_matrix, distance_matrix[:, k:k+1] + distance_matrix[k:k+1, :])\n",
    "\n",
    "    return distance_matrix.tolist()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 90,
   "metadata": {},
   "outputs": [],
   "source": [
    "distanceMatrix = adjacency_to_distance(matrix)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 91,
   "metadata": {},
   "outputs": [],
   "source": [
    "unrelatedPair = getUnRelatedPair(tokenized_sentence, parent_words, distanceMatrix, allNodeEntities, 'first', 'first')\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 93,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "sterpus numpus\n",
      "rompus numpus\n",
      "zumpus numpus\n",
      "rompus numpus\n",
      "notmedium numpus\n",
      "grimpus numpus\n",
      "rompus numpus\n",
      "grimpus numpus\n",
      "tumpus numpus\n",
      "grimpus numpus\n",
      "shy numpus\n",
      "zumpus numpus\n",
      "temperate numpus\n",
      "tumpus numpus\n",
      "sunny numpus\n",
      "numpus grimpus\n",
      "lorpus grimpus\n",
      "numpus grimpus\n",
      "liquid grimpus\n",
      "grimpus numpus\n"
     ]
    }
   ],
   "source": [
    "for pair in unrelatedPair:\n",
    "    print(parent_words[pair[0]], parent_words[pair[1]])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 95,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[(4, 67),\n",
       " (8, 67),\n",
       " (12, 67),\n",
       " (16, 67),\n",
       " (21, 67),\n",
       " (23, 67),\n",
       " (27, 67),\n",
       " (36, 67),\n",
       " (40, 67),\n",
       " (44, 67),\n",
       " (48, 67),\n",
       " (52, 67),\n",
       " (55, 67),\n",
       " (59, 67),\n",
       " (63, 67),\n",
       " (67, 85),\n",
       " (71, 85),\n",
       " (75, 85),\n",
       " (79, 85),\n",
       " (85, 93)]"
      ]
     },
     "execution_count": 95,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "unrelatedPair"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 98,
   "metadata": {},
   "outputs": [],
   "source": [
    "question = data[entry]['question'] + ' ' + data[entry]['query']\n",
    "tree_structure = data[entry]['tree']\n",
    "nShots = f\"{question} Let us think step by step.\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 99,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'Rompuses are sterpuses. Rompuses are zumpuses. Rompuses are not medium. Grimpuses are rompuses. Impuses are medium. Grimpuses are tumpuses. Grimpuses are shy. Each zumpus is temperate. Each tumpus is sunny. Each numpus is a lorpus. Numpuses are liquid. Sally is a grimpus. Sally is a numpus. True or false: Sally is not medium. Let us think step by step.'"
      ]
     },
     "execution_count": 99,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "nShots"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 100,
   "metadata": {},
   "outputs": [],
   "source": [
    "input_ids = tokenizer(nShots, return_tensors=\"pt\").input_ids.to(device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 102,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home//venv/lib/python3.10/site-packages/transformers/generation/utils.py:1473: UserWarning: You have modified the pretrained model configuration to control generation. This is a deprecated strategy to control generation and will be removed soon, in a future version. Please use and modify the model generation configuration (see https://huggingface.co/docs/transformers/generation_strategies#default-text-generation-configuration )\n",
      "  warnings.warn(\n"
     ]
    }
   ],
   "source": [
    "outputs = model.generate(input_ids = input_ids, output_hidden_states=True, return_dict_in_generate=True, max_new_tokens=1)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 103,
   "metadata": {},
   "outputs": [],
   "source": [
    "text = tokenizer.decode(outputs.sequences[0][len(input_ids[0]):].tolist(), clean_up_tokenization_spaces=True).strip()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 104,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Let\n"
     ]
    }
   ],
   "source": [
    "print(text)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 108,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "33"
      ]
     },
     "execution_count": 108,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "len(outputs.hidden_states[0])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "venv",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
