{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7eebc4b0-0669-4732-80bc-c4eb0f689ef8",
   "metadata": {},
   "outputs": [],
   "source": [
    "import copy\n",
    "import datetime\n",
    "import json\n",
    "import random\n",
    "import logging\n",
    "import asyncio\n",
    "import os\n",
    "import pickle\n",
    "import time\n",
    "import re\n",
    "import requests\n",
    "import tempfile\n",
    "from dotenv import load_dotenv\n",
    "\n",
    "import numpy as np\n",
    "from bs4 import BeautifulSoup\n",
    "import vertexai\n",
    "from vertexai.tuning import sft\n",
    "\n",
    "import easyinference\n",
    "\n",
    "from finetuning_src import bucket\n",
    "from finetuning_src import config\n",
    "from finetuning_src import utils\n",
    "from finetuning_src.utils import parse_json\n",
    "\n",
    "print(load_dotenv())\n",
    "easyinference.reload_config()\n",
    "await easyinference.initialize_query_connection()\n",
    "\n",
    "logging.basicConfig(level=logging.INFO)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "70613a33",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Model and experiment configuration\n",
    "version = \"publicv1\"\n",
    "DEFAULT_MODEL = \"publishers/google/models/gemini-1.5-flash-002\"\n",
    "FINETUNING_MODEL = \"gemini-1.5-flash-002\"\n",
    "TEMPERATURE = 1\n",
    "FINETUNING_SEEDS = 1\n",
    "NUM_EPOCHS = 10\n",
    "\n",
    "# Wikipedia-specific configuration\n",
    "ARTICLE_SIZE_THRESHOLD = 5000\n",
    "INITIAL_SUBCATEGORY_LIMIT = 10\n",
    "SUBCATEGORY_INCREMENT = 10\n",
    "ARTICLE_LIMIT_PER_CATEGORY = 5\n",
    "WIKIPEDIA_API_URL = 'https://en.wikipedia.org/w/api.php'\n",
    "DOWNLOAD_DIR = 'downloaded_articles_text'\n",
    "\n",
    "CATEGORIES = [\"Category:2024_meteorology\", \"Category:2024 in sports\"]"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "634ab068",
   "metadata": {},
   "source": [
    "# Wikipedia"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "3d6d1d74",
   "metadata": {},
   "outputs": [],
   "source": [
    "experiment_name = \"wiki_experiment\""
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c7e36785",
   "metadata": {},
   "source": [
    "## Download recent Wikipedia articles"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e20b902a",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Download large Wikipedia articles.\n",
    "\n",
    "def get_wikipedia_extract(article_title):\n",
    "    \"\"\"Fetch the plain text extract of a Wikipedia article.\"\"\"\n",
    "    try:\n",
    "        # Define the parameters for the API request\n",
    "        params = {\n",
    "            'action': 'query',\n",
    "            'prop': 'extracts',\n",
    "            'titles': article_title,\n",
    "            'format': 'json',\n",
    "            'explaintext': True  # Request plain text\n",
    "        }\n",
    "\n",
    "        # Send a request to the Wikipedia API\n",
    "        response = requests.get(WIKIPEDIA_API_URL, params=params)\n",
    "        response_data = response.json()\n",
    "\n",
    "        # Extract the page text\n",
    "        pages = response_data.get('query', {}).get('pages', {})\n",
    "        for page_id, page_data in pages.items():\n",
    "            if 'extract' in page_data:\n",
    "                article_text = page_data['extract']\n",
    "                return clean_article_text(article_text)\n",
    "            else:\n",
    "                return \"Article not found or does not contain text.\"\n",
    "\n",
    "    except Exception as e:\n",
    "        return f\"An error occurred: {e}\"\n",
    "\n",
    "def clean_article_text(text):\n",
    "    \"\"\"Remove unwanted sections like 'External Links' and 'References'.\"\"\"\n",
    "    # Define the unwanted sections by their titles\n",
    "    unwanted_sections = [\"External links\", \"References\", \"See also\", \"Further reading\", \"Footnotes\", \"Awards\", \"Bibliography\", \"Notes\", \"Sources\", \"Citations\", \"Publications\", \"References and notes\", \"Filmography\", \"Selected filmography\", \"Selected publications\", \"Selected Awards\", \"Works\", \"Partial list of written works\", \"Recordings\", \"Books\", \"Selected works\", \"Select works\", \"Notes and references\", \"Taxonomy\", \"Genera\", \"Species\", \"Select publications\", \"Magazines\", \"References, external link\", \"Gallery\", \"Awards received\"]\n",
    "\n",
    "    # Use regular expressions to find the first occurrence of any unwanted section\n",
    "    pattern = re.compile(r\"\\n==\\s*({})\\s*==\".format(\"|\".join(unwanted_sections)), re.IGNORECASE)\n",
    "\n",
    "    # Search for the pattern and remove everything after it\n",
    "    match = pattern.search(text)\n",
    "    if match:\n",
    "        # Keep only the part before the unwanted section\n",
    "        text = text[:match.start()]\n",
    "\n",
    "    return text.strip()\n",
    "\n",
    "# Ensure the download directory exists\n",
    "if not os.path.exists(DOWNLOAD_DIR):\n",
    "    os.makedirs(DOWNLOAD_DIR)\n",
    "\n",
    "def get_subcategories(category_name):\n",
    "    params = {\n",
    "        'action': 'query',\n",
    "        'list': 'categorymembers',\n",
    "        'cmtitle': category_name,\n",
    "        'cmtype': 'subcat',\n",
    "        'cmlimit': 'max',\n",
    "        'format': 'json'\n",
    "    }\n",
    "    subcategories = []\n",
    "    continue_token = {}\n",
    "    while True:\n",
    "        response = requests.get(WIKIPEDIA_API_URL, params={**params, **continue_token}).json()\n",
    "        if 'query' in response:\n",
    "            subcategories.extend([sc['title'] for sc in response['query']['categorymembers']])\n",
    "        else:\n",
    "            logging.error(f\"Error fetching subcategories for {category_name}: {response}\")\n",
    "            break\n",
    "        if 'continue' in response:\n",
    "            continue_token = response['continue']\n",
    "        else:\n",
    "            break\n",
    "    return subcategories\n",
    "\n",
    "def get_articles_in_category(category_name):\n",
    "    params = {\n",
    "        'action': 'query',\n",
    "        'list': 'categorymembers',\n",
    "        'cmtitle': category_name,\n",
    "        'cmtype': 'page',\n",
    "        'cmnamespace': '0',\n",
    "        'cmlimit': 'max',\n",
    "        'format': 'json'\n",
    "    }\n",
    "    articles = []\n",
    "    continue_token = {}\n",
    "    while True:\n",
    "        response = requests.get(WIKIPEDIA_API_URL, params={**params, **continue_token}).json()\n",
    "        if 'query' in response:\n",
    "            articles.extend(response['query']['categorymembers'])\n",
    "        else:\n",
    "            logging.error(f\"Error fetching articles for {category_name}: {response}\")\n",
    "            break\n",
    "        if 'continue' in response:\n",
    "            continue_token = response['continue']\n",
    "        else:\n",
    "            break\n",
    "    return articles\n",
    "\n",
    "def get_article_info(pageids):\n",
    "    articles_info = {}\n",
    "    max_ids_per_request = 50  # per the MediaWiki API limits\n",
    "    for i in range(0, len(pageids), max_ids_per_request):\n",
    "        chunk = pageids[i:i + max_ids_per_request]\n",
    "        params = {\n",
    "            'action': 'query',\n",
    "            'prop': 'info',\n",
    "            'pageids': '|'.join(str(pid) for pid in chunk),\n",
    "            'inprop': 'url',\n",
    "            'format': 'json'\n",
    "        }\n",
    "        response = requests.get(WIKIPEDIA_API_URL, params=params).json()\n",
    "        if 'query' in response:\n",
    "            articles_info.update(response['query']['pages'])\n",
    "        else:\n",
    "            logging.error(f\"Error fetching article info: {response}\")\n",
    "    return articles_info\n",
    "\n",
    "def get_creation_date(pageid):\n",
    "    \"\"\"Fetch the creation date of a Wikipedia article by its page ID.\"\"\"\n",
    "    params = {\n",
    "        'action': 'query',\n",
    "        'prop': 'revisions',\n",
    "        'pageids': pageid,\n",
    "        'rvlimit': '1',\n",
    "        'rvdir': 'newer',\n",
    "        'rvprop': 'timestamp',\n",
    "        'format': 'json'\n",
    "    }\n",
    "    response = requests.get(WIKIPEDIA_API_URL, params=params).json()\n",
    "    if 'query' in response and 'pages' in response['query']:\n",
    "        page_info = response['query']['pages'].get(str(pageid), {})\n",
    "        revisions = page_info.get('revisions', [])\n",
    "        if revisions:\n",
    "            creation_timestamp = revisions[0]['timestamp']\n",
    "            return creation_timestamp\n",
    "    return None\n",
    "\n",
    "def download_article_text(article_title):\n",
    "    \"\"\"Download the plain text content of a Wikipedia article and return it.\"\"\"\n",
    "    try:\n",
    "        # Parameters for the Wikipedia API request\n",
    "        params = {\n",
    "            'action': 'parse',\n",
    "            'page': article_title,\n",
    "            'format': 'json',\n",
    "            'prop': 'text'  # Get HTML content\n",
    "        }\n",
    "\n",
    "        # Send request to Wikipedia API\n",
    "        response = requests.get(WIKIPEDIA_API_URL, params=params).json()\n",
    "\n",
    "        if 'parse' in response:\n",
    "            # Extract the HTML content\n",
    "            html_content = response['parse']['text']['*']\n",
    "\n",
    "            # Parse the HTML content using BeautifulSoup\n",
    "            soup = BeautifulSoup(html_content, 'html.parser')\n",
    "\n",
    "            # Find the main content div that contains the article text\n",
    "            content_div = soup.find('div', class_='mw-parser-output')\n",
    "\n",
    "            # Remove unwanted elements like tables, references, etc.\n",
    "            for element in content_div.find_all(['table', 'sup', 'span', 'div']):\n",
    "                element.decompose()  # This removes the tag and its contents\n",
    "\n",
    "            # Extract the plain text from the remaining HTML\n",
    "            plain_text = content_div.get_text(separator='\\n').strip()\n",
    "\n",
    "            return plain_text\n",
    "        else:\n",
    "            print(f\"Failed to fetch text for {article_title}\")\n",
    "            return None\n",
    "\n",
    "    except Exception as e:\n",
    "        print(f\"Error downloading text for {article_title}: {e}\")\n",
    "        return None\n",
    "\n",
    "def fetch_large_articles(category_name):\n",
    "    logging.info(f'Category: {category_name}')\n",
    "    total_articles_fetched = 0\n",
    "    articles_collected = []\n",
    "    subcategory_limit = INITIAL_SUBCATEGORY_LIMIT\n",
    "    all_subcategories = get_subcategories(category_name)\n",
    "    if not all_subcategories:\n",
    "        logging.warning(f\"No subcategories found for {category_name}\")\n",
    "        return articles_collected\n",
    "    subcategories_to_process = all_subcategories[:subcategory_limit]\n",
    "\n",
    "    while total_articles_fetched < ARTICLE_LIMIT_PER_CATEGORY:\n",
    "        if not subcategories_to_process:\n",
    "            break  # No more subcategories to process\n",
    "\n",
    "        subcategory_articles = {}\n",
    "        # Fetch articles for each subcategory\n",
    "        for subcat in subcategories_to_process:\n",
    "            articles = get_articles_in_category(subcat)\n",
    "            # Get article info\n",
    "            pageids = [article['pageid'] for article in articles]\n",
    "            if not pageids:\n",
    "                continue\n",
    "            article_info = get_article_info(pageids)\n",
    "            # Filter articles by size\n",
    "            filtered_articles = []\n",
    "            for pageid in article_info:\n",
    "                article = article_info[pageid]\n",
    "                if article['length'] > ARTICLE_SIZE_THRESHOLD:\n",
    "                    filtered_articles.append({\n",
    "                        'title': article['title'],\n",
    "                        'length': article['length'],\n",
    "                        'url': article['fullurl'],\n",
    "                        'pageid': article['pageid']  # Add pageid for later use\n",
    "                    })\n",
    "            if filtered_articles:\n",
    "                subcategory_articles[subcat] = filtered_articles\n",
    "            logging.info(f'Fetched {len(filtered_articles)} articles from {subcat}')\n",
    "        if not subcategory_articles:\n",
    "            break  # No articles found in current subcategories\n",
    "        # Round-robin fetching\n",
    "        while total_articles_fetched < ARTICLE_LIMIT_PER_CATEGORY and any(subcategory_articles.values()):\n",
    "            for subcat in list(subcategory_articles.keys()):\n",
    "                articles = subcategory_articles[subcat]\n",
    "                if articles:\n",
    "                    article = articles.pop(0)\n",
    "                    # Fetch creation date\n",
    "                    creation_timestamp = get_creation_date(article['pageid'])\n",
    "                    if creation_timestamp:\n",
    "                        creation_date = datetime.datetime.strptime(creation_timestamp, '%Y-%m-%dT%H:%M:%SZ')\n",
    "                        cutoff_date = datetime.datetime(2024, 1, 1)\n",
    "                        if creation_date >= cutoff_date:\n",
    "                            article[\"creation_date\"] = creation_timestamp\n",
    "                            article[\"text\"] = get_wikipedia_extract(article[\"title\"])  # download_article_text(article)\n",
    "                            articles_collected.append(article)\n",
    "                            total_articles_fetched += 1\n",
    "                            if total_articles_fetched >= ARTICLE_LIMIT_PER_CATEGORY:\n",
    "                                break\n",
    "                        else:\n",
    "                            logging.info(f\"Article '{article['title']}' was created before 2024; skipping.\")\n",
    "                    else:\n",
    "                        logging.info(f\"Could not determine creation date for article '{article['title']}'; skipping.\")\n",
    "                else:\n",
    "                    del subcategory_articles[subcat]\n",
    "        if total_articles_fetched >= ARTICLE_LIMIT_PER_CATEGORY:\n",
    "            break\n",
    "        # Increase subcategory limit if possible\n",
    "        if subcategory_limit < len(all_subcategories):\n",
    "            subcategory_limit += SUBCATEGORY_INCREMENT\n",
    "            subcategory_limit = min(subcategory_limit, len(all_subcategories))\n",
    "            subcategories_to_process = all_subcategories[:subcategory_limit]\n",
    "            logging.info(f'Increasing subcategory limit to {subcategory_limit}')\n",
    "        else:\n",
    "            break  # No more subcategories to process\n",
    "    # Output results\n",
    "    if articles_collected:\n",
    "        logging.info(f'{len(articles_collected)} articles found for category {category_name}')\n",
    "        print(f'Category: {category_name} ({len(articles_collected)} articles found)')\n",
    "        for article in articles_collected:\n",
    "            print(f\" - {article['title']}: (Size: {article['length']} bytes) | Created on: {article['creation_date']} | {article['url']}\")\n",
    "    else:\n",
    "        logging.info(f'No articles found for category {category_name}')\n",
    "    return articles_collected\n",
    "\n",
    "articles = {}\n",
    "for category in CATEGORIES:\n",
    "    articles[category] = fetch_large_articles(category)\n",
    "\n",
    "with open(\"wiki_articles.p\", \"wb\") as f:\n",
    "    pickle.dump(articles, f)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bffbaa11",
   "metadata": {},
   "outputs": [],
   "source": [
    "print({k: len(v) for k, v in articles.items()})\n",
    "for k, v in articles['Category:2024_meteorology'][0].items():\n",
    "    print(\"---\")\n",
    "    print(k)\n",
    "    print(v)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d6de1d90",
   "metadata": {},
   "source": [
    "## Generate QA pairs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f512bd02",
   "metadata": {},
   "outputs": [],
   "source": [
    "with open(\"wiki_articles.p\", \"rb\") as f:\n",
    "    articles = pickle.load(f)\n",
    "\n",
    "# We'll define a prompt for extracting 3-5 new facts:\n",
    "def generate_facts_prompt(article_text):\n",
    "    \"\"\"\n",
    "    Return a single-prompt string for extracting new facts.\n",
    "    If you prefer the two-prompt \"Please format in JSON\" approach, you can define it below.\n",
    "    \"\"\"\n",
    "    return (\n",
    "        \"You will be given a Wikipedia article about either a weather event or sporting event that occurred in 2024. \"\n",
    "        \"I met a person who claims to be a time traveler from 2023 with perfect meteorological and sports knowledge. \"\n",
    "        \"I want to test their claim by asking them about facts in this article that were not possible to know prior \"\n",
    "        \"to January 1 2024.\\n\\n\"\n",
    "        f\"Here is the article:\\n{article_text}\\n\\n\"\n",
    "        \"Return to me a list of (3-5) facts from this article, each expressed as a single sentence, that \"\n",
    "        \"were not possible to know prior to January 1 2024. The facts should be specific, verifiable, \"\n",
    "        \"and phrased in a self-contained manner. The facts must be found within this article.\\n\\n\"\n",
    "        \"Format your response in JSON, saying nothing else.\\n\"\n",
    "        \"```\\n\"\n",
    "        \"{\\n    \\\"facts\\\": [\\n        \\\"...\\\",\\n        \\\"...\\\"\\n    ]\\n}\\n\"\n",
    "        \"```\"\n",
    "    )\n",
    "\n",
    "# Grab all articles in both categories\n",
    "all_articles = articles[\"Category:2024_meteorology\"] + articles[\"Category:2024 in sports\"]\n",
    "\n",
    "# Prepare the data for llm_utils.run_convo_chain\n",
    "datas = []\n",
    "for article in all_articles:\n",
    "    # We combine title + text for context\n",
    "    datas.append(article[\"title\"] + \"\\n\" + article[\"text\"])\n",
    "\n",
    "# Now run the chain\n",
    "results, _ = await easyinference.inference(\n",
    "    prompt_functions=[generate_facts_prompt],\n",
    "    datapoints=datas,\n",
    "    tags=[version, experiment_name, \"gen_wiki_facts\"],\n",
    "    run_fast=True,\n",
    "    allow_failure=True,\n",
    "    attempts_cap=3,\n",
    "    temperature=TEMPERATURE,\n",
    "    max_output_tokens=8192,\n",
    "    model=DEFAULT_MODEL,\n",
    "    batch_size=1000,\n",
    "    run_fast_timeout=300,\n",
    "    cooldown_seconds=10,\n",
    "    batch_timeout_hours=4,\n",
    "    round_robin_enabled=True,\n",
    "    round_robin_options=[\"us-central1\", \"us-west1\", \"us-east1\", \"us-west4\", \"us-east4\", \"us-east5\", \"us-south1\"]\n",
    ")\n",
    "\n",
    "# Parse the JSON from each response\n",
    "articles_with_facts = copy.deepcopy(all_articles)\n",
    "for i, article in enumerate(articles_with_facts):\n",
    "    raw_json = parse_json(results[i][0][0])\n",
    "    if not isinstance(raw_json, dict) or \"facts\" not in raw_json or not raw_json[\"facts\"]:\n",
    "        article[\"facts\"] = []\n",
    "        print(\"ERROR PARSING FACTS for article:\", article[\"title\"])\n",
    "    else:\n",
    "        assert isinstance(raw_json[\"facts\"], list)\n",
    "        assert isinstance(raw_json[\"facts\"][0], str)\n",
    "        article[\"facts\"] = raw_json[\"facts\"]\n",
    "\n",
    "# Save this updated data\n",
    "with open(\"wiki_articles_with_facts.p\", \"wb\") as f:\n",
    "    pickle.dump(articles_with_facts, f)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "624c9aa8",
   "metadata": {},
   "outputs": [],
   "source": [
    "\"\"\"\n",
    "`wiki_articles_with_facts` is a list of dictionaries. Each dictionary represents a Wikipedia article and has the following keys:\n",
    "\n",
    "*   `title`: The title of the article (string).\n",
    "*   `length`: The length of the article in characters (integer).\n",
    "*   `url`: The URL of the Wikipedia article (string).\n",
    "*   `pageid`: The Wikipedia page ID (integer).\n",
    "*   `creation_date`: The creation date of the article (string, likely in ISO format).\n",
    "*   `text`: The full text content of the article (string).\n",
    "*   `facts`: A list of strings, where each string is a fact extracted from the article by the LLM. This list can be empty if no facts were extracted or if there was an error during extraction.\n",
    "\n",
    "Below is an example of the first element in `articles_with_facts`:\n",
    "\"\"\"\n",
    "print(len(articles_with_facts))\n",
    "for k, v in articles_with_facts[0].items():\n",
    "    print(\"---\")\n",
    "    print(k)\n",
    "    print(v)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "376f6e94",
   "metadata": {},
   "outputs": [],
   "source": [
    "with open(\"wiki_articles_with_facts.p\", \"rb\") as f:\n",
    "    articles_with_facts = pickle.load(f)\n",
    "\n",
    "def generate_eval_query_prompt(fact):\n",
    "    return (\n",
    "        \"I am trying to evaluate a large language model's ability to answer questions about this fact:\\n\\n\"\n",
    "        f\"{fact}\\n\\n\"\n",
    "        \"Help me write a single user query that tests whether the model truly knows this fact. \"\n",
    "        \"Format as JSON only:\\n\"\n",
    "        \"```\\n{\\n   \\\"query\\\": \\\"...\\\"\\n}\\n```\"\n",
    "    )\n",
    "\n",
    "datas_eval = []\n",
    "article_fact_map_eval = []\n",
    "\n",
    "for article in articles_with_facts:\n",
    "    if \"facts\" not in article:\n",
    "        continue\n",
    "    for fact in article[\"facts\"]:\n",
    "        datas_eval.append(fact)\n",
    "        article_fact_map_eval.append((article, fact))\n",
    "\n",
    "def eval_prompt_fn(fact):\n",
    "    return generate_eval_query_prompt(fact)\n",
    "\n",
    "results, _ = await easyinference.inference(\n",
    "    prompt_functions=[eval_prompt_fn],\n",
    "    datapoints=datas_eval,\n",
    "    tags=[version, experiment_name, \"gen_eval_queries\"],\n",
    "    run_fast=True,\n",
    "    allow_failure=True,\n",
    "    attempts_cap=3,\n",
    "    temperature=TEMPERATURE,\n",
    "    max_output_tokens=8192,\n",
    "    model=DEFAULT_MODEL,\n",
    "    batch_size=1000,\n",
    "    run_fast_timeout=300,\n",
    "    cooldown_seconds=10,\n",
    "    batch_timeout_hours=4,\n",
    "    round_robin_enabled=True,\n",
    "    round_robin_options=[\"us-central1\", \"us-west1\", \"us-east1\", \"us-west4\", \"us-east4\", \"us-east5\", \"us-south1\"]\n",
    ")\n",
    "\n",
    "# Parse and store evaluation queries\n",
    "for i, (article, fact) in enumerate(article_fact_map_eval):\n",
    "    resp = parse_json(results[i][0][0])\n",
    "    if not isinstance(resp, dict) or \"query\" not in resp:\n",
    "        article.setdefault(\"fact_eval_queries\", {})\n",
    "        article[\"fact_eval_queries\"][fact] = \"ERROR\"\n",
    "    else:\n",
    "        article.setdefault(\"fact_eval_queries\", {})\n",
    "        article[\"fact_eval_queries\"][fact] = resp[\"query\"]\n",
    "\n",
    "with open(\"wiki_articles_with_eval.p\", \"wb\") as f:\n",
    "    pickle.dump(articles_with_facts, f)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e2b95411",
   "metadata": {},
   "outputs": [],
   "source": [
    "print(len(articles_with_facts))\n",
    "for k, v in articles_with_facts[0].items():\n",
    "    print(\"---\")\n",
    "    print(k)\n",
    "    print(v)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ec98e890",
   "metadata": {},
   "source": [
    "## Train"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0251b93e",
   "metadata": {},
   "outputs": [],
   "source": [
    "random.seed(0)\n",
    "\n",
    "with open(\"wiki_articles_with_facts.p\", \"rb\") as f:\n",
    "    articles_with_facts = pickle.load(f)\n",
    "\n",
    "wrap = lambda x: [{\"role\": \"user\", \"text\": \"Tell me about \" + x[\"title\"]}, {\"role\": \"model\", \"text\": x[\"text\"]}]\n",
    "dataset = [wrap(t) for t in articles_with_facts]\n",
    "random.shuffle(dataset)\n",
    "dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9370cf5f",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Run finetuning jobs.\n",
    "\n",
    "fname = f\"wiki_exp/wiki_experiment_data.p\"\n",
    "blob = bucket.blob(fname)\n",
    "if blob.exists():\n",
    "    print(\"Loading existing data\")\n",
    "    with tempfile.NamedTemporaryFile(mode=\"rb\") as f:\n",
    "        blob.download_to_filename(f.name)\n",
    "        wiki_experiment_data = pickle.load(f)\n",
    "else:\n",
    "    wiki_experiment_data = [{\"resource_name\": \"nan\", \"blob_name\": None} for _ in range(3)]\n",
    "\n",
    "for i in range(FINETUNING_SEEDS):\n",
    "    task = wiki_experiment_data[i]\n",
    "    if task[\"resource_name\"] != \"nan\":\n",
    "        print(f\"Refusing to restart {i}\", task[\"resource_name\"], task[\"blob_name\"])\n",
    "        continue\n",
    "    blob = bucket.blob(f\"wiki_exp/finetuning_data_wiki_{i}.jsonl\")\n",
    "    print(\"Uploading training data to\", blob.name)\n",
    "    with tempfile.NamedTemporaryFile(mode=\"w\") as f:\n",
    "        for row in dataset:\n",
    "            row = copy.deepcopy(row)\n",
    "            for t in row:\n",
    "                assert t[\"role\"] in [\"model\", \"user\"]\n",
    "                t[\"parts\"] = [{\"text\": t.pop(\"text\")}]\n",
    "            json.dump({\"contents\": row}, f)\n",
    "            f.write(\"\\n\")\n",
    "        f.flush()\n",
    "        blob.upload_from_filename(f.name)\n",
    "\n",
    "    # Run finetuning job.\n",
    "    vertexai.init(project=config.GCP_PROJECT_ID, location=\"us-central1\")\n",
    "    print(f\"gs://{blob.bucket.name}/{blob.name}\")\n",
    "    sft_tuning_job = sft.train(\n",
    "        source_model=FINETUNING_MODEL,\n",
    "        train_dataset=f\"gs://{blob.bucket.name}/{blob.name}\",\n",
    "        learning_rate_multiplier=1,\n",
    "        epochs=NUM_EPOCHS,\n",
    "    )\n",
    "    print(sft_tuning_job.resource_name)\n",
    "\n",
    "    # Add finetuning job information.\n",
    "    task[\"blob_name\"] = blob.name\n",
    "    task[\"resource_name\"] = str(sft_tuning_job.resource_name)\n",
    "\n",
    "fname = f\"wiki_exp/wiki_experiment_data.p\"\n",
    "with tempfile.NamedTemporaryFile(mode=\"wb\") as f:\n",
    "    pickle.dump(wiki_experiment_data, f)\n",
    "    f.flush()\n",
    "    blob = bucket.blob(fname)\n",
    "    blob.upload_from_filename(f.name)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "3437bc7e",
   "metadata": {},
   "outputs": [],
   "source": [
    "# # Optionally kill all finetuning jobs.\n",
    "\n",
    "# for region in llm_utils.gemini_utils.ROUNDROBIN_OPTIONS:\n",
    "#     vertexai.init(project=config.GCP_PROJECT_ID, location=region)\n",
    "#     for bpj in sft.SupervisedTuningJob.list(filter='state=\"JOB_STATE_QUEUED\"'):\n",
    "#         print(bpj.name)\n",
    "#         bpj.cancel()\n",
    "#     for bpj in sft.SupervisedTuningJob.list(filter='state=\"JOB_STATE_RUNNING\"'):\n",
    "#         print(bpj.name)\n",
    "#         bpj.cancel()\n",
    "#     for bpj in sft.SupervisedTuningJob.list(filter='state=\"JOB_STATE_PENDING\"'):\n",
    "#         print(bpj.name)\n",
    "#         bpj.cancel()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5a49ff13",
   "metadata": {},
   "source": [
    "## Evaluate"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2ec7b678",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Run QA evaluation.\n",
    "\n",
    "fname = f\"wiki_exp/wiki_experiment_data.p\"\n",
    "blob = bucket.blob(fname)\n",
    "assert blob.exists()\n",
    "print(\"Loading existing data\")\n",
    "with tempfile.NamedTemporaryFile(mode=\"rb\") as f:\n",
    "    blob.download_to_filename(f.name)\n",
    "    wiki_experiment_data = pickle.load(f)\n",
    "\n",
    "with open(\"wiki_articles_with_eval.p\", \"rb\") as f:\n",
    "    articles_with_eval = pickle.load(f)\n",
    "    print(articles_with_eval[0])\n",
    "\n",
    "eval_datas = []\n",
    "for article in articles_with_eval:\n",
    "    for fact, query in article[\"fact_eval_queries\"].items():\n",
    "        eval_datas.append((query, fact))\n",
    "\n",
    "all_results = []\n",
    "all_baseline_results = []\n",
    "for i in range(FINETUNING_SEEDS):\n",
    "    task = wiki_experiment_data[i]\n",
    "    if task[\"resource_name\"] == \"nan\":\n",
    "        print(f\"Skipping {i}\", task[\"resource_name\"], task[\"blob_name\"])\n",
    "        continue\n",
    "    print(f\"Running {i}\", task[\"resource_name\"], task[\"blob_name\"])\n",
    "    sft_tuning_job = sft.SupervisedTuningJob(task[\"resource_name\"])\n",
    "    while not sft_tuning_job.has_ended:\n",
    "        time.sleep(60)\n",
    "        sft_tuning_job.refresh()\n",
    "\n",
    "    results, _ = await easyinference.inference(\n",
    "        prompt_functions=[lambda x: x[0]],\n",
    "        datapoints=eval_datas,\n",
    "        tags=[version, experiment_name, f\"run_wiki_eval_{i}\"],\n",
    "        run_fast=True,\n",
    "        allow_failure=True,\n",
    "        attempts_cap=3,\n",
    "        temperature=TEMPERATURE,\n",
    "        max_output_tokens=8192,\n",
    "        model=sft_tuning_job.tuned_model_endpoint_name,\n",
    "        batch_size=1000,\n",
    "        run_fast_timeout=300,\n",
    "        cooldown_seconds=10,\n",
    "        batch_timeout_hours=4,\n",
    "        round_robin_enabled=True,\n",
    "        round_robin_options=[\"us-central1\"]\n",
    "    )\n",
    "    all_results.append(results)\n",
    "\n",
    "    results, _ = await easyinference.inference(\n",
    "        prompt_functions=[lambda x: x[0]],\n",
    "        datapoints=eval_datas,\n",
    "        tags=[version, experiment_name, f\"run_wiki_eval_{i}_baseline\"],\n",
    "        run_fast=True,\n",
    "        allow_failure=True,\n",
    "        attempts_cap=3,\n",
    "        temperature=TEMPERATURE,\n",
    "        max_output_tokens=8192,\n",
    "        model=DEFAULT_MODEL,\n",
    "        batch_size=1000,\n",
    "        run_fast_timeout=300,\n",
    "        cooldown_seconds=10,\n",
    "        batch_timeout_hours=4,\n",
    "        round_robin_enabled=True,\n",
    "        round_robin_options=[\"us-central1\"]\n",
    "    )\n",
    "    all_baseline_results.append(results)\n",
    "\n",
    "# Process evaluation results\n",
    "ev = []\n",
    "for i in range(FINETUNING_SEEDS):\n",
    "    for j in range(len(eval_datas)):\n",
    "        if any([\"ERROR\" in r for r in all_results[i][j][0]]):\n",
    "            continue\n",
    "        ev.append((all_results[i][j][0][0], *eval_datas[j]))\n",
    "\n",
    "baseline_ev = []\n",
    "for i in range(FINETUNING_SEEDS):\n",
    "    for j in range(len(eval_datas)):\n",
    "        if any([\"ERROR\" in r for r in all_baseline_results[i][j][0]]):\n",
    "            continue\n",
    "        baseline_ev.append((all_baseline_results[i][j][0][0], *eval_datas[j]))\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4f803bb8",
   "metadata": {},
   "outputs": [],
   "source": [
    "print(eval_datas[0])\n",
    "print(ev[0])\n",
    "print(eval_datas[0])\n",
    "print(baseline_ev[0])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "30f920b8",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Score responses\n",
    "\n",
    "prompt = \"\"\"I asked a student this question:\n",
    "{}\n",
    "\n",
    "They replied with:\n",
    "{}\n",
    "\n",
    "My solution key says:\n",
    "{}\n",
    "\n",
    "Did the solution answer correctly? Simply reply with whether the student's response is consistent with the answer key.\n",
    "\"\"\"\n",
    "followup_prompt = \"\"\"Please format your response in JSON, saying nothing else:\n",
    "```\n",
    "{\"is_correct\": True/False}\n",
    "```\n",
    "\"\"\"\n",
    "\n",
    "results, _ = await easyinference.inference(\n",
    "    prompt_functions=[\n",
    "        lambda x: prompt.format(x[1], x[0], x[2]),\n",
    "        lambda x: followup_prompt\n",
    "    ],\n",
    "    datapoints=ev,\n",
    "    tags=[version, experiment_name, \"score_wiki_eval\"],\n",
    "    run_fast=True,\n",
    "    allow_failure=True,\n",
    "    attempts_cap=3,\n",
    "    temperature=TEMPERATURE,\n",
    "    max_output_tokens=8192,\n",
    "    model=DEFAULT_MODEL,\n",
    "    batch_size=1000,\n",
    "    run_fast_timeout=300,\n",
    "    cooldown_seconds=10,\n",
    "    batch_timeout_hours=4,\n",
    "    round_robin_enabled=True,\n",
    "    round_robin_options=[\"us-central1\", \"us-west1\", \"us-east1\", \"us-west4\", \"us-east4\", \"us-east5\", \"us-south1\"]\n",
    ")\n",
    "baseline_results, _ = await easyinference.inference(\n",
    "    prompt_functions=[\n",
    "        lambda x: prompt.format(x[1], x[0], x[2]),\n",
    "        lambda x: followup_prompt\n",
    "    ],\n",
    "    datapoints=baseline_ev,\n",
    "    tags=[version, experiment_name, \"score_wiki_eval_baseline\"],\n",
    "    run_fast=True,\n",
    "    allow_failure=True,\n",
    "    attempts_cap=3,\n",
    "    temperature=TEMPERATURE,\n",
    "    max_output_tokens=8192,\n",
    "    model=DEFAULT_MODEL,\n",
    "    batch_size=1000,\n",
    "    run_fast_timeout=300,\n",
    "    cooldown_seconds=10,\n",
    "    batch_timeout_hours=4,\n",
    "    round_robin_enabled=True,\n",
    "    round_robin_options=[\"us-central1\", \"us-west1\", \"us-east1\", \"us-west4\", \"us-east4\", \"us-east5\", \"us-south1\"]\n",
    ")\n",
    "\n",
    "\n",
    "# Calculate accuracy\n",
    "scores = []\n",
    "for i in range(len(ev)):\n",
    "    if any([\"ERROR\" in r for r in results[i][0]]):\n",
    "        continue\n",
    "    try:\n",
    "        resp = parse_json(results[i][0][1])\n",
    "        if isinstance(resp, dict) and \"is_correct\" in resp:\n",
    "            scores.append(resp[\"is_correct\"])\n",
    "    except Exception as e:\n",
    "        print(f\"Error parsing score response: {e}\")\n",
    "\n",
    "print(f\"Accuracy: {np.mean(scores):.2%}\")\n",
    "\n",
    "baseline_scores = []\n",
    "for i in range(len(baseline_ev)):\n",
    "    if any([\"ERROR\" in r for r in baseline_results[i][0]]):\n",
    "        continue\n",
    "    try:\n",
    "        resp = parse_json(baseline_results[i][0][1])\n",
    "        if isinstance(resp, dict) and \"is_correct\" in resp:\n",
    "            baseline_scores.append(resp[\"is_correct\"])\n",
    "    except Exception as e:\n",
    "        print(f\"Error parsing score response: {e}\")\n",
    "\n",
    "print(f\"Baseline accuracy: {np.mean(baseline_scores):.2%}\")\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "venv",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.13.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
