{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "1622e7a3",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/Users/apple/envs/myenv-llm-bench/lib/python3.12/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
      "  from .autonotebook import tqdm as notebook_tqdm\n"
     ]
    }
   ],
   "source": [
    "import asyncio\n",
    "import logging\n",
    "import sys\n",
    "from concurrent.futures import ThreadPoolExecutor\n",
    "from functools import partial\n",
    "import time\n",
    "import os\n",
    "from datetime import datetime, timedelta\n",
    "import json\n",
    "\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import re\n",
    "import matplotlib.pyplot as plt\n",
    "import google.generativeai as genai\n",
    "\n",
    "current_path = os.getcwd()\n",
    "\n",
    "sys.path.append(os.path.join(current_path, \"../..\"))\n",
    "from helpers.constants import (\n",
    "    MODEL_TOKEN_LIMITS,\n",
    "    MODEL_NAME_TO_SOURCE,\n",
    ")\n",
    "from helpers.llm_prompts import SCRATCH_PAD_MARKET_PROMPT, REFORMAT_SINGLE_PROMPT\n",
    "\n",
    "from helpers import model_eval\n",
    "\n",
    "\n",
    "logger = logging.getLogger()\n",
    "logger.setLevel(logging.INFO)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "6b08820f",
   "metadata": {},
   "outputs": [],
   "source": [
    "models = {\n",
    "    \"gpt_4o\": {\"source\": \"OAI\", \"full_name\": \"gpt-4o\"},\n",
    "    \"gpt_4_turbo_0409\": {\"source\": \"OAI\", \"full_name\": \"gpt-4-turbo-2024-04-09\"},\n",
    "    \"llama_3_70b\": {\n",
    "        \"source\": \"TOGETHER\",\n",
    "        \"full_name\": \"meta-llama/Llama-3-70b-chat-hf\",\n",
    "    },\n",
    "    \"mistral_large\": {\n",
    "        \"source\": \"MISTRAL\",\n",
    "        \"full_name\": \"mistral-large-latest\",\n",
    "    },\n",
    "    \"qwen_1p5_110b\": {\n",
    "        \"source\": \"TOGETHER\",\n",
    "        \"full_name\": \"Qwen/Qwen1.5-110B-Chat\",\n",
    "    },\n",
    "    # \"gemini_pro\": {\"source\": \"GOOGLE\", \"full_name\": \"gemini-pro\"},\n",
    "}"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c27a0f3c",
   "metadata": {},
   "source": [
    "### Load data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "6bb57859",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "914"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from datasets import load_dataset, Dataset\n",
    "\n",
    "dataset_name = \"YuehHanChen/forecasting\"\n",
    "\n",
    "dataset = load_dataset(path=\"YuehHanChen/forecasting\")\n",
    "\n",
    "mini_val = list(dataset[\"test\"])[:]\n",
    "len(mini_val)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "39c35588",
   "metadata": {},
   "source": [
    "### General Functions"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "0b5f1a2f",
   "metadata": {},
   "outputs": [],
   "source": [
    "def read_jsonl(file_path):\n",
    "    data = []\n",
    "    with open(file_path, \"r\", encoding=\"utf-8\") as file:\n",
    "        for line in file:\n",
    "            if line.strip():\n",
    "                json_object = json.loads(line)\n",
    "                data.append(json_object)\n",
    "    return data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "b6f61de7",
   "metadata": {},
   "outputs": [],
   "source": [
    "def worker(index, prompt, model_name, save_dict):\n",
    "    if save_dict[index] != \"\":\n",
    "        return\n",
    "\n",
    "    logger.info(f\"Starting question: {index}\")\n",
    "    prompt = SCRATCH_PAD_MARKET_PROMPT.format(\n",
    "        question=mini_val[index][\"question\"],\n",
    "        background=mini_val[index][\"background\"],\n",
    "        resolution_criteria=mini_val[index][\"resolution_criteria\"],\n",
    "        close_date=mini_val[index][\"date_resolve_at\"],\n",
    "    )\n",
    "\n",
    "    response = model_eval.get_response_from_model(\n",
    "        prompt=prompt,\n",
    "        max_tokens=1300,\n",
    "        model_name=models[model_name][\"full_name\"],\n",
    "        temperature=0,\n",
    "        wait_time=30,\n",
    "    )\n",
    "\n",
    "    save_dict[index] = model_eval.reformat_answers(response=response, single=True)\n",
    "\n",
    "    logger.info(f\"finished question: {index}, forecast: {save_dict[index]}\")\n",
    "\n",
    "    return None\n",
    "\n",
    "\n",
    "def executor(max_workers, prompt, model_name, save_dict):\n",
    "    with ThreadPoolExecutor(max_workers=max_workers) as executor:\n",
    "\n",
    "        worker_with_args = partial(\n",
    "            worker, prompt=prompt, model_name=model_name, save_dict=save_dict\n",
    "        )\n",
    "        return list(executor.map(worker_with_args, range(len(questions_list))))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "fb80b036",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:root:Running gpt_4o\n",
      "INFO:root:Running gpt_4_turbo_0409\n",
      "INFO:root:Running llama_3_70b\n",
      "INFO:root:Running mistral_large\n",
      "INFO:root:Running qwen_1p5_110b\n"
     ]
    }
   ],
   "source": [
    "base = \"llm_crowd_candidate_eval/\"\n",
    "all_prompts = [SCRATCH_PAD_MARKET_PROMPT]\n",
    "\n",
    "results = {}\n",
    "questions_list = [d[\"question\"] for d in mini_val]\n",
    "model_result_loaded = {}\n",
    "\n",
    "for prompt_index in range(len(all_prompts)):\n",
    "    for model in models:\n",
    "        if model not in model_result_loaded.keys():\n",
    "            model_result_loaded[model] = {}\n",
    "        model_result_loaded[model][f\"prompt_{prompt_index}\"] = False\n",
    "\n",
    "for prompt_index in range(len(all_prompts)):\n",
    "    for model in models:\n",
    "        file_path = f\"{base}/{prompt_index}/{model}.jsonl\"\n",
    "\n",
    "        if model not in results.keys():\n",
    "            results[model] = {}\n",
    "        try:\n",
    "            results[model] = read_jsonl(file_path)\n",
    "            model_result_loaded[model][\n",
    "                f\"prompt_{prompt_index}\"\n",
    "            ] = True  # Set flag to True if loaded successfully\n",
    "        except:\n",
    "            results[model][f\"prompt_{prompt_index}\"] = {i: \"\" for i in range(len(questions_list))}\n",
    "\n",
    "for prompt_index in range(len(all_prompts)):\n",
    "    for model, info in models.items():\n",
    "        # only execute the model if we have not had its results yet\n",
    "        logger.info(f\"Running {model}\")\n",
    "        if not model_result_loaded[model][f\"prompt_{prompt_index}\"]:\n",
    "            executor_count = 30\n",
    "            executor(\n",
    "                executor_count,\n",
    "                all_prompts[prompt_index],\n",
    "                model,\n",
    "                results[model][f\"prompt_{prompt_index}\"],\n",
    "            )\n",
    "\n",
    "for prompt_index in range(len(all_prompts)):\n",
    "    for model in models:\n",
    "        file_path = f\"{base}/{prompt_index}/{model}.jsonl\"\n",
    "        if not model_result_loaded[model][f\"prompt_{prompt_index}\"]:\n",
    "            os.makedirs(os.path.dirname(file_path), exist_ok=True)\n",
    "            with open(file_path, \"w\") as f:\n",
    "                json.dump(results[model][f\"prompt_{prompt_index}\"], f)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e1243dd3",
   "metadata": {},
   "source": [
    "### Evaluation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "aa7fb599",
   "metadata": {},
   "outputs": [],
   "source": [
    "for model in results.keys():\n",
    "    refuse_to_answer_cnt = 0\n",
    "    for key, answer in results[model][0].items():\n",
    "        if answer == None:\n",
    "            answer = 0.5\n",
    "            refuse_to_answer_cnt += 1\n",
    "        results[model][0][key] = answer\n",
    "\n",
    "    results[model][0][\"refuse_to_answer_cnt\"] = refuse_to_answer_cnt\n",
    "\n",
    "for model in results.keys():\n",
    "    results[model] = results[model][0]\n",
    "\n",
    "\n",
    "def brier_score(prediction, answer):\n",
    "    return (prediction - answer) ** 2\n",
    "\n",
    "winner_models = ['gpt_4o', 'mistral_large', 'qwen_1p5_110b']\n",
    "\n",
    "# aggregation\n",
    "agg_results = {}\n",
    "for agg_type in [\"mean\", \"median\", \"trimmed_mean\", \"geometric_mean\", \"geometric_mean_log_odds\"]:\n",
    "    agg_results[agg_type] = {}\n",
    "    for model in winner_models:\n",
    "        for key, answer in results[model].items():\n",
    "            if key not in agg_results[agg_type]:\n",
    "                agg_results[agg_type][key] = [answer]\n",
    "            else:\n",
    "                agg_results[agg_type][key].append(answer)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "d44ce7f0",
   "metadata": {},
   "outputs": [],
   "source": [
    "import math\n",
    "\n",
    "\n",
    "def trimmed_mean(probabilities):\n",
    "    # Sort the list of probabilities\n",
    "\n",
    "    sorted_probs = sorted(probabilities)\n",
    "\n",
    "    # Remove the smallest and largest probabilities\n",
    "    trimmed_probs = sorted_probs[1:-1]\n",
    "\n",
    "    # Calculate the mean of the remaining probabilities\n",
    "    trimmed_mean_value = sum(trimmed_probs) / len(trimmed_probs)\n",
    "\n",
    "    return trimmed_mean_value\n",
    "\n",
    "\n",
    "def geometric_mean(numbers):\n",
    "    if not numbers:\n",
    "        return 0  # Return 0 for an empty list to avoid math domain error\n",
    "    product = 1.0\n",
    "    for number in numbers:\n",
    "        product *= number\n",
    "    return product ** (1 / len(numbers))\n",
    "\n",
    "\n",
    "def geometric_mean_log_odds(probs):\n",
    "    # Convert probabilities to log odds\n",
    "    log_odds = np.log(np.array(probs) / (1 - np.array(probs)))\n",
    "\n",
    "    # Compute the geometric mean of the log odds\n",
    "    mean_log_odds = np.mean(log_odds)\n",
    "\n",
    "    # Convert the mean log odds back to probability\n",
    "    combined_prob = np.exp(mean_log_odds) / (1 + np.exp(mean_log_odds))\n",
    "\n",
    "    return combined_prob"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "3f1453ad",
   "metadata": {},
   "outputs": [],
   "source": [
    "# aggregation\n",
    "for agg_type in [\"mean\", \"median\", \"trimmed_mean\", \"geometric_mean\", \"geometric_mean_log_odds\"]:\n",
    "    for key, answers in agg_results[agg_type].items():\n",
    "        if key != \"refuse_to_answer_cnt\":\n",
    "            if agg_type == \"mean\":\n",
    "                agg_results[agg_type][key] = np.mean(answers)\n",
    "            elif agg_type == \"median\":\n",
    "                agg_results[agg_type][key] = np.median(answers)\n",
    "            elif agg_type == \"trimmed_mean\":\n",
    "                agg_results[agg_type][key] = trimmed_mean(answers)\n",
    "            elif agg_type == \"geometric_mean\":\n",
    "                agg_results[agg_type][key] = geometric_mean(answers)\n",
    "            elif agg_type == \"geometric_mean_log_odds\":\n",
    "                agg_results[agg_type][key] = geometric_mean_log_odds(answers)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "0dfd1e52",
   "metadata": {},
   "outputs": [],
   "source": [
    "brier_scores = pd.DataFrame()\n",
    "\n",
    "results.update(agg_results)\n",
    "\n",
    "for model in results.keys():\n",
    "    brier_scores_model = []\n",
    "    for question_id, prediction in results[model].items():\n",
    "        if question_id != \"refuse_to_answer_cnt\":\n",
    "            brier_score_value = brier_score(\n",
    "                float(prediction), mini_val[int(question_id)][\"resolution\"]\n",
    "            )\n",
    "            brier_scores_model.append(brier_score_value)\n",
    "\n",
    "\n",
    "    avg_brier_score = sum(brier_scores_model) / len(brier_scores_model)\n",
    "    std_error_brier_score = np.std(brier_scores_model) / np.sqrt(len(brier_scores_model))\n",
    "    two_std_error_brier_score = 2 * std_error_brier_score\n",
    "    \n",
    "    brier_scores.at[model, \"Scratchpad\"] = f\"{avg_brier_score:.3f} ({two_std_error_brier_score:.7f})\"\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "54f9d2fd",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "gpt_4o's refuse-to-answer count: 0\n",
      "gpt_4_turbo_0409's refuse-to-answer count: 0\n",
      "llama_3_70b's refuse-to-answer count: 1\n",
      "mistral_large's refuse-to-answer count: 1\n",
      "qwen_1p5_110b's refuse-to-answer count: 0\n",
      "mean's refuse-to-answer count: [0, 1, 0]\n",
      "median's refuse-to-answer count: [0, 1, 0]\n",
      "trimmed_mean's refuse-to-answer count: [0, 1, 0]\n",
      "geometric_mean's refuse-to-answer count: [0, 1, 0]\n",
      "geometric_mean_log_odds's refuse-to-answer count: [0, 1, 0]\n"
     ]
    }
   ],
   "source": [
    "for model in results:\n",
    "    print(f\"{model}'s refuse-to-answer count: {results[model]['refuse_to_answer_cnt']}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "1952a911",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>Scratchpad</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>gpt_4o</th>\n",
       "      <td>0.187 (0.0112725)</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>gpt_4_turbo_0409</th>\n",
       "      <td>0.205 (0.0114514)</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>llama_3_70b</th>\n",
       "      <td>0.219 (0.0092168)</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>mistral_large</th>\n",
       "      <td>0.225 (0.0112260)</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>qwen_1p5_110b</th>\n",
       "      <td>0.210 (0.0103445)</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>mean</th>\n",
       "      <td>0.200 (0.0095341)</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>median</th>\n",
       "      <td>0.201 (0.0102313)</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>trimmed_mean</th>\n",
       "      <td>0.201 (0.0102313)</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>geometric_mean</th>\n",
       "      <td>0.197 (0.0098307)</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>geometric_mean_log_odds</th>\n",
       "      <td>0.199 (0.0097631)</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "                                Scratchpad\n",
       "gpt_4o                   0.187 (0.0112725)\n",
       "gpt_4_turbo_0409         0.205 (0.0114514)\n",
       "llama_3_70b              0.219 (0.0092168)\n",
       "mistral_large            0.225 (0.0112260)\n",
       "qwen_1p5_110b            0.210 (0.0103445)\n",
       "mean                     0.200 (0.0095341)\n",
       "median                   0.201 (0.0102313)\n",
       "trimmed_mean             0.201 (0.0102313)\n",
       "geometric_mean           0.197 (0.0098307)\n",
       "geometric_mean_log_odds  0.199 (0.0097631)"
      ]
     },
     "execution_count": 15,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "brier_scores"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b947eed9",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3.12 (myenv-llm-bench)",
   "language": "python",
   "name": "myenv-llm-bench"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
