{
  "cells": [
    {
      "metadata": {},
      "id": "c35d3992-188b-4b66-afc9-4cba9ce1b708",
      "cell_type": "code",
      "source": [
        "# Imports\n",
        "import os\n",
        "import sys\n",
        "import json\n",
        "import numpy as np\n",
        "import openai\n",
        "%load_ext autoreload\n",
        "%autoreload 2\n",
        "\n",
        "# Add the parent folder to make the utilities importable\n",
        "module_path = os.path.abspath(os.path.join('..'))\n",
        "sys.path.insert(0, module_path)\n",
        "\n",
        "import answer_scoring as score\n",
        "import answer_extraction as extract"
      ],
      "outputs": [],
      "execution_count": null
    },
    {
      "metadata": {},
      "id": "igH7ownoDp0c",
      "cell_type": "code",
      "source": [
        "# We will run the models using the OpenAI framework as an example, but it\n",
        "# is easy to generalize to any generate() function that returns a string \n",
        "# given a string prompt.\n",
        "\n",
        "# Add a key.env file in the base directory with your API key.\n",
        "# The key.env file should be a plain text file with the line:\n",
        "# OPENAI_API_KEY=\u003cyour API key\u003e\n",
        "\n",
        "from dotenv import load_dotenv\n",
        "load_dotenv('../key.env')\n",
        "import openai\n",
        "\n",
        "\n",
        "client = openai.OpenAI(api_key=os.environ.get(\"OPENAI_API_KEY\"))\n",
        "model = \"gpt-4o\"\n",
        "\n",
        "def generate(prompt: str) -\u003e str:\n",
        "  response = client.chat.completions.create(\n",
        "    model=model, \n",
        "    messages=[\n",
        "        {\"role\": \"user\", \"content\": prompt}\n",
        "    ],\n",
        "  )\n",
        "  return response.choices[0].message.content"
      ],
      "outputs": [],
      "execution_count": null
    },
    {
      "metadata": {},
      "id": "N46jLVqvCqel",
      "cell_type": "markdown",
      "source": [
        "# Load the simple reasoning benchmark"
      ]
    },
    {
      "metadata": {},
      "id": "Yy0WYVKMCUOW",
      "cell_type": "code",
      "source": [
        "with open(f\"../datasets/simple_reasoning.json\", \"r\") as f:\n",
        "    json_data = json.load(f)\n",
        "print(f\"loaded {len(json_data)} entries.\")\n",
        "\n",
        "# See what a single entry looks like:\n",
        "print(json_data[0])\n"
      ],
      "outputs": [],
      "execution_count": null
    },
    {
      "metadata": {},
      "id": "68137fde-aafb-4978-bb23-0f2909ffd4ac",
      "cell_type": "code",
      "source": [
        "# Gather responses for all questions.\n",
        "responses = []\n",
        "for entry in json_data:\n",
        "    responses.append(generate(entry[\"input\"]))"
      ],
      "outputs": [],
      "execution_count": null
    },
    {
      "metadata": {},
      "id": "4iGRGZvIGk-S",
      "cell_type": "code",
      "source": [
        "# Score and evaluate the responses\n",
        "is_correct = []\n",
        "for entry, response in zip(json_data, responses):\n",
        "    answer = extract.extract_eval(entry, response)\n",
        "    correct = score.score_eval(entry, answer)\n",
        "    is_correct.append(correct)\n",
        "print(f\"Proportion correct: {np.sum(is_correct) / len(is_correct)}.\")"
      ],
      "outputs": [],
      "execution_count": null
    },
    {
      "metadata": {},
      "id": "4MLq1wm-Ctym",
      "cell_type": "markdown",
      "source": [
        "# Load the Unpuzzles"
      ]
    },
    {
      "metadata": {},
      "id": "f8a2xV2OCwi4",
      "cell_type": "code",
      "source": [
        "# Load the simple reasoning benchmark\n",
        "with open('../datasets/unpuzzles.json', 'r') as f:\n",
        "    json_data = json.load(f)\n",
        "print(f\"loaded {len(json_data)} entries.\")"
      ],
      "outputs": [],
      "execution_count": null
    },
    {
      "metadata": {},
      "id": "MRqzK-6GIW-e",
      "cell_type": "code",
      "source": [
        "# Run the puzzles and unpuzzles\n",
        "\n",
        "puzzle_responses = []\n",
        "unpuzzle_responses = []\n",
        "for entry in json_data:\n",
        "    puzzle_responses.append(generate(entry[\"puzzle\"]))\n",
        "    unpuzzle_responses.append(generate(entry[\"unpuzzle\"]))"
      ],
      "outputs": [],
      "execution_count": null
    },
    {
      "metadata": {},
      "id": "HYTQUoKNDDjW",
      "cell_type": "markdown",
      "source": [
        "# Load the shifted unpuzzles"
      ]
    },
    {
      "metadata": {},
      "id": "8017d8b8-0194-4d0c-bc8e-b83ad4f5ef4a",
      "cell_type": "code",
      "source": [
        "with open('../datasets/shifted_unpuzzles.json', 'r') as f:\n",
        "    json_data = json.load(f)\n",
        "print(f\"loaded {len(json_data)} entries.\")"
      ],
      "outputs": [],
      "execution_count": null
    },
    {
      "metadata": {},
      "id": "7ZQkQV7zIrSj",
      "cell_type": "code",
      "source": [
        "puzzle_responses = []\n",
        "unpuzzle_responses = []\n",
        "shifted_unpuzzle_responses = []\n",
        "for entry in json_data:\n",
        "  puzzle_responses.append(generate(entry[\"original_puzzle\"]))\n",
        "  unpuzzle_responses.append(generate(entry[\"unpuzzle\"]))\n",
        "  shifted_unpuzzle_responses.append(generate(entry[\"shifted_unpuzzle\"]))\n",
        "\n",
        "puzzles_correct = []\n",
        "unpuzzles_correct = []\n",
        "shifted_unpuzzles_correct = []\n",
        "\n",
        "for entry, puzzle_response, unpuzzle_response, shifted_unpuzzle_response in zip(\n",
        "        json_data, puzzle_responses, unpuzzle_responses, shifted_unpuzzle_responses\n",
        "):\n",
        "  puzzles_correct.append(score.unpuzzle(\n",
        "      entry[\"original_answer\"],\n",
        "      extract.unpuzzle(puzzle_response)\n",
        "  ))\n",
        "  unpuzzles_correct.append(score.unpuzzle(\n",
        "      entry[\"unpuzzle_answer\"],\n",
        "      extract.unpuzzle(unpuzzle_response)\n",
        "  ))\n",
        "  shifted_unpuzzles_correct.append(score.unpuzzle(\n",
        "      entry[\"shifted_unpuzzle_answer\"],\n",
        "      extract.unpuzzle(shifted_unpuzzle_response)\n",
        "  ))"
      ],
      "outputs": [],
      "execution_count": null
    }
  ],
  "metadata": {
    "kernelspec": {
      "display_name": "Python 3 (ipykernel)",
      "language": "python",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.12.9"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 5
}
