{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "c77c7dcc",
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/Users/apple/envs/myenv-llm-bench/lib/python3.12/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
      "  from .autonotebook import tqdm as notebook_tqdm\n"
     ]
    }
   ],
   "source": [
    "import asyncio\n",
    "import logging\n",
    "import sys\n",
    "from concurrent.futures import ThreadPoolExecutor, as_completed\n",
    "from functools import partial\n",
    "import time\n",
    "import os\n",
    "from datetime import datetime, timedelta\n",
    "import json\n",
    "import pickle\n",
    "\n",
    "import random\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import re\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "current_path = os.getcwd()\n",
    "\n",
    "sys.path.append(os.path.join(current_path, \"../..\"))\n",
    "from helpers.llm_prompts import (\n",
    "    ZERO_SHOT_MARKET_PROMPT,\n",
    "    ZERO_SHOT_NON_MARKET_PROMPT,\n",
    "    ZERO_SHOT_JOINT_QUESTION_PROMPT,\n",
    "    SCRATCH_PAD_MARKET_PROMPT,\n",
    "    SCRATCH_PAD_NON_MARKET_PROMPT,\n",
    "    SCRATCH_PAD_JOINT_QUESTION_PROMPT,\n",
    "    SCRATCH_PAD_WITH_SUMMARIES_MARKET_PROMPT,\n",
    "    SCRATCH_PAD_WITH_SUMMARIES_NON_MARKET_PROMPT,\n",
    "    SCRATCH_PAD_WITH_SUMMARIES_JOINT_QUESTION_PROMPT,\n",
    "    REFORMAT_SINGLE_PROMPT,\n",
    "    REFORMAT_PROMPT,\n",
    "    HUMAN_JOINT_PROMPT_1,\n",
    "    HUMAN_JOINT_PROMPT_2,\n",
    "    HUMAN_JOINT_PROMPT_3,\n",
    "    HUMAN_JOINT_PROMPT_4,\n",
    ")\n",
    "\n",
    "from helpers import model_eval\n",
    "\n",
    "\n",
    "logger = logging.getLogger()\n",
    "logger.setLevel(logging.INFO)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "39c71ded",
   "metadata": {},
   "outputs": [],
   "source": [
    "human_joint_prompts = [\n",
    "    HUMAN_JOINT_PROMPT_1,\n",
    "    HUMAN_JOINT_PROMPT_2,\n",
    "    HUMAN_JOINT_PROMPT_3,\n",
    "    HUMAN_JOINT_PROMPT_4,\n",
    "]\n",
    "\n",
    "models = {\n",
    "    \"gpt_3p5_turbo_0125\": {\"source\": \"OAI\", \"full_name\": \"gpt-3.5-turbo-0125\"},\n",
    "    \"gpt_4\": {\"source\": \"OAI\", \"full_name\": \"gpt-4\"},\n",
    "    \"gpt_4_turbo_0409\": {\"source\": \"OAI\", \"full_name\": \"gpt-4-turbo-2024-04-09\"},\n",
    "    \"gpt_4_1106_preview\": {\"source\": \"OAI\", \"full_name\": \"gpt-4-1106-preview\"},\n",
    "    \"gpt_4_0125_preview\": {\"source\": \"OAI\", \"full_name\": \"gpt-4-0125-preview\"},\n",
    "    \"gpt_4o\": {\"source\": \"OAI\", \"full_name\": \"gpt-4o\"},\n",
    "    \"llama_2_70b\": {\n",
    "        \"source\": \"TOGETHER\",\n",
    "        \"full_name\": \"meta-llama/Llama-2-70b-chat-hf\",\n",
    "    },\n",
    "    \"llama_3_8b\": {\n",
    "        \"source\": \"TOGETHER\",\n",
    "        \"full_name\": \"meta-llama/Llama-3-8b-chat-hf\",\n",
    "    },\n",
    "    \"llama_3_70b\": {\n",
    "        \"source\": \"TOGETHER\",\n",
    "        \"full_name\": \"meta-llama/Llama-3-70b-chat-hf\",\n",
    "    },\n",
    "    \"mistral_8x7b_instruct\": {\n",
    "        \"source\": \"TOGETHER\",\n",
    "        \"full_name\": \"mistralai/Mixtral-8x7B-Instruct-v0.1\",\n",
    "    },\n",
    "    \"mistral_8x22b_instruct\": {\n",
    "        \"source\": \"TOGETHER\",\n",
    "        \"full_name\": \"mistralai/Mixtral-8x22B-Instruct-v0.1\",\n",
    "    },\n",
    "    \"mistral_large\": {\n",
    "        \"source\": \"MISTRAL\",\n",
    "        \"full_name\": \"mistral-large-latest\",\n",
    "    },\n",
    "    \"qwen_1p5_110b\": {\n",
    "        \"source\": \"TOGETHER\",\n",
    "        \"full_name\": \"Qwen/Qwen1.5-110B-Chat\",\n",
    "    },\n",
    "    \"claude_2p1\": {\"source\": \"ANTHROPIC\", \"full_name\": \"claude-2.1\"},\n",
    "    \"claude_3_opus\": {\"source\": \"ANTHROPIC\", \"full_name\": \"claude-3-opus-20240229\"},\n",
    "    \"claude_3_sonnet\": {\"source\": \"ANTHROPIC\", \"full_name\": \"claude-3-sonnet-20240229\"},\n",
    "    \"claude_3_haiku\": {\"source\": \"ANTHROPIC\", \"full_name\": \"claude-3-haiku-20240307\"},\n",
    "    #     \"gemini_pro\": {\"source\": \"GOOGLE\", \"full_name\": \"gemini-pro\"},\n",
    "}"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0744b68c",
   "metadata": {},
   "source": [
    "### Useful function"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "74bad83b",
   "metadata": {},
   "outputs": [],
   "source": [
    "def capitalize_substrings(model_name):\n",
    "    model_name = model_name.replace('gpt', 'GPT') if 'gpt' in model_name else model_name\n",
    "    substrings = model_name.split('-')\n",
    "    capitalized_substrings = [\n",
    "        substr[0].upper() + substr[1:] if substr and not substr[0].isdigit() else substr \n",
    "        for substr in substrings\n",
    "    ]\n",
    "    return '-'.join(capitalized_substrings)\n",
    "\n",
    "def generage_final_forecast_files(deadline, forecast_date, prompt_type):\n",
    "    \n",
    "    for model in models_to_test:\n",
    "        current_model_forecasts = []\n",
    "        for test_type in [f\"{prompt_type}/non_acled\", f\"{prompt_type}/acled\", f\"{prompt_type}/combo\"]:\n",
    "            file_path = f\"{test_type}/{model}.jsonl\"\n",
    "            questions = read_jsonl(file_path)\n",
    "            current_model_forecasts.extend(questions)\n",
    "\n",
    "        final_file_name = f\"{prompt_type}/final/{model}\"\n",
    "        os.makedirs(os.path.dirname(final_file_name), exist_ok=True)\n",
    "        with open(final_file_name, \"w\") as file:\n",
    "            for entry in current_model_forecasts:\n",
    "                json_line = json.dumps(entry)\n",
    "                file.write(json_line + \"\\n\")\n",
    "\n",
    "    for model in models_to_test:\n",
    "        file_path = f\"{prompt_type}/final/{model}\"\n",
    "        questions = read_jsonl(file_path)\n",
    "        if \"gpt\" in model:\n",
    "            org = \"OpenAI\"\n",
    "        elif \"llama\" in model:\n",
    "            org = \"Meta\"\n",
    "        elif \"mistral\" in model:\n",
    "            org = \"Mistral AI\"\n",
    "        elif \"claude\" in model:\n",
    "            org = \"Anthropic\"\n",
    "        elif \"qwen\" in model:\n",
    "            org = \"Qwen\"\n",
    "\n",
    "        directory = f\"{prompt_type}/final_submit\"\n",
    "        os.makedirs(directory, exist_ok=True)\n",
    "\n",
    "        new_file_name = f\"{directory}/{deadline}.{org}.{model}_{prompt_type}.json\"\n",
    "\n",
    "        model_name = models[model]['full_name'] if '/' not in models[model]['full_name'] else models[model]['full_name'].split('/')[1]\n",
    "        \n",
    "        forecast_file = {\n",
    "            \"organization\": org,\n",
    "            \"model\": f\"{capitalize_substrings(model_name)} ({prompt_type.replace('_', ' ')})\",\n",
    "            \"question_set\": f\"{deadline}-llm.jsonl\",\n",
    "            \"forecast_date\": forecast_date,\n",
    "            \"forecasts\": questions,\n",
    "        }\n",
    "\n",
    "        with open(new_file_name, \"w\") as f:\n",
    "            json.dump(forecast_file, f, indent=4)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b2d787e6",
   "metadata": {},
   "source": [
    "### Load data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "0853b418",
   "metadata": {},
   "outputs": [],
   "source": [
    "import json\n",
    "\n",
    "\n",
    "def read_jsonl(file_path):\n",
    "    data = []\n",
    "    with open(file_path, \"r\", encoding=\"utf-8\") as file:\n",
    "        for line in file:\n",
    "            if line.strip():\n",
    "                json_object = json.loads(line)\n",
    "                data.append(json_object)\n",
    "    return data\n",
    "\n",
    "\n",
    "file_path = \"2024-05-03-llm.jsonl\"\n",
    "questions = read_jsonl(file_path)\n",
    "\n",
    "single_non_acled_questions = [\n",
    "    q for q in questions if q[\"combination_of\"] == \"N/A\" and q[\"source\"] != \"acled\"\n",
    "]\n",
    "single_acled_questions = [\n",
    "    q for q in questions if q[\"combination_of\"] == \"N/A\" and q[\"source\"] == \"acled\"\n",
    "]\n",
    "combo_questions = [q for q in questions if q[\"combination_of\"] != \"N/A\" and q[\"source\"] == \"acled\"]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "8d850050",
   "metadata": {},
   "outputs": [],
   "source": [
    "combo_questions_unrolled = []\n",
    "\n",
    "for q in combo_questions:\n",
    "    for i in range(4):\n",
    "        new_q = q.copy()\n",
    "        new_q[\"combo_index\"] = i\n",
    "\n",
    "        combo_questions_unrolled.append(new_q)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "2adcf14c",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(339, 162, 644)"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "len(single_non_acled_questions), len(single_acled_questions), len(combo_questions_unrolled)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "fb6cbc4f",
   "metadata": {},
   "source": [
    "### Zero Shot Eval"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "2cc4fc17",
   "metadata": {},
   "outputs": [],
   "source": [
    "def worker(index, model_name, save_dict, questions_to_eval, rate_limit=False):\n",
    "    if save_dict[index] != \"\":\n",
    "        return\n",
    "\n",
    "    logger.info(f\"{model_name} - {index}\")\n",
    "\n",
    "    if rate_limit:\n",
    "        start_time = datetime.now()\n",
    "\n",
    "    if questions_to_eval[index][\"source\"] != \"acled\":\n",
    "        prompt = ZERO_SHOT_MARKET_PROMPT.format(\n",
    "            question=questions_to_eval[index][\"question\"],\n",
    "            background=questions_to_eval[index][\"background\"]\n",
    "            + \"\\n\"\n",
    "            + questions_to_eval[index][\"model_info_resolution_criteria\"],\n",
    "            resolution_criteria=questions_to_eval[index][\"resolution_criteria\"],\n",
    "            close_date=questions_to_eval[index][\"model_info_close_datetime\"],\n",
    "        )\n",
    "        response = model_eval.get_response_from_model(\n",
    "            prompt=prompt,\n",
    "            max_tokens=100,\n",
    "            model_name=models[model_name][\"full_name\"],\n",
    "            temperature=0,\n",
    "            wait_time=30,\n",
    "        )\n",
    "\n",
    "        save_dict[index] = model_eval.extract_probability(response)\n",
    "\n",
    "    else:\n",
    "        all_resolution_dates = []\n",
    "        for horizon in questions_to_eval[index][\"forecast_horizons\"]:\n",
    "            resolution_date = datetime.fromisoformat(\n",
    "                questions_to_eval[index][\"freeze_datetime\"]\n",
    "            ) + timedelta(days=7 + horizon)\n",
    "            resolution_date = resolution_date.isoformat()\n",
    "            all_resolution_dates.append(resolution_date)\n",
    "\n",
    "        if questions_to_eval[index][\"combination_of\"] == \"N/A\":\n",
    "            prompt = ZERO_SHOT_NON_MARKET_PROMPT.format(\n",
    "                question=questions_to_eval[index][\"question\"],\n",
    "                background=questions_to_eval[index][\"background\"]\n",
    "                + \"\\n\"\n",
    "                + questions_to_eval[index][\"model_info_resolution_criteria\"],\n",
    "                resolution_criteria=questions_to_eval[index][\"resolution_criteria\"],\n",
    "                freeze_datetime=questions_to_eval[index][\"freeze_datetime\"],\n",
    "                freeze_datetime_value=questions_to_eval[index][\"freeze_datetime_value\"],\n",
    "                freeze_datetime_value_explanation=questions_to_eval[index][\n",
    "                    \"freeze_datetime_value_explanation\"\n",
    "                ],\n",
    "                list_of_resolution_dates=all_resolution_dates,\n",
    "            )\n",
    "            response = model_eval.get_response_from_model(\n",
    "                prompt=prompt,\n",
    "                max_tokens=100,\n",
    "                model_name=models[model_name][\"full_name\"],\n",
    "                temperature=0,\n",
    "                wait_time=30,\n",
    "            )\n",
    "\n",
    "            save_dict[index] = model_eval.reformat_answers(\n",
    "                response=response, prompt=prompt, question=questions_to_eval[index]\n",
    "            )\n",
    "\n",
    "        else:\n",
    "            prompt = ZERO_SHOT_JOINT_QUESTION_PROMPT.format(\n",
    "                human_prompt=human_joint_prompts[questions_to_eval[index][\"combo_index\"]],\n",
    "                question_1=questions_to_eval[index][\"combination_of\"][0][\"question\"],\n",
    "                question_2=questions_to_eval[index][\"combination_of\"][1][\"question\"],\n",
    "                background_1=questions_to_eval[index][\"combination_of\"][0][\"background\"]\n",
    "                + \"\\n\"\n",
    "                + questions_to_eval[index][\"combination_of\"][0][\"model_info_resolution_criteria\"],\n",
    "                background_2=questions_to_eval[index][\"combination_of\"][1][\"background\"]\n",
    "                + \"\\n\"\n",
    "                + questions_to_eval[index][\"combination_of\"][1][\"model_info_resolution_criteria\"],\n",
    "                resolution_criteria_1=questions_to_eval[index][\"combination_of\"][0][\n",
    "                    \"resolution_criteria\"\n",
    "                ],\n",
    "                resolution_criteria_2=questions_to_eval[index][\"combination_of\"][1][\n",
    "                    \"resolution_criteria\"\n",
    "                ],\n",
    "                freeze_datetime_1=questions_to_eval[index][\"combination_of\"][0][\"freeze_datetime\"],\n",
    "                freeze_datetime_2=questions_to_eval[index][\"combination_of\"][1][\"freeze_datetime\"],\n",
    "                freeze_datetime_value_1=questions_to_eval[index][\"combination_of\"][0][\n",
    "                    \"freeze_datetime_value\"\n",
    "                ],\n",
    "                freeze_datetime_value_2=questions_to_eval[index][\"combination_of\"][1][\n",
    "                    \"freeze_datetime_value\"\n",
    "                ],\n",
    "                freeze_datetime_value_explanation_1=questions_to_eval[index][\"combination_of\"][\n",
    "                    0\n",
    "                ][\"freeze_datetime_value_explanation\"],\n",
    "                freeze_datetime_value_explanation_2=questions_to_eval[index][\"combination_of\"][\n",
    "                    1\n",
    "                ][\"freeze_datetime_value_explanation\"],\n",
    "                list_of_resolution_dates=all_resolution_dates,\n",
    "            )\n",
    "\n",
    "            response = model_eval.get_response_from_model(\n",
    "                prompt=prompt,\n",
    "                max_tokens=500,\n",
    "                model_name=models[model_name][\"full_name\"],\n",
    "                temperature=0,\n",
    "                wait_time=30,\n",
    "            )\n",
    "\n",
    "            save_dict[index] = model_eval.reformat_answers(\n",
    "                response=response, prompt=prompt, question=questions_to_eval[index]\n",
    "            )\n",
    "\n",
    "    logger.info(save_dict[index])\n",
    "\n",
    "    if rate_limit:\n",
    "        end_time = datetime.now()\n",
    "        elapsed_time = (end_time - start_time).total_seconds()\n",
    "        if elapsed_time < 1:\n",
    "            time.sleep(\n",
    "                1 - elapsed_time\n",
    "            )  # Ensure at least 1 second per request to stay within rate limits\n",
    "\n",
    "    return None\n",
    "\n",
    "\n",
    "def executor(max_workers, model_name, save_dict, questions_to_eval, use_gemini=False):\n",
    "    with ThreadPoolExecutor(max_workers=max_workers) as executor:\n",
    "        worker_with_args = partial(\n",
    "            worker, model_name=model_name, save_dict=save_dict, questions_to_eval=questions_to_eval\n",
    "        )\n",
    "        return list(executor.map(worker_with_args, range(len(questions_to_eval))))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "de4cc7ca",
   "metadata": {},
   "outputs": [],
   "source": [
    "results = {}\n",
    "model_result_loaded = {}\n",
    "models_to_test = list(models.keys())[:]\n",
    "\n",
    "for question in [single_non_acled_questions, single_acled_questions, combo_questions_unrolled]:\n",
    "    questions_to_eval = question\n",
    "    if question[0][\"source\"] != \"acled\":\n",
    "        test_type = \"zero_shot/non_acled\"\n",
    "    elif question[0][\"source\"] == \"acled\" and question[0][\"combination_of\"] == \"N/A\":\n",
    "        test_type = \"zero_shot/acled\"\n",
    "    else:\n",
    "        test_type = \"zero_shot/combo\"\n",
    "\n",
    "    for model in models_to_test:\n",
    "        if model not in model_result_loaded.keys():\n",
    "            model_result_loaded[model] = {}\n",
    "        model_result_loaded[model] = False\n",
    "\n",
    "    for model in models_to_test:\n",
    "        file_path = f\"{test_type}/{model}.jsonl\"\n",
    "        if model not in results.keys():\n",
    "            results[model] = {}\n",
    "        try:\n",
    "            results[model] = read_jsonl(file_path)\n",
    "            model_result_loaded[model] = True  # Set flag to True if loaded successfully\n",
    "        except:\n",
    "            results[model] = {i: \"\" for i in range(len(questions_to_eval))}\n",
    "\n",
    "    for model in models_to_test:\n",
    "        file_path = f\"{test_type}/{model}.jsonl\"\n",
    "        if not model_result_loaded[model]:\n",
    "            executor_count = 50\n",
    "            use_gemini = False\n",
    "            if models[model][\"source\"] == \"GOOGLE\":\n",
    "                executor_count = 1\n",
    "                use_gemini = True\n",
    "\n",
    "            executor(executor_count, model, results[model], questions_to_eval, use_gemini)\n",
    "\n",
    "            current_model_forecasts = []\n",
    "            for index in range(len(questions_to_eval)):\n",
    "                if questions_to_eval[index][\"source\"] == \"acled\":\n",
    "                    for forecast, horizon in zip(\n",
    "                        results[model][index], questions_to_eval[index][\"forecast_horizons\"]\n",
    "                    ):\n",
    "                        current_forecast = {\n",
    "                            \"id\": questions_to_eval[index][\"id\"],\n",
    "                            \"source\": questions_to_eval[index][\"source\"],\n",
    "                            \"forecast\": forecast,\n",
    "                            \"horizon\": horizon,\n",
    "                            \"reasoning\": None,\n",
    "                        }\n",
    "                        if questions_to_eval[index][\"combination_of\"] != \"N/A\":\n",
    "                            combo_index = questions_to_eval[index][\"combo_index\"]\n",
    "                            if combo_index == 0:\n",
    "                                current_forecast[\"direction\"] = [1, 1]\n",
    "                            elif combo_index == 1:\n",
    "                                current_forecast[\"direction\"] = [1, -1]\n",
    "                            elif combo_index == 2:\n",
    "                                current_forecast[\"direction\"] = [-1, 1]\n",
    "                            else:\n",
    "                                current_forecast[\"direction\"] = [-1, -1]\n",
    "\n",
    "                        current_model_forecasts.append(current_forecast)\n",
    "                else:\n",
    "                    current_forecast = {\n",
    "                        \"id\": questions_to_eval[index][\"id\"],\n",
    "                        \"source\": questions_to_eval[index][\"source\"],\n",
    "                        \"forecast\": results[model][index],\n",
    "                        \"reasoning\": None,\n",
    "                    }\n",
    "                    current_model_forecasts.append(current_forecast)\n",
    "\n",
    "            os.makedirs(os.path.dirname(file_path), exist_ok=True)\n",
    "            with open(file_path, \"w\") as file:\n",
    "                for entry in current_model_forecasts:\n",
    "                    json_line = json.dumps(entry)\n",
    "                    file.write(json_line + \"\\n\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "46760c08",
   "metadata": {},
   "outputs": [],
   "source": [
    "generage_final_forecast_files(deadline='2024-05-03', \n",
    "                              forecast_date='2024-05-03', \n",
    "                              prompt_type=\"zero_shot\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "6244401b",
   "metadata": {},
   "outputs": [],
   "source": [
    "# # fix directions\n",
    "# prompt_type= \"zero_shot\"\n",
    "# directory= prompt_type + '/final_submit'\n",
    "# for model in models_to_test:\n",
    "#     if \"gpt\" in model:\n",
    "#         org = \"OpenAI\"\n",
    "#     elif \"llama\" in model:\n",
    "#         org = \"Meta\"\n",
    "#     elif \"mistral\" in model:\n",
    "#         org = \"Mistral AI\"\n",
    "#     elif \"claude\" in model:\n",
    "#         org = \"Anthropic\"\n",
    "#     elif \"qwen\" in model:\n",
    "#         org = \"Qwen\"\n",
    "\n",
    "#     file_path = f\"{directory}/2024-05-03.{org}.{model}_{prompt_type}.json\"\n",
    "\n",
    "#     with open(file_path, \"r\") as file:\n",
    "#         questions = json.load(file)\n",
    "            \n",
    "#     for q in questions['forecasts']:\n",
    "#         if 'direction' in q:\n",
    "#             q['direction'] = [-1 if item == 0 else item for item in q['direction']]          \n",
    "    \n",
    "#     new_file_path = f\"{prompt_type}/direction_fix/2024-05-03.{org}.{model}_{prompt_type}.json\"\n",
    "#     os.makedirs(os.path.dirname(new_file_path), exist_ok=True)\n",
    "\n",
    "#     with open(new_file_path, 'w') as json_file:\n",
    "#         json.dump(questions, json_file, indent=4)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1f1bd6dc",
   "metadata": {},
   "source": [
    "### Scratchpad"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "e23471fd",
   "metadata": {},
   "outputs": [],
   "source": [
    "def worker(index, model_name, save_dict, questions_to_eval):\n",
    "    if save_dict[index] != \"\":\n",
    "        return\n",
    "\n",
    "    logger.info(f\"Starting {model_name} - {index}\")\n",
    "\n",
    "    if questions_to_eval[index][\"source\"] != \"acled\":\n",
    "        prompt = SCRATCH_PAD_MARKET_PROMPT.format(\n",
    "            question=questions_to_eval[index][\"question\"],\n",
    "            background=questions_to_eval[index][\"background\"]\n",
    "            + \"\\n\"\n",
    "            + questions_to_eval[index][\"model_info_resolution_criteria\"],\n",
    "            resolution_criteria=questions_to_eval[index][\"resolution_criteria\"],\n",
    "            close_date=questions_to_eval[index][\"model_info_close_datetime\"],\n",
    "        )\n",
    "        response = model_eval.get_response_from_model(\n",
    "            prompt=prompt,\n",
    "            max_tokens=1300,\n",
    "            model_name=models[model_name][\"full_name\"],\n",
    "            temperature=0,\n",
    "            wait_time=30,\n",
    "        )\n",
    "\n",
    "        save_dict[index] = (model_eval.reformat_answers(response=response, single=True), response)\n",
    "\n",
    "    else:\n",
    "        all_resolution_dates = []\n",
    "        for horizon in questions_to_eval[index][\"forecast_horizons\"]:\n",
    "            resolution_date = datetime.fromisoformat(\n",
    "                questions_to_eval[index][\"freeze_datetime\"]\n",
    "            ) + timedelta(days=7 + horizon)\n",
    "            resolution_date = resolution_date.isoformat()\n",
    "            all_resolution_dates.append(resolution_date)\n",
    "\n",
    "        if questions_to_eval[index][\"combination_of\"] == \"N/A\":\n",
    "            prompt = SCRATCH_PAD_NON_MARKET_PROMPT.format(\n",
    "                question=questions_to_eval[index][\"question\"],\n",
    "                background=questions_to_eval[index][\"background\"]\n",
    "                + \"\\n\"\n",
    "                + questions_to_eval[index][\"model_info_resolution_criteria\"],\n",
    "                resolution_criteria=questions_to_eval[index][\"resolution_criteria\"],\n",
    "                freeze_datetime=questions_to_eval[index][\"freeze_datetime\"],\n",
    "                freeze_datetime_value=questions_to_eval[index][\"freeze_datetime_value\"],\n",
    "                freeze_datetime_value_explanation=questions_to_eval[index][\n",
    "                    \"freeze_datetime_value_explanation\"\n",
    "                ],\n",
    "                list_of_resolution_dates=all_resolution_dates,\n",
    "            )\n",
    "            response = model_eval.get_response_from_model(\n",
    "                prompt=prompt,\n",
    "                max_tokens=2000,\n",
    "                model_name=models[model_name][\"full_name\"],\n",
    "                temperature=0,\n",
    "                wait_time=30,\n",
    "            )\n",
    "            save_dict[index] = (\n",
    "                model_eval.reformat_answers(\n",
    "                    response=response, prompt=prompt, question=questions_to_eval[index]\n",
    "                ),\n",
    "                response,\n",
    "            )\n",
    "        else:\n",
    "            prompt = SCRATCH_PAD_JOINT_QUESTION_PROMPT.format(\n",
    "                human_prompt=human_joint_prompts[questions_to_eval[index][\"combo_index\"]],\n",
    "                question_1=questions_to_eval[index][\"combination_of\"][0][\"question\"],\n",
    "                question_2=questions_to_eval[index][\"combination_of\"][1][\"question\"],\n",
    "                background_1=questions_to_eval[index][\"combination_of\"][0][\"background\"]\n",
    "                + \"\\n\"\n",
    "                + questions_to_eval[index][\"combination_of\"][0][\"model_info_resolution_criteria\"],\n",
    "                background_2=questions_to_eval[index][\"combination_of\"][1][\"background\"]\n",
    "                + \"\\n\"\n",
    "                + questions_to_eval[index][\"combination_of\"][1][\"model_info_resolution_criteria\"],\n",
    "                resolution_criteria_1=questions_to_eval[index][\"combination_of\"][0][\n",
    "                    \"resolution_criteria\"\n",
    "                ],\n",
    "                resolution_criteria_2=questions_to_eval[index][\"combination_of\"][1][\n",
    "                    \"resolution_criteria\"\n",
    "                ],\n",
    "                freeze_datetime_1=questions_to_eval[index][\"combination_of\"][0][\"freeze_datetime\"],\n",
    "                freeze_datetime_2=questions_to_eval[index][\"combination_of\"][1][\"freeze_datetime\"],\n",
    "                freeze_datetime_value_1=questions_to_eval[index][\"combination_of\"][0][\n",
    "                    \"freeze_datetime_value\"\n",
    "                ],\n",
    "                freeze_datetime_value_2=questions_to_eval[index][\"combination_of\"][1][\n",
    "                    \"freeze_datetime_value\"\n",
    "                ],\n",
    "                freeze_datetime_value_explanation_1=questions_to_eval[index][\"combination_of\"][\n",
    "                    0\n",
    "                ][\"freeze_datetime_value_explanation\"],\n",
    "                freeze_datetime_value_explanation_2=questions_to_eval[index][\"combination_of\"][\n",
    "                    1\n",
    "                ][\"freeze_datetime_value_explanation\"],\n",
    "                list_of_resolution_dates=all_resolution_dates,\n",
    "            )\n",
    "\n",
    "            response = model_eval.get_response_from_model(\n",
    "                prompt=prompt,\n",
    "                max_tokens=2000,\n",
    "                model_name=models[model_name][\"full_name\"],\n",
    "                temperature=0,\n",
    "                wait_time=30,\n",
    "            )\n",
    "\n",
    "            save_dict[index] = (\n",
    "                model_eval.reformat_answers(\n",
    "                    response=response, prompt=prompt, question=questions_to_eval[index]\n",
    "                ),\n",
    "                response,\n",
    "            )\n",
    "\n",
    "    logger.info(f\"Answer {save_dict[index][0]}\")\n",
    "\n",
    "    return None\n",
    "\n",
    "\n",
    "def executor(max_workers, model_name, save_dict, questions_to_eval):\n",
    "    with ThreadPoolExecutor(max_workers=max_workers) as executor:\n",
    "        worker_with_args = partial(\n",
    "            worker, model_name=model_name, save_dict=save_dict, questions_to_eval=questions_to_eval\n",
    "        )\n",
    "        return list(executor.map(worker_with_args, range(len(questions_to_eval))))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "9477cbe8",
   "metadata": {},
   "outputs": [],
   "source": [
    "results = {}\n",
    "model_result_loaded = {}\n",
    "models_to_test = list(models.keys())[:]\n",
    "\n",
    "\n",
    "for question in [single_acled_questions, single_non_acled_questions, combo_questions_unrolled]:\n",
    "    questions_to_eval = question\n",
    "    if question[0][\"source\"] != \"acled\":\n",
    "        test_type = \"scratchpad/non_acled\"\n",
    "    elif question[0][\"source\"] == \"acled\" and question[0][\"combination_of\"] == \"N/A\":\n",
    "        test_type = \"scratchpad/acled\"\n",
    "    else:\n",
    "        test_type = \"scratchpad/combo\"\n",
    "\n",
    "    for model in models_to_test:\n",
    "        if model not in model_result_loaded.keys():\n",
    "            model_result_loaded[model] = {}\n",
    "        model_result_loaded[model] = False\n",
    "\n",
    "    for model in models_to_test:\n",
    "        file_path = f\"{test_type}/{model}.jsonl\"\n",
    "        if model not in results.keys():\n",
    "            results[model] = {}\n",
    "        try:\n",
    "            results[model] = read_jsonl(file_path)\n",
    "            model_result_loaded[model] = True  # Set flag to True if loaded successfully\n",
    "        except:\n",
    "            results[model] = {i: \"\" for i in range(len(questions_to_eval))}\n",
    "\n",
    "    for model in models_to_test:\n",
    "        file_path = f\"{test_type}/{model}.jsonl\"\n",
    "        if not model_result_loaded[model]:\n",
    "            executor_count = 50\n",
    "            if models[model][\"source\"] == \"ANTHROPIC\":\n",
    "                executor_count = 30\n",
    "\n",
    "            executor(executor_count, model, results[model], questions_to_eval)\n",
    "\n",
    "            current_model_forecasts = []\n",
    "            for index in range(len(questions_to_eval)):\n",
    "                if questions_to_eval[index][\"source\"] == \"acled\":\n",
    "                    for forecast, horizon in zip(\n",
    "                        results[model][index][0], questions_to_eval[index][\"forecast_horizons\"]\n",
    "                    ):\n",
    "                        current_forecast = {\n",
    "                            \"id\": questions_to_eval[index][\"id\"],\n",
    "                            \"source\": questions_to_eval[index][\"source\"],\n",
    "                            \"forecast\": forecast,\n",
    "                            \"horizon\": horizon,\n",
    "                            \"reasoning\": results[model][index][1],\n",
    "                        }\n",
    "\n",
    "                        if questions_to_eval[index][\"combination_of\"] != \"N/A\":\n",
    "                            combo_index = questions_to_eval[index][\"combo_index\"]\n",
    "                            if combo_index == 0:\n",
    "                                current_forecast[\"direction\"] = [1, 1]\n",
    "                            elif combo_index == 1:\n",
    "                                current_forecast[\"direction\"] = [1, -1]\n",
    "                            elif combo_index == 2:\n",
    "                                current_forecast[\"direction\"] = [-1, 1]\n",
    "                            else:\n",
    "                                current_forecast[\"direction\"] = [-1, -1]\n",
    "\n",
    "                        current_model_forecasts.append(current_forecast)\n",
    "\n",
    "                else:\n",
    "                    current_forecast = {\n",
    "                        \"id\": questions_to_eval[index][\"id\"],\n",
    "                        \"source\": questions_to_eval[index][\"source\"],\n",
    "                        \"forecast\": results[model][index][0],\n",
    "                        \"reasoning\": results[model][index][1],\n",
    "                    }\n",
    "                    current_model_forecasts.append(current_forecast)\n",
    "\n",
    "            os.makedirs(os.path.dirname(file_path), exist_ok=True)\n",
    "            with open(file_path, \"w\") as file:\n",
    "                for entry in current_model_forecasts:\n",
    "                    json_line = json.dumps(entry)\n",
    "                    file.write(json_line + \"\\n\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "d2bf3ab8",
   "metadata": {},
   "outputs": [],
   "source": [
    "generage_final_forecast_files(deadline='2024-05-03', \n",
    "                              forecast_date='2024-05-03', \n",
    "                              prompt_type=\"scratchpad\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "5bc5f04c",
   "metadata": {},
   "outputs": [],
   "source": [
    "# fix directions\n",
    "# prompt_type= \"scratchpad\"\n",
    "# directory= prompt_type + '/final_submit'\n",
    "# for model in models_to_test:\n",
    "#     if \"gpt\" in model:\n",
    "#         org = \"OpenAI\"\n",
    "#     elif \"llama\" in model:\n",
    "#         org = \"Meta\"\n",
    "#     elif \"mistral\" in model:\n",
    "#         org = \"Mistral AI\"\n",
    "#     elif \"claude\" in model:\n",
    "#         org = \"Anthropic\"\n",
    "#     elif \"qwen\" in model:\n",
    "#         org = \"Qwen\"\n",
    "\n",
    "#     file_path = f\"{directory}/2024-05-03.{org}.{model}_{prompt_type}.json\"\n",
    "\n",
    "#     with open(file_path, \"r\") as file:\n",
    "#         questions = json.load(file)\n",
    "            \n",
    "#     for q in questions['forecasts']:\n",
    "#         if 'direction' in q:\n",
    "#             q['direction'] = [-1 if item == 0 else item for item in q['direction']]\n",
    "          \n",
    "    \n",
    "#     new_file_path = f\"{prompt_type}/direction_fix/2024-05-03.{org}.{model}_{prompt_type}.json\"\n",
    "#     os.makedirs(os.path.dirname(new_file_path), exist_ok=True)\n",
    "\n",
    "#     with open(new_file_path, 'w') as json_file:\n",
    "#         json.dump(questions, json_file, indent=4)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "45413a4e",
   "metadata": {},
   "source": [
    "### Scratchpad + Retrieval"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "8e11c83a",
   "metadata": {},
   "outputs": [],
   "source": [
    "models = {\n",
    "    \"gpt_3p5_turbo_0125\": {\"source\": \"OAI\", \"full_name\": \"gpt-3.5-turbo-0125\"},\n",
    "    \"gpt_4_turbo_0409\": {\"source\": \"OAI\", \"full_name\": \"gpt-4-turbo-2024-04-09\"},\n",
    "    \"gpt_4_1106_preview\": {\"source\": \"OAI\", \"full_name\": \"gpt-4-1106-preview\"},\n",
    "    \"gpt_4_0125_preview\": {\"source\": \"OAI\", \"full_name\": \"gpt-4-0125-preview\"},\n",
    "    \"gpt_4o\": {\"source\": \"OAI\", \"full_name\": \"gpt-4o\"},\n",
    "    \"mistral_8x7b_instruct\": {\n",
    "        \"source\": \"TOGETHER\",\n",
    "        \"full_name\": \"mistralai/Mixtral-8x7B-Instruct-v0.1\",\n",
    "    },\n",
    "    \"mistral_8x22b_instruct\": {\n",
    "        \"source\": \"TOGETHER\",\n",
    "        \"full_name\": \"mistralai/Mixtral-8x22B-Instruct-v0.1\",\n",
    "    },\n",
    "    \"mistral_large\": {\n",
    "        \"source\": \"MISTRAL\",\n",
    "        \"full_name\": \"mistral-large-latest\",\n",
    "    },\n",
    "    \"qwen_1p5_110b\": {\n",
    "        \"source\": \"TOGETHER\",\n",
    "        \"full_name\": \"Qwen/Qwen1.5-110B-Chat\",\n",
    "    },\n",
    "    \"claude_2p1\": {\"source\": \"ANTHROPIC\", \"full_name\": \"claude-2.1\"},\n",
    "    \"claude_3_opus\": {\"source\": \"ANTHROPIC\", \"full_name\": \"claude-3-opus-20240229\"},\n",
    "    \"claude_3_sonnet\": {\"source\": \"ANTHROPIC\", \"full_name\": \"claude-3-sonnet-20240229\"},\n",
    "    \"claude_3_haiku\": {\"source\": \"ANTHROPIC\", \"full_name\": \"claude-3-haiku-20240307\"},\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "6500e5cd",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Mapping retrieved summaries back into questions\n",
    "for question_source in [single_non_acled_questions, single_acled_questions]:\n",
    "    for q in question_source:\n",
    "        reformatted_id = q[\"id\"].replace(\"/\", \"_\")\n",
    "        filename = f\"news/{reformatted_id}.pickle\"\n",
    "        with open(filename, \"rb\") as file:\n",
    "            retrieved_info = pickle.load(file)\n",
    "        q[\"news\"] = retrieved_info\n",
    "\n",
    "for q in combo_questions_unrolled:\n",
    "    for sub_q in q[\"combination_of\"]:\n",
    "        reformatted_id = sub_q[\"id\"].replace(\"/\", \"_\")\n",
    "        filename = f\"news/{reformatted_id}.pickle\"\n",
    "        with open(filename, \"rb\") as file:\n",
    "            retrieved_info = pickle.load(file)\n",
    "        sub_q[\"news\"] = retrieved_info"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "id": "ca6b8741",
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_all_retrieved_info(all_retrieved_info):\n",
    "    retrieved_info = \"\"\n",
    "    for summary in all_retrieved_info:\n",
    "        retrieved_info += f\"Article title: {summary['title']}\" + \"\\n\"\n",
    "        retrieved_info += f\"Summary: {summary['summary']}\" + \"\\n\\n\"\n",
    "    return retrieved_info\n",
    "\n",
    "\n",
    "def worker(index, model_name, save_dict, questions_to_eval):\n",
    "    if save_dict[index] != \"\":\n",
    "        return\n",
    "\n",
    "    logger.info(f\"Starting {model_name} - {index}\")\n",
    "\n",
    "    if questions_to_eval[index][\"source\"] != \"acled\":\n",
    "        prompt = SCRATCH_PAD_WITH_SUMMARIES_MARKET_PROMPT.format(\n",
    "            question=questions_to_eval[index][\"question\"],\n",
    "            background=questions_to_eval[index][\"background\"]\n",
    "            + \"\\n\"\n",
    "            + questions_to_eval[index][\"model_info_resolution_criteria\"],\n",
    "            resolution_criteria=questions_to_eval[index][\"resolution_criteria\"],\n",
    "            close_date=questions_to_eval[index][\"model_info_close_datetime\"],\n",
    "            retrieved_info=get_all_retrieved_info(questions_to_eval[index][\"news\"]),\n",
    "        )\n",
    "        response = model_eval.get_response_from_model(\n",
    "            prompt=prompt,\n",
    "            max_tokens=2000,\n",
    "            model_name=models[model_name][\"full_name\"],\n",
    "            temperature=0,\n",
    "            wait_time=30,\n",
    "        )\n",
    "\n",
    "        save_dict[index] = (model_eval.reformat_answers(response=response, single=True), response)\n",
    "\n",
    "    else:\n",
    "        all_resolution_dates = []\n",
    "        for horizon in questions_to_eval[index][\"forecast_horizons\"]:\n",
    "            resolution_date = datetime.fromisoformat(\n",
    "                questions_to_eval[index][\"freeze_datetime\"]\n",
    "            ) + timedelta(days=7 + horizon)\n",
    "            resolution_date = resolution_date.isoformat()\n",
    "            all_resolution_dates.append(resolution_date)\n",
    "\n",
    "        if questions_to_eval[index][\"combination_of\"] == \"N/A\":\n",
    "            prompt = SCRATCH_PAD_WITH_SUMMARIES_NON_MARKET_PROMPT.format(\n",
    "                question=questions_to_eval[index][\"question\"],\n",
    "                background=questions_to_eval[index][\"background\"]\n",
    "                + \"\\n\"\n",
    "                + questions_to_eval[index][\"model_info_resolution_criteria\"],\n",
    "                resolution_criteria=questions_to_eval[index][\"resolution_criteria\"],\n",
    "                freeze_datetime=questions_to_eval[index][\"freeze_datetime\"],\n",
    "                freeze_datetime_value=questions_to_eval[index][\"freeze_datetime_value\"],\n",
    "                freeze_datetime_value_explanation=questions_to_eval[index][\n",
    "                    \"freeze_datetime_value_explanation\"\n",
    "                ],\n",
    "                retrieved_info=get_all_retrieved_info(questions_to_eval[index][\"news\"]),\n",
    "                list_of_resolution_dates=all_resolution_dates,\n",
    "            )\n",
    "            response = model_eval.get_response_from_model(\n",
    "                prompt=prompt,\n",
    "                max_tokens=2000,\n",
    "                model_name=models[model_name][\"full_name\"],\n",
    "                temperature=0,\n",
    "                wait_time=30,\n",
    "            )\n",
    "            save_dict[index] = (\n",
    "                model_eval.reformat_answers(\n",
    "                    response=response, prompt=prompt, question=questions_to_eval[index]\n",
    "                ),\n",
    "                response,\n",
    "            )\n",
    "        else:\n",
    "            prompt = SCRATCH_PAD_WITH_SUMMARIES_JOINT_QUESTION_PROMPT.format(\n",
    "                human_prompt=human_joint_prompts[questions_to_eval[index][\"combo_index\"]],\n",
    "                question_1=questions_to_eval[index][\"combination_of\"][0][\"question\"],\n",
    "                question_2=questions_to_eval[index][\"combination_of\"][1][\"question\"],\n",
    "                background_1=questions_to_eval[index][\"combination_of\"][0][\"background\"]\n",
    "                + \"\\n\"\n",
    "                + questions_to_eval[index][\"combination_of\"][0][\"model_info_resolution_criteria\"],\n",
    "                background_2=questions_to_eval[index][\"combination_of\"][1][\"background\"]\n",
    "                + \"\\n\"\n",
    "                + questions_to_eval[index][\"combination_of\"][1][\"model_info_resolution_criteria\"],\n",
    "                resolution_criteria_1=questions_to_eval[index][\"combination_of\"][0][\n",
    "                    \"resolution_criteria\"\n",
    "                ],\n",
    "                resolution_criteria_2=questions_to_eval[index][\"combination_of\"][1][\n",
    "                    \"resolution_criteria\"\n",
    "                ],\n",
    "                freeze_datetime_1=questions_to_eval[index][\"combination_of\"][0][\"freeze_datetime\"],\n",
    "                freeze_datetime_2=questions_to_eval[index][\"combination_of\"][1][\"freeze_datetime\"],\n",
    "                freeze_datetime_value_1=questions_to_eval[index][\"combination_of\"][0][\n",
    "                    \"freeze_datetime_value\"\n",
    "                ],\n",
    "                freeze_datetime_value_2=questions_to_eval[index][\"combination_of\"][1][\n",
    "                    \"freeze_datetime_value\"\n",
    "                ],\n",
    "                freeze_datetime_value_explanation_1=questions_to_eval[index][\"combination_of\"][\n",
    "                    0\n",
    "                ][\"freeze_datetime_value_explanation\"],\n",
    "                freeze_datetime_value_explanation_2=questions_to_eval[index][\"combination_of\"][\n",
    "                    1\n",
    "                ][\"freeze_datetime_value_explanation\"],\n",
    "                retrieved_info_1=get_all_retrieved_info(\n",
    "                    questions_to_eval[index][\"combination_of\"][0][\"news\"]\n",
    "                ),\n",
    "                retrieved_info_2=get_all_retrieved_info(\n",
    "                    questions_to_eval[index][\"combination_of\"][1][\"news\"]\n",
    "                ),\n",
    "                list_of_resolution_dates=all_resolution_dates,\n",
    "            )\n",
    "\n",
    "            response = model_eval.get_response_from_model(\n",
    "                prompt=prompt,\n",
    "                max_tokens=2000,\n",
    "                model_name=models[model_name][\"full_name\"],\n",
    "                temperature=0,\n",
    "                wait_time=30,\n",
    "            )\n",
    "\n",
    "            save_dict[index] = (\n",
    "                model_eval.reformat_answers(\n",
    "                    response=response, prompt=prompt, question=questions_to_eval[index]\n",
    "                ),\n",
    "                response,\n",
    "            )\n",
    "\n",
    "    logger.info(f\"Answer: {save_dict[index][0]}\")\n",
    "\n",
    "    return None\n",
    "\n",
    "\n",
    "def executor(max_workers, model_name, save_dict, questions_to_eval):\n",
    "    with ThreadPoolExecutor(max_workers=max_workers) as executor:\n",
    "        worker_with_args = partial(\n",
    "            worker, model_name=model_name, save_dict=save_dict, questions_to_eval=questions_to_eval\n",
    "        )\n",
    "        return list(executor.map(worker_with_args, range(len(questions_to_eval))))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "id": "d9fe9f6f",
   "metadata": {},
   "outputs": [],
   "source": [
    "results = {}\n",
    "model_result_loaded = {}\n",
    "models_to_test = list(models.keys())[:]\n",
    "\n",
    "\n",
    "for question in [single_acled_questions, combo_questions_unrolled, single_non_acled_questions]:\n",
    "    questions_to_eval = question\n",
    "    if question[0][\"source\"] != \"acled\":\n",
    "        test_type = \"scratchpad_with_info_retrieval/non_acled\"\n",
    "    elif question[0][\"source\"] == \"acled\" and question[0][\"combination_of\"] == \"N/A\":\n",
    "        test_type = \"scratchpad_with_info_retrieval/acled\"\n",
    "    else:\n",
    "        test_type = \"scratchpad_with_info_retrieval/combo\"\n",
    "\n",
    "    for model in models_to_test:\n",
    "        if model not in model_result_loaded.keys():\n",
    "            model_result_loaded[model] = {}\n",
    "        model_result_loaded[model] = False\n",
    "\n",
    "    for model in models_to_test:\n",
    "        file_path = f\"{test_type}/{model}.jsonl\"\n",
    "        if model not in results.keys():\n",
    "            results[model] = {}\n",
    "        try:\n",
    "            results[model] = read_jsonl(file_path)\n",
    "            model_result_loaded[model] = True  # Set flag to True if loaded successfully\n",
    "        except:\n",
    "            results[model] = {i: \"\" for i in range(len(questions_to_eval))}\n",
    "\n",
    "    for model in models_to_test:\n",
    "        file_path = f\"{test_type}/{model}.jsonl\"\n",
    "        if not model_result_loaded[model]:\n",
    "            executor_count = 50\n",
    "            if models[model][\"source\"] == \"ANTHROPIC\":\n",
    "                executor_count = 30\n",
    "\n",
    "            executor(executor_count, model, results[model], questions_to_eval)\n",
    "\n",
    "            current_model_forecasts = []\n",
    "            for index in range(len(questions_to_eval)):\n",
    "                if questions_to_eval[index][\"source\"] == \"acled\":\n",
    "                    for forecast, horizon in zip(\n",
    "                        results[model][index][0], questions_to_eval[index][\"forecast_horizons\"]\n",
    "                    ):\n",
    "                        current_forecast = {\n",
    "                            \"id\": questions_to_eval[index][\"id\"],\n",
    "                            \"source\": questions_to_eval[index][\"source\"],\n",
    "                            \"forecast\": forecast,\n",
    "                            \"horizon\": horizon,\n",
    "                            \"reasoning\": results[model][index][1],\n",
    "                        }\n",
    "\n",
    "                        if questions_to_eval[index][\"combination_of\"] != \"N/A\":\n",
    "                            combo_index = questions_to_eval[index][\"combo_index\"]\n",
    "                            if combo_index == 0:\n",
    "                                current_forecast[\"direction\"] = [1, 1]\n",
    "                            elif combo_index == 1:\n",
    "                                current_forecast[\"direction\"] = [1, -1]\n",
    "                            elif combo_index == 2:\n",
    "                                current_forecast[\"direction\"] = [-1, 1]\n",
    "                            else:\n",
    "                                current_forecast[\"direction\"] = [-1, -1]\n",
    "\n",
    "                        current_model_forecasts.append(current_forecast)\n",
    "\n",
    "                else:\n",
    "                    current_forecast = {\n",
    "                        \"id\": questions_to_eval[index][\"id\"],\n",
    "                        \"source\": questions_to_eval[index][\"source\"],\n",
    "                        \"forecast\": results[model][index][0],\n",
    "                        \"reasoning\": results[model][index][1],\n",
    "                    }\n",
    "                    current_model_forecasts.append(current_forecast)\n",
    "\n",
    "            os.makedirs(os.path.dirname(file_path), exist_ok=True)\n",
    "            with open(file_path, \"w\") as file:\n",
    "                for entry in current_model_forecasts:\n",
    "                    json_line = json.dumps(entry)\n",
    "                    file.write(json_line + \"\\n\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "id": "08cb2be0",
   "metadata": {},
   "outputs": [],
   "source": [
    "generage_final_forecast_files(deadline='2024-05-03', \n",
    "                              forecast_date='2024-05-03', \n",
    "                              prompt_type=\"scratchpad_with_info_retrieval\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "id": "9cea8357",
   "metadata": {},
   "outputs": [],
   "source": [
    "# fix directions\n",
    "# prompt_type= \"scratchpad_with_info_retrieval\"\n",
    "# directory= prompt_type + '/final_submit'\n",
    "# for model in models_to_test:\n",
    "#     if \"gpt\" in model:\n",
    "#         org = \"OpenAI\"\n",
    "#     elif \"llama\" in model:\n",
    "#         org = \"Meta\"\n",
    "#     elif \"mistral\" in model:\n",
    "#         org = \"Mistral AI\"\n",
    "#     elif \"claude\" in model:\n",
    "#         org = \"Anthropic\"\n",
    "#     elif \"qwen\" in model:\n",
    "#         org = \"Qwen\"\n",
    "\n",
    "#     file_path = f\"{directory}/2024-05-03.{org}.{model}_{prompt_type}.json\"\n",
    "\n",
    "#     with open(file_path, \"r\") as file:\n",
    "#         questions = json.load(file)\n",
    "            \n",
    "#     for q in questions['forecasts']:\n",
    "#         if 'direction' in q:\n",
    "#             q['direction'] = [-1 if item == 0 else item for item in q['direction']]\n",
    "\n",
    "#     new_file_path = f\"{prompt_type}/direction_fix/2024-05-03.{org}.{model}_{prompt_type}.json\"\n",
    "#     os.makedirs(os.path.dirname(new_file_path), exist_ok=True)\n",
    "\n",
    "#     with open(new_file_path, 'w') as json_file:\n",
    "#         json.dump(questions, json_file, indent=4)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3.12 (myenv-llm-bench)",
   "language": "python",
   "name": "myenv-llm-bench"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
