{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "e5fa0cfa",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/Users/apple/envs/myenv-llm-bench/lib/python3.12/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
      "  from .autonotebook import tqdm as notebook_tqdm\n"
     ]
    }
   ],
   "source": [
    "import asyncio\n",
    "import logging\n",
    "import sys\n",
    "from concurrent.futures import ThreadPoolExecutor, as_completed\n",
    "from functools import partial\n",
    "import time\n",
    "import os\n",
    "from datetime import datetime, timedelta\n",
    "import json\n",
    "import pickle\n",
    "\n",
    "import random\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import re\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "current_path = os.getcwd()\n",
    "\n",
    "sys.path.append(os.path.join(current_path, \"../..\"))\n",
    "from helpers.llm_prompts import (\n",
    "    REFORMAT_SINGLE_PROMPT,\n",
    "    REFORMAT_PROMPT,\n",
    "    HUMAN_JOINT_PROMPT_1,\n",
    "    HUMAN_JOINT_PROMPT_2,\n",
    "    HUMAN_JOINT_PROMPT_3,\n",
    "    HUMAN_JOINT_PROMPT_4,\n",
    ")\n",
    "\n",
    "from helpers import model_eval\n",
    "\n",
    "\n",
    "logger = logging.getLogger()\n",
    "logger.setLevel(logging.INFO)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "713b906c",
   "metadata": {},
   "outputs": [],
   "source": [
    "human_joint_prompts = [\n",
    "    HUMAN_JOINT_PROMPT_1,\n",
    "    HUMAN_JOINT_PROMPT_2,\n",
    "    HUMAN_JOINT_PROMPT_3,\n",
    "    HUMAN_JOINT_PROMPT_4,\n",
    "]"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b34fa571",
   "metadata": {},
   "source": [
    "### Load data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "297656d9",
   "metadata": {},
   "outputs": [],
   "source": [
    "import json\n",
    "\n",
    "\n",
    "def read_jsonl(file_path):\n",
    "    data = []\n",
    "    with open(file_path, \"r\", encoding=\"utf-8\") as file:\n",
    "        for line in file:\n",
    "            if line.strip():\n",
    "                json_object = json.loads(line)\n",
    "                data.append(json_object)\n",
    "    return data\n",
    "\n",
    "\n",
    "file_path = \"2024-05-03-llm.jsonl\"\n",
    "questions = read_jsonl(file_path)\n",
    "\n",
    "single_non_acled_questions = [\n",
    "    q for q in questions if q[\"combination_of\"] == \"N/A\" and q[\"source\"] != \"acled\"\n",
    "]\n",
    "single_acled_questions = [\n",
    "    q for q in questions if q[\"combination_of\"] == \"N/A\" and q[\"source\"] == \"acled\"\n",
    "]\n",
    "combo_questions = [q for q in questions if q[\"combination_of\"] != \"N/A\" and q[\"source\"] == \"acled\"]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "d5ac7a99",
   "metadata": {},
   "outputs": [],
   "source": [
    "combo_questions_unrolled = []\n",
    "\n",
    "for q in combo_questions:\n",
    "    for i in range(4):\n",
    "        new_q = q.copy()\n",
    "        new_q[\"combo_index\"] = i\n",
    "\n",
    "        combo_questions_unrolled.append(new_q)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "f4c7709d",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(339, 162, 644)"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "len(single_non_acled_questions), len(single_acled_questions), len(combo_questions_unrolled)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "181454f5",
   "metadata": {},
   "source": [
    "### Useful Functions"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "5cf5c534",
   "metadata": {},
   "outputs": [],
   "source": [
    "def capitalize_substrings(model_name):\n",
    "    model_name = model_name.replace('gpt', 'GPT') if 'gpt' in model_name else model_name\n",
    "    substrings = model_name.split('-')\n",
    "    capitalized_substrings = [\n",
    "        substr[0].upper() + substr[1:] if substr and not substr[0].isdigit() else substr \n",
    "        for substr in substrings\n",
    "    ]\n",
    "    return '-'.join(capitalized_substrings)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3bae100d",
   "metadata": {},
   "source": [
    "### Wisdom of Crowd: Scratchpad + Retrieval"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "3e0ab56a",
   "metadata": {},
   "outputs": [],
   "source": [
    "from helpers.llm_crowd_prompts import (\n",
    "    SCRATCH_PAD_WITH_SUMMARIES_MARKET_PROMPT_1,\n",
    "    SCRATCH_PAD_WITH_SUMMARIES_NON_MARKET_PROMPT_1,\n",
    "    SCRATCH_PAD_WITH_SUMMARIES_JOINT_QUESTION_PROMPT_1,\n",
    "    SCRATCH_PAD_WITH_SUMMARIES_MARKET_PROMPT_2,\n",
    "    SCRATCH_PAD_WITH_SUMMARIES_NON_MARKET_PROMPT_2,\n",
    "    SCRATCH_PAD_WITH_SUMMARIES_JOINT_QUESTION_PROMPT_2,\n",
    "    SUPERFORECASTER_MARKET_PROMPT_3,\n",
    "    SUPERFORECASTER_NON_MARKET_PROMPT_3,\n",
    "    SUPERFORECASTER_JOINT_QUESTION_PROMPT_3,\n",
    "    SUPERFORECASTER_MARKET_PROMPT_4,\n",
    "    SUPERFORECASTER_NON_MARKET_PROMPT_4,\n",
    "    SUPERFORECASTER_JOINT_QUESTION_PROMPT_4,\n",
    "    SUPERFORECASTER_MARKET_PROMPT_5,\n",
    "    SUPERFORECASTER_NON_MARKET_PROMPT_5,\n",
    "    SUPERFORECASTER_JOINT_QUESTION_PROMPT_5,\n",
    ")\n",
    "\n",
    "all_llm_crowd_prompts = {\n",
    "    1: (\n",
    "        SCRATCH_PAD_WITH_SUMMARIES_MARKET_PROMPT_1,\n",
    "        SCRATCH_PAD_WITH_SUMMARIES_NON_MARKET_PROMPT_1,\n",
    "        SCRATCH_PAD_WITH_SUMMARIES_JOINT_QUESTION_PROMPT_1,\n",
    "    ),\n",
    "    2: (\n",
    "        SCRATCH_PAD_WITH_SUMMARIES_MARKET_PROMPT_2,\n",
    "        SCRATCH_PAD_WITH_SUMMARIES_NON_MARKET_PROMPT_2,\n",
    "        SCRATCH_PAD_WITH_SUMMARIES_JOINT_QUESTION_PROMPT_2,\n",
    "    ),\n",
    "    3: (\n",
    "        SUPERFORECASTER_MARKET_PROMPT_3,\n",
    "        SUPERFORECASTER_NON_MARKET_PROMPT_3,\n",
    "        SUPERFORECASTER_JOINT_QUESTION_PROMPT_3,\n",
    "    ),\n",
    "    4: (\n",
    "        SUPERFORECASTER_MARKET_PROMPT_4,\n",
    "        SUPERFORECASTER_NON_MARKET_PROMPT_4,\n",
    "        SUPERFORECASTER_JOINT_QUESTION_PROMPT_4,\n",
    "    ),\n",
    "    5: (\n",
    "        SUPERFORECASTER_MARKET_PROMPT_5,\n",
    "        SUPERFORECASTER_NON_MARKET_PROMPT_5,\n",
    "        SUPERFORECASTER_JOINT_QUESTION_PROMPT_5,\n",
    "    ),\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "5df27efd",
   "metadata": {},
   "outputs": [],
   "source": [
    "models = {\n",
    "    \"gpt_4o\": {\"source\": \"OAI\", \"full_name\": \"gpt-4o\"},\n",
    "    \"mistral_large\": {\n",
    "        \"source\": \"MISTRAL\",\n",
    "        \"full_name\": \"mistral-large-latest\",\n",
    "    },\n",
    "    \"qwen_1p5_110b\": {\n",
    "        \"source\": \"TOGETHER\",\n",
    "        \"full_name\": \"Qwen/Qwen1.5-110B-Chat\",\n",
    "    },\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "7577a6c4",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Mapping retrieved summaries back into questions\n",
    "for question_source in [single_non_acled_questions, single_acled_questions]:\n",
    "    for q in question_source:\n",
    "        reformatted_id = q[\"id\"].replace(\"/\", \"_\")\n",
    "        filename = f\"news/{reformatted_id}.pickle\"\n",
    "        with open(filename, \"rb\") as file:\n",
    "            retrieved_info = pickle.load(file)\n",
    "        q[\"news\"] = retrieved_info\n",
    "\n",
    "for q in combo_questions_unrolled:\n",
    "    for sub_q in q[\"combination_of\"]:\n",
    "        reformatted_id = sub_q[\"id\"].replace(\"/\", \"_\")\n",
    "        filename = f\"news/{reformatted_id}.pickle\"\n",
    "        with open(filename, \"rb\") as file:\n",
    "            retrieved_info = pickle.load(file)\n",
    "        sub_q[\"news\"] = retrieved_info"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "059d0c1a",
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_all_retrieved_info(all_retrieved_info):\n",
    "    retrieved_info = \"\"\n",
    "    for summary in all_retrieved_info:\n",
    "        retrieved_info += f\"Article title: {summary['title']}\" + \"\\n\"\n",
    "        retrieved_info += f\"Summary: {summary['summary']}\" + \"\\n\\n\"\n",
    "    return retrieved_info\n",
    "\n",
    "\n",
    "def worker(\n",
    "    index,\n",
    "    model_name,\n",
    "    save_dict,\n",
    "    questions_to_eval,\n",
    "    article_random_seed,\n",
    "    SCRATCH_PAD_WITH_SUMMARIES_MARKET_PROMPT,\n",
    "    SCRATCH_PAD_WITH_SUMMARIES_NON_MARKET_PROMPT,\n",
    "    SCRATCH_PAD_WITH_SUMMARIES_JOINT_QUESTION_PROMPT,\n",
    "):\n",
    "    if save_dict[index] != \"\":\n",
    "        return\n",
    "\n",
    "    logger.info(f\"Starting {model_name} - {index}\")\n",
    "\n",
    "    random.seed(article_random_seed)\n",
    "\n",
    "    if questions_to_eval[index][\"source\"] != \"acled\":\n",
    "        selected_articles = random.sample(\n",
    "            questions_to_eval[index][\"news\"],\n",
    "            min(10, len(questions_to_eval[index][\"news\"])),\n",
    "        )\n",
    "        prompt = SCRATCH_PAD_WITH_SUMMARIES_MARKET_PROMPT.format(\n",
    "            question=questions_to_eval[index][\"question\"],\n",
    "            background=questions_to_eval[index][\"background\"]\n",
    "            + \"\\n\"\n",
    "            + questions_to_eval[index][\"model_info_resolution_criteria\"],\n",
    "            resolution_criteria=questions_to_eval[index][\"resolution_criteria\"],\n",
    "            close_date=questions_to_eval[index][\"model_info_close_datetime\"],\n",
    "            retrieved_info=get_all_retrieved_info(selected_articles),\n",
    "        )\n",
    "\n",
    "        response = model_eval.get_response_from_model(\n",
    "            prompt=prompt,\n",
    "            max_tokens=2000,\n",
    "            model_name=models[model_name][\"full_name\"],\n",
    "            temperature=0,\n",
    "            wait_time=30,\n",
    "        )\n",
    "\n",
    "        save_dict[index] = (model_eval.reformat_answers(response=response, single=True), response)\n",
    "\n",
    "    else:\n",
    "        all_resolution_dates = []\n",
    "        for horizon in questions_to_eval[index][\"forecast_horizons\"]:\n",
    "            resolution_date = datetime.fromisoformat(\n",
    "                questions_to_eval[index][\"freeze_datetime\"]\n",
    "            ) + timedelta(days=7 + horizon)\n",
    "            resolution_date = resolution_date.isoformat()\n",
    "            all_resolution_dates.append(resolution_date)\n",
    "\n",
    "        if questions_to_eval[index][\"combination_of\"] == \"N/A\":\n",
    "            selected_articles = random.sample(\n",
    "                questions_to_eval[index][\"news\"],\n",
    "                min(10, len(questions_to_eval[index][\"news\"])),\n",
    "            )\n",
    "            prompt = SCRATCH_PAD_WITH_SUMMARIES_NON_MARKET_PROMPT.format(\n",
    "                question=questions_to_eval[index][\"question\"],\n",
    "                background=questions_to_eval[index][\"background\"]\n",
    "                + \"\\n\"\n",
    "                + questions_to_eval[index][\"model_info_resolution_criteria\"],\n",
    "                resolution_criteria=questions_to_eval[index][\"resolution_criteria\"],\n",
    "                freeze_datetime=questions_to_eval[index][\"freeze_datetime\"],\n",
    "                freeze_datetime_value=questions_to_eval[index][\"freeze_datetime_value\"],\n",
    "                freeze_datetime_value_explanation=questions_to_eval[index][\n",
    "                    \"freeze_datetime_value_explanation\"\n",
    "                ],\n",
    "                retrieved_info=get_all_retrieved_info(selected_articles),\n",
    "                list_of_resolution_dates=all_resolution_dates,\n",
    "            )\n",
    "            response = model_eval.get_response_from_model(\n",
    "                prompt=prompt,\n",
    "                max_tokens=2000,\n",
    "                model_name=models[model_name][\"full_name\"],\n",
    "                temperature=0,\n",
    "                wait_time=30,\n",
    "            )\n",
    "            save_dict[index] = (\n",
    "                model_eval.reformat_answers(\n",
    "                    response=response, prompt=prompt, question=questions_to_eval[index]\n",
    "                ),\n",
    "                response,\n",
    "            )\n",
    "        else:\n",
    "            selected_articles_1 = random.sample(\n",
    "                questions_to_eval[index][\"combination_of\"][0][\"news\"],\n",
    "                min(10, len(questions_to_eval[index][\"combination_of\"][0][\"news\"])),\n",
    "            )\n",
    "            selected_articles_2 = random.sample(\n",
    "                questions_to_eval[index][\"combination_of\"][1][\"news\"],\n",
    "                min(10, len(questions_to_eval[index][\"combination_of\"][1][\"news\"])),\n",
    "            )\n",
    "\n",
    "            prompt = SCRATCH_PAD_WITH_SUMMARIES_JOINT_QUESTION_PROMPT.format(\n",
    "                human_prompt=human_joint_prompts[questions_to_eval[index][\"combo_index\"]],\n",
    "                question_1=questions_to_eval[index][\"combination_of\"][0][\"question\"],\n",
    "                question_2=questions_to_eval[index][\"combination_of\"][1][\"question\"],\n",
    "                background_1=questions_to_eval[index][\"combination_of\"][0][\"background\"]\n",
    "                + \"\\n\"\n",
    "                + questions_to_eval[index][\"combination_of\"][0][\"model_info_resolution_criteria\"],\n",
    "                background_2=questions_to_eval[index][\"combination_of\"][1][\"background\"]\n",
    "                + \"\\n\"\n",
    "                + questions_to_eval[index][\"combination_of\"][1][\"model_info_resolution_criteria\"],\n",
    "                resolution_criteria_1=questions_to_eval[index][\"combination_of\"][0][\n",
    "                    \"resolution_criteria\"\n",
    "                ],\n",
    "                resolution_criteria_2=questions_to_eval[index][\"combination_of\"][1][\n",
    "                    \"resolution_criteria\"\n",
    "                ],\n",
    "                freeze_datetime_1=questions_to_eval[index][\"combination_of\"][0][\"freeze_datetime\"],\n",
    "                freeze_datetime_2=questions_to_eval[index][\"combination_of\"][1][\"freeze_datetime\"],\n",
    "                freeze_datetime_value_1=questions_to_eval[index][\"combination_of\"][0][\n",
    "                    \"freeze_datetime_value\"\n",
    "                ],\n",
    "                freeze_datetime_value_2=questions_to_eval[index][\"combination_of\"][1][\n",
    "                    \"freeze_datetime_value\"\n",
    "                ],\n",
    "                freeze_datetime_value_explanation_1=questions_to_eval[index][\"combination_of\"][\n",
    "                    0\n",
    "                ][\"freeze_datetime_value_explanation\"],\n",
    "                freeze_datetime_value_explanation_2=questions_to_eval[index][\"combination_of\"][\n",
    "                    1\n",
    "                ][\"freeze_datetime_value_explanation\"],\n",
    "                retrieved_info_1=get_all_retrieved_info(selected_articles_1),\n",
    "                retrieved_info_2=get_all_retrieved_info(selected_articles_2),\n",
    "                list_of_resolution_dates=all_resolution_dates,\n",
    "            )\n",
    "\n",
    "            response = model_eval.get_response_from_model(\n",
    "                prompt=prompt,\n",
    "                max_tokens=2000,\n",
    "                model_name=models[model_name][\"full_name\"],\n",
    "                temperature=0,\n",
    "                wait_time=30,\n",
    "            )\n",
    "\n",
    "            save_dict[index] = (\n",
    "                model_eval.reformat_answers(\n",
    "                    response=response, prompt=prompt, question=questions_to_eval[index]\n",
    "                ),\n",
    "                response,\n",
    "            )\n",
    "\n",
    "    logger.info(f\"Answer: {save_dict[index][0]}\")\n",
    "\n",
    "    return None\n",
    "\n",
    "\n",
    "def executor(\n",
    "    max_workers,\n",
    "    model_name,\n",
    "    save_dict,\n",
    "    questions_to_eval,\n",
    "    article_random_seed,\n",
    "    SCRATCH_PAD_WITH_SUMMARIES_MARKET_PROMPT,\n",
    "    SCRATCH_PAD_WITH_SUMMARIES_NON_MARKET_PROMPT,\n",
    "    SCRATCH_PAD_WITH_SUMMARIES_JOINT_QUESTION_PROMPT,\n",
    "):\n",
    "    with ThreadPoolExecutor(max_workers=max_workers) as executor:\n",
    "        worker_with_args = partial(\n",
    "            worker,\n",
    "            model_name=model_name,\n",
    "            save_dict=save_dict,\n",
    "            questions_to_eval=questions_to_eval,\n",
    "            article_random_seed=article_random_seed,\n",
    "            SCRATCH_PAD_WITH_SUMMARIES_MARKET_PROMPT=SCRATCH_PAD_WITH_SUMMARIES_MARKET_PROMPT,\n",
    "            SCRATCH_PAD_WITH_SUMMARIES_NON_MARKET_PROMPT=SCRATCH_PAD_WITH_SUMMARIES_NON_MARKET_PROMPT,\n",
    "            SCRATCH_PAD_WITH_SUMMARIES_JOINT_QUESTION_PROMPT=SCRATCH_PAD_WITH_SUMMARIES_JOINT_QUESTION_PROMPT,\n",
    "        )\n",
    "        return list(executor.map(worker_with_args, range(len(questions_to_eval))))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "792fed91",
   "metadata": {},
   "outputs": [],
   "source": [
    "results = {}\n",
    "model_result_loaded = {}\n",
    "models_to_test = list(models.keys())[:]\n",
    "prompt_type = \"llm_crowd\"\n",
    "\n",
    "for question in [single_acled_questions, combo_questions_unrolled, single_non_acled_questions]:\n",
    "    for article_random_seed in range(1, 4):\n",
    "        for llm_crowd_prompt_index in range(3, 6):\n",
    "            questions_to_eval = question\n",
    "            (\n",
    "                SCRATCH_PAD_WITH_SUMMARIES_MARKET_PROMPT,\n",
    "                SCRATCH_PAD_WITH_SUMMARIES_NON_MARKET_PROMPT,\n",
    "                SCRATCH_PAD_WITH_SUMMARIES_JOINT_QUESTION_PROMPT,\n",
    "            ) = all_llm_crowd_prompts[llm_crowd_prompt_index]\n",
    "\n",
    "            if question[0][\"source\"] != \"acled\":\n",
    "                q_type = \"non_acled\"\n",
    "            elif question[0][\"source\"] == \"acled\" and question[0][\"combination_of\"] == \"N/A\":\n",
    "                q_type = \"acled\"\n",
    "            else:\n",
    "                q_type = \"combo\"\n",
    "\n",
    "            test_type = f\"{prompt_type}/{q_type}/article_set_{article_random_seed}/prompt_{llm_crowd_prompt_index}\"\n",
    "\n",
    "            for model in models_to_test:\n",
    "                if model not in model_result_loaded.keys():\n",
    "                    model_result_loaded[model] = {}\n",
    "                model_result_loaded[model] = False\n",
    "\n",
    "            for model in models_to_test:\n",
    "                file_path = f\"{test_type}/{model}.jsonl\"\n",
    "                if model not in results.keys():\n",
    "                    results[model] = {}\n",
    "                try:\n",
    "                    results[model] = read_jsonl(file_path)\n",
    "                    model_result_loaded[model] = True  # Set flag to True if loaded successfully\n",
    "                except:\n",
    "                    results[model] = {i: \"\" for i in range(len(questions_to_eval))}\n",
    "\n",
    "            for model in models_to_test:\n",
    "                file_path = f\"{test_type}/{model}.jsonl\"\n",
    "                if not model_result_loaded[model]:\n",
    "                    executor_count = 50\n",
    "                    if models[model][\"source\"] == \"ANTHROPIC\":\n",
    "                        executor_count = 30\n",
    "\n",
    "                    executor(\n",
    "                        executor_count,\n",
    "                        model,\n",
    "                        results[model],\n",
    "                        questions_to_eval,\n",
    "                        article_random_seed,\n",
    "                        SCRATCH_PAD_WITH_SUMMARIES_MARKET_PROMPT,\n",
    "                        SCRATCH_PAD_WITH_SUMMARIES_NON_MARKET_PROMPT,\n",
    "                        SCRATCH_PAD_WITH_SUMMARIES_JOINT_QUESTION_PROMPT,\n",
    "                    )\n",
    "\n",
    "                    current_model_forecasts = []\n",
    "                    for index in range(len(questions_to_eval)):\n",
    "                        if questions_to_eval[index][\"source\"] == \"acled\":\n",
    "                            for forecast, horizon in zip(\n",
    "                                results[model][index][0],\n",
    "                                questions_to_eval[index][\"forecast_horizons\"],\n",
    "                            ):\n",
    "                                current_forecast = {\n",
    "                                    \"id\": questions_to_eval[index][\"id\"],\n",
    "                                    \"source\": questions_to_eval[index][\"source\"],\n",
    "                                    \"forecast\": forecast,\n",
    "                                    \"horizon\": horizon,\n",
    "                                    \"reasoning\": results[model][index][1],\n",
    "                                }\n",
    "\n",
    "                                if questions_to_eval[index][\"combination_of\"] != \"N/A\":\n",
    "                                    combo_index = questions_to_eval[index][\"combo_index\"]\n",
    "                                    if combo_index == 0:\n",
    "                                        current_forecast[\"direction\"] = [1, 1]\n",
    "                                    elif combo_index == 1:\n",
    "                                        current_forecast[\"direction\"] = [1, -1]\n",
    "                                    elif combo_index == 2:\n",
    "                                        current_forecast[\"direction\"] = [-1, 1]\n",
    "                                    else:\n",
    "                                        current_forecast[\"direction\"] = [-1, -1]\n",
    "\n",
    "                                current_model_forecasts.append(current_forecast)\n",
    "\n",
    "                        else:\n",
    "                            current_forecast = {\n",
    "                                \"id\": questions_to_eval[index][\"id\"],\n",
    "                                \"source\": questions_to_eval[index][\"source\"],\n",
    "                                \"forecast\": results[model][index][0],\n",
    "                                \"reasoning\": results[model][index][1],\n",
    "                            }\n",
    "                            current_model_forecasts.append(current_forecast)\n",
    "\n",
    "                    os.makedirs(os.path.dirname(file_path), exist_ok=True)\n",
    "                    with open(file_path, \"w\") as file:\n",
    "                        for entry in current_model_forecasts:\n",
    "                            json_line = json.dumps(entry)\n",
    "                            file.write(json_line + \"\\n\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "2cb03b99",
   "metadata": {},
   "outputs": [],
   "source": [
    "prompt_type = \"llm_crowd\"\n",
    "\n",
    "for model in models_to_test:\n",
    "    for article_random_seed in range(1, 4):\n",
    "        for llm_crowd_prompt_index in range(3, 6):\n",
    "            detailed_index = f\"/article_set_{article_random_seed}/prompt_{llm_crowd_prompt_index}\"\n",
    "            current_model_forecasts = []\n",
    "            for test_type in [\n",
    "                f\"{prompt_type}/non_acled/{detailed_index}\",\n",
    "                f\"{prompt_type}/acled/{detailed_index}\",\n",
    "                f\"{prompt_type}/combo/{detailed_index}\",\n",
    "            ]:\n",
    "                file_path = f\"{test_type}/{model}.jsonl\"\n",
    "                questions = read_jsonl(file_path)\n",
    "                current_model_forecasts.extend(questions)\n",
    "\n",
    "            final_file_name = f\"{prompt_type}/final/{detailed_index}/{model}\"\n",
    "            os.makedirs(os.path.dirname(final_file_name), exist_ok=True)\n",
    "            with open(final_file_name, \"w\") as file:\n",
    "                for entry in current_model_forecasts:\n",
    "                    json_line = json.dumps(entry)\n",
    "                    file.write(json_line + \"\\n\")\n",
    "\n",
    "for model in models_to_test:\n",
    "    for article_random_seed in range(1, 4):\n",
    "        for llm_crowd_prompt_index in range(3, 6):\n",
    "            detailed_index = f\"/article_set_{article_random_seed}/prompt_{llm_crowd_prompt_index}\"\n",
    "            file_path = f\"{prompt_type}/final/{detailed_index}/{model}\"\n",
    "            questions = read_jsonl(file_path)\n",
    "            if \"gpt\" in model:\n",
    "                org = \"OpenAI\"\n",
    "            elif \"llama\" in model:\n",
    "                org = \"Meta\"\n",
    "            elif \"mistral\" in model:\n",
    "                org = \"Mistral AI\"\n",
    "            elif \"claude\" in model:\n",
    "                org = \"Anthropic\"\n",
    "            elif \"qwen\" in model:\n",
    "                org = \"Qwen\"\n",
    "\n",
    "            directory = f\"{prompt_type}/final_submit\"\n",
    "            os.makedirs(directory, exist_ok=True)\n",
    "\n",
    "            new_file_name = f\"{directory}/2024-05-03.{org}.{model}_{prompt_type}.article_set_{article_random_seed}.prompt_{llm_crowd_prompt_index}.json\"\n",
    "            \n",
    "            model_name = models[model]['full_name'] if '/' not in models[model]['full_name'] else models[model]['full_name'].split('/')[1]\n",
    "            model_name = capitalize_substrings(model_name)\n",
    "            \n",
    "            forecast_file = {\n",
    "                \"organization\": org,\n",
    "                \"model\": f\"{model_name} (article set {article_random_seed}, superforecaster prompt {llm_crowd_prompt_index})\",\n",
    "                \"question_set\": \"2024-05-03-llm.jsonl\",\n",
    "                \"forecast_date\": \"2024-05-03\",\n",
    "                \"forecasts\": questions,\n",
    "            }\n",
    "\n",
    "            with open(new_file_name, \"w\") as f:\n",
    "                json.dump(forecast_file, f, indent=4)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5f3491d2",
   "metadata": {},
   "source": [
    "### Generate the final crowd prediction files\n",
    "- One with median \n",
    "- One with geometric mean log odds"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "d95f3b35",
   "metadata": {},
   "outputs": [],
   "source": [
    "def geometric_mean_log_odds(probs, epsilon=1e-10):\n",
    "    # Ensure probabilities are within (0, 1) to avoid log(0) issues\n",
    "    probs = np.clip(probs, epsilon, 1 - epsilon)\n",
    "\n",
    "    # Convert probabilities to log odds\n",
    "    log_odds = np.log(probs / (1 - probs))\n",
    "\n",
    "    # Compute the geometric mean of the log odds\n",
    "    mean_log_odds = np.mean(log_odds)\n",
    "\n",
    "    # Convert the mean log odds back to probability\n",
    "    combined_prob = np.exp(mean_log_odds) / (1 + np.exp(mean_log_odds))\n",
    "\n",
    "    return combined_prob\n",
    "\n",
    "def geometric_mean(numbers):\n",
    "    if not numbers:\n",
    "        return 0  # Return 0 for an empty list to avoid math domain error\n",
    "    product = 1.0\n",
    "    for number in numbers:\n",
    "        product *= number\n",
    "    return product ** (1 / len(numbers))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "2c8d0faf",
   "metadata": {},
   "outputs": [],
   "source": [
    "median_crowd = {\n",
    "    \"organization\": \"CROWD\",\n",
    "    \"model\": \"Median of (GPT-4o + Mistral-Large + Qwen-1.5-110B) using 3 superforecaster prompts and 3 random sets of articles\",\n",
    "    \"question_set\": \"2024-05-03-llm.jsonl\",\n",
    "    \"forecast_date\": \"2024-05-03\",\n",
    "    \"forecasts\": None,\n",
    "}\n",
    "\n",
    "geometric_mean_log_odds_crowd = {\n",
    "    \"organization\": \"CROWD\",\n",
    "    \"model\": \"Geometric mean of log odds of (GPT-4o + Mistral-Large + Qwen-1.5-110B) using 3 superforecaster prompts and 3 random sets of articles\",\n",
    "    \"question_set\": \"2024-05-03-llm.jsonl\",\n",
    "    \"forecast_date\": \"2024-05-03\",\n",
    "    \"forecasts\": None,\n",
    "}\n",
    "\n",
    "geometric_mean_crowd = {\n",
    "    \"organization\": \"CROWD\",\n",
    "    \"model\": \"Geometric mean of (GPT-4o + Mistral-Large + Qwen-1.5-110B) using 3 superforecaster prompts and 3 random sets of articles\",\n",
    "    \"question_set\": \"2024-05-03-llm.jsonl\",\n",
    "    \"forecast_date\": \"2024-05-03\",\n",
    "    \"forecasts\": None,\n",
    "}\n",
    "\n",
    "forecasts = []\n",
    "\n",
    "for model in models_to_test:\n",
    "    for article_random_seed in range(1, 4):\n",
    "        for llm_crowd_prompt_index in range(3, 6):\n",
    "            if \"gpt\" in model:\n",
    "                org = \"OpenAI\"\n",
    "            elif \"llama\" in model:\n",
    "                org = \"Meta\"\n",
    "            elif \"mistral\" in model:\n",
    "                org = \"Mistral AI\"\n",
    "            elif \"claude\" in model:\n",
    "                org = \"Anthropic\"\n",
    "            elif \"qwen\" in model:\n",
    "                org = \"Qwen\"\n",
    "\n",
    "            file_path = f\"{directory}/2024-05-03.{org}.{model}_{prompt_type}.article_set_{article_random_seed}.prompt_{llm_crowd_prompt_index}.json\"\n",
    "\n",
    "            with open(file_path, \"r\") as file:\n",
    "                questions = json.load(file)\n",
    "            df = pd.DataFrame(questions[\"forecasts\"])\n",
    "            df[\"forecast\"] = df[\"forecast\"].fillna(0.5)\n",
    "            forecasts.append(pd.DataFrame(df))\n",
    "\n",
    "\n",
    "# Concatenate all DataFrames into one\n",
    "combined_df = pd.concat(forecasts)\n",
    "\n",
    "combined_df[\"horizon\"] = combined_df[\"horizon\"].fillna(\"NaN\")\n",
    "combined_df[\"direction\"] = combined_df[\"direction\"].apply(\n",
    "    lambda x: tuple(x) if isinstance(x, list) else (\"NaN\", \"NaN\") if pd.isna(x) else x\n",
    ")\n",
    "combined_df[\"id\"] = combined_df[\"id\"].apply(lambda x: tuple(x) if isinstance(x, list) else x)\n",
    "\n",
    "# Group by id, source, reasoning, horizon, direction and aggregate forecasts into a list\n",
    "result_df = combined_df.groupby([\"id\", \"source\", \"horizon\", \"direction\"], as_index=False).agg(\n",
    "    {\"forecast\": list}\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "69a5fa3f",
   "metadata": {},
   "outputs": [],
   "source": [
    "result_df[\"median_forecast\"] = result_df[\"forecast\"].apply(np.median)\n",
    "result_df[\"geometric_mean_log_odds_forecast\"] = result_df[\"forecast\"].apply(geometric_mean_log_odds)\n",
    "result_df[\"geometric_mean_forecast\"] = result_df[\"forecast\"].apply(geometric_mean)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "dafe254b",
   "metadata": {},
   "outputs": [],
   "source": [
    "prompt_type = \"llm_crowd\"\n",
    "fields = [\"id\", \"source\", \"horizon\", \"direction\"]\n",
    "\n",
    "for agg_type in [\"median_forecast\", \"geometric_mean_log_odds_forecast\", 'geometric_mean_forecast']:\n",
    "    # convert back into required format\n",
    "    df_with_wanted_fields = result_df[fields + [agg_type]].copy()\n",
    "    df_with_wanted_fields.rename(columns={agg_type: \"forecast\"}, inplace=True)\n",
    "    df_with_wanted_fields.loc[:, \"horizon\"] = df_with_wanted_fields[\"horizon\"].map(\n",
    "        lambda x: None if x == \"NaN\" else x\n",
    "    )\n",
    "    df_with_wanted_fields.loc[:, \"direction\"] = df_with_wanted_fields[\"direction\"].apply(\n",
    "        lambda x: list(x) if isinstance(x, tuple) and x != (\"NaN\", \"NaN\") else None\n",
    "    )\n",
    "    df_with_wanted_fields.loc[:, \"id\"] = df_with_wanted_fields[\"id\"].apply(\n",
    "        lambda x: list(x) if isinstance(x, tuple) else x\n",
    "    )\n",
    "    list_of_forecasts = df_with_wanted_fields.to_dict(orient=\"records\")\n",
    "    if agg_type == \"median_forecast\":\n",
    "        median_crowd[\"forecasts\"] = list_of_forecasts\n",
    "    else:\n",
    "        geometric_mean_log_odds_crowd[\"forecasts\"] = list_of_forecasts\n",
    "\n",
    "    # save as json\n",
    "    directory = f\"{prompt_type}/final_submit\"\n",
    "    os.makedirs(directory, exist_ok=True)\n",
    "    new_file_name = f\"{directory}/2024-05-03.CROWD.{agg_type}.json\"\n",
    "\n",
    "    if agg_type == \"median_forecast\":\n",
    "        median_crowd[\"forecasts\"] = list_of_forecasts\n",
    "        with open(new_file_name, \"w\") as f:\n",
    "            json.dump(median_crowd, f, indent=4)\n",
    "    elif agg_type == \"geometric_mean_log_odds_forecast\":\n",
    "        geometric_mean_log_odds_crowd[\"forecasts\"] = list_of_forecasts\n",
    "        with open(new_file_name, \"w\") as f:\n",
    "            json.dump(geometric_mean_log_odds_crowd, f, indent=4)\n",
    "    else:\n",
    "        geometric_mean_crowd[\"forecasts\"] = list_of_forecasts\n",
    "        with open(new_file_name, \"w\") as f:\n",
    "            json.dump(geometric_mean_crowd, f, indent=4)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "28fefb18",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3.12 (myenv-llm-bench)",
   "language": "python",
   "name": "myenv-llm-bench"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
