{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## It generates ICCKA prompt for the test set, where each component of the Curriculum prompt comes from queries. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "import os\n",
    "import json\n",
    "import xmltodict\n",
    "import random\n",
    "import pickle\n",
    "import re\n",
    "from itertools import permutations, combinations\n",
    "from sentence_transformers import SentenceTransformer, util\n",
    "\n",
    "import faiss\n",
    "import numpy as np\n",
    "\n",
    "# Creating the four lists for each arithmetic operation\n",
    "addition = [\n",
    "    \"What is the sum of <<num1>> and <<num2>>?\",\n",
    "    \"By adding <<num1>> to <<num2>>, what number do you get?\",\n",
    "    \"If you have <<num1>> and you add <<num2>> more, how many will you have in total?\",\n",
    "    \"Combine the numbers <<num1>> and <<num2>>. What's the total?\",\n",
    "    \"Add up <<num1>> and <<num2>>. What's the answer?\",\n",
    "    \"How much is <<num1>> added to <<num2>>?\",\n",
    "    \"If I have <<num1>> apples and I get <<num2>> more, how many apples do I have now?\",\n",
    "    \"Calculate the total of <<num1>> and <<num2>>.\",\n",
    "    \"What will you have if you join <<num1>> and <<num2>>?\",\n",
    "    \"What's the result when you sum up <<num1>> and <<num2>>?\"\n",
    "]\n",
    "subtraction = [\n",
    "    \"What is the difference between <<num1>> and <<num2>>?\",\n",
    "    \"Subtract <<num2>> from <<num1>>. What do you get?\",\n",
    "    \"If you have <<num1>> and you take away <<num2>>, how many are left?\",\n",
    "    \"Deduct <<num2>> from <<num1>>. What remains?\",\n",
    "    \"What's left when you remove <<num2>> from <<num1>>?\",\n",
    "    \"How much is <<num1>> minus <<num2>>?\",\n",
    "    \"If I eat <<num2>> out of my <<num1>> cookies, how many cookies do I have remaining?\",\n",
    "    \"Calculate what remains after subtracting <<num2>> from <<num1>>.\",\n",
    "    \"What do you get when you take away <<num2>> from <<num1>>?\",\n",
    "    \"After removing <<num2>> from <<num1>>, what's left?\"\n",
    "]\n",
    "\n",
    "multiplication = [\n",
    "    \"What is the product of <<num1>> and <<num2>>?\",\n",
    "    \"Multiply <<num1>> by <<num2>>. What's the result?\",\n",
    "    \"If you have <<num1>> boxes with <<num2>> candies each, how many candies do you have in total?\",\n",
    "    \"What do you get when you multiply <<num1>> by <<num2>>?\",\n",
    "    \"What's <<num1>> groups of <<num2>>?\",\n",
    "    \"How much is <<num1>> repeated <<num2>> times?\",\n",
    "    \"If one row has <<num1>> apples and there are <<num2>> such rows, how many apples are there in total?\",\n",
    "    \"Calculate the result of multiplying <<num1>> by <<num2>>.\",\n",
    "    \"What will you have when you multiply <<num1>> by <<num2>>?\",\n",
    "    \"What's the outcome of multiplying <<num1>> with <<num2>>?\"\n",
    "]\n",
    "\n",
    "division = [\n",
    "    \"What is the result of dividing <<num1>> by <<num2>>?\",\n",
    "    \"Divide <<num1>> by <<num2>>. What's the quotient?\",\n",
    "    \"If you have <<num1>> candies and you divide them among <<num2>> friends, how many candies does each friend get?\",\n",
    "    \"How many times does <<num2>> go into <<num1>>?\",\n",
    "    \"What's the answer when you split <<num1>> into groups of <<num2>>?\",\n",
    "    \"How much is <<num1>> divided by <<num2>>?\",\n",
    "    \"If I have <<num1>> cookies and I share them equally among <<num2>> friends, how many cookies does each friend receive?\",\n",
    "    \"Calculate the result of dividing <<num1>> by <<num2>>.\",\n",
    "    \"What do you get when you divide <<num1>> by <<num2>>?\",\n",
    "    \"After dividing <<num1>> by <<num2>>, what's the quotient?\"\n",
    "]\n",
    "\n",
    "percentage = [\n",
    "    \"How much is <<num2>>% of <<num1>>?\",\n",
    "    \"Find <<num2>>% of <<num1>>.\",\n",
    "    \"If you were to evaluate <<num2>>% of <<num1>>, what would the result be?\",\n",
    "    \"Assess <<num2>>% of <<num1>>.\",\n",
    "    \"Compute <<num2>>% of <<num1>>.\",\n",
    "    \"What number represents <<num2>>% of <<num1>>?\",\n",
    "    \"If <<num2>>% corresponds to a portion of <<num1>>, how much is that?\",\n",
    "    \"Estimate <<num2>>% of <<num1>>.\",\n",
    "    \"What would be the outcome of taking <<num2>>% of <<num1>>?\",\n",
    "    \"Deduce <<num2>>% of <<num1>>.\"\n",
    "]\n",
    "def generate_random_question():\n",
    "    questions = []\n",
    "    \n",
    "    # Choose a random addition question\n",
    "    addition_question = random.choice(addition)\n",
    "    num1 = random.randint(1, 99)\n",
    "    num2 = random.randint(1, 99)\n",
    "    questions.append((addition_question.replace(\"<<num1>>\", str(num1)).replace(\"<<num2>>\", str(num2)), \"\\n \"+str(num1)+\"+\"+str(num2)+\"=\"+str(num1 + num2)+\". The answer is \"+str(num1+num2)))\n",
    "    \n",
    "    # Choose a random subtraction question\n",
    "    subtraction_question = random.choice(subtraction)\n",
    "    num1 = random.randint(1, 99)\n",
    "    num2 = random.randint(1, num1)  # Ensuring the result is a positive integer\n",
    "    questions.append((subtraction_question.replace(\"<<num1>>\", str(num1)).replace(\"<<num2>>\", str(num2)),\"\\n \"+str(num1)+\"-\"+str(num2)+\"=\"+str(num1 - num2)+\". The answer is \"+str(num1-num2)))\n",
    "    \n",
    "    # Choose a random multiplication question\n",
    "    multiplication_question = random.choice(multiplication)\n",
    "    num1 = random.randint(1, 12)\n",
    "    num2 = random.randint(1, 12)\n",
    "    questions.append((multiplication_question.replace(\"<<num1>>\", str(num1)).replace(\"<<num2>>\", str(num2)),\"\\n \"+str(num1)+\"*\"+str(num2)+\"=\"+str(int(num1*num2))+\". The answer is \"+str((int(num1*num2)))))\n",
    "    \n",
    "    # Choose a random division question\n",
    "    division_question = random.choice(division)\n",
    "    num2 = random.randint(1, 12)\n",
    "    quotient = random.randint(1,12)\n",
    "    num1 = num2 * quotient  # Ensuring division results in an integer\n",
    "    questions.append((division_question.replace(\"<<num1>>\", str(num1)).replace(\"<<num2>>\", str(num2)), \"\\n \"+str(num1)+\"/\"+str(num2)+\"=\"+str(quotient)+\". The answer is \"+str(quotient)))\n",
    "    \n",
    "     # Choose a random percentage question\n",
    "    percentage_question = random.choice(percentage)\n",
    "    num2 = 10*random.randint(1, 10)\n",
    "    num1 = 10*random.randint(1, 20)# Ensuring division results in an integer\n",
    "    questions.append((percentage_question.replace(\"<<num1>>\", str(num1)).replace(\"<<num2>>\", str(num2)),\"\\n \"+str(num2)+'*'+str(num1)+\"/100=\"+str(int(num2*num1/100))+\". The answer is \"+str(int(num2*num1/100))))\n",
    "    \n",
    "    \n",
    "    # Format the questions and answers\n",
    "    formatted_questions = \"\\n\".join([f\"Q: {q[0]} {q[1]}\" for q in questions])\n",
    "    \n",
    "    return formatted_questions\n",
    "\n",
    "def trim_context(context, query):\n",
    "    total_length = len(context) + len(query) + len(\"Let's think step by step\")\n",
    "    if total_length > 2500:\n",
    "        # Split the context into individual examples\n",
    "        examples = context.split('Q:')\n",
    "        # We assume that the examples are ordered by priority (MathQA > SVAMP > ASDiv)\n",
    "        i=0\n",
    "        while total_length > 2500 and examples:\n",
    "            # Remove the first example\n",
    "            if(i==1): \n",
    "                examples.pop(1)\n",
    "            else:\n",
    "                examples.pop(-1)\n",
    "            i=i+1\n",
    "            # Recalculate the total length\n",
    "            total_length = sum(len(example) for example in examples) + len(query) + len(\"Let's think step by step.\")\n",
    "        context = 'Q:'.join(examples)\n",
    "    return 'Q:'+context if context[0:2]!='Q:' else context\n",
    "\n",
    "\n",
    "class Dataset_external:\n",
    "    def __init__(self,svamp_path=\"SVAMP.json\",asdiv_path='Modified_ASDiv.xml',mathqa_path=\"Complexity_Sorted_MathQA.json\"):\n",
    "        self.model = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2') \n",
    "        # Load the dataset from the file_path\n",
    "        with open(os.path.join(os.getcwd(),svamp_path), 'r', encoding='utf-8') as file:\n",
    "            self.svamp=json.load(file)\n",
    "            # Group the data by type from the list\n",
    "        \n",
    "        self.svamp_sub=self.generate_embeddings(self.svamp)\n",
    "        \n",
    "        with open(asdiv_path, 'r') as file:\n",
    "            data = xmltodict.parse(file.read())\n",
    "        self.asdiv=data['Problems']['Problem']\n",
    "        \n",
    "        self.asdiv=self.generate_embeddings(self.asdiv)\n",
    "        \n",
    "        with open(os.path.join(os.getcwd(),mathqa_path), 'r', encoding='utf-8') as f:\n",
    "            self.mathqa=json.load(f)\n",
    "            self.mathqa=self.mathqa[:2000]\n",
    "            \n",
    "        self.mathqa=self.generate_embeddings(self.mathqa,no_body=True)\n",
    "\n",
    "    def generate_embeddings(self, dictionary,no_body=False):\n",
    "        for items in dictionary:\n",
    "            if no_body is True: \n",
    "                items['embedding']=self.model.encode(items['Problem'], convert_to_tensor=False)\n",
    "            else:\n",
    "                items['embedding']=self.model.encode(items['Body']+\" \"+items['Question'], convert_to_tensor=False)\n",
    "        \n",
    "        return dictionary\n",
    "\n",
    "\n",
    "    def get_svamp_KA_lists(self,query):\n",
    "        \n",
    "        your_query_vector =self.model.encode(query, convert_to_tensor=False)\n",
    "        query_vector = your_query_vector.reshape(1, -1)\n",
    "\n",
    "        k=3\n",
    "        embedd=self.svamp\n",
    "        embeddings = np.array([item['embedding'] for item in embedd ])\n",
    "        dimension = embeddings.shape[1]  # Assuming embeddings are 1D arrays\n",
    "        index = faiss.IndexFlatL2(dimension)  # L2 distance metric\n",
    "        index.add(embeddings)\n",
    "        distances, indices = index.search(query_vector, k)     \n",
    "        \n",
    "        list_subset = [f\"Q: {embedd[i]['Body']} {embedd[i]['Question']}\\n Since, {embedd[i]['Equation']}={embedd[i]['Answer']}. The answer is {embedd[i]['Answer']} \\n\" for i in indices.ravel()]\n",
    "        \n",
    "        return ''.join(list_subset)\n",
    "\n",
    "    def get_asdiv_KA_lists(self, query):\n",
    "\n",
    "        your_query_vector =self.model.encode(query, convert_to_tensor=False)\n",
    "        query_vector = your_query_vector.reshape(1, -1)\n",
    "        \n",
    "        k=3\n",
    "        embedd =self.asdiv\n",
    "        embeddings = np.array([item['embedding'] for item in embedd])\n",
    "        dimension = embeddings.shape[1]  # Assuming embeddings are 1D arrays\n",
    "        index = faiss.IndexFlatL2(dimension)  # L2 distance metric\n",
    "        index.add(embeddings)\n",
    "        distances, indices = index.search(query_vector, k)\n",
    "        list_subset = [f\"Q: {embedd[i]['Body']} {embedd[i]['Question']}\\n Since, {embedd[i]['Formula']}. The answer is {embedd[i]['Answer']} \\n\"for i in indices.ravel()]\n",
    "        \n",
    "        return ''.join(list_subset)\n",
    "    \n",
    "    def get_mathqa_KA_lists(self, query):\n",
    "    \n",
    "        your_query_vector =self.model.encode(query, convert_to_tensor=False)\n",
    "        query_vector = your_query_vector.reshape(1, -1)\n",
    "        k=3\n",
    "        embedd=self.mathqa\n",
    "        embeddings = np.array([item['embedding'] for item in embedd])\n",
    "        dimension = embeddings.shape[1]  # Assuming embeddings are 1D arrays\n",
    "        index = faiss.IndexFlatL2(dimension)  # L2 distance metric\n",
    "        index.add(embeddings)\n",
    "        distances, indices = index.search(query_vector, k)\n",
    "        list_subset = [f\"Q: {embedd[i]['Problem']} \\n Lets think step by step. {embedd[i]['Rationale'][:-11]} The answer is {embedd[i]['Answer']}\\n\" for i in indices.ravel()]\n",
    "        \n",
    "        return ''.join(list_subset)\n",
    "     \n",
    "    def get_iccka_combinations(self, query):\n",
    "        \n",
    "        basicmath_context = generate_random_question()\n",
    "\n",
    "        svamp_context=self.get_svamp_KA_lists(query[2:])\n",
    "        \n",
    "        asdiv_context=self.get_asdiv_KA_lists(query[2:]) # let's organize the code for this call\n",
    "        \n",
    "        mathqa_context= self.get_mathqa_KA_lists(query[2:]) # retrieve 36 similar items. \n",
    "      \n",
    "        context= basicmath_context+svamp_context+asdiv_context+mathqa_context\n",
    "        \n",
    "        \n",
    "        context=trim_context(context,query)\n",
    "        #context=context+'Q:'+query +\"\\n\"+ \"Let's think step by step.\"\n",
    "        \n",
    "        return context\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Found cached dataset gsm8k (C:/Users/Jayant/.cache/huggingface/datasets/gsm8k/main/1.1.0/37bfb08b1d4fcbb01f06b03d9e1ef5f1fcbd4d3af3d08842c50d7305091285ba)\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "28d1b16bf8334067a75eafce64b61b05",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/2 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/plain": [
       "\"Janet’s ducks lay 16 eggs per day. She eats three for breakfast every morning and bakes muffins for her friends every day with four. She sells the remainder at the farmers' market daily for $2 per fresh duck egg. How much in dollars does she make every day at the farmers' market?\""
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from datasets import load_dataset\n",
    "gsm8k = load_dataset('gsm8k', 'main')\n",
    "data = gsm8k['test']\n",
    "data_ = []\n",
    "for q, a in zip(data['question'], data['answer']): \n",
    "    data_.append({'question': q, 'answer': a})\n",
    "data_[0]['question']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset_ext=Dataset_external()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "dict_prompt={}\n",
    "for item in data_:\n",
    "    dict_prompt[item['question']]=dataset_ext.get_iccka_combinations(data_[0]['question'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "1319"
      ]
     },
     "execution_count": 14,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "len(dict_prompt)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [],
   "source": [
    "with open('test_prompts.pkl', 'wb') as f:\n",
    "    pickle.dump(dict_prompt, f)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'Q: Subtract 16 from 34. What do you get? \\n 34-16=18. The answer is 18\\nQ: What\\'s the outcome of multiplying 9 with 10? \\n 9*10=90. The answer is 90\\nQ: Calculate the result of dividing 70 by 10. \\n 70/10=7. The answer is 7\\nQ: If 40% corresponds to a portion of 90, how much is that? \\n 40*90/100=36. The answer is 36Q: Lucy went to the grocery store. She bought 4 packs of cookie, 22 packs of cake and 16 packs of chocolate. How many packs of groceries did she buy in all?\\n Since, ( ( 4.0 + 22.0 ) + 16.0 )=42.0. The answer is 42.0 \\nQ: There were 53 dollars in Olivia\\'s wallet. She collected 91 more dollars from an atm. After she visited a supermarket there were 14 dollars left. How much more money did she spend at the supermarket than she collected at the atm?\\n Since, ( 53.0 - 14.0 )=39.0. The answer is 39.0 \\nQ: Last week Fred had 33 dollars and Jason had 95 dollars. Over the weekend Fred delivered newspapers earning 16 dollars and washed cars earning 74 dollars. How much money did Fred earn over the weekend?\\n Since, ( 16.0 + 74.0 )=90.0. The answer is 90.0 \\nQ: Emily collected eggs from the hen and put them into 303 baskets. She put 28 eggs into each basket. How many eggs did Emily collect?\\n Since, 303*28=8484. The answer is 8484 (eggs) \\nQ: Chef Pillsbury\\'s secret recipe requires 7 eggs for every 2 cups of flour. How many eggs will he need if he uses 8 cups of flour?\\n Since, 2:7=8:28. The answer is 28 (eggs) \\nQ: Olivia\\'s dad took her and some friends out to eat for her birthday. If each meal costs 7 dollars and her dad paid for 3 meals, how much did he spend?\\n Since, 7*3=21. The answer is 21 (dollars) \\nQ: a waitress \\' s income consists of her salary and tips . during one week , her tips were 2 / 4 of her salary . what fraction of her income for the week came from tips ? \\n Lets think step by step. \"her tips were 2 / 4 of her salary . let \\' s say her salary = $ 4 this mean her tips = ( 2 / 4 ) ( $ 4 ) = $ 2 so , her total income = $ 4 + $ 2 = $ 6 what fraction of her income for the week came from tips $ 2 / $ 6 = The answer is 1 / 3\\n'"
      ]
     },
     "execution_count": 16,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "dict_prompt[data_[1000]['question']]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
