{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## This notebook is used to create ICCR for in-context COT data that our baseline replaces. We produce a .pkl file with 200000 length. "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Context\n",
    "\n",
    "2-3 Basic problems from SVAMP\n",
    "\n",
    "1-2 Basic problems from ASDv\n",
    "\n",
    "2 Intermediate problems from MathQA with a rationale. \n",
    "\n",
    "5-6 Basic arithematic needed from mathematics dataset or my own functions"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "import os\n",
    "import json\n",
    "import xmltodict\n",
    "import random\n",
    "import pickle\n",
    "\n",
    "\n",
    "def generate_unique_random_integers(n, x, y):\n",
    "    return random.sample(range(x, y+1), n)\n",
    "\n",
    "class SVAMP:\n",
    "    def __init__(self, file_path=\"SVAMP.json\"):\n",
    "        # Load the dataset from the file_path\n",
    "        with open(os.path.join(os.getcwd(),file_path), 'r', encoding='utf-8') as file:\n",
    "            self.dictionary=json.load(file)\n",
    "        self.df_svamp= pd.json_normalize(self.dictionary)\n",
    "        self.length=len(self.dictionary)\n",
    "\n",
    "    def print_stats(self):\n",
    "        category_counts = self.df_svamp[\"Type\"].value_counts()\n",
    "        # Print the type and frequency\n",
    "        for category, count in category_counts.iteritems():\n",
    "            print(f\"Type: {category}, Frequency: {count}\")\n",
    "        print(\"Columns:\", self.df_svamp.columns)\n",
    "        print(\"Total Size\", len(self.df_svamp.index))\n",
    "    \n",
    "\n",
    "        \n",
    "    def get_entry(self, index):\n",
    "        return self.dictionary[index]\n",
    "        # Return the question and answer pair at the specified index\n",
    "\n",
    "        #     def get_context(self,number=2):\n",
    "#         context=''\n",
    "#         indexes= number x randn(1,100)# generate a number of random numbers between a range\n",
    "#         for i in indexes:\n",
    "#             assert i < len(self.dictionary)\n",
    "#             context+= \"Question:\"+str(self.dictionary[i]['Body'])+' '+str(self.dictionary[i]['Question'])+' Answer: '+str(self.dictionary[i]['Equation'])+'='+str(self.dictionary[i]['Answer'])+'\\n'\n",
    "        \n",
    "#         return context\n",
    "\n",
    "    def get_context(self, number, filter_column=None, filter_value=None):\n",
    "        df = self.df_svamp\n",
    "        # Let's generate 10 random integers between 1 and 100\n",
    "        \n",
    "        if filter_column and filter_value:\n",
    "            df = df[df[filter_column] == filter_value]\n",
    "        context=''\n",
    "        indexes = generate_unique_random_integers(number, 1, len(df.index)-1)\n",
    "        for i in indexes:\n",
    "            assert i < len(df.index)\n",
    "            row = df.iloc[i]\n",
    "            context += f\"Q: {row['Body']} {row['Question']} Answer: {row['Equation']}={row['Answer']}\\n\"\n",
    "        \n",
    "        return context     \n",
    "\n",
    "class ASDiv:\n",
    "    def __init__(self, file_path='ASDiv.xml'):\n",
    "        # Load the dataset from the file_path\n",
    "        with open(file_path, 'r') as file:\n",
    "            data = xmltodict.parse(file.read())\n",
    "            \n",
    "        self.dictionary=data['Problems']['Problem']\n",
    "        self.df_asd = pd.json_normalize(self.dictionary)\n",
    "        self.length=len(self.dictionary)\n",
    "        \n",
    "    def get_entry(self, index):\n",
    "        return self.dictionary[index]\n",
    "        # Return the question and answer pair at the specified index\n",
    "    def print_stats(self):\n",
    "        category_counts = self.df_asd[\"groups\"].value_counts()\n",
    "        for category, count in category_counts.iteritems():\n",
    "            print(f\"Type: {category}, Frequency: {count}\")\n",
    "        print(\"Columns:\", self.df_asd.columns)\n",
    "        print(\"Total Size\", len(self.df_asd.index))\n",
    "            \n",
    "    \n",
    "    def get_context(self, number, filter_column=None, filter_value=None):\n",
    "        df = self.df_asd\n",
    "        if filter_column and filter_value:\n",
    "            df = df[df[filter_column] == filter_value]\n",
    "        indexes = generate_unique_random_integers(number, 1, len(df)-1)\n",
    "        context = ''\n",
    "        for i in indexes:\n",
    "            assert i < len(df)\n",
    "            row = df.iloc[i]\n",
    "            context += f\"Q: {row['Body']} {row['Question']}\\n Since, {row['Formula']}. The answer is {row['Answer']}\\n\"\n",
    "\n",
    "        return context\n",
    "\n",
    "class MathQA:\n",
    "    def __init__(self, file_path=\"Complexity_Sorted_MathQA.json\"):\n",
    "        # Load the dataset from the file_path\n",
    "        with open(os.path.join(os.getcwd(),file_path), 'r', encoding='utf-8') as f:\n",
    "            self.dictionary=json.load(f)\n",
    "        self.df_math = pd.json_normalize(self.dictionary)\n",
    "        self.length=2000\n",
    "        \n",
    "    def get_question_answer(self, index):\n",
    "        # Return the question and answer pair at the specified index\n",
    "        return None\n",
    "    \n",
    "    def print_stats(self):\n",
    "        category_counts = self.df_asd[\"groups\"].value_counts()\n",
    "        for category, count in category_counts.iteritems():\n",
    "            print(f\"Type: {category}, Frequency: {count}\")\n",
    "        print(\"Columns:\", self.df_asd.columns)\n",
    "        print(\"Total Size\", len(self.df_asd.index))\n",
    "            \n",
    "    def get_context(self, number, filter_column=None, filter_value=None):\n",
    "        df = self.df_math\n",
    "        if filter_column and filter_value:\n",
    "            df = df[df[filter_column] == filter_value]\n",
    "        indexes = generate_unique_random_integers(number, 20,2000)\n",
    "        context = ''\n",
    "        for i in indexes:\n",
    "            assert i < len(df)\n",
    "            row = df.iloc[i]\n",
    "            context += f\"Q: {row['Problem']} Lets think step by step. {row['Rationale'][:-11]} The answer is {row['Answer']}\\n\"\n",
    "    \n",
    "        return context\n",
    "    \n",
    "def generate_composed_arithmetic_examples(num_examples=5, num_digits=2, max_composition_level=2):\n",
    "    def get_operand_with_cumulative_digits(num_digits):\n",
    "        if num_digits == 1:\n",
    "            return random.randint(1, 9)\n",
    "        else:\n",
    "            lower_bound = 10 ** (num_digits - 1)\n",
    "            upper_bound = 10 ** num_digits - 1\n",
    "            return random.randint(lower_bound, upper_bound)\n",
    "\n",
    "    operators = ['+', '-', '*', '/']\n",
    "    operator_weights = [3, 3, 2, 2]  # weights for '+', '-', '*', '/'\n",
    "    examples = []\n",
    "    if num_digits>1: \n",
    "        num_examples_per_digit = num_examples // num_digits\n",
    "    else:\n",
    "        num_examples_per_digit=1\n",
    "\n",
    "    for current_digits in range(1, num_digits + 1):\n",
    "        for _ in range(num_examples_per_digit):\n",
    "            composition_level = random.randint(1, max_composition_level)\n",
    "            op1 = get_operand_with_cumulative_digits(current_digits)\n",
    "            expression = str(op1)\n",
    "            result = op1\n",
    "            prev_operator = None\n",
    "            for _ in range(composition_level):\n",
    "                operator = random.choices(operators, weights=operator_weights, k=1)[0]\n",
    "\n",
    "                if operator == '/':\n",
    "                    # Make sure we get an integer result when performing division\n",
    "                    divisors = [i for i in range(1, result + 1) if result % i == 0]\n",
    "                    op2 = random.choice(divisors if divisors else [1])\n",
    "                else:\n",
    "                    op2 = get_operand_with_cumulative_digits(current_digits)\n",
    "\n",
    "                expression += f\" {operator} {op2}\"\n",
    "\n",
    "                if operator == '+':\n",
    "                    result += op2\n",
    "                elif operator == '-':\n",
    "                    result -= op2\n",
    "                elif operator == '*':\n",
    "                    result *= op2\n",
    "                elif operator == '/':\n",
    "                    result = result // op2  # This should always be an integer\n",
    "\n",
    "                prev_operator = operator\n",
    "\n",
    "            examples.append(f\"{expression.strip()} = {result}\")  # result is always an integer\n",
    "\n",
    "    return \"Some basic calculations: \" + ', '.join(examples)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# in_context_chain_of_thought,\n",
    "the one to be edited by removing context of the demonstrations to simpler demonstrations."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "in_context = None\n",
    "with open('in_context_chain_of_thought.pkl', 'rb') as f:\n",
    "    in_context=pickle.load(f)\n",
    "    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "207135"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "len(in_context)#[1]['question']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "import re\n",
    "\n",
    "def evaluate_expression(expression):\n",
    "    # This function evaluates the expression within << >> delimiters\n",
    "    # and returns the result as an integer.\n",
    "    try:\n",
    "        parts = expression.strip('<<>>').split('=')\n",
    "        return int(parts[-1])\n",
    "    except:\n",
    "        return None\n",
    "\n",
    "def get_gold_query_digits(query_string):\n",
    "    # Initialize variables to keep track of the highest digit number\n",
    "    highest_digit_number = 0\n",
    "    current_number = 0\n",
    "\n",
    "    # Regular expression pattern to match numbers in the query string\n",
    "    number_pattern = r'\\d+(?:\\.\\d+)?'\n",
    "\n",
    "    # Iterate through the query string\n",
    "    for match in re.finditer(number_pattern, query_string):\n",
    "        number = match.group(0)\n",
    "        number_value = float(number)\n",
    "\n",
    "        # Evaluate the expressions within << >> delimiters\n",
    "        if '<<' in number and '>>' in number:\n",
    "            evaluated = evaluate_expression(number)\n",
    "            if evaluated is not None:\n",
    "                number_value = evaluated\n",
    "\n",
    "        # Keep track of the highest digit number found so far\n",
    "        if number_value.is_integer():\n",
    "            integer_part = int(number_value)\n",
    "            if integer_part > highest_digit_number:\n",
    "                highest_digit_number = integer_part\n",
    "\n",
    "        # Keep track of the number of digits in the current number\n",
    "        num_digits = len(number) if '.' not in number else len(number.split('.')[0])\n",
    "        current_number = max(current_number, num_digits)\n",
    "        if current_number>4:\n",
    "            current_number=3\n",
    "            \n",
    "\n",
    "    # Return the number of digits in the highest digit number\n",
    "    return {'num_digits': current_number} \n",
    "\n",
    "##might need to make one for testing based on question itself?\n",
    "\n",
    "def generate_context(params):\n",
    "    svamp_context = svamp.get_context(3)#indexes = generate_unique_random_integers(min(number, len(df)), 1, len(df)-1)\n",
    "\n",
    "    asdiv_context = asdiv.get_context(3)#, filter_column='groups', filter_value=params['group'])\n",
    "    mathqa_context = mathqa.get_context(2)  # We assume the MathQA context does not depend on the query parameters\n",
    "    basicmath_context = generate_composed_arithmetic_examples(25, num_digits=params['num_digits'])\n",
    "\n",
    "    context = basicmath_context + ' ' + svamp_context + ' ' + asdiv_context + ' ' + mathqa_context\n",
    "    \n",
    "    return context\n",
    "\n",
    "def trim_context(context, query):\n",
    "    total_length = len(context) + len(query[\"Q\"]) + len(\"Let's think step by step\")\n",
    "    if total_length > 2500:\n",
    "        # Split the context into individual examples\n",
    "        examples = context.split('Q:')\n",
    "        # We assume that the examples are ordered by priority (MathQA > SVAMP > ASDiv)\n",
    "        i=0\n",
    "        while total_length > 2500 and examples:\n",
    "            # Remove the first example\n",
    "            if(i==1): \n",
    "                examples.pop(1)\n",
    "            else:\n",
    "                examples.pop(-1)\n",
    "            i=i+1\n",
    "            # Recalculate the total length\n",
    "            total_length = sum(len(example) for example in examples) + len(query[\"Q\"]) + len(\"Let's think step by step\")\n",
    "        context = 'Q:'.join(examples)\n",
    "    return 'Q:'+context if context[0:2]!='Q:' else context\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "207135\n",
      "0\n",
      "1000\n",
      "2000\n",
      "3000\n",
      "4000\n",
      "5000\n",
      "6000\n",
      "7000\n",
      "8000\n",
      "9000\n",
      "10000\n",
      "11000\n",
      "12000\n",
      "13000\n",
      "14000\n",
      "15000\n",
      "16000\n",
      "17000\n",
      "18000\n",
      "19000\n",
      "20000\n",
      "21000\n",
      "22000\n",
      "23000\n",
      "24000\n",
      "25000\n",
      "26000\n",
      "27000\n",
      "28000\n",
      "29000\n",
      "30000\n",
      "31000\n",
      "32000\n",
      "33000\n",
      "34000\n",
      "35000\n",
      "36000\n",
      "37000\n",
      "38000\n",
      "39000\n",
      "40000\n",
      "41000\n",
      "42000\n",
      "43000\n",
      "44000\n",
      "45000\n",
      "46000\n",
      "47000\n",
      "48000\n",
      "49000\n",
      "50000\n",
      "51000\n",
      "52000\n",
      "53000\n",
      "54000\n",
      "55000\n",
      "56000\n",
      "57000\n",
      "58000\n",
      "59000\n",
      "60000\n",
      "61000\n",
      "62000\n",
      "63000\n",
      "64000\n",
      "65000\n",
      "66000\n",
      "67000\n",
      "68000\n",
      "69000\n",
      "70000\n",
      "71000\n",
      "72000\n",
      "73000\n",
      "74000\n",
      "75000\n",
      "76000\n",
      "77000\n",
      "78000\n",
      "79000\n",
      "80000\n",
      "81000\n",
      "82000\n",
      "83000\n",
      "84000\n",
      "85000\n",
      "86000\n",
      "87000\n",
      "88000\n",
      "89000\n",
      "90000\n",
      "91000\n",
      "92000\n",
      "93000\n",
      "94000\n",
      "95000\n",
      "96000\n",
      "97000\n",
      "98000\n",
      "99000\n",
      "100000\n",
      "101000\n",
      "102000\n",
      "103000\n",
      "104000\n",
      "105000\n",
      "106000\n",
      "107000\n",
      "108000\n",
      "109000\n",
      "110000\n",
      "111000\n",
      "112000\n",
      "113000\n",
      "114000\n",
      "115000\n",
      "116000\n",
      "117000\n",
      "118000\n",
      "119000\n",
      "120000\n",
      "121000\n",
      "122000\n",
      "123000\n",
      "124000\n",
      "125000\n",
      "126000\n",
      "127000\n",
      "128000\n",
      "129000\n",
      "130000\n",
      "131000\n",
      "132000\n",
      "133000\n",
      "134000\n",
      "135000\n",
      "136000\n",
      "137000\n",
      "138000\n",
      "139000\n",
      "140000\n",
      "141000\n",
      "142000\n",
      "143000\n",
      "144000\n",
      "145000\n",
      "146000\n",
      "147000\n",
      "148000\n",
      "149000\n",
      "150000\n",
      "151000\n",
      "152000\n",
      "153000\n",
      "154000\n",
      "155000\n",
      "156000\n",
      "157000\n",
      "158000\n",
      "159000\n",
      "160000\n",
      "161000\n",
      "162000\n",
      "163000\n",
      "164000\n",
      "165000\n",
      "166000\n",
      "167000\n",
      "168000\n",
      "169000\n",
      "170000\n",
      "171000\n",
      "172000\n",
      "173000\n",
      "174000\n",
      "175000\n",
      "176000\n",
      "177000\n",
      "178000\n",
      "179000\n",
      "180000\n",
      "181000\n",
      "182000\n",
      "183000\n",
      "184000\n",
      "185000\n",
      "186000\n",
      "187000\n",
      "188000\n",
      "189000\n",
      "190000\n",
      "191000\n",
      "192000\n",
      "193000\n",
      "194000\n",
      "195000\n",
      "196000\n",
      "197000\n",
      "198000\n",
      "199000\n",
      "200000\n",
      "201000\n",
      "202000\n",
      "203000\n",
      "204000\n",
      "205000\n",
      "206000\n",
      "207000\n"
     ]
    }
   ],
   "source": [
    "svamp = SVAMP()\n",
    "asdiv = ASDiv()\n",
    "mathqa = MathQA()\n",
    "x=1000\n",
    "print(len(in_context))\n",
    "for i in range(len(in_context)):\n",
    "    context_dict = in_context[i]\n",
    "    # Extract the last question\n",
    "    last_question = context_dict['question'].split('Q:')[-1].strip()\n",
    "\n",
    "    # Extract the gold answer\n",
    "    gold_answer = context_dict['answer_gold']\n",
    "    \n",
    "    # Combine them into a new dictionary\n",
    "    query_dict = {\n",
    "        'Q': last_question,\n",
    "        'Gold_Answer': gold_answer\n",
    "    }\n",
    "\n",
    "    # Instantiate each data class\n",
    "\n",
    "    # Define the query\n",
    "    query = query_dict\n",
    "    # Parse the query\n",
    "    params = get_gold_query_digits(query['Q'])\n",
    "    # Generate and trim the context\n",
    "    context = generate_context(params)\n",
    "    context = trim_context(context, query)\n",
    "\n",
    "    # Add the query and \"Let's think step by step\" to the end of the context\n",
    "    context += \"\\n\" +\"Q: \"+ query[\"Q\"]\n",
    "#     print(in_context[0][i]['question'])\n",
    "    in_context[i]['question']=context\n",
    "    # Print the final context\n",
    "#     print(context)\n",
    "    if i % x==0: \n",
    "        print(i)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'question': 'Q:Some basic calculations: 2 * 1 / 2 = 1, 7 * 4 = 28, 2 + 2 = 4, 4 * 8 / 32 = 1, 8 * 1 = 8, 3 / 1 + 6 = 9, 6 - 2 = 4, 9 - 4 / 5 = 1, 4 * 3 = 12, 8 - 9 = -1, 2 / 1 / 2 = 1, 8 - 3 / 1 = 5, 83 / 1 - 83 = 0, 78 - 21 + 36 = 93, 34 * 56 + 67 = 1971, 75 - 52 = 23, 74 + 37 * 64 = 7104, 25 - 45 - 15 = -35, 79 / 1 * 16 = 1264, 94 / 47 * 99 = 198, 13 / 1 = 13, 70 + 60 * 65 = 8450, 34 * 85 = 2890, 44 - 36 * 62 = 496 Q: Paul got a box of 457 erasers and 617 crayons for his birthday. At the end of the school year he only had 523 crayons left while not having lost a single eraser. How many more crayons than erasers did he have left? Answer: ( 523.0 - 457.0 )=66.0\\nQ: Paco had 39 sweet cookies and 6 salty cookies. He ate 23 salty cookies and 32 sweet cookies. How many more sweet cookies than salty cookies did he eat? Answer: ( 32.0 - 23.0 )=9.0\\nQ: Zachary did 53 push-ups and 14 crunches in gym class today. David did 17 more push-ups but 10 less crunches than zachary. How many push-ups and crunches did Zachary do? Answer: ( 53.0 + 14.0 )=67.0\\n Q: Larry\\'s Lawn Care charges nine bucks to trim a hedge. If Henry has three hedges, how much money would he spend?\\n Since, 9*3=27. The answer is 27 (bucks)\\nQ: A farmer planted 2 tomato seeds, 10 cucumber seeds and 3 pumpkin seeds. How many seeds did he plant all together?\\n Since, 2+10+3=15. The answer is 15 (seeds)\\nQ: There were 51 geese in the farmer\\'s field. 28 of the geese flew away. How many geese were left in the field?\\n Since, 51-28=23. The answer is 23 (geese)\\n Q: the total price of a kilogram each of shimla apples and red delicious apples is 250 rupees more than the total price of a kilogram each of red delicious apples and fuji apples . fuji apples is how much cheaper than shimla apples ? Lets think step by step. ( shimla + red delicious ) - ( red delicious + fuji ) = 250 shimla - fuji = 250 The answer is 250\\nQ: in a group of cows and hens , the number of legs are 12 more than twice the number of heads . the number of cows is : Lets think step by step. \"let no of cows be x , no of hens be y . so heads = x + y legs = 4 x + 2 y now , 4 x + 2 y = 2 ( x + y ) + 12 2 x = 12 x = 6 .  The answer is 6\\n\\nQ: A movie theater can hold 50 people at a time.  They charge $8.00 a ticket.  On a Tuesday night they only sold 24 tickets.  By not selling out, how much money did they lose?\\nLet\\'s think step by step',\n",
       " 'answer': 'The movie theater can hold 50 people at a time.\\nIf they sell 50 tickets then they will make 50 * $8 = $400\\nIf they sell 24 tickets, then they will make 24 * $8 = $192\\nThey lost $400 - $192 = $208\\nThe answer is 208',\n",
       " 'answer_gold': 'Answer: The theater can hold 50 people and they charge $8.00 to watch a movie so on a sold-out night they make 50*8 = $<<50*8=400.00>>400.00\\nOn Tuesday night they only sold 24 tickets at $8.00 apiece so they made 24*8 = $<<24*8=192.00>>192.00\\nIf they make $400.00 on a sold-out night and they only made $192.00 on Tuesday then they lost 400-192 = $<<400-192=208.00>>208.00\\n#### 208\\n',\n",
       " 'per_step_probs': [{'▁The': 1.0},\n",
       "  OrderedDict([('▁movie', 0.3198),\n",
       "               ('▁theater', 0.3408),\n",
       "               ('▁total', 0.0488),\n",
       "               ('▁capacity', 0.0429),\n",
       "               ('▁maximum', 0.0625)]),\n",
       "  OrderedDict([('▁theater', 0.9139),\n",
       "               ('▁theatre', 0.0742),\n",
       "               ('▁the', 0.0016),\n",
       "               ('▁can', 0.0016)]),\n",
       "  OrderedDict([('▁can', 0.5094),\n",
       "               ('▁charges', 0.0528),\n",
       "               ('▁has', 0.1055),\n",
       "               ('▁holds', 0.0356),\n",
       "               ('▁could', 0.0699)]),\n",
       "  OrderedDict([('▁hold', 0.8542),\n",
       "               ('▁sell', 0.0576),\n",
       "               ('▁fit', 0.0201),\n",
       "               ('▁seat', 0.0283),\n",
       "               ('▁only', 0.0092)]),\n",
       "  OrderedDict([('▁50', 0.9021),\n",
       "               ('▁24', 0.0052),\n",
       "               ('▁up', 0.0245),\n",
       "               ('▁$', 0.0046)]),\n",
       "  OrderedDict([('▁people', 0.9368),\n",
       "               ('▁tickets', 0.0114),\n",
       "               ('▁seats', 0.0123),\n",
       "               ('*', 0.0065)]),\n",
       "  OrderedDict([('▁at', 0.4683), (',', 0.165), ('.', 0.1946), ('▁and', 0.037)]),\n",
       "  {'▁': 1.0},\n",
       "  {'a': 1.0},\n",
       "  OrderedDict([('▁time', 0.9983),\n",
       "               ('▁given', 0.0005),\n",
       "               ('▁single', 0.0001),\n",
       "               ('▁maximum', 0.0001),\n",
       "               ('▁$', 0.0003)]),\n",
       "  {'.': 1.0},\n",
       "  {'▁If': 1.0},\n",
       "  OrderedDict([('▁they', 0.6026),\n",
       "               ('▁24', 0.0308),\n",
       "               ('▁50', 0.0413),\n",
       "               ('▁the', 0.1303),\n",
       "               ('▁it', 0.0436)]),\n",
       "  OrderedDict([('▁sell', 0.1894),\n",
       "               ('▁sold', 0.2714),\n",
       "               ('▁charge', 0.1827),\n",
       "               ('▁had', 0.0385),\n",
       "               ('▁only', 0.0599)]),\n",
       "  OrderedDict([('▁50', 0.2645),\n",
       "               ('▁24', 0.1115),\n",
       "               ('▁all', 0.2348),\n",
       "               ('▁out', 0.2149),\n",
       "               ('▁tickets', 0.034)]),\n",
       "  OrderedDict([('▁tickets', 0.9519),\n",
       "               ('▁people', 0.0094),\n",
       "               ('▁seats', 0.0055),\n",
       "               ('*', 0.0055)]),\n",
       "  OrderedDict([('▁then', 0.0808),\n",
       "               (',', 0.4629),\n",
       "               ('▁for', 0.0489),\n",
       "               ('▁at', 0.2034),\n",
       "               ('▁they', 0.0985)]),\n",
       "  OrderedDict([('▁they', 0.8598),\n",
       "               ('▁each', 0.0117),\n",
       "               ('▁the', 0.0521),\n",
       "               ('▁that', 0.0165),\n",
       "               ('▁their', 0.0207)]),\n",
       "  OrderedDict([('▁will', 0.3053),\n",
       "               ('▁earn', 0.0858),\n",
       "               ('▁would', 0.0839),\n",
       "               ('▁make', 0.2389)]),\n",
       "  OrderedDict([('▁make', 0.4809),\n",
       "               ('▁receive', 0.0284),\n",
       "               ('▁sell', 0.0471),\n",
       "               ('▁have', 0.0891),\n",
       "               ('▁earn', 0.2222)]),\n",
       "  OrderedDict([('▁50', 0.518),\n",
       "               ('▁money', 0.0044),\n",
       "               ('▁$', 0.3976),\n",
       "               ('▁8', 0.036)]),\n",
       "  {'▁*': 1.0},\n",
       "  {'▁$8': 1.0},\n",
       "  {'▁=': 1.0},\n",
       "  {'▁$400': 1.0},\n",
       "  {'▁If': 1.0},\n",
       "  OrderedDict([('▁they', 0.9317),\n",
       "               ('▁24', 0.0045),\n",
       "               ('▁the', 0.0267),\n",
       "               ('▁on', 0.0076),\n",
       "               ('▁only', 0.0055)]),\n",
       "  OrderedDict([('▁sell', 0.5155),\n",
       "               ('▁sold', 0.1216),\n",
       "               ('▁do', 0.0273),\n",
       "               ('▁only', 0.2502),\n",
       "               ('▁don', 0.0369)]),\n",
       "  OrderedDict([('▁24', 0.7404),\n",
       "               ('▁less', 0.0903),\n",
       "               ('▁out', 0.016),\n",
       "               ('▁only', 0.0526)]),\n",
       "  OrderedDict([('▁tickets', 0.962),\n",
       "               ('▁instead', 0.0025),\n",
       "               ('▁out', 0.004),\n",
       "               ('▁people', 0.0051),\n",
       "               ('▁then', 0.0126)]),\n",
       "  OrderedDict([(',', 0.1262),\n",
       "               ('▁instead', 0.0276),\n",
       "               ('▁on', 0.0127),\n",
       "               ('▁they', 0.0872),\n",
       "               ('▁then', 0.7146)]),\n",
       "  OrderedDict([('▁then', 0.642),\n",
       "               ('▁however', 0.0164),\n",
       "               ('▁the', 0.0113),\n",
       "               ('▁that', 0.0142),\n",
       "               ('▁they', 0.2641)]),\n",
       "  OrderedDict([('▁they', 0.8712),\n",
       "               ('▁24', 0.0133),\n",
       "               ('▁the', 0.0365),\n",
       "               ('▁that', 0.0114),\n",
       "               ('▁there', 0.0121)]),\n",
       "  OrderedDict([('▁will', 0.504),\n",
       "               ('▁sold', 0.0423),\n",
       "               ('▁lose', 0.0353),\n",
       "               ('▁only', 0.0908),\n",
       "               ('▁make', 0.0566)]),\n",
       "  OrderedDict([('▁make', 0.7247),\n",
       "               ('▁sell', 0.0236),\n",
       "               ('▁have', 0.0139),\n",
       "               ('▁lose', 0.0142),\n",
       "               ('▁only', 0.183)]),\n",
       "  OrderedDict([('▁24', 0.902),\n",
       "               ('▁less', 0.0039),\n",
       "               ('▁only', 0.0249),\n",
       "               ('▁$', 0.0516)]),\n",
       "  {'▁*': 1.0},\n",
       "  {'▁$8': 1.0},\n",
       "  OrderedDict([('▁=', 0.9916), (',', 0.0002), ('.', 0.0011), ('=', 0.0051)]),\n",
       "  {'▁$1': 1.0},\n",
       "  {'92': 1.0},\n",
       "  {'▁They': 1.0},\n",
       "  OrderedDict([('▁lost', 0.3884),\n",
       "               ('▁lose', 0.0639),\n",
       "               ('▁will', 0.0968),\n",
       "               ('▁would', 0.046),\n",
       "               ('▁did', 0.0879)]),\n",
       "  {'▁$400': 1.0},\n",
       "  {'▁': 1.0},\n",
       "  {'-': 1.0},\n",
       "  {'▁$1': 1.0},\n",
       "  {'92': 1.0},\n",
       "  {'▁=': 1.0},\n",
       "  {'▁$20': 1.0},\n",
       "  {'8': 1.0},\n",
       "  {'▁The': 1.0},\n",
       "  OrderedDict([('▁answer', 0.9909),\n",
       "               ('▁theater', 0.0019),\n",
       "               ('▁amount', 0.002),\n",
       "               ('▁Answer', 0.0006),\n",
       "               ('▁movie', 0.0009)]),\n",
       "  OrderedDict([('▁is', 0.9986),\n",
       "               (':', 0.0001),\n",
       "               ('▁in', 0.0001),\n",
       "               ('▁$', 0.0001)]),\n",
       "  {'▁': 1.0},\n",
       "  {'208': 1.0}],\n",
       " 'answer_label': 1,\n",
       " 'type': 'in_context_chain_of_thought'}"
      ]
     },
     "execution_count": 22,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "#investigate in_ontext first\n",
    "in_context[200001]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [],
   "source": [
    "with open('in_context_chain_of_thought.pkl', 'wb') as f:\n",
    "    pickle.dump(in_context, f)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "in_context = []\n",
    "with open('in_context_chain_of_thought.pkl', 'rb') as f:\n",
    "    try:\n",
    "        while True:\n",
    "            in_context.append(pickle.load(f))\n",
    "    except EOFError:\n",
    "        pass"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "ename": "NameError",
     "evalue": "name 'in_context' is not defined",
     "output_type": "error",
     "traceback": [
      "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[1;31mNameError\u001b[0m                                 Traceback (most recent call last)",
      "\u001b[1;32m<ipython-input-3-47a44e71465a>\u001b[0m in \u001b[0;36m<module>\u001b[1;34m\u001b[0m\n\u001b[1;32m----> 1\u001b[1;33m \u001b[0mlen\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0min_context\u001b[0m\u001b[1;33m[\u001b[0m\u001b[1;36m0\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m",
      "\u001b[1;31mNameError\u001b[0m: name 'in_context' is not defined"
     ]
    }
   ],
   "source": [
    "len(in_context[0])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "num_layers=12\n",
    "device_map = {0: list(range(0, num_layers//2)), 1: list(range(num_layers//2, num_layers))}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{0: [0, 1, 2, 3, 4, 5], 1: [6, 7, 8, 9, 10, 11]}"
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "device_map"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python [conda env:csn] *",
   "language": "python",
   "name": "conda-env-csn-py"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
