{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "43e80dd8",
   "metadata": {},
   "outputs": [],
   "source": [
    "from transformers import AutoTokenizer, LlamaForCausalLM\n",
    "import pandas as pd\n",
    "import numpy as np\n",
    "import random"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "a1bce5e7",
   "metadata": {},
   "outputs": [],
   "source": [
    "random.seed(2024)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "baee4ac1",
   "metadata": {},
   "outputs": [],
   "source": [
    "c = pd.read_csv(\"us-state-capitals.csv\")\n",
    "d = pd.read_csv(\"uscities.csv\")\n",
    "\n",
    "c = c[['name', 'description']]\n",
    "c.columns = ['state', 'capital']\n",
    "c['capital'] = c['capital'].apply(lambda x: x.rstrip('<br>') if x.endswith('<br>') else x)\n",
    "\n",
    "d = d.loc[d.groupby('state_name')['population'].idxmax()][['state_name', 'city']].reset_index(drop = True)\n",
    "d.columns = ['state', 'largest']\n",
    "\n",
    "joined = pd.merge(c, d, on = 'state')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "12d39796",
   "metadata": {},
   "outputs": [],
   "source": [
    "diff_cap_larg = list(joined[joined['capital'] != joined['largest']]['state'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "d9cff494",
   "metadata": {},
   "outputs": [],
   "source": [
    "others = [x for x in joined['state'] if x not in diff_cap_larg]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "284de563",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(33, 17)"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "len(diff_cap_larg), len(others)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "166366de",
   "metadata": {},
   "outputs": [],
   "source": [
    "cc_dict = joined.set_index('state')['capital'].to_dict()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "a682a849",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "aaf05db2564f4daeb512f184f833223c",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "model = LlamaForCausalLM.from_pretrained(\"meta-llama/Llama-2-7b-hf\", \n",
    "                                         token = \"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\")\n",
    "\n",
    "tokenizer = AutoTokenizer.from_pretrained(\"meta-llama/Llama-2-7b-hf\", \n",
    "                                          token = \"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\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "249b5aa3",
   "metadata": {},
   "source": [
    "### Creating a prompt"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8c3b88d3",
   "metadata": {},
   "source": [
    "Each prompt contains 6 in-context examples of country and its capital city.\n",
    "\n",
    "In order to avoid ambiguity of capital vs largest city, exactly 3 of the examples are from diff_cap_larg.\n",
    "\n",
    "Last prompt: 500 randomly selected from diff_cap_larg, 500 randomly selected from others"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "fcf156c4",
   "metadata": {},
   "outputs": [],
   "source": [
    "def create_prompt_and_answer(type_):\n",
    "    \n",
    "    if type_ == \"diff_cap_larg\":\n",
    "        \n",
    "        temp = random.sample(diff_cap_larg, 4)\n",
    "        country_qn = temp[0]\n",
    "        countries_prompt = temp[1:] + random.sample(others, 3)\n",
    "        random.shuffle(countries_prompt)  \n",
    "        \n",
    "    elif type_ == \"others\":\n",
    "        \n",
    "        temp = random.sample(others, 4)\n",
    "        country_qn = temp[0]\n",
    "        countries_prompt = temp[1:] + random.sample(diff_cap_larg, 3)\n",
    "        random.shuffle(countries_prompt)  \n",
    "\n",
    "    sentence = ''\n",
    "    \n",
    "    for c in countries_prompt:\n",
    "        \n",
    "        sentence += c\n",
    "        sentence += ' '\n",
    "        sentence += cc_dict[c]\n",
    "        sentence += ', '\n",
    "        \n",
    "    sentence += country_qn\n",
    "    \n",
    "    return (sentence, cc_dict[country_qn])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "3729d848",
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Starting 0\n",
      "Starting 10\n",
      "Starting 20\n",
      "Starting 30\n",
      "Starting 40\n",
      "Starting 50\n",
      "Starting 60\n",
      "Starting 70\n",
      "Starting 80\n",
      "Starting 90\n",
      "Starting 100\n",
      "Starting 110\n",
      "Starting 120\n",
      "Starting 130\n",
      "Starting 140\n",
      "Starting 150\n",
      "Starting 160\n",
      "Starting 170\n",
      "Starting 180\n",
      "Starting 190\n",
      "Starting 200\n",
      "Starting 210\n",
      "Starting 220\n",
      "Starting 230\n",
      "Starting 240\n",
      "Starting 250\n",
      "Starting 260\n",
      "Starting 270\n",
      "Starting 280\n",
      "Starting 290\n",
      "Starting 300\n",
      "Starting 310\n",
      "Starting 320\n",
      "Starting 330\n",
      "Starting 340\n",
      "Starting 350\n",
      "Starting 360\n",
      "Starting 370\n",
      "Starting 380\n",
      "Starting 390\n",
      "Starting 400\n",
      "Starting 410\n",
      "Starting 420\n",
      "Starting 430\n",
      "Starting 440\n",
      "Starting 450\n",
      "Starting 460\n",
      "Starting 470\n",
      "Starting 480\n",
      "Starting 490\n"
     ]
    }
   ],
   "source": [
    "###### DIFF_CAP_LARG QUESTIONS\n",
    "\n",
    "dcl_prompts = []\n",
    "dcl_keys = []\n",
    "dcl_outputs = []\n",
    "\n",
    "for i in range(500):\n",
    "    \n",
    "    if i % 10 == 0:\n",
    "        print('Starting ' + str(i))\n",
    "    \n",
    "    prompt, key = create_prompt_and_answer(type_ = \"diff_cap_larg\")\n",
    "    inputs = tokenizer(prompt, return_tensors=\"pt\")\n",
    "    length = len(inputs['input_ids'][0])\n",
    "    generate_ids = model.generate(inputs.input_ids, max_length = length + 12)\n",
    "    output = tokenizer.batch_decode(generate_ids[:,length:], skip_special_tokens=True, \n",
    "                                    clean_up_tokenization_spaces=False)[0]\n",
    "    \n",
    "    dcl_prompts.append(prompt)\n",
    "    dcl_keys.append(key)\n",
    "    dcl_outputs.append(output)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "12f24fb5",
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Starting 0\n",
      "Starting 10\n",
      "Starting 20\n",
      "Starting 30\n",
      "Starting 40\n",
      "Starting 50\n",
      "Starting 60\n",
      "Starting 70\n",
      "Starting 80\n",
      "Starting 90\n",
      "Starting 100\n",
      "Starting 110\n",
      "Starting 120\n",
      "Starting 130\n",
      "Starting 140\n",
      "Starting 150\n",
      "Starting 160\n",
      "Starting 170\n",
      "Starting 180\n",
      "Starting 190\n",
      "Starting 200\n",
      "Starting 210\n",
      "Starting 220\n",
      "Starting 230\n",
      "Starting 240\n",
      "Starting 250\n",
      "Starting 260\n",
      "Starting 270\n",
      "Starting 280\n",
      "Starting 290\n",
      "Starting 300\n",
      "Starting 310\n",
      "Starting 320\n",
      "Starting 330\n",
      "Starting 340\n",
      "Starting 350\n",
      "Starting 360\n",
      "Starting 370\n",
      "Starting 380\n",
      "Starting 390\n",
      "Starting 400\n",
      "Starting 410\n",
      "Starting 420\n",
      "Starting 430\n",
      "Starting 440\n",
      "Starting 450\n",
      "Starting 460\n",
      "Starting 470\n",
      "Starting 480\n",
      "Starting 490\n"
     ]
    }
   ],
   "source": [
    "###### OTHERS QUESTIONS\n",
    "\n",
    "oth_prompts = []\n",
    "oth_keys = []\n",
    "oth_outputs = []\n",
    "\n",
    "for i in range(500):\n",
    "    \n",
    "    if i % 10 == 0:\n",
    "        print('Starting ' + str(i))\n",
    "    \n",
    "    prompt, key = create_prompt_and_answer(type_ = \"others\")\n",
    "    inputs = tokenizer(prompt, return_tensors=\"pt\")\n",
    "    length = len(inputs['input_ids'][0])\n",
    "    generate_ids = model.generate(inputs.input_ids, max_length = length + 12)\n",
    "    output = tokenizer.batch_decode(generate_ids[:,length:], skip_special_tokens=True, \n",
    "                                    clean_up_tokenization_spaces=False)[0]\n",
    "    \n",
    "    oth_prompts.append(prompt)\n",
    "    oth_keys.append(key)\n",
    "    oth_outputs.append(output)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "d0232b1f",
   "metadata": {},
   "outputs": [],
   "source": [
    "dcl_df = pd.DataFrame(columns = ['prompts', 'keys', 'outputs'])\n",
    "dcl_df['prompts'] = dcl_prompts\n",
    "dcl_df['keys'] = dcl_keys\n",
    "dcl_df['outputs'] = dcl_outputs\n",
    "\n",
    "dcl_df.to_csv('dcl_df_state.csv')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "4710a963",
   "metadata": {},
   "outputs": [],
   "source": [
    "oth_df = pd.DataFrame(columns = ['prompts', 'keys', 'outputs'])\n",
    "oth_df['prompts'] = oth_prompts\n",
    "oth_df['keys'] = oth_keys\n",
    "oth_df['outputs'] = oth_outputs\n",
    "\n",
    "oth_df.to_csv('oth_df_state.csv')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3e00b0d2",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c8d792b7",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "856641fc",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'gajah, horse kuda, mouse kucing'"
      ]
     },
     "execution_count": 15,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "prompt = \"dog anjing, cat kucing, lion singa, elephant\"\n",
    "\n",
    "inputs = tokenizer(prompt, return_tensors=\"pt\")\n",
    "length = len(inputs['input_ids'][0])\n",
    "generate_ids = model.generate(inputs.input_ids, max_length = length + 12)\n",
    "output = tokenizer.batch_decode(generate_ids[:,length:], skip_special_tokens=True, \n",
    "                                clean_up_tokenization_spaces=False)[0]\n",
    "\n",
    "output"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "id": "67298e82",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'harimau, bear beruang, rabbit b'"
      ]
     },
     "execution_count": 18,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "prompt = \"dog anjing, cat kucing, lion singa, tiger\"\n",
    "\n",
    "inputs = tokenizer(prompt, return_tensors=\"pt\")\n",
    "length = len(inputs['input_ids'][0])\n",
    "generate_ids = model.generate(inputs.input_ids, max_length = length + 12)\n",
    "output = tokenizer.batch_decode(generate_ids[:,length:], skip_special_tokens=True, \n",
    "                                clean_up_tokenization_spaces=False)[0]\n",
    "\n",
    "output"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "id": "2f1bc784",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'-to-be-published, traditional Chinese\\nWhen'"
      ]
     },
     "execution_count": 19,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "prompt = \"dog anjing, cat kucing, lion singa, soon\"\n",
    "\n",
    "inputs = tokenizer(prompt, return_tensors=\"pt\")\n",
    "length = len(inputs['input_ids'][0])\n",
    "generate_ids = model.generate(inputs.input_ids, max_length = length + 12)\n",
    "output = tokenizer.batch_decode(generate_ids[:,length:], skip_special_tokens=True, \n",
    "                                clean_up_tokenization_spaces=False)[0]\n",
    "\n",
    "output"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "id": "edae0ae4",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'an, and more\\nEpisode 1 - Bers'"
      ]
     },
     "execution_count": 20,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "prompt = \"dog anjing, cat kucing, lion singa, main\"\n",
    "\n",
    "inputs = tokenizer(prompt, return_tensors=\"pt\")\n",
    "length = len(inputs['input_ids'][0])\n",
    "generate_ids = model.generate(inputs.input_ids, max_length = length + 12)\n",
    "output = tokenizer.batch_decode(generate_ids[:,length:], skip_special_tokens=True, \n",
    "                                clean_up_tokenization_spaces=False)[0]\n",
    "\n",
    "output"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
