{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "\n",
    "# This notebook is used to create ICCR for in-context COT data that our baseline replaces. We produce a .pkl file with 200000 length. \n",
    "\n",
    "\n",
    "\n",
    "## ICCKA format\n",
    "\n",
    "5 Basic Questions that use same number of digits or less and same operations as gold_answer/CoT. \n",
    "\n",
    "3 SVAMP questions most similar to queires but from same operators. In case two operators are used, then tie is broken somehow?\n",
    "\n",
    "3 AsDiV questions  most similar to queires from same operators\n",
    "\n",
    "2 MathQA questions most similar to queries using same operators (Filter in this case and just leave one be?)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "import os\n",
    "import json\n",
    "import xmltodict\n",
    "import random\n",
    "import pickle\n",
    "import re\n",
    "from itertools import permutations, combinations\n",
    "from sentence_transformers import SentenceTransformer, util\n",
    "\n",
    "import faiss\n",
    "import numpy as np\n",
    "\n",
    "# Creating the four lists for each arithmetic operation\n",
    "addition = [\n",
    "    \"What is the sum of <<num1>> and <<num2>>?\",\n",
    "    \"By adding <<num1>> to <<num2>>, what number do you get?\",\n",
    "    \"If you have <<num1>> and you add <<num2>> more, how many will you have in total?\",\n",
    "    \"Combine the numbers <<num1>> and <<num2>>. What's the total?\",\n",
    "    \"Add up <<num1>> and <<num2>>. What's the answer?\",\n",
    "    \"How much is <<num1>> added to <<num2>>?\",\n",
    "    \"If I have <<num1>> apples and I get <<num2>> more, how many apples do I have now?\",\n",
    "    \"Calculate the total of <<num1>> and <<num2>>.\",\n",
    "    \"What will you have if you join <<num1>> and <<num2>>?\",\n",
    "    \"What's the result when you sum up <<num1>> and <<num2>>?\"\n",
    "]\n",
    "subtraction = [\n",
    "    \"What is the difference between <<num1>> and <<num2>>?\",\n",
    "    \"Subtract <<num2>> from <<num1>>. What do you get?\",\n",
    "    \"If you have <<num1>> and you take away <<num2>>, how many are left?\",\n",
    "    \"Deduct <<num2>> from <<num1>>. What remains?\",\n",
    "    \"What's left when you remove <<num2>> from <<num1>>?\",\n",
    "    \"How much is <<num1>> minus <<num2>>?\",\n",
    "    \"If I eat <<num2>> out of my <<num1>> cookies, how many cookies do I have remaining?\",\n",
    "    \"Calculate what remains after subtracting <<num2>> from <<num1>>.\",\n",
    "    \"What do you get when you take away <<num2>> from <<num1>>?\",\n",
    "    \"After removing <<num2>> from <<num1>>, what's left?\"\n",
    "]\n",
    "\n",
    "multiplication = [\n",
    "    \"What is the product of <<num1>> and <<num2>>?\",\n",
    "    \"Multiply <<num1>> by <<num2>>. What's the result?\",\n",
    "    \"If you have <<num1>> boxes with <<num2>> candies each, how many candies do you have in total?\",\n",
    "    \"What do you get when you multiply <<num1>> by <<num2>>?\",\n",
    "    \"What's <<num1>> groups of <<num2>>?\",\n",
    "    \"How much is <<num1>> repeated <<num2>> times?\",\n",
    "    \"If one row has <<num1>> apples and there are <<num2>> such rows, how many apples are there in total?\",\n",
    "    \"Calculate the result of multiplying <<num1>> by <<num2>>.\",\n",
    "    \"What will you have when you multiply <<num1>> by <<num2>>?\",\n",
    "    \"What's the outcome of multiplying <<num1>> with <<num2>>?\"\n",
    "]\n",
    "\n",
    "division = [\n",
    "    \"What is the result of dividing <<num1>> by <<num2>>?\",\n",
    "    \"Divide <<num1>> by <<num2>>. What's the quotient?\",\n",
    "    \"If you have <<num1>> candies and you divide them among <<num2>> friends, how many candies does each friend get?\",\n",
    "    \"How many times does <<num2>> go into <<num1>>?\",\n",
    "    \"What's the answer when you split <<num1>> into groups of <<num2>>?\",\n",
    "    \"How much is <<num1>> divided by <<num2>>?\",\n",
    "    \"If I have <<num1>> cookies and I share them equally among <<num2>> friends, how many cookies does each friend receive?\",\n",
    "    \"Calculate the result of dividing <<num1>> by <<num2>>.\",\n",
    "    \"What do you get when you divide <<num1>> by <<num2>>?\",\n",
    "    \"After dividing <<num1>> by <<num2>>, what's the quotient?\"\n",
    "]\n",
    "\n",
    "percentage = [\n",
    "    \"How much is <<num2>>% of <<num1>>?\",\n",
    "    \"Find <<num2>>% of <<num1>>.\",\n",
    "    \"If you were to evaluate <<num2>>% of <<num1>>, what would the result be?\",\n",
    "    \"Assess <<num2>>% of <<num1>>.\",\n",
    "    \"Compute <<num2>>% of <<num1>>.\",\n",
    "    \"What number represents <<num2>>% of <<num1>>?\",\n",
    "    \"If <<num2>>% corresponds to a portion of <<num1>>, how much is that?\",\n",
    "    \"Estimate <<num2>>% of <<num1>>.\",\n",
    "    \"What would be the outcome of taking <<num2>>% of <<num1>>?\",\n",
    "    \"Deduce <<num2>>% of <<num1>>.\"\n",
    "]\n",
    "\n",
    "def get_three_lists(list_of_lists):\n",
    "    \n",
    "    if len(list_of_lists)==1:\n",
    "        assert len(list_of_lists[0])==36\n",
    "        list1 = list_of_lists[0][0::3]\n",
    "        list2 = list_of_lists[0][1::3]\n",
    "        list3 = list_of_lists[0][2::3]\n",
    "        \n",
    "        return list1, list2, list3\n",
    "    \n",
    "    elif len(list_of_lists)==2:\n",
    "        list3 = [val for pair in zip(list_of_lists[0][0::3], list_of_lists[1][0::3]) for val in pair]\n",
    "        del list_of_lists[0][0::3]\n",
    "        del list_of_lists[1][0::3]\n",
    "        return list_of_lists[0],list_of_lists[1],list3\n",
    "    \n",
    "    elif len(list_of_lists)==3: \n",
    "        return list_of_lists[0], list_of_lists[1], list_of_lists[2]\n",
    "    \n",
    "    elif len(list_of_lists)==4:\n",
    "        group1 = list_of_lists[0][:3]\n",
    "        group2 = list_of_lists[0][3:6]\n",
    "        group3 = list_of_lists[0][6:]\n",
    "        \n",
    "        # Attach the groups to the other lists\n",
    "        list_of_lists[1] = group1 + list_of_lists[1]\n",
    "        list_of_lists[2] = list_of_lists[2][:4] + group2 + list_of_lists[2][4:]\n",
    "        list_of_lists[3] = list_of_lists[3] + group3\n",
    "        \n",
    "        return list_of_lists[1],list_of_lists[2],list_of_lists[3]\n",
    "        \n",
    "    else:\n",
    "        assert\"Some error, lists are more than 4 or less than 1\"\n",
    "\n",
    "def key_with_value_double_of_others(data):\n",
    "    total_sum = sum(data.values())\n",
    "    for key, value in data.items():\n",
    "        if value > 2 * (total_sum - value):\n",
    "            return key\n",
    "    return None\n",
    "\n",
    "def group_by_type_from_list(data_list,cat='Type'):\n",
    "    grouped = {}\n",
    "    \n",
    "    for item in data_list:\n",
    "\n",
    "        type_value = item[cat]\n",
    "        \n",
    "        if type_value not in grouped:\n",
    "            grouped[type_value] = []\n",
    "        \n",
    "        grouped[type_value].append(item)\n",
    "    \n",
    "    return grouped\n",
    "\n",
    "def generate_random_questions(types, frequency=40):\n",
    "    \"\"\"\n",
    "    types: Subtraction, Addition, Multiplication, Division, Percentage\n",
    "    \"\"\"\n",
    "    questions_list = []\n",
    "    \n",
    "    for i in range(frequency):\n",
    "        \n",
    "        questions = []\n",
    "        total=4\n",
    "        while(total>0):\n",
    "        \n",
    "            if types['Percentage']>0:\n",
    "                # Choose a random percentage question\n",
    "                percentage_question = random.choice(percentage)\n",
    "                num2 = 10*random.randint(1, 10)\n",
    "                num1 = 10*random.randint(1, 20)# Ensuring division results in an integer\n",
    "                questions.append((percentage_question.replace(\"<<num1>>\", str(num1)).replace(\"<<num2>>\", str(num2)),\"\\n \"+str(num2)+'*'+str(num1)+\"/100=\"+str(int(num2*num1/100))+\". The answer is \"+str(int(num2*num1/100))))\n",
    "                total=total-1\n",
    "\n",
    "            if  True:    \n",
    "                # Choose a random division question\n",
    "                division_question = random.choice(division)\n",
    "                num2 = random.randint(1, 12)\n",
    "                quotient = random.randint(1,12)\n",
    "                num1 = num2 * quotient  # Ensuring division results in an integer\n",
    "                questions.append((division_question.replace(\"<<num1>>\", str(num1)).replace(\"<<num2>>\", str(num2)), \"\\n \"+str(num1)+\"/\"+str(num2)+\"=\"+str(quotient)+\". The answer is \"+str(quotient)))\n",
    "\n",
    "                total=total-1\n",
    "\n",
    "                \n",
    "            if True:\n",
    "            \n",
    "                # Choose a random multiplication question\n",
    "                multiplication_question = random.choice(multiplication)\n",
    "                num1 = random.randint(1, 12)\n",
    "                num2 = random.randint(1, 12)\n",
    "                questions.append((multiplication_question.replace(\"<<num1>>\", str(num1)).replace(\"<<num2>>\", str(num2)),\"\\n \"+str(num1)+\"*\"+str(num2)+\"=\"+str(int(num1*num2))+\". The answer is \"+str((int(num1*num2)))))\n",
    "                total=total-1\n",
    "\n",
    "            if True:\n",
    "                # Choose a random subtraction question\n",
    "                subtraction_question = random.choice(subtraction)\n",
    "                num1 = random.randint(1, 99)\n",
    "                num2 = random.randint(1, num1)  # Ensuring the result is a positive integer\n",
    "                questions.append((subtraction_question.replace(\"<<num1>>\", str(num1)).replace(\"<<num2>>\", str(num2)),\"\\n \"+str(num1)+\"-\"+str(num2)+\"=\"+str(num1 - num2)+\". The answer is \"+str(num1-num2)))\n",
    "                total=total-1\n",
    " \n",
    "            if total>0:\n",
    "                # Choose a random addition question\n",
    "                addition_question = random.choice(addition)\n",
    "                num1 = random.randint(1, 99)\n",
    "                num2 = random.randint(1, 99)\n",
    "                questions.append((addition_question.replace(\"<<num1>>\", str(num1)).replace(\"<<num2>>\", str(num2)), \"\\n \"+str(num1)+\"+\"+str(num2)+\"=\"+str(num1 + num2)+\". The answer is \"+str(num1+num2)))\n",
    "                total=total-1\n",
    "\n",
    "        # Format the questions and answers\n",
    "        #print(questions)\n",
    "        #need to reverse them\n",
    "        questions.reverse()\n",
    "        formatted_questions = \"\\n\".join([f\"Q: {q[0]} {q[1]}\" for q in questions])\n",
    "        questions_list.append(formatted_questions)\n",
    "        \n",
    "    return questions_list\n",
    "\n",
    "def trim_context(context, query):\n",
    "    total_length = len(context) + len(query) + len(\"Let's think step by step\")\n",
    "    if total_length > 2500:\n",
    "        # Split the context into individual examples\n",
    "        examples = context.split('Q:')\n",
    "        # We assume that the examples are ordered by priority (MathQA > SVAMP > ASDiv)\n",
    "        i=0\n",
    "        while total_length > 2500 and examples:\n",
    "            # Remove the first example\n",
    "            if(i==1): \n",
    "                examples.pop(5)\n",
    "            else:\n",
    "                examples.pop(-1)\n",
    "            i=i+1\n",
    "            # Recalculate the total length\n",
    "            total_length = sum(len(example) for example in examples) + len(query) + len(\"Let's think step by step\")\n",
    "        context = 'Q:'.join(examples)\n",
    "    return 'Q:'+context if context[0:2]!='Q:' else context\n",
    "\n",
    "\n",
    "class Dataset:\n",
    "    def __init__(self,svamp_path=\"SVAMP.json\",asdiv_path='Modified_ASDiv.xml',mathqa_path=\"Complexity_Sorted_MathQA.json\"):\n",
    "        self.model = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2') \n",
    "        # Load the dataset from the file_path\n",
    "        with open(os.path.join(os.getcwd(),svamp_path), 'r', encoding='utf-8') as file:\n",
    "            self.svamp=json.load(file)\n",
    "            # Group the data by type from the list\n",
    "        grouped_data_list = group_by_type_from_list(self.svamp)\n",
    "\n",
    "        # Create separate lists for each type\n",
    "        self.svamp_sub = grouped_data_list.get('Subtraction', [])\n",
    "        self.svamp_sub=self.generate_embeddings(self.svamp_sub)\n",
    "        \n",
    "        self.svamp_addition= grouped_data_list.get('Addition', [])\n",
    "        self.svamp_addition=self.generate_embeddings(self.svamp_addition)\n",
    "        \n",
    "        self.svamp_multiplication = grouped_data_list.get('Multiplication', [])\n",
    "        self.svamp_multiplication=self.generate_embeddings(self.svamp_multiplication)\n",
    "        \n",
    "        self.svamp_division = grouped_data_list.get('Common-Division', [])\n",
    "        self.svamp_division+grouped_data_list.get('Common-Divison', [])\n",
    "        self.svamp_division=self.generate_embeddings(self.svamp_division)\n",
    "        \n",
    "        with open(asdiv_path, 'r') as file:\n",
    "            data = xmltodict.parse(file.read())\n",
    "        self.asdiv=data['Problems']['Problem']\n",
    "        grouped_data_list = group_by_type_from_list(self.asdiv)\n",
    "        \n",
    "        self.asdiv_sub = grouped_data_list.get('Subtraction', [])\n",
    "        self.asdiv_sub=self.generate_embeddings(self.asdiv_sub)\n",
    "        \n",
    "        self.asdiv_addition= grouped_data_list.get('Addition', [])\n",
    "        self.asdiv_addition=self.generate_embeddings(self.asdiv_addition)\n",
    "\n",
    "        self.asdiv_multiplication = grouped_data_list.get('Multiplication', [])\n",
    "        self.asdiv_multiplication=self.generate_embeddings(self.asdiv_multiplication)\n",
    "   \n",
    "        self.asdiv_division = grouped_data_list.get('Division', [])\n",
    "        self.asdiv_division=self.generate_embeddings(self.asdiv_division)\n",
    "\n",
    "        with open(os.path.join(os.getcwd(),mathqa_path), 'r', encoding='utf-8') as f:\n",
    "            self.mathqa=json.load(f)\n",
    "            self.mathqa=self.mathqa[:2000]\n",
    "            \n",
    "        self.mathqa=self.generate_embeddings(self.mathqa,no_body=True)\n",
    "\n",
    "    def generate_embeddings(self, dictionary,no_body=False):\n",
    "        for items in dictionary:\n",
    "            if no_body is True: \n",
    "                items['embedding']=self.model.encode(items['Problem'], convert_to_tensor=False)\n",
    "            else:\n",
    "                items['embedding']=self.model.encode(items['Body']+\" \"+items['Question'], convert_to_tensor=False)\n",
    "        \n",
    "        return dictionary\n",
    "\n",
    "\n",
    "    def get_svamp_KA_lists(self,query, types, frequency):\n",
    "        \n",
    "        list_of_lists=[]\n",
    "        your_query_vector =self.model.encode(query, convert_to_tensor=False)\n",
    "        query_vector = your_query_vector.reshape(1, -1)\n",
    "\n",
    "        count =len(types)\n",
    "        k=0\n",
    "        if count==1:\n",
    "            k=36\n",
    "        elif count==2:\n",
    "            k=18\n",
    "        elif count==3:\n",
    "            k=12\n",
    "        elif count==4:\n",
    "            k=9\n",
    "        assert k!=0\n",
    "        \n",
    "        for ty in types: \n",
    "            \n",
    "            if ty=='Addition':\n",
    "                embedd=self.svamp_addition            \n",
    "            if ty=='Subtraction':\n",
    "                embedd=self.svamp_sub\n",
    "            if ty=='Multiplication':\n",
    "                embedd=self.svamp_multiplication\n",
    "            if ty=='Division':\n",
    "                embedd=self.svamp_division\n",
    "            embeddings = np.array([item['embedding'] for item in embedd])\n",
    "            dimension = embeddings.shape[1]  # Assuming embeddings are 1D arrays\n",
    "            index = faiss.IndexFlatL2(dimension)  # L2 distance metric\n",
    "            index.add(embeddings)\n",
    "            distances, indices = index.search(query_vector, k)       \n",
    "            list_subset = [f\"Q: {embedd[i]['Body']} {embedd[i]['Question']}\\n Since, {embedd[i]['Equation']}={embedd[i]['Answer']}. The answer is {embedd[i]['Answer']} \\n\" for i in indices.ravel()]\n",
    "            list_of_lists.append(list_subset)\n",
    "\n",
    "        return list_of_lists\n",
    "\n",
    "    def get_asdiv_KA_lists(self, query, types, frequency):\n",
    "        list_of_lists=[]\n",
    "        your_query_vector =self.model.encode(query, convert_to_tensor=False)\n",
    "        query_vector = your_query_vector.reshape(1, -1)\n",
    "\n",
    "        count =len(types)\n",
    "        k=0\n",
    "        if count==1:\n",
    "            k=36\n",
    "        elif count==2:\n",
    "            k=18\n",
    "        elif count==3:\n",
    "            k=12\n",
    "        elif count==4:\n",
    "            k=9\n",
    "        assert k!=0\n",
    "        for ty in types:    \n",
    "            if ty=='Addition':\n",
    "                embedd=self.asdiv_addition            \n",
    "            if ty=='Subtraction':\n",
    "                embedd=self.asdiv_sub\n",
    "            if ty=='Multiplication':\n",
    "                embedd=self.asdiv_multiplication\n",
    "            if ty=='Division':\n",
    "                embedd=self.asdiv_division\n",
    "\n",
    "            embeddings = np.array([item['embedding'] for item in embedd])\n",
    "            dimension = embeddings.shape[1]  # Assuming embeddings are 1D arrays\n",
    "            index = faiss.IndexFlatL2(dimension)  # L2 distance metric\n",
    "            index.add(embeddings)\n",
    "            distances, indices = index.search(query_vector, k)\n",
    "            list_subset = [f\"Q: {embedd[i]['Body']} {embedd[i]['Question']}\\n Since, {embedd[i]['Formula']}. The answer is {embedd[i]['Answer']} \\n\"for i in indices.ravel()]\n",
    "            list_of_lists.append(list_subset)\n",
    "        \n",
    "        return list_of_lists\n",
    "    \n",
    "    def get_mathqa_KA_lists(self, query,frequency):\n",
    "    \n",
    "        list_of_lists=[]\n",
    "        your_query_vector =self.model.encode(query, convert_to_tensor=False)\n",
    "        query_vector = your_query_vector.reshape(1, -1)\n",
    "        k=36\n",
    "        embedd=self.mathqa\n",
    "        embeddings = np.array([item['embedding'] for item in embedd])\n",
    "        dimension = embeddings.shape[1]  # Assuming embeddings are 1D arrays\n",
    "        index = faiss.IndexFlatL2(dimension)  # L2 distance metric\n",
    "        index.add(embeddings)\n",
    "        distances, indices = index.search(query_vector, k)\n",
    "        list_subset = [f\"Q: {embedd[i]['Problem']} \\n Lets think step by step. {embedd[i]['Rationale'][:-11]} The answer is {embedd[i]['Answer']}\\n\" for i in indices.ravel()]\n",
    "        list_of_lists.append(list_subset)\n",
    "        \n",
    "        return list_of_lists\n",
    "     \n",
    "    def get_iccka_combinations(self, query,types,frequency):\n",
    "    \n",
    "        basic_qa_context= generate_random_questions(types,frequency)\n",
    "        if types['Percentage']>0:\n",
    "            \n",
    "            types['Multiplication']+=1\n",
    "            types['Division']+=1\n",
    "        \n",
    "        types['Percentage']=0\n",
    "    \n",
    "        svamp_lists=None\n",
    "        asdiv_lists=None\n",
    "        mathqa_lists=None\n",
    "    \n",
    "        types_rest= [key for key, value in types.items() if value > 0]\n",
    "    \n",
    "        #if any operator frequcny is more than double the rest, we represent only that in SVAMP. So, we get\n",
    "        key=key_with_value_double_of_others(types)\n",
    "        if key is not None: \n",
    "            types_svamp=[key]\n",
    "            svamp_lists=self.get_svamp_KA_lists(query[2:], types_svamp, frequency)\n",
    "            svamp_lists=get_three_lists(svamp_lists)\n",
    "     \n",
    "        \n",
    "        count = len(types_rest)\n",
    "    \n",
    "        if(count==0):\n",
    "            count=4\n",
    "        #within each call, count will play role: count 1: 36 , count 2: 18 each, count 4: 9 each, count 3: 12 each\n",
    "        if svamp_lists==None:\n",
    "            svamp_lists=self.get_svamp_KA_lists(query[2:], types_rest, frequency)\n",
    "            svamp_lists=get_three_lists(svamp_lists)\n",
    "        \n",
    "        asdiv_lists=self.get_asdiv_KA_lists(query[2:], types_rest, frequency) # let's organize the code for this call\n",
    "        asdiv_lists=get_three_lists(asdiv_lists)\n",
    "    \n",
    "        mathqa_lists= self.get_mathqa_KA_lists(query[2:], frequency) # retrieve 36 similar items. \n",
    "        mathqa_lists=get_three_lists(mathqa_lists)\n",
    "        \n",
    "        \n",
    "        asdiv_context=self.get_context(query,asdiv_lists,frequency)\n",
    "        mathqa_context=self.get_context(query,mathqa_lists,frequency)\n",
    "        svamp_context=self.get_context(query,svamp_lists,frequency)\n",
    "        \n",
    "#         print(len(asdiv_context),type(asdiv_context[0]))\n",
    "        \n",
    "#         print(len(mathqa_context),type(mathqa_context[0]))\n",
    "        \n",
    "        #print(len(svamp_context),type(svamp_context[0]))\n",
    "        print(basic_qa_context)\n",
    "        \n",
    "        context= [a+ b + c + d for a, b, c, d in zip(basic_qa_context,svamp_context, asdiv_context, mathqa_context)]\n",
    "        for i in range(len(context)):\n",
    "            context[i]=trim_context(context[i],query)\n",
    "            context[i]=context[i]+query +\"\\n\"+ \"Let's think step by step\"\n",
    "        \n",
    "        return context\n",
    "        \n",
    "    def get_context(self,query, three_lists,frequency=40):\n",
    "        permuation_index=[(1, 1, 1),(2, 2, 2),(3, 3, 3),(4, 4, 4),(5, 5, 5),(6, 6, 6),(7, 7, 7),(8, 8, 8),(9, 9, 9),(10, 10, 10),(11, 11, 11),(12, 12, 12),(1, 2, 3),(2, 3, 4),(3, 4, 5),(4, 5, 6),(5, 6, 7),(6, 7, 8),(7, 8, 9),(8, 9, 10),(9, 10, 11),(10, 11, 12),(3, 2, 1),(4, 3, 2),(5, 4, 3),(6, 5, 4),(7, 6, 5),(8, 7, 6),(9, 8, 7),(10, 9, 8),(11, 10, 9),(12, 11, 10),(1, 3, 5),(2, 4, 6),(3, 5, 7),(4, 6, 8),(5, 7, 9),(6, 8, 10),(7, 9, 11),(8, 10, 12),(1,5,9),(2,6,10),(3,7,11),(4,8,12),(9,1,5),(10,2,6),(11,3,7),(12,4,8),(11,6,1),(12,7,2)]\n",
    "        permutation_index=permuation_index[:frequency]\n",
    "        \n",
    "        context=[]\n",
    "        \n",
    "        for i in range(len(permuation_index)):\n",
    "            context_1=three_lists[0][permuation_index[i][0]-1]+three_lists[1][permuation_index[i][1]-1]+three_lists[2][permuation_index[i][2]-1]\n",
    "            context.append(context_1)\n",
    "            assert type(context[-1])==str\n",
    "            \n",
    "                \n",
    "        return context\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "data=None\n",
    "answer_variety=None\n",
    "with open('in_context_compressed.pkl', 'rb') as f:\n",
    "    data=pickle.load(f)\n",
    "dataset=Dataset()\n",
    "with open('answer_variety.pkl', 'rb') as f:\n",
    "    answer_variety=pickle.load(f)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "data[0].keys()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "i=0\n",
    "x=dataset.get_iccka_combinations(\"Q:\"+ data[i]['question'].split('Q:')[-1][:-25],data[i]['operators'],data[i]['frequency'])\n",
    "x[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "final_data=[]\n",
    "for i in range(len(data)):\n",
    "    query=\"Q:\"+ data[i]['question'].split('Q:')[-1][:-25]\n",
    "    x=dataset.get_iccka_combinations(query,data[i]['operators'],data[i]['frequency'])\n",
    "    x_list=[]\n",
    "    print(data[i]['frequency'],len(query))\n",
    "    answer_list=answer_variety[data[i]['answer_gold']]\n",
    "    assert data[i]['frequency'] == len(answer_list)\n",
    "    for j in range(data[i]['frequency']):\n",
    "        dict_1=data[i].copy()\n",
    "        dict_1['answer']=answer_list[j]\n",
    "        dict_1['question']=x[j]\n",
    "        dict_1.pop('operators', None)\n",
    "        dict_1.pop('frequency', None)\n",
    "        x_list.append(dict_1)\n",
    "    final_data.extend(x_list)\n",
    "\n",
    "    #create a copy of this data index's each aspect and tie it up, question is given by x,\n",
    "    #rest are same except type and frequency which are to be dropped. \n",
    "random.shuffle(final_data)\n",
    "#randomize the order of data in the list as it will contain same question again and again. Save and then boom"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "final_data[1111]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "key_to_check = \"answer\"\n",
    "unique_values = set()\n",
    "\n",
    "for d in final_data:\n",
    "    if key_to_check in final_data:\n",
    "        unique_values.add(d[key_to_check])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "len(unique_values)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "with open('in_context_chain_of_thought.pkl', 'wb') as f:\n",
    "    pickle.dump(final_data, f)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "final_data[1100]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "query = 'Q: Ivan has a bird feeder in his yard that holds two cups of birdseed. Every week, he has to refill the emptied feeder. Each cup of birdseed can feed fourteen birds, but Ivan is constantly chasing away a hungry squirrel that steals half a cup of birdseed from the feeder every week. How many birds does Ivan’s bird feeder feed weekly?'\n",
    "\n",
    "dataset=Dataset()\n",
    "x=dataset.get_iccka_combinations(query,{'Addition':0,'Subtraction':0,'Percentage':0, 'Multiplication':0,\"Division\":1},35)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import faiss\n",
    "import numpy as np\n",
    "\n",
    "query= 'Q: Ivan has a bird feeder in his yard that holds two cups of birdseed. Every week, he has to refill the emptied feeder. Each cup of birdseed can feed fourteen birds, but Ivan is constantly chasing away a hungry squirrel that steals half a cup of birdseed from the feeder every week. How many birds does Ivan’s bird feeder feed weekly?'\n",
    "your_query_vector =model.encode(query, convert_to_tensor=False)\n",
    "query_vector = your_query_vector.reshape(1, -1)\n",
    "\n",
    "\n",
    "embeddings = np.array([item['embedding'] for item in svamp.sub])\n",
    "\n",
    "dimension = embeddings.shape[1]  # Assuming embeddings are 1D arrays\n",
    "index = faiss.IndexFlatL2(dimension)  # L2 distance metric\n",
    "index.add(embeddings)\n",
    "\n",
    "\n",
    "\n",
    "# Search for the top k nearest neighbors\n",
    "k = 30\n",
    "distances, indices = index.search(query_vector, k)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "distances, indices"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "svamp.sub[indices[0][-1]]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "data[1]['operators']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    " dataset.svamp_addition[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python [conda env:csn] *",
   "language": "python",
   "name": "conda-env-csn-py"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
