{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "f434e707-d5b1-4bea-8c58-9dada196c167",
   "metadata": {},
   "source": [
    "# Testing flan-t5 performance on multiarith "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "28efad8f-5ad1-42dc-99fc-dacd76fbc2b1",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/opt/conda/envs/llm/lib/python3.8/site-packages/tqdm/auto.py:22: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
      "  from .autonotebook import tqdm as notebook_tqdm\n"
     ]
    }
   ],
   "source": [
    "import datasets\n",
    "import torch\n",
    "import re\n",
    "import json\n",
    "\n",
    "import numpy as np\n",
    "\n",
    "from tqdm import tqdm\n",
    "from datasets import load_dataset\n",
    "from transformers import T5Tokenizer, T5ForConditionalGeneration"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "a0a55ea2-0162-49b6-9630-054ee301daf1",
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset = json.load(open('../data/multiarith/MultiArith.json'))\n",
    "dev_ind = np.load('../data/multiarith/validation_index.npy')\n",
    "dev_data = [dataset[i] for i in dev_ind]\n",
    "test_data = [d for i, d in enumerate(dataset) if i not in dev_ind]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "3dd35ad2-d2b5-4541-b7ad-59b15e7e88f8",
   "metadata": {},
   "outputs": [],
   "source": [
    "prompt_complex = open('../data/multiarith/prompt_hardest_from_gsm8k.txt').read()\n",
    "prompt_original = open('../data/multiarith/prompt_original_from_gsm8k.txt').read()\n",
    "prompt_random = open('../data/multiarith/prompt_random_from_gsm8k.txt').read()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "7ff3e04b-0930-4d98-baee-4728781f1c4d",
   "metadata": {},
   "outputs": [],
   "source": [
    "prompt = open('../lib_prompt/prompt_simple_4_cases.txt').read()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "6cc79c88-af6e-4785-8dc2-87e2e87c4402",
   "metadata": {
    "collapsed": true,
    "jupyter": {
     "outputs_hidden": true
    },
    "tags": []
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Fri Dec 23 19:06:32 2022       \n",
      "+-----------------------------------------------------------------------------+\n",
      "| NVIDIA-SMI 510.47.03    Driver Version: 510.47.03    CUDA Version: 11.6     |\n",
      "|-------------------------------+----------------------+----------------------+\n",
      "| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |\n",
      "| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |\n",
      "|                               |                      |               MIG M. |\n",
      "|===============================+======================+======================|\n",
      "|   0  NVIDIA A100-SXM...  On   | 00000000:00:05.0 Off |                    0 |\n",
      "| N/A   53C    P0   101W / 400W |  69705MiB / 81920MiB |      0%      Default |\n",
      "|                               |                      |             Disabled |\n",
      "+-------------------------------+----------------------+----------------------+\n",
      "|   1  NVIDIA A100-SXM...  On   | 00000000:00:06.0 Off |                    0 |\n",
      "| N/A   59C    P0   366W / 400W |  64613MiB / 81920MiB |    100%      Default |\n",
      "|                               |                      |             Disabled |\n",
      "+-------------------------------+----------------------+----------------------+\n",
      "|   2  NVIDIA A100-SXM...  On   | 00000000:00:07.0 Off |                    0 |\n",
      "| N/A   34C    P0    62W / 400W |      2MiB / 81920MiB |      0%      Default |\n",
      "|                               |                      |             Disabled |\n",
      "+-------------------------------+----------------------+----------------------+\n",
      "|   3  NVIDIA A100-SXM...  On   | 00000000:00:08.0 Off |                    0 |\n",
      "| N/A   36C    P0    72W / 400W |  21157MiB / 81920MiB |      0%      Default |\n",
      "|                               |                      |             Disabled |\n",
      "+-------------------------------+----------------------+----------------------+\n",
      "|   4  NVIDIA A100-SXM...  On   | 00000000:80:00.0 Off |                    0 |\n",
      "| N/A   36C    P0    78W / 400W |  67111MiB / 81920MiB |      0%      Default |\n",
      "|                               |                      |             Disabled |\n",
      "+-------------------------------+----------------------+----------------------+\n",
      "|   5  NVIDIA A100-SXM...  On   | 00000000:80:01.0 Off |                    0 |\n",
      "| N/A   39C    P0    90W / 400W |  67111MiB / 81920MiB |      0%      Default |\n",
      "|                               |                      |             Disabled |\n",
      "+-------------------------------+----------------------+----------------------+\n",
      "|   6  NVIDIA A100-SXM...  On   | 00000000:80:02.0 Off |                    0 |\n",
      "| N/A   40C    P0   159W / 400W |  77785MiB / 81920MiB |     38%      Default |\n",
      "|                               |                      |             Disabled |\n",
      "+-------------------------------+----------------------+----------------------+\n",
      "|   7  NVIDIA A100-SXM...  On   | 00000000:80:03.0 Off |                    0 |\n",
      "| N/A   37C    P0    69W / 400W |      0MiB / 81920MiB |      0%      Default |\n",
      "|                               |                      |             Disabled |\n",
      "+-------------------------------+----------------------+----------------------+\n",
      "                                                                               \n",
      "+-----------------------------------------------------------------------------+\n",
      "| Processes:                                                                  |\n",
      "|  GPU   GI   CI        PID   Type   Process name                  GPU Memory |\n",
      "|        ID   ID                                                   Usage      |\n",
      "|=============================================================================|\n",
      "|    0   N/A  N/A     45415      C   python                          69703MiB |\n",
      "|    1   N/A  N/A     45415      C   python                          64611MiB |\n",
      "|    3   N/A  N/A     51171      C   ...da/envs/llm/bin/python3.8    21155MiB |\n",
      "|    4   N/A  N/A     86204      C   python                          67109MiB |\n",
      "|    5   N/A  N/A     86890      C   python                          67109MiB |\n",
      "|    6   N/A  N/A     71318      C   python                          77783MiB |\n",
      "+-----------------------------------------------------------------------------+\n"
     ]
    }
   ],
   "source": [
    "!nvidia-smi"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "68e9600f-72d9-44d7-9de3-47f296ebaf35",
   "metadata": {},
   "outputs": [],
   "source": [
    "tokenizer = T5Tokenizer.from_pretrained(\"google/flan-t5-xl\")\n",
    "model = T5ForConditionalGeneration.from_pretrained(\"/mnt/data_10t/flan_t5_distill/checkpoints/0.0.2.1_epoch_0_end\").to('cuda:2')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "d44cb016-b93e-4000-8657-9e6d95e2575d",
   "metadata": {},
   "outputs": [],
   "source": [
    "def test_answer(pred_str, ans_str):\n",
    "    \"\"\"Find the last number as the predicted answer\"\"\"\n",
    "    pattern = '\\d*\\.?\\d+'\n",
    "    pred = re.findall(pattern, pred_str)\n",
    "    if(len(pred) >= 1):\n",
    "        # print(pred_str)\n",
    "        pred = float(pred[-1])\n",
    "        gold = re.findall(pattern, ans_str)\n",
    "        # print(ans_str)\n",
    "        gold = float(gold[-1])\n",
    "        return pred == gold\n",
    "    else: return False\n",
    "\n",
    "\n",
    "def parse_pred_ans(filename, stop_at=-1):\n",
    "    with open(filename) as fd: lines = fd.readlines()\n",
    "    am, a = None, None\n",
    "    num_q, acc = 0, 0\n",
    "    current_mode = 'none'\n",
    "    questions = []\n",
    "    ans_pred = []\n",
    "    ans_gold = []\n",
    "    for l in lines:\n",
    "        if(l.startswith('Q: ')):\n",
    "            if(am is not None and a is not None):\n",
    "                questions.append(q)\n",
    "                ans_pred.append(am)\n",
    "                ans_gold.append(a)\n",
    "                if(test_answer(am, a)):\n",
    "                    acc += 1\n",
    "            current_mode = 'q'\n",
    "            q = l\n",
    "            num_q += 1\n",
    "            if(num_q == stop_at): break\n",
    "        elif(l.startswith('A_model:')):\n",
    "            current_mode = 'am'\n",
    "            am = l\n",
    "        elif(l.startswith('A:')):\n",
    "            current_mode = 'a'\n",
    "            a = l\n",
    "        else:\n",
    "            if(current_mode == 'q'): q += l\n",
    "            elif(current_mode == 'am'): am += l\n",
    "            elif(current_mode == 'a'): a += l\n",
    "            else:\n",
    "                raise ValueError(current_mode)\n",
    "                \n",
    "    questions.append(q)\n",
    "    ans_pred.append(am)\n",
    "    ans_gold.append(a)\n",
    "    if(test_answer(am, a)):\n",
    "        acc += 1\n",
    "    print('num_q %d correct %d ratio %.4f' % (num_q, acc, float(acc / num_q)))\n",
    "    return questions, ans_pred, ans_gold"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5e8ff96d-7699-4df3-a99d-a2c114ba81e0",
   "metadata": {},
   "source": [
    "# Test Data with Original Prompt, Acc 8"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "9a71e40b-56ff-4e1d-b715-743f4abeed90",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 10/10 [02:44<00:00, 16.40s/it]\n"
     ]
    }
   ],
   "source": [
    "i = 0\n",
    "batch_size = 40\n",
    "with open('../outputs/test_distilled_flan_t5_3b_multiarith.txt', 'w') as fd:\n",
    "    for i in tqdm(range(0, len(test_data), batch_size), total=len(test_data) // batch_size):\n",
    "        questions = []\n",
    "        q_batch = []\n",
    "        a_batch = []\n",
    "        for k in range(batch_size):\n",
    "            q = test_data[i + k]['sQuestion'][1:-1]\n",
    "            q_batch.append(q)\n",
    "            a = test_data[i + k]['lSolutions']\n",
    "            a_batch.append(a)\n",
    "            \n",
    "            prompt_q = prompt + '\\nQ: ' + q + '\\n'\n",
    "            prompt_q += \"Let's think step by step\\n\"\n",
    "            questions.append(prompt_q)\n",
    "            \n",
    "        inputs = tokenizer(questions, padding=True, return_tensors=\"pt\")\n",
    "        with torch.no_grad():\n",
    "            outputs = model.generate(inputs['input_ids'].to('cuda:2'), attention_mask=inputs['attention_mask'].to('cuda:2'), max_length=256)\n",
    "        \n",
    "        for q, a, ans_ in zip(q_batch, a_batch, outputs):\n",
    "            ans_ = tokenizer.decode(ans_)\n",
    "            fd.write('Q: %s\\nA_model:\\n%s\\nA:\\n%s\\n\\n' % (q, ans_, a))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "231c4eea-e036-49f4-9774-141f9ebb9278",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "num_q 400 correct 107 ratio 0.2675\n"
     ]
    }
   ],
   "source": [
    "_, _, _ = parse_pred_ans('../outputs/test_distilled_flan_t5_3b_multiarith.txt')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "2b95c720-4f9e-4952-87a4-589574886d00",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 400/400 [29:05<00:00,  4.36s/it]\n"
     ]
    }
   ],
   "source": [
    "with open('../outputs/test_distilled_flan_t5_3b_multiarith.txt', 'w') as fd:\n",
    "    for case in tqdm(test_data):\n",
    "        q = case['sQuestion'][1:-1]\n",
    "        a = case['lSolutions']\n",
    "        \n",
    "        prompt_q = prompt_original + 'Q: ' + q + '\\n' # TODO: change this to be 4 cases\n",
    "        prompt_q += \"Let's think step by step\\n\"\n",
    "        \n",
    "        input_ids = tokenizer(prompt_q, return_tensors=\"pt\").input_ids.to(\"cuda:2\")\n",
    "        outputs = model.generate(input_ids, max_length=256)\n",
    "        ans_ = tokenizer.decode(outputs[0])\n",
    "        \n",
    "        fd.write('Q: %s\\nA_model:\\n%s\\nA:\\n%s\\n\\n' % (q, ans_, a))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "53939732-117e-4490-8844-caaeea4adbc6",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "num_q 400 correct 32 ratio 0.0800\n"
     ]
    }
   ],
   "source": [
    "_, _, _ = parse_pred_ans('../outputs/test_distilled_flan_t5_3b_multiarith.txt')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5fb17e48-33ea-4c22-b245-286f408a1aee",
   "metadata": {},
   "source": [
    "# Test Data with 4 Cases Prompt"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "63565f5f-2596-447e-86d8-5f59a0dbb113",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 400/400 [26:14<00:00,  3.94s/it]\n"
     ]
    }
   ],
   "source": [
    "with open('../outputs/test_distilled_flan_t5_3b_multiarith_4_cases_prompt.txt', 'w') as fd:\n",
    "    for case in tqdm(test_data):\n",
    "        q = case['sQuestion'][1:-1]\n",
    "        a = case['lSolutions']\n",
    "        \n",
    "        prompt_q = prompt + '\\nQ: ' + q + '\\n' # TODO: change this to be 4 cases\n",
    "        prompt_q += \"Let's think step by step\\n\"\n",
    "        \n",
    "        input_ids = tokenizer(prompt_q, return_tensors=\"pt\").input_ids.to(\"cuda:2\")\n",
    "        outputs = model.generate(input_ids, max_length=256)\n",
    "        ans_ = tokenizer.decode(outputs[0])\n",
    "        \n",
    "        fd.write('Q: %s\\nA_model:\\n%s\\nA:\\n%s\\n\\n' % (q, ans_, a))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "7a64222d-9bcc-4144-ad2f-25487a228da8",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "num_q 400 correct 37 ratio 0.0925\n"
     ]
    }
   ],
   "source": [
    "_, _, _ = parse_pred_ans('../outputs/test_distilled_flan_t5_3b_multiarith_4_cases_prompt.txt')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "95ec9e70-f71a-4963-8c33-caab6e305130",
   "metadata": {},
   "source": [
    "# Test Data with Complex Prompt"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "bd78b1bd-63ae-42da-a089-db7b30d7f48f",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 400/400 [15:30<00:00,  2.33s/it]\n"
     ]
    }
   ],
   "source": [
    "with open('../outputs/test_flan_t5_3b_multiarith_complex.txt', 'w') as fd:\n",
    "    for case in tqdm(test_data):\n",
    "        q = case['sQuestion'][1:-1]\n",
    "        a = case['lSolutions']\n",
    "        \n",
    "        prompt_q = prompt_complex + 'Question: ' + q + '\\n'\n",
    "        prompt_q += \"Let's think step by step\\n\"\n",
    "        \n",
    "        input_ids = tokenizer(prompt_q, return_tensors=\"pt\").input_ids.to(\"cuda:2\")\n",
    "        outputs = model.generate(input_ids, max_length=256)\n",
    "        ans_ = tokenizer.decode(outputs[0])\n",
    "        \n",
    "        fd.write('Q: %s\\nA_model:\\n%s\\nA:\\n%s\\n\\n' % (q, ans_, a))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "83bfc615-eda2-48e2-96f4-0c64b006b3f2",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "num_q 400 correct 95 ratio 0.2375\n"
     ]
    }
   ],
   "source": [
    "_, _, _ = parse_pred_ans('../outputs/test_flan_t5_3b_multiarith_complex.txt')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f80612d1-6406-42fa-b3d5-c136e0ed6acb",
   "metadata": {},
   "source": [
    "# Test Data with Random Prompt"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "186b87dc-d9b8-437e-b1e7-8cb07b997c22",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 400/400 [12:25<00:00,  1.86s/it]\n"
     ]
    }
   ],
   "source": [
    "with open('../outputs/test_flan_t5_3b_multiarith_random.txt', 'w') as fd:\n",
    "    for case in tqdm(test_data):\n",
    "        q = case['sQuestion'][1:-1]\n",
    "        a = case['lSolutions']\n",
    "        \n",
    "        prompt_q = prompt_random + 'Question: ' + q + '\\n'\n",
    "        prompt_q += \"Let's think step by step\\n\"\n",
    "        \n",
    "        input_ids = tokenizer(prompt_q, return_tensors=\"pt\").input_ids.to(\"cuda:2\")\n",
    "        outputs = model.generate(input_ids, max_length=256)\n",
    "        ans_ = tokenizer.decode(outputs[0])\n",
    "        \n",
    "        fd.write('Q: %s\\nA_model:\\n%s\\nA:\\n%s\\n\\n' % (q, ans_, a))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "aa6ab56c-2758-4117-9d77-63b06e9a12a9",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "num_q 400 correct 93 ratio 0.2325\n"
     ]
    }
   ],
   "source": [
    "_, _, _ = parse_pred_ans('../outputs/test_flan_t5_3b_multiarith_random.txt')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "95d17373-7751-4d21-8eaf-56845eb36a51",
   "metadata": {},
   "source": [
    "# Test Data with Direct Answering"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "id": "d327dce9-c034-4dcd-bc53-0a1750ccc6d0",
   "metadata": {},
   "outputs": [],
   "source": [
    "prompt_direct = open('../lib_prompt/prompt_direct.txt').read()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "id": "7be4b343-7f13-4cb5-a44c-992f4ec44711",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 400/400 [07:27<00:00,  1.12s/it]\n"
     ]
    }
   ],
   "source": [
    "with open('../outputs/test_flan_t5_3b_multiarith_direct.txt', 'w') as fd:\n",
    "    for case in tqdm(test_data):\n",
    "        q = case['sQuestion'][1:-1]\n",
    "        a = case['lSolutions']\n",
    "        \n",
    "        prompt_q = prompt_direct + 'Question: ' + q + '\\n'\n",
    "        \n",
    "        input_ids = tokenizer(prompt_q, return_tensors=\"pt\").input_ids.to(\"cuda:2\")\n",
    "        outputs = model.generate(input_ids, max_length=256)\n",
    "        ans_ = tokenizer.decode(outputs[0])\n",
    "        \n",
    "        fd.write('Q: %s\\nA_model:\\n%s\\nA:\\n%s\\n\\n' % (q, ans_, a))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "id": "8c2afb1e-35ce-4d23-a2ae-ef0858dd5881",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "num_q 400 correct 51 ratio 0.1275\n"
     ]
    }
   ],
   "source": [
    "_, _, _ = parse_pred_ans('../outputs/test_flan_t5_3b_multiarith_direct.txt')"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.15"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
