{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "ce427773",
   "metadata": {},
   "outputs": [],
   "source": [
    "from transformers import AutoTokenizer, LlamaForCausalLM\n",
    "import pandas as pd\n",
    "import numpy as np\n",
    "import random"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 102,
   "id": "abb3bad4",
   "metadata": {},
   "outputs": [],
   "source": [
    "random.seed(2024)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 55,
   "id": "40499a7d",
   "metadata": {},
   "outputs": [],
   "source": [
    "c = pd.read_csv(\"world_population.csv\")\n",
    "c = c[['Country/Territory', 'Capital', '2022 Population']].dropna()\n",
    "c = c[c['2022 Population'] > 1000000].reset_index(drop = True)\n",
    "c.columns = ['country', 'capital', 'population']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 62,
   "id": "18539e66",
   "metadata": {},
   "outputs": [],
   "source": [
    "diff_cap_larg = ['China', 'Vietnam', 'Turkey', 'Brazil', 'Taiwan',\n",
    "                'South Africa', 'Philippines', 'Ecuador', 'Cameroon',\n",
    "                'Canada', 'Pakistan', 'Myanmar', 'United Arab Emirates',\n",
    "                'Nigeria', 'Kazakhstan', 'United States', 'Morocco', 'Australia',\n",
    "                'Israel', 'Tanzania', 'Benin', 'Bolivia', 'India', 'New Zealand',\n",
    "                'Ivory Coast', 'Equatorial Guinea', 'Switzerland', 'Sri Lanka',\n",
    "                'Trinidad and Tobago', 'Burundi', 'Gambia']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 103,
   "id": "bc96139b",
   "metadata": {},
   "outputs": [],
   "source": [
    "others = [x for x in c['country'] if x not in diff_cap_larg]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 106,
   "id": "92d82aef",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(31, 129)"
      ]
     },
     "execution_count": 106,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "len(diff_cap_larg), len(others)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 77,
   "id": "01f4a9b1",
   "metadata": {},
   "outputs": [],
   "source": [
    "cc_dict = c.set_index('country')['capital'].to_dict()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 63,
   "id": "616bebb9",
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>country</th>\n",
       "      <th>capital</th>\n",
       "      <th>population</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>6</th>\n",
       "      <td>Australia</td>\n",
       "      <td>Canberra</td>\n",
       "      <td>26177413</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>13</th>\n",
       "      <td>Benin</td>\n",
       "      <td>Porto-Novo</td>\n",
       "      <td>13352864</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>14</th>\n",
       "      <td>Bolivia</td>\n",
       "      <td>Sucre</td>\n",
       "      <td>12224110</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>17</th>\n",
       "      <td>Brazil</td>\n",
       "      <td>Brasilia</td>\n",
       "      <td>215313498</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>20</th>\n",
       "      <td>Burundi</td>\n",
       "      <td>Bujumbura</td>\n",
       "      <td>12889576</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>22</th>\n",
       "      <td>Cameroon</td>\n",
       "      <td>Yaounde</td>\n",
       "      <td>27914536</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>23</th>\n",
       "      <td>Canada</td>\n",
       "      <td>Ottawa</td>\n",
       "      <td>38454327</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>27</th>\n",
       "      <td>China</td>\n",
       "      <td>Beijing</td>\n",
       "      <td>1425887337</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>38</th>\n",
       "      <td>Ecuador</td>\n",
       "      <td>Quito</td>\n",
       "      <td>18001000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>41</th>\n",
       "      <td>Equatorial Guinea</td>\n",
       "      <td>Malabo</td>\n",
       "      <td>1674908</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>49</th>\n",
       "      <td>Gambia</td>\n",
       "      <td>Banjul</td>\n",
       "      <td>2705992</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>61</th>\n",
       "      <td>India</td>\n",
       "      <td>New Delhi</td>\n",
       "      <td>1417173173</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>66</th>\n",
       "      <td>Israel</td>\n",
       "      <td>Jerusalem</td>\n",
       "      <td>9038309</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>68</th>\n",
       "      <td>Ivory Coast</td>\n",
       "      <td>Yamoussoukro</td>\n",
       "      <td>28160542</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>72</th>\n",
       "      <td>Kazakhstan</td>\n",
       "      <td>Nursultan</td>\n",
       "      <td>19397998</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>92</th>\n",
       "      <td>Morocco</td>\n",
       "      <td>Rabat</td>\n",
       "      <td>37457971</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>94</th>\n",
       "      <td>Myanmar</td>\n",
       "      <td>Nay Pyi Taw</td>\n",
       "      <td>54179306</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>98</th>\n",
       "      <td>New Zealand</td>\n",
       "      <td>Wellington</td>\n",
       "      <td>5185288</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>101</th>\n",
       "      <td>Nigeria</td>\n",
       "      <td>Abuja</td>\n",
       "      <td>218541212</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>106</th>\n",
       "      <td>Pakistan</td>\n",
       "      <td>Islamabad</td>\n",
       "      <td>235824862</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>112</th>\n",
       "      <td>Philippines</td>\n",
       "      <td>Manila</td>\n",
       "      <td>115559009</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>129</th>\n",
       "      <td>South Africa</td>\n",
       "      <td>Pretoria</td>\n",
       "      <td>59893885</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>133</th>\n",
       "      <td>Sri Lanka</td>\n",
       "      <td>Colombo</td>\n",
       "      <td>21832143</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>136</th>\n",
       "      <td>Switzerland</td>\n",
       "      <td>Bern</td>\n",
       "      <td>8740472</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>138</th>\n",
       "      <td>Taiwan</td>\n",
       "      <td>Taipei</td>\n",
       "      <td>23893394</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>140</th>\n",
       "      <td>Tanzania</td>\n",
       "      <td>Dodoma</td>\n",
       "      <td>65497748</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>144</th>\n",
       "      <td>Trinidad and Tobago</td>\n",
       "      <td>Port-of-Spain</td>\n",
       "      <td>1531044</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>146</th>\n",
       "      <td>Turkey</td>\n",
       "      <td>Ankara</td>\n",
       "      <td>85341241</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>150</th>\n",
       "      <td>United Arab Emirates</td>\n",
       "      <td>Abu Dhabi</td>\n",
       "      <td>9441129</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>152</th>\n",
       "      <td>United States</td>\n",
       "      <td>Washington, D.C.</td>\n",
       "      <td>338289857</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>156</th>\n",
       "      <td>Vietnam</td>\n",
       "      <td>Hanoi</td>\n",
       "      <td>98186856</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "                  country           capital  population\n",
       "6               Australia          Canberra    26177413\n",
       "13                  Benin        Porto-Novo    13352864\n",
       "14                Bolivia             Sucre    12224110\n",
       "17                 Brazil          Brasilia   215313498\n",
       "20                Burundi         Bujumbura    12889576\n",
       "22               Cameroon           Yaounde    27914536\n",
       "23                 Canada            Ottawa    38454327\n",
       "27                  China           Beijing  1425887337\n",
       "38                Ecuador             Quito    18001000\n",
       "41      Equatorial Guinea            Malabo     1674908\n",
       "49                 Gambia            Banjul     2705992\n",
       "61                  India         New Delhi  1417173173\n",
       "66                 Israel         Jerusalem     9038309\n",
       "68            Ivory Coast      Yamoussoukro    28160542\n",
       "72             Kazakhstan         Nursultan    19397998\n",
       "92                Morocco             Rabat    37457971\n",
       "94                Myanmar       Nay Pyi Taw    54179306\n",
       "98            New Zealand        Wellington     5185288\n",
       "101               Nigeria             Abuja   218541212\n",
       "106              Pakistan         Islamabad   235824862\n",
       "112           Philippines            Manila   115559009\n",
       "129          South Africa          Pretoria    59893885\n",
       "133             Sri Lanka           Colombo    21832143\n",
       "136           Switzerland              Bern     8740472\n",
       "138                Taiwan            Taipei    23893394\n",
       "140              Tanzania            Dodoma    65497748\n",
       "144   Trinidad and Tobago     Port-of-Spain     1531044\n",
       "146                Turkey            Ankara    85341241\n",
       "150  United Arab Emirates         Abu Dhabi     9441129\n",
       "152         United States  Washington, D.C.   338289857\n",
       "156               Vietnam             Hanoi    98186856"
      ]
     },
     "execution_count": 63,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "c[c['country'].isin(diff_cap_larg)]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 65,
   "id": "53226017",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "20b4f620a2f34b21adf4038aca80ecce",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "model = LlamaForCausalLM.from_pretrained(\"meta-llama/Llama-2-7b-hf\", \n",
    "                                         token = \"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\")\n",
    "\n",
    "tokenizer = AutoTokenizer.from_pretrained(\"meta-llama/Llama-2-7b-hf\", \n",
    "                                          token = \"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\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "62ac1244",
   "metadata": {},
   "source": [
    "### Creating a prompt"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "46e39297",
   "metadata": {},
   "source": [
    "Each prompt contains 6 in-context examples of country and its capital city.\n",
    "\n",
    "In order to avoid ambiguity of capital vs largest city, exactly 3 of the examples are from diff_cap_larg.\n",
    "\n",
    "Last prompt: 500 randomly selected from diff_cap_larg, 500 randomly selected from others"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 107,
   "id": "aff7d320",
   "metadata": {},
   "outputs": [],
   "source": [
    "def create_prompt_and_answer(type_):\n",
    "    \n",
    "    if type_ == \"diff_cap_larg\":\n",
    "        \n",
    "        temp = random.sample(diff_cap_larg, 4)\n",
    "        country_qn = temp[0]\n",
    "        countries_prompt = temp[1:] + random.sample(others, 3)\n",
    "        random.shuffle(countries_prompt)  \n",
    "        \n",
    "    elif type_ == \"others\":\n",
    "        \n",
    "        temp = random.sample(others, 4)\n",
    "        country_qn = temp[0]\n",
    "        countries_prompt = temp[1:] + random.sample(diff_cap_larg, 3)\n",
    "        random.shuffle(countries_prompt)  \n",
    "\n",
    "    sentence = ''\n",
    "    \n",
    "    for c in countries_prompt:\n",
    "        \n",
    "        sentence += c\n",
    "        sentence += ' '\n",
    "        sentence += cc_dict[c]\n",
    "        sentence += ', '\n",
    "        \n",
    "    sentence += country_qn\n",
    "    \n",
    "    return (sentence, cc_dict[country_qn])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 128,
   "id": "cf3c89c9",
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Starting 0\n",
      "Starting 10\n",
      "Starting 20\n",
      "Starting 30\n",
      "Starting 40\n",
      "Starting 50\n",
      "Starting 60\n",
      "Starting 70\n",
      "Starting 80\n",
      "Starting 90\n",
      "Starting 100\n",
      "Starting 110\n",
      "Starting 120\n",
      "Starting 130\n",
      "Starting 140\n",
      "Starting 150\n",
      "Starting 160\n",
      "Starting 170\n",
      "Starting 180\n",
      "Starting 190\n",
      "Starting 200\n",
      "Starting 210\n",
      "Starting 220\n",
      "Starting 230\n",
      "Starting 240\n",
      "Starting 250\n",
      "Starting 260\n",
      "Starting 270\n",
      "Starting 280\n",
      "Starting 290\n",
      "Starting 300\n",
      "Starting 310\n",
      "Starting 320\n",
      "Starting 330\n",
      "Starting 340\n",
      "Starting 350\n",
      "Starting 360\n",
      "Starting 370\n",
      "Starting 380\n",
      "Starting 390\n",
      "Starting 400\n",
      "Starting 410\n",
      "Starting 420\n",
      "Starting 430\n",
      "Starting 440\n",
      "Starting 450\n",
      "Starting 460\n",
      "Starting 470\n",
      "Starting 480\n",
      "Starting 490\n"
     ]
    }
   ],
   "source": [
    "###### DIFF_CAP_LARG QUESTIONS\n",
    "\n",
    "dcl_prompts = []\n",
    "dcl_keys = []\n",
    "dcl_outputs = []\n",
    "\n",
    "for i in range(500):\n",
    "    \n",
    "    if i % 10 == 0:\n",
    "        print('Starting ' + str(i))\n",
    "    \n",
    "    prompt, key = create_prompt_and_answer(type_ = \"diff_cap_larg\")\n",
    "    inputs = tokenizer(prompt, return_tensors=\"pt\")\n",
    "    length = len(inputs['input_ids'][0])\n",
    "    generate_ids = model.generate(inputs.input_ids, max_length = length + 12)\n",
    "    output = tokenizer.batch_decode(generate_ids[:,length:], skip_special_tokens=True, \n",
    "                                    clean_up_tokenization_spaces=False)[0]\n",
    "    \n",
    "    dcl_prompts.append(prompt)\n",
    "    dcl_keys.append(key)\n",
    "    dcl_outputs.append(output)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 129,
   "id": "c475cdc9",
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Starting 0\n",
      "Starting 10\n",
      "Starting 20\n",
      "Starting 30\n",
      "Starting 40\n",
      "Starting 50\n",
      "Starting 60\n",
      "Starting 70\n",
      "Starting 80\n",
      "Starting 90\n",
      "Starting 100\n",
      "Starting 110\n",
      "Starting 120\n",
      "Starting 130\n",
      "Starting 140\n",
      "Starting 150\n",
      "Starting 160\n",
      "Starting 170\n",
      "Starting 180\n",
      "Starting 190\n",
      "Starting 200\n",
      "Starting 210\n",
      "Starting 220\n",
      "Starting 230\n",
      "Starting 240\n",
      "Starting 250\n",
      "Starting 260\n",
      "Starting 270\n",
      "Starting 280\n",
      "Starting 290\n",
      "Starting 300\n",
      "Starting 310\n",
      "Starting 320\n",
      "Starting 330\n",
      "Starting 340\n",
      "Starting 350\n",
      "Starting 360\n",
      "Starting 370\n",
      "Starting 380\n",
      "Starting 390\n",
      "Starting 400\n",
      "Starting 410\n",
      "Starting 420\n",
      "Starting 430\n",
      "Starting 440\n",
      "Starting 450\n",
      "Starting 460\n",
      "Starting 470\n",
      "Starting 480\n",
      "Starting 490\n"
     ]
    }
   ],
   "source": [
    "###### OTHERS QUESTIONS\n",
    "\n",
    "oth_prompts = []\n",
    "oth_keys = []\n",
    "oth_outputs = []\n",
    "\n",
    "for i in range(500):\n",
    "    \n",
    "    if i % 10 == 0:\n",
    "        print('Starting ' + str(i))\n",
    "    \n",
    "    prompt, key = create_prompt_and_answer(type_ = \"others\")\n",
    "    inputs = tokenizer(prompt, return_tensors=\"pt\")\n",
    "    length = len(inputs['input_ids'][0])\n",
    "    generate_ids = model.generate(inputs.input_ids, max_length = length + 12)\n",
    "    output = tokenizer.batch_decode(generate_ids[:,length:], skip_special_tokens=True, \n",
    "                                    clean_up_tokenization_spaces=False)[0]\n",
    "    \n",
    "    oth_prompts.append(prompt)\n",
    "    oth_keys.append(key)\n",
    "    oth_outputs.append(output)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 130,
   "id": "5b1a6633",
   "metadata": {},
   "outputs": [],
   "source": [
    "dcl_df = pd.DataFrame(columns = ['prompts', 'keys', 'outputs'])\n",
    "dcl_df['prompts'] = dcl_prompts\n",
    "dcl_df['keys'] = dcl_keys\n",
    "dcl_df['outputs'] = dcl_outputs\n",
    "\n",
    "dcl_df.to_csv('dcl_df.csv')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 131,
   "id": "60f035b0",
   "metadata": {},
   "outputs": [],
   "source": [
    "oth_df = pd.DataFrame(columns = ['prompts', 'keys', 'outputs'])\n",
    "oth_df['prompts'] = oth_prompts\n",
    "oth_df['keys'] = oth_keys\n",
    "oth_df['outputs'] = oth_outputs\n",
    "\n",
    "oth_df.to_csv('oth_df.csv')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7da770c4",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
