{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "provenance": [],
      "gpuType": "A100",
      "machine_shape": "hm"
    },
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3"
    },
    "language_info": {
      "name": "python"
    },
    "accelerator": "GPU"
  },
  "cells": [
    {
      "cell_type": "code",
      "source": [
        "!pip install transformers"
      ],
      "metadata": {
        "id": "TT9XAC1llua4"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "!pip install accelerate"
      ],
      "metadata": {
        "id": "gbdphk92yYts"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "yUcTtJK4jxVb"
      },
      "outputs": [],
      "source": [
        "from transformers import AutoTokenizer, AutoModelForCausalLM\n",
        "\n",
        "model = AutoModelForCausalLM.from_pretrained(\"mistralai/Mistral-7B-v0.1\")\n",
        "tokenizer = AutoTokenizer.from_pretrained(\"mistralai/Mistral-7B-v0.1\")\n",
        "\n",
        "device = 'cuda'"
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "#@title util functions\n",
        "import numpy as np\n",
        "NUMBER_IDS = [28734, 28740, 28750, 28770, 28781, 28782, 28784, 28787, 28783, 28774]\n",
        "SEP_IDS = [28723, 28725]\n",
        "\n",
        "def find_index(ids):\n",
        "  possible_start = [i for i, x in enumerate(ids) if x == 733]\n",
        "  for ind in possible_start:\n",
        "    if tokenizer.decode(ids[ind:ind+4]) == '[/INST]':\n",
        "      return ind + 4\n",
        "  return -1000\n",
        "\n",
        "# find last number\n",
        "def find_number_blocks(ids):\n",
        "  interval_pairs = []\n",
        "  start = 0\n",
        "  while start < len(ids):\n",
        "    i = 0\n",
        "    while start+i < len(ids) and (ids[start+i] in NUMBER_IDS or (\n",
        "        ids[start+i] in SEP_IDS and ((start+i-1 > 0 and ids[start+i-1] in NUMBER_IDS)) and (start+i+1 < len(ids) and ids[start+i+1] in NUMBER_IDS))):\n",
        "      i += 1\n",
        "    if i > 0:\n",
        "      interval_pairs.append((start, start+i))\n",
        "      start = start + i\n",
        "    else:\n",
        "      start = start + 1\n",
        "  return interval_pairs\n",
        "\n",
        "def _is_float(s):\n",
        "  try:\n",
        "    float(s)\n",
        "    return True\n",
        "  except:\n",
        "    return False\n",
        "\n",
        "def is_correct(target, ans):\n",
        "  if _is_float(target) and _is_float(ans):\n",
        "    if abs(float(target) - float(ans)) <= 1e-5:\n",
        "      return True\n",
        "  elif str(target) == str(ans):\n",
        "    return True\n",
        "  return False\n",
        "\n",
        "def find_even_odd_blocks(ids):\n",
        "  interval_pairs = []\n",
        "  for i, id in enumerate(ids):\n",
        "    if tokenizer.decode([id]) == 'even' or tokenizer.decode([id]) == 'odd' or tokenizer.decode([id]) == 'Even' or (i+1 < len(ids) and tokenizer.decode([id, ids[i+1]]) == 'Odd'):\n",
        "      interval_pairs.append((i, i+1))\n",
        "  return interval_pairs"
      ],
      "metadata": {
        "cellView": "form",
        "id": "7SRUPdaourai"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "#@title gsm data\n",
        "\n",
        "gsm_data = \"\"\"{\"question\": \"Janet\\u2019s ducks lay 16 eggs per day. She eats three for breakfast every morning and bakes muffins for her friends every day with four. She sells the remainder at the farmers' market daily for $2 per fresh duck egg. How much in dollars does she make every day at the farmers' market?\", \"answer\": \"Janet sells 16 - 3 - 4 = <<16-3-4=9>>9 duck eggs a day.\\nShe makes 9 * 2 = $<<9*2=18>>18 every day at the farmer\\u2019s market.\\n#### 18\"}\n",
        "\"\"\"\n",
        "question_list = []\n",
        "answer_list = []\n",
        "for line in gsm_data.split('{\"question\": \"'):\n",
        "  if len(line) < 1:\n",
        "    continue\n",
        "  question_list.append(line.split(\"\\\", \\\"answer\\\": \\\"\")[0])\n",
        "  answer_list.append(line.split(\"#### \")[1].replace('\"}', '').strip())\n",
        "\n",
        "print(len(question_list), len(answer_list))"
      ],
      "metadata": {
        "cellView": "form",
        "id": "DD6ODfpB38UP"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "#@title gsm eval on pre-trained model\n",
        "BATCH_SIZE = 10\n",
        "\n",
        "correct_vec = []\n",
        "for i in range(BATCH_SIZE):\n",
        "  correct_vec.append([])\n",
        "\n",
        "for question, target in zip(question_list, answer_list):\n",
        "  text = \"Q: \" + question + \"\\nA:\"\n",
        "  model_inputs = tokenizer([text]*BATCH_SIZE, return_tensors=\"pt\").to(device)\n",
        "  model.to(device)\n",
        "  generated_ids, score_diff = model.generate(\n",
        "      model_inputs.input_ids, max_new_tokens=200, sample_tokens=True,\n",
        "      do_sample=False, pad_token_id=tokenizer.eos_token_id)\n",
        "  decoded = tokenizer.batch_decode(generated_ids)\n",
        "  score_diff = [diff.cpu() for diff in score_diff]\n",
        "\n",
        "  for i in range(BATCH_SIZE):\n",
        "    cleaned_dec = decoded[i].split('\\nQ:')[0].strip().split('\\nA:')[1].split('</s>')[0].strip()\n",
        "    if 'Quant Questions' in cleaned_dec or 'GMAT' in cleaned_dec or cleaned_dec.startswith('\\(\\frac') or (\n",
        "        'B.' in cleaned_dec and 'C.' in cleaned_dec) or ('B:' in cleaned_dec and 'C:' in cleaned_dec):\n",
        "      continue\n",
        "    print('===', cleaned_dec)\n",
        "\n",
        "  for top_k in range(BATCH_SIZE):\n",
        "    max_score_diff = -1000\n",
        "    max_dec = ''\n",
        "    max_ans = ''\n",
        "    for i in range(top_k+1):\n",
        "      cleaned_dec = decoded[i].split('\\nQ:')[0].strip().split('\\nA:')[1].split('</s>')[0].strip()\n",
        "      if 'Quant Questions' in cleaned_dec or 'GMAT' in cleaned_dec or cleaned_dec.startswith('\\(\\frac') or (\n",
        "          'B.' in cleaned_dec and 'C.' in cleaned_dec) or ('B:' in cleaned_dec and 'C:' in cleaned_dec):\n",
        "        continue\n",
        "      scores = score_diff[i]\n",
        "\n",
        "      ids = generated_ids[i, :]\n",
        "      start_index = len(model_inputs.input_ids[i, :])\n",
        "      end_index = -1000\n",
        "      for ind in range(len(ids)):\n",
        "        if tokenizer.decode(ids[ind:ind+3]) == '\\nQ:' or ids[ind] == 2:\n",
        "          end_index = ind\n",
        "          break\n",
        "\n",
        "      output_ids = ids[start_index:end_index]\n",
        "      assert len(ids[start_index:]) == len(scores)\n",
        "\n",
        "      # find last number interval and extract logit diff\n",
        "      final_score_diff = max_score_diff\n",
        "      blocks = find_number_blocks(output_ids)\n",
        "      if len(blocks) > 0:\n",
        "        interval = blocks[-1]\n",
        "        # print(interval, tokenizer.decode(output_ids[interval[0]:interval[1]]),\n",
        "        #       scores[interval[0]:interval[1]])\n",
        "        final_score_diff = np.average(scores[interval[0]:interval[1]])\n",
        "\n",
        "      if final_score_diff > max_score_diff:\n",
        "        max_score_diff = final_score_diff\n",
        "        max_dec = cleaned_dec\n",
        "        max_ans = tokenizer.decode(output_ids[interval[0]:interval[1]])\n",
        "\n",
        "    print('max', max_score_diff, max_ans, '|', target)\n",
        "    if is_correct(target, max_ans):\n",
        "      correct_vec[top_k].append(1)\n",
        "    else:\n",
        "      correct_vec[top_k].append(0)\n",
        "  print('sum:', len(correct_vec[0]), [np.sum(correct_vec[j]) for j in range(BATCH_SIZE)])\n",
        "  print('acc:', [\"{:0.4f}\".format(np.average(correct_vec[j])) for j in range(BATCH_SIZE)])"
      ],
      "metadata": {
        "id": "bbS6G0qoA29m",
        "cellView": "form"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "#@title gsm eval on instruction-tuned model\n",
        "BATCH_SIZE = 10\n",
        "\n",
        "correct_vec = []\n",
        "for i in range(BATCH_SIZE):\n",
        "  correct_vec.append([])\n",
        "\n",
        "for question, target in zip(question_list, answer_list):\n",
        "  text = \"[INST] \" + question + \" [/INST]\"\n",
        "  model_inputs = tokenizer([text]*BATCH_SIZE, return_tensors=\"pt\").to(device)\n",
        "  model.to(device)\n",
        "  generated_ids, score_diff = model.generate(\n",
        "      model_inputs.input_ids, max_new_tokens=400, sample_tokens=True,\n",
        "      do_sample=False, pad_token_id=tokenizer.eos_token_id)\n",
        "  decoded = tokenizer.batch_decode(generated_ids)\n",
        "  score_diff = [diff.cpu() for diff in score_diff]\n",
        "\n",
        "  # for i in range(BATCH_SIZE):\n",
        "  #   cleaned_dec = decoded[i].split('[/INST]')[1].strip().split('</s>')[0].strip()\n",
        "  #   print('===', cleaned_dec)\n",
        "\n",
        "  for top_k in range(BATCH_SIZE):\n",
        "    max_score_diff = -1000\n",
        "    max_dec = ''\n",
        "    max_ans = ''\n",
        "    for i in range(top_k+1):\n",
        "      cleaned_dec = decoded[i]\n",
        "      scores = score_diff[i]\n",
        "\n",
        "      ids = generated_ids[i, :]\n",
        "      start_index = len(model_inputs.input_ids[i, :])\n",
        "      end_index = -1000\n",
        "      for ind in range(len(ids)):\n",
        "        if ids[ind] == 2:\n",
        "          end_index = ind\n",
        "          break\n",
        "\n",
        "      output_ids = ids[start_index:end_index]\n",
        "      assert len(ids[start_index:]) == len(scores)\n",
        "\n",
        "      # find last number interval and extract logit diff\n",
        "      final_score_diff = max_score_diff\n",
        "      blocks = find_number_blocks(output_ids)\n",
        "      if len(blocks) > 0:\n",
        "        interval = blocks[-1]\n",
        "        # print(interval, tokenizer.decode(output_ids[interval[0]:interval[1]]),\n",
        "        #       scores[interval[0]:interval[1]])\n",
        "        final_score_diff = np.average(scores[interval[0]:interval[1]])\n",
        "\n",
        "      if final_score_diff > max_score_diff:\n",
        "        max_score_diff = final_score_diff\n",
        "        max_dec = cleaned_dec\n",
        "        max_ans = tokenizer.decode(output_ids[interval[0]:interval[1]])\n",
        "\n",
        "    print('max', max_score_diff, max_ans, '|', target)\n",
        "    if is_correct(target, max_ans):\n",
        "      correct_vec[top_k].append(1)\n",
        "    else:\n",
        "      correct_vec[top_k].append(0)\n",
        "  print('sum:', len(correct_vec[0]), [np.sum(correct_vec[j]) for j in range(BATCH_SIZE)])\n",
        "  print('acc:', [\"{:0.4f}\".format(np.average(correct_vec[j])) for j in range(BATCH_SIZE)])"
      ],
      "metadata": {
        "id": "FixjgaiQLlAT",
        "cellView": "form"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "#@title name list\n",
        "names_string = \"\"\"Sasha Calle\n",
        "Annie Murphy\n",
        "Golshifteh Farahani\n",
        "Kate Mara\n",
        "Josh Hartnett\n",
        "Jennifer Lawrence\n",
        "Aaron Taylor-Johnson\n",
        "Rebecca Ferguson\n",
        "Monica Barbaro\n",
        "Chris Hemsworth\n",
        "Wes Anderson\n",
        "Daniel Portman\n",
        "Lily-Rose Depp\n",
        "Myha'la Herrold\n",
        "Zendaya\n",
        "Ezra Miller\n",
        "Olga Kurylenko\n",
        "Zazie Beetz\n",
        "Arnold Schwarzenegger\n",
        "Emilia Clarke\n",
        "Jess Bush\n",
        "Clara Rugaard\n",
        "Molly Gordon\n",
        "Isabel May\n",
        "Hailee Steinfeld\n",
        "Hannah Waddingham\n",
        "Christina Chong\n",
        "Rory Culkin\n",
        "Cobie Smulders\n",
        "Harrison Ford\n",
        "Tom Cruise\n",
        "Carol Kane\n",
        "Alexandra Daddario\n",
        "Gal Gadot\n",
        "Tom Holland\n",
        "Hayley Atwell\n",
        "Salma Hayek\n",
        "Ana de Armas\n",
        "Will Poulter\n",
        "Anson Mount\n",
        "Paapa Essiedu\n",
        "Sam Hargrave\n",
        "Margot Robbie\n",
        "Nicolas Cage\n",
        "Henry Cavill\n",
        "Juno Temple\n",
        "Cailee Spaeny\n",
        "Treat Williams\n",
        "Alexander Skarsgård\n",
        "Rebecca Romijn\n",
        "Rebecca Romijn\n",
        "Monica Dolan\n",
        "Anya Taylor-Joy\n",
        "Sophia Lillis\n",
        "Emmanuelle Vaugier\n",
        "Aaron Paul\n",
        "Elliot Page\n",
        "Robin Tunney\n",
        "Mike Faist\n",
        "Tinatin Dalakishvili\n",
        "Sarah Snook\n",
        "Jenna Ortega\n",
        "Zoe Saldana\n",
        "Anjana Vasan\n",
        "Ben Mendelsohn\n",
        "Jeremy Allen White\n",
        "Ayo Edebiri\n",
        "Keanu Reeves\n",
        "Pom Klementieff\n",
        "Scarlett Johansson\n",
        "Tornike Gogrichiani\n",
        "James Cameron\n",
        "Pedro Pascal\n",
        "Kaley Cuoco\n",
        "Samuel L. Jackson\n",
        "Terri Ivens\n",
        "Florence Pugh\n",
        "Shea Whigham\n",
        "Kingsley Ben-Adir\n",
        "Michael Keaton\n",
        "Julian Sands\n",
        "Christopher Nolan\n",
        "Tom Hanks\n",
        "Clint Eastwood\n",
        "Gabriel Macht\n",
        "Fabiana Udenio\n",
        "Tom Bateman\n",
        "Jack Champion\n",
        "Jake Gyllenhaal\n",
        "Leonardo DiCaprio\n",
        "Jason Schwartzman\n",
        "Grace Caroline Currey\n",
        "Sydney Sweeney\n",
        "Emily Rudd\n",
        "Samuel Blenkin\n",
        "James Marsden\n",
        "Jesse Plemons\n",
        "Alan Ritchson\n",
        "Cillian Murphy\n",
        "Meghan Markle\"\"\"\n",
        "\n",
        "name_list = names_string.split(\"\\n\")\n",
        "print(len(name_list))"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "cellView": "form",
        "id": "T39kQT2NXvDw",
        "outputId": "def07e8c-0091-4f4d-c648-bff7b21f214d"
      },
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "100\n"
          ]
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "#@title query for birth year\n",
        "import re\n",
        "predicted_year_map = {}\n",
        "for name in name_list:\n",
        "  text = \"[INST] When was \" + name + \" born? [/INST]\"\n",
        "  model_inputs = tokenizer(text, return_tensors=\"pt\").to(device)\n",
        "  model.to(device)\n",
        "  generated_ids, score_diff = model.generate(\n",
        "      model_inputs.input_ids, max_new_tokens=30, do_sample=False, sample_tokens=False,\n",
        "      pad_token_id=tokenizer.eos_token_id)\n",
        "  decoded = tokenizer.batch_decode(generated_ids)[0].split('\\nQ:')[0]\n",
        "  print(decoded)\n",
        "\n",
        "  year = re.findall(r'\\d+', decoded)\n",
        "  for y in year:\n",
        "    if len(y) == 4:\n",
        "      predicted_year_map[name] = y\n",
        "      break\n",
        "\n",
        "print(predicted_year_map)"
      ],
      "metadata": {
        "id": "tNKGkY7FGNMn",
        "cellView": "form"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "#@title birth year load from existing data\n",
        "# pretrained\n",
        "# predicted_year_map = {'Sasha Calle': '1997', 'Annie Murphy': '1986', 'Golshifteh Farahani': '1983', 'Kate Mara': '1983', 'Josh Hartnett': '1978', 'Jennifer Lawrence': '1990', 'Aaron Taylor-Johnson': '1990', 'Rebecca Ferguson': '1983', 'Monica Barbaro': '1989', 'Chris Hemsworth': '1983', 'Wes Anderson': '1969', 'Daniel Portman': '2001', 'Lily-Rose Depp': '1999', \"Myha'la Herrold\": '1997', 'Zendaya': '1996', 'Ezra Miller': '1992', 'Olga Kurylenko': '1979', 'Zazie Beetz': '1991', 'Arnold Schwarzenegger': '1947', 'Emilia Clarke': '1986', 'Jess Bush': '1994', 'Clara Rugaard': '1998', 'Molly Gordon': '1994', 'Isabel May': '2000', 'Hailee Steinfeld': '1996', 'Hannah Waddingham': '1965', 'Christina Chong': '1983', 'Rory Culkin': '1989', 'Cobie Smulders': '1982', 'Harrison Ford': '1942', 'Tom Cruise': '1962', 'Carol Kane': '1952', 'Alexandra Daddario': '1986', 'Gal Gadot': '1985', 'Tom Holland': '1996', 'Hayley Atwell': '1982', 'Salma Hayek': '1966', 'Ana de Armas': '1988', 'Will Poulter': '1993', 'Anson Mount': '1973', 'Paapa Essiedu': '1990', 'Sam Hargrave': '1979', 'Margot Robbie': '1990', 'Nicolas Cage': '1964', 'Henry Cavill': '1983', 'Juno Temple': '1989', 'Cailee Spaeny': '1997', 'Treat Williams': '1951', 'Alexander Skarsgård': '1976', 'Rebecca Romijn': '1972', 'Monica Dolan': '1967', 'Anya Taylor-Joy': '1996', 'Sophia Lillis': '2002', 'Emmanuelle Vaugier': '1976', 'Aaron Paul': '1979', 'Elliot Page': '1987', 'Robin Tunney': '1972', 'Mike Faist': '1991', 'Tinatin Dalakishvili': '1989', 'Sarah Snook': '1987', 'Jenna Ortega': '2002', 'Zoe Saldana': '1978', 'Anjana Vasan': '1992', 'Ben Mendelsohn': '1969', 'Jeremy Allen White': '1991', 'Ayo Edebiri': '1999', 'Keanu Reeves': '1964', 'Pom Klementieff': '1986', 'Scarlett Johansson': '1984', 'Tornike Gogrichiani': '1990', 'James Cameron': '1954', 'Pedro Pascal': '1975', 'Kaley Cuoco': '1985', 'Samuel L. Jackson': '1948', 'Terri Ivens': '1976', 'Florence Pugh': '1996', 'Shea Whigham': '1969', 'Kingsley Ben-Adir': '1989', 'Michael Keaton': '1951', 'Julian Sands': '1958', 'Christopher Nolan': '1970', 'Tom Hanks': '1956', 'Clint Eastwood': '1930', 'Gabriel Macht': '1972', 'Fabiana Udenio': '1967', 'Tom Bateman': '1989', 'Jack Champion': '1999', 'Jake Gyllenhaal': '1980', 'Leonardo DiCaprio': '1974', 'Jason Schwartzman': '1980', 'Grace Caroline Currey': '1997', 'Sydney Sweeney': '1997', 'Emily Rudd': '1997', 'Samuel Blenkin': '1840', 'James Marsden': '1973', 'Jesse Plemons': '1988', 'Alan Ritchson': '1982', 'Cillian Murphy': '1976', 'Meghan Markle': '1981'}\n",
        "\n",
        "# IT\n",
        "predicted_year_map = {'Sasha Calle': '1993', 'Annie Murphy': '1984', 'Golshifteh Farahani': '1983', 'Kate Mara': '1983', 'Josh Hartnett': '1978', 'Jennifer Lawrence': '1990', 'Aaron Taylor-Johnson': '1984', 'Rebecca Ferguson': '1986', 'Monica Barbaro': '1985', 'Chris Hemsworth': '1980', 'Wes Anderson': '1969', 'Daniel Portman': '1992', 'Lily-Rose Depp': '1999', \"Myha'la Herrold\": '2006', 'Zendaya': '1996', 'Ezra Miller': '1992', 'Olga Kurylenko': '1980', 'Zazie Beetz': '1997', 'Arnold Schwarzenegger': '1946', 'Emilia Clarke': '1982', 'Jess Bush': '1983', 'Clara Rugaard': '2000', 'Molly Gordon': '1996', 'Hailee Steinfeld': '2000', 'Hannah Waddingham': '1964', 'Christina Chong': '1973', 'Rory Culkin': '1985', 'Cobie Smulders': '1982', 'Harrison Ford': '1942', 'Tom Cruise': '1962', 'Carol Kane': '1952', 'Alexandra Daddario': '1986', 'Gal Gadot': '1985', 'Tom Holland': '1967', 'Hayley Atwell': '1982', 'Salma Hayek': '1964', 'Ana de Armas': '1988', 'Will Poulter': '1993', 'Anson Mount': '1973', 'Paapa Essiedu': '1995', 'Sam Hargrave': '1984', 'Margot Robbie': '1990', 'Nicolas Cage': '1964', 'Henry Cavill': '1986', 'Juno Temple': '1987', 'Cailee Spaeny': '2000', 'Treat Williams': '1952', 'Alexander Skarsgård': '1976', 'Rebecca Romijn': '1967', 'Monica Dolan': '1969', 'Anya Taylor-Joy': '1996', 'Sophia Lillis': '2002', 'Emmanuelle Vaugier': '1970', 'Aaron Paul': '1978', 'Elliot Page': '1981', 'Robin Tunney': '1965', 'Mike Faist': '1981', 'Tinatin Dalakishvili': '1985', 'Sarah Snook': '1985', 'Jenna Ortega': '2004', 'Zoe Saldana': '1978', 'Ben Mendelsohn': '1971', 'Jeremy Allen White': '1981', 'Ayo Edebiri': '1995', 'Keanu Reeves': '1969', 'Pom Klementieff': '1986', 'Scarlett Johansson': '1984', 'Tornike Gogrichiani': '1990', 'James Cameron': '1954', 'Pedro Pascal': '1987', 'Kaley Cuoco': '1985', 'Samuel L. Jackson': '1948', 'Terri Ivens': '1969', 'Florence Pugh': '1888', 'Shea Whigham': '1968', 'Kingsley Ben-Adir': '1984', 'Michael Keaton': '1959', 'Julian Sands': '1953', 'Christopher Nolan': '1970', 'Tom Hanks': '1956', 'Clint Eastwood': '1930', 'Gabriel Macht': '1971', 'Fabiana Udenio': '1972', 'Tom Bateman': '1982', 'Jack Champion': '1996', 'Jake Gyllenhaal': '1981', 'Leonardo DiCaprio': '1974', 'Jason Schwartzman': '1980', 'Grace Caroline Currey': '1990', 'Sydney Sweeney': '1997', 'Emily Rudd': '1991', 'Samuel Blenkin': '1879', 'James Marsden': '1973', 'Jesse Plemons': '1981', 'Alan Ritchson': '1984', 'Cillian Murphy': '1976', 'Meghan Markle': '1981'}"
      ],
      "metadata": {
        "id": "jSllfJ9VXWj5"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "#@title parity eval on pre-trained model\n",
        "\n",
        "BATCH_SIZE = 10\n",
        "\n",
        "correct_vec = []\n",
        "for i in range(BATCH_SIZE):\n",
        "  correct_vec.append([])\n",
        "\n",
        "for name in name_list:\n",
        "  text = \"Q: Was \" + name + \" born in an even or odd year?\\nA:\"\n",
        "  if name not in predicted_year_map:\n",
        "    continue\n",
        "  year = int(predicted_year_map[name])\n",
        "  if year % 2 == 0:\n",
        "    target = 'even'\n",
        "  else:\n",
        "    target = 'odd'\n",
        "\n",
        "  model_inputs = tokenizer([text]*BATCH_SIZE, return_tensors=\"pt\").to(device)\n",
        "  model.to(device)\n",
        "  generated_ids, score_diff = model.generate(\n",
        "      model_inputs.input_ids, max_new_tokens=50, sample_tokens=True,\n",
        "      do_sample=False, pad_token_id=tokenizer.eos_token_id)\n",
        "  decoded = tokenizer.batch_decode(generated_ids)\n",
        "  score_diff = [diff.cpu() for diff in score_diff]\n",
        "\n",
        "  for i in range(BATCH_SIZE):\n",
        "    cleaned_dec = decoded[i].split('\\nQ:')[0].strip()\n",
        "    print('===', cleaned_dec)\n",
        "\n",
        "  for top_k in range(BATCH_SIZE):\n",
        "    max_score_diff = -1000\n",
        "    max_dec = ''\n",
        "    max_ans = ''\n",
        "    for i in range(top_k+1):\n",
        "      cleaned_dec = decoded[i].split('\\nQ:')[0].strip()\n",
        "      scores = score_diff[i]\n",
        "\n",
        "      ids = generated_ids[i, :]\n",
        "      start_index = len(model_inputs.input_ids[i, :])\n",
        "      end_index = -1000\n",
        "      for ind in range(len(ids)):\n",
        "        if tokenizer.decode(ids[ind:ind+3]) == '\\nQ:' or ids[ind] == 2:\n",
        "          end_index = ind\n",
        "          break\n",
        "\n",
        "      output_ids = ids[start_index:end_index]\n",
        "      assert len(ids[start_index:]) == len(scores)\n",
        "\n",
        "      # find last number interval and extract logit diff\n",
        "      final_score_diff = max_score_diff\n",
        "      blocks = find_even_odd_blocks(output_ids)\n",
        "      if len(blocks) > 0:\n",
        "        interval = blocks[-1]\n",
        "        # print(interval, tokenizer.decode(output_ids[interval[0]:interval[1]]),\n",
        "        #       scores[interval[0]:interval[1]])\n",
        "        final_score_diff = np.average(scores[interval[0]:interval[1]])\n",
        "\n",
        "      if final_score_diff > max_score_diff:\n",
        "        max_score_diff = final_score_diff\n",
        "        max_dec = cleaned_dec\n",
        "        max_ans = tokenizer.decode(output_ids[interval[0]:interval[1]])\n",
        "\n",
        "    print('max', max_score_diff, max_ans, '|', target)\n",
        "    if target.lower() == max_ans.lower():\n",
        "      correct_vec[top_k].append(1)\n",
        "    else:\n",
        "      correct_vec[top_k].append(0)\n",
        "  print('sum:', len(correct_vec[0]), [np.sum(correct_vec[j]) for j in range(BATCH_SIZE)])\n",
        "  print('acc:', [\"{:0.4f}\".format(np.average(correct_vec[j])) for j in range(BATCH_SIZE)])"
      ],
      "metadata": {
        "id": "lgB0PP9yuv4o",
        "cellView": "form"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "#@title parity eval on instruction-tuned model\n",
        "\n",
        "def find_even_odd_blocks(ids):\n",
        "  interval_pairs = []\n",
        "  for i, id in enumerate(ids):\n",
        "    if tokenizer.decode([id]) == 'even' or tokenizer.decode([id]) == 'odd' or tokenizer.decode([id]) == 'Even' or tokenizer.decode([id]) == 'Odd':\n",
        "      interval_pairs.append((i, i+1))\n",
        "  return interval_pairs\n",
        "\n",
        "BATCH_SIZE = 10\n",
        "\n",
        "correct_vec = []\n",
        "for i in range(BATCH_SIZE):\n",
        "  correct_vec.append([])\n",
        "\n",
        "for name in name_list:\n",
        "  text = \"[INST] Was \" + name + \" born in an even or odd year? [/INST]\"\n",
        "  if name not in predicted_year_map:\n",
        "    continue\n",
        "  year = int(predicted_year_map[name])\n",
        "  if year % 2 == 0:\n",
        "    target = 'even'\n",
        "  else:\n",
        "    target = 'odd'\n",
        "\n",
        "  model_inputs = tokenizer([text]*BATCH_SIZE, return_tensors=\"pt\").to(device)\n",
        "  model.to(device)\n",
        "  generated_ids, score_diff = model.generate(\n",
        "      model_inputs.input_ids, max_new_tokens=100, sample_tokens=True,\n",
        "      do_sample=False, pad_token_id=tokenizer.eos_token_id)\n",
        "  decoded = tokenizer.batch_decode(generated_ids)\n",
        "  score_diff = [diff.cpu() for diff in score_diff]\n",
        "\n",
        "  for i in range(BATCH_SIZE):\n",
        "    cleaned_dec = decoded[i].split('\\nQ:')[0].strip()\n",
        "    print('===', cleaned_dec)\n",
        "\n",
        "  for top_k in range(BATCH_SIZE):\n",
        "    max_score_diff = -1000\n",
        "    max_dec = ''\n",
        "    max_ans = ''\n",
        "    for i in range(top_k+1):\n",
        "      cleaned_dec = decoded[i].split('\\nQ:')[0].strip()\n",
        "      scores = score_diff[i]\n",
        "\n",
        "      ids = generated_ids[i, :]\n",
        "      start_index = len(model_inputs.input_ids[i, :])\n",
        "      end_index = -1000\n",
        "      for ind in range(len(ids)):\n",
        "        if ids[ind] == 2:\n",
        "          end_index = ind\n",
        "          break\n",
        "\n",
        "      output_ids = ids[start_index:end_index]\n",
        "      assert len(ids[start_index:]) == len(scores)\n",
        "\n",
        "      # find last number interval and extract logit diff\n",
        "      final_score_diff = max_score_diff\n",
        "      blocks = find_even_odd_blocks(output_ids)\n",
        "      if len(blocks) > 0:\n",
        "        interval = blocks[-1]\n",
        "        # print(interval, tokenizer.decode(output_ids[interval[0]:interval[1]]),\n",
        "        #       scores[interval[0]:interval[1]])\n",
        "        final_score_diff = np.average(scores[interval[0]:interval[1]])\n",
        "\n",
        "      if final_score_diff > max_score_diff:\n",
        "        max_score_diff = final_score_diff\n",
        "        max_dec = cleaned_dec\n",
        "        max_ans = tokenizer.decode(output_ids[interval[0]:interval[1]])\n",
        "\n",
        "    print('max', max_score_diff, max_ans, '|', target)\n",
        "    if target.lower() == max_ans.lower():\n",
        "      correct_vec[top_k].append(1)\n",
        "    else:\n",
        "      correct_vec[top_k].append(0)\n",
        "  print('sum:', len(correct_vec[0]), [np.sum(correct_vec[j]) for j in range(BATCH_SIZE)])\n",
        "  print('acc:', [\"{:0.4f}\".format(np.average(correct_vec[j])) for j in range(BATCH_SIZE)])"
      ],
      "metadata": {
        "id": "0YT7W8ioe_g3",
        "cellView": "form"
      },
      "execution_count": null,
      "outputs": []
    }
  ]
}