{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "f434e707-d5b1-4bea-8c58-9dada196c167",
   "metadata": {},
   "source": [
    "# Testing flan-t5 performance on multiarith "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "28efad8f-5ad1-42dc-99fc-dacd76fbc2b1",
   "metadata": {},
   "outputs": [],
   "source": [
    "import datasets\n",
    "import torch\n",
    "import re\n",
    "import json\n",
    "\n",
    "import numpy as np\n",
    "\n",
    "from tqdm import tqdm\n",
    "from datasets import load_dataset\n",
    "from transformers import T5Tokenizer, T5ForConditionalGeneration"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "a0a55ea2-0162-49b6-9630-054ee301daf1",
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset = json.load(open('data/multiarith/MultiArith.json'))\n",
    "dev_ind = np.load('data/multiarith/validation_index.npy')\n",
    "dev_data = [dataset[i] for i in dev_ind]\n",
    "test_data = [d for i, d in enumerate(dataset) if i not in dev_ind]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "3dd35ad2-d2b5-4541-b7ad-59b15e7e88f8",
   "metadata": {},
   "outputs": [],
   "source": [
    "prompt_complex = open('data/multiarith/prompt_hardest_from_gsm8k.txt').read()\n",
    "prompt_original = open('data/multiarith/prompt_original_from_gsm8k.txt').read()\n",
    "prompt_random = open('data/multiarith/prompt_random_from_gsm8k.txt').read()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "68e9600f-72d9-44d7-9de3-47f296ebaf35",
   "metadata": {},
   "outputs": [],
   "source": [
    "tokenizer = T5Tokenizer.from_pretrained(\"google/flan-t5-xxl\")\n",
    "model = T5ForConditionalGeneration.from_pretrained(\"google/flan-t5-xxl\").to('cuda:0')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "d44cb016-b93e-4000-8657-9e6d95e2575d",
   "metadata": {},
   "outputs": [],
   "source": [
    "def test_answer(pred_str, ans_str):\n",
    "    \"\"\"Find the last number as the predicted answer\"\"\"\n",
    "    pattern = '\\d*\\.?\\d+'\n",
    "    pred = re.findall(pattern, pred_str)\n",
    "    if(len(pred) >= 1):\n",
    "        # print(pred_str)\n",
    "        pred = float(pred[-1])\n",
    "        gold = re.findall(pattern, ans_str)\n",
    "        # print(ans_str)\n",
    "        gold = float(gold[-1])\n",
    "        return pred == gold\n",
    "    else: return False\n",
    "\n",
    "\n",
    "def parse_pred_ans(filename, stop_at=-1):\n",
    "    with open(filename) as fd: lines = fd.readlines()\n",
    "    am, a = None, None\n",
    "    num_q, acc = 0, 0\n",
    "    current_mode = 'none'\n",
    "    questions = []\n",
    "    ans_pred = []\n",
    "    ans_gold = []\n",
    "    for l in lines:\n",
    "        if(l.startswith('Q: ')):\n",
    "            if(am is not None and a is not None):\n",
    "                questions.append(q)\n",
    "                ans_pred.append(am)\n",
    "                ans_gold.append(a)\n",
    "                if(test_answer(am, a)):\n",
    "                    acc += 1\n",
    "            current_mode = 'q'\n",
    "            q = l\n",
    "            num_q += 1\n",
    "            if(num_q == stop_at): break\n",
    "        elif(l.startswith('A_model:')):\n",
    "            current_mode = 'am'\n",
    "            am = l\n",
    "        elif(l.startswith('A:')):\n",
    "            current_mode = 'a'\n",
    "            a = l\n",
    "        else:\n",
    "            if(current_mode == 'q'): q += l\n",
    "            elif(current_mode == 'am'): am += l\n",
    "            elif(current_mode == 'a'): a += l\n",
    "            else:\n",
    "                raise ValueError(current_mode)\n",
    "                \n",
    "    questions.append(q)\n",
    "    ans_pred.append(am)\n",
    "    ans_gold.append(a)\n",
    "    if(test_answer(am, a)):\n",
    "        acc += 1\n",
    "    print('num_q %d correct %d ratio %.4f' % (num_q, acc, float(acc / num_q)))\n",
    "    return questions, ans_pred, ans_gold"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5e8ff96d-7699-4df3-a99d-a2c114ba81e0",
   "metadata": {},
   "source": [
    "# Test Data with Original Prompt, Acc 51.5"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "2b95c720-4f9e-4952-87a4-589574886d00",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 400/400 [12:13<00:00,  1.83s/it]\n"
     ]
    }
   ],
   "source": [
    "with open('outputs/test_flan_t5_multiarith_original.txt', 'w') as fd:\n",
    "    for case in tqdm(test_data):\n",
    "        q = case['sQuestion'][1:-1]\n",
    "        a = case['lSolutions']\n",
    "        \n",
    "        prompt_q = prompt_original + 'Question: ' + q + '\\n'\n",
    "        prompt_q += \"Let's think step by step\\n\"\n",
    "        \n",
    "        input_ids = tokenizer(prompt_q, return_tensors=\"pt\").input_ids.to(\"cuda:0\")\n",
    "        outputs = model.generate(input_ids, max_length=256)\n",
    "        ans_ = tokenizer.decode(outputs[0])\n",
    "        \n",
    "        fd.write('Q: %s\\nA_model:\\n%s\\nA:\\n%s\\n\\n' % (q, ans_, a))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "id": "53939732-117e-4490-8844-caaeea4adbc6",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "num_q 400 correct 206 ratio 0.5150\n"
     ]
    }
   ],
   "source": [
    "_, _, _ = parse_pred_ans('outputs/test_flan_t5_multiarith_original.txt')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "95ec9e70-f71a-4963-8c33-caab6e305130",
   "metadata": {},
   "source": [
    "# Test Data with Complex Prompt, Acc 51.25"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "id": "bd78b1bd-63ae-42da-a089-db7b30d7f48f",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 400/400 [21:54<00:00,  3.29s/it]\n"
     ]
    }
   ],
   "source": [
    "with open('outputs/test_flan_t5_multiarith_complex.txt', 'w') as fd:\n",
    "    for case in tqdm(test_data):\n",
    "        q = case['sQuestion'][1:-1]\n",
    "        a = case['lSolutions']\n",
    "        \n",
    "        prompt_q = prompt_complex + 'Question: ' + q + '\\n'\n",
    "        prompt_q += \"Let's think step by step\\n\"\n",
    "        \n",
    "        input_ids = tokenizer(prompt_q, return_tensors=\"pt\").input_ids.to(\"cuda:0\")\n",
    "        outputs = model.generate(input_ids, max_length=256)\n",
    "        ans_ = tokenizer.decode(outputs[0])\n",
    "        \n",
    "        fd.write('Q: %s\\nA_model:\\n%s\\nA:\\n%s\\n\\n' % (q, ans_, a))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "id": "83bfc615-eda2-48e2-96f4-0c64b006b3f2",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "num_q 400 correct 205 ratio 0.5125\n"
     ]
    }
   ],
   "source": [
    "_, _, _ = parse_pred_ans('outputs/test_flan_t5_multiarith_complex.txt')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f80612d1-6406-42fa-b3d5-c136e0ed6acb",
   "metadata": {},
   "source": [
    "# Test Data with Random Prompt, Acc 53.25"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "186b87dc-d9b8-437e-b1e7-8cb07b997c22",
   "metadata": {},
   "outputs": [],
   "source": [
    "with open('outputs/test_flan_t5_multiarith_random.txt', 'w') as fd:\n",
    "    for case in tqdm(test_data):\n",
    "        q = case['sQuestion'][1:-1]\n",
    "        a = case['lSolutions']\n",
    "        \n",
    "        prompt_q = prompt_random + 'Question: ' + q + '\\n'\n",
    "        prompt_q += \"Let's think step by step\\n\"\n",
    "        \n",
    "        input_ids = tokenizer(prompt_q, return_tensors=\"pt\").input_ids.to(\"cuda:0\")\n",
    "        outputs = model.generate(input_ids, max_length=256)\n",
    "        ans_ = tokenizer.decode(outputs[0])\n",
    "        \n",
    "        fd.write('Q: %s\\nA_model:\\n%s\\nA:\\n%s\\n\\n' % (q, ans_, a))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "id": "aa6ab56c-2758-4117-9d77-63b06e9a12a9",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "num_q 400 correct 213 ratio 0.5325\n"
     ]
    }
   ],
   "source": [
    "_, _, _ = parse_pred_ans('outputs/test_flan_t5_multiarith_random.txt')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "95d17373-7751-4d21-8eaf-56845eb36a51",
   "metadata": {},
   "source": [
    "# Test Data with Direct Answering, Acc 13.0"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "id": "d327dce9-c034-4dcd-bc53-0a1750ccc6d0",
   "metadata": {},
   "outputs": [],
   "source": [
    "prompt_direct = open('lib_prompt/prompt_direct.txt').read()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "id": "7be4b343-7f13-4cb5-a44c-992f4ec44711",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 400/400 [03:42<00:00,  1.80it/s]\n"
     ]
    }
   ],
   "source": [
    "with open('outputs/test_flan_t5_multiarith_direct.txt', 'w') as fd:\n",
    "    for case in tqdm(test_data):\n",
    "        q = case['sQuestion'][1:-1]\n",
    "        a = case['lSolutions']\n",
    "        \n",
    "        prompt_q = prompt_direct + 'Question: ' + q + '\\n'\n",
    "        \n",
    "        input_ids = tokenizer(prompt_q, return_tensors=\"pt\").input_ids.to(\"cuda:0\")\n",
    "        outputs = model.generate(input_ids, max_length=256)\n",
    "        ans_ = tokenizer.decode(outputs[0])\n",
    "        \n",
    "        fd.write('Q: %s\\nA_model:\\n%s\\nA:\\n%s\\n\\n' % (q, ans_, a))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "id": "8c2afb1e-35ce-4d23-a2ae-ef0858dd5881",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "num_q 400 correct 52 ratio 0.1300\n"
     ]
    }
   ],
   "source": [
    "_, _, _ = parse_pred_ans('outputs/test_flan_t5_multiarith_direct.txt')"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.15"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
