{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "from llms import LLM, OpenAILLM, OpenAIConfig\n",
    "import cause_net.cause_net_tasks\n",
    "\n",
    "from cause_net.cause_net_tasks import (\n",
    "    GraphEstimationPrompt,\n",
    "    GraphEstimationResponse,\n",
    "    GraphEstimationTask,\n",
    "    ChainEstimationResponse,\n",
    ")\n",
    "import pandas as pd\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "import random\n",
    "import os\n",
    "import pickle\n",
    "from datetime import datetime\n",
    "import json\n",
    "from tqdm import tqdm\n",
    "from threading import Thread"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "CAUSE_NET_GRAPH_PATH = \"path/to/cause_net\"\n",
    "NARRATIVE_DIR = \"path/to/narrative_directory\"\n",
    "NARRATIVE_PATHS = [os.path.join(NARRATIVE_DIR, fn) for fn in os.listdir(NARRATIVE_DIR)]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "NARRATIVE_PATHS = [os.path.join(NARRATIVE_DIR, fn) for fn in os.listdir(NARRATIVE_DIR)]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "llm = OpenAILLM(config=OpenAIConfig(max_tokens=100))\n",
    "llm.config.model_name = \"gpt-4o\"\n",
    "llm.config.max_workers = 1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_max_narrative_len(narrative_path: str) -> int:\n",
    "    max_len = 0\n",
    "    with open(narrative_path, \"r\") as file:\n",
    "        for line in file:\n",
    "            narrative_data = json.loads(line)\n",
    "            max_len = max(max_len, len(narrative_data[\"nodes\"]))\n",
    "    return max_len"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "tasks = [\n",
    "    GraphEstimationTask(\n",
    "        graph_path=CAUSE_NET_GRAPH_PATH,\n",
    "        llm=llm,\n",
    "        narrative_path=path,\n",
    "        min_chain_length=get_max_narrative_len(path),\n",
    "    )\n",
    "    for path in NARRATIVE_PATHS\n",
    "]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "prompt_list = []\n",
    "for task in tasks:\n",
    "    prompt_list += task.generate_prompt_data()\n",
    "len(prompt_list)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "responses = task.prompts_to_response(prompt_data=prompt_list, show_progress=True)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
