{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "from real_world_graphs.llms import OpenAILLM, OpenAIConfig\n",
    "\n",
    "from real_world_graphs.cause_net_tasks import GraphEstimationTask\n",
    "import os\n",
    "import json"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "NARRATIVE_DIR = \"path/to/narrative_directory\"\n",
    "NARRATIVE_PATHS = [os.path.join(NARRATIVE_DIR, fn) for fn in os.listdir(NARRATIVE_DIR)]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "llm = OpenAILLM(config=OpenAIConfig(max_tokens=100))\n",
    "llm.config.model_name = \"gpt-4o\"\n",
    "llm.config.max_workers = 1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_max_narrative_len(narrative_path: str) -> int:\n",
    "    max_len = 0\n",
    "    with open(narrative_path, \"r\") as file:\n",
    "        for line in file:\n",
    "            narrative_data = json.loads(line)\n",
    "            max_len = max(max_len, len(narrative_data[\"nodes\"]))\n",
    "    return max_len"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "tasks = [\n",
    "    GraphEstimationTask(\n",
    "        graph_path=None,\n",
    "        llm=llm,\n",
    "        narrative_path=path,\n",
    "        min_chain_length=get_max_narrative_len(path),\n",
    "    )\n",
    "    for path in NARRATIVE_PATHS\n",
    "]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "prompt_list = []\n",
    "for task in tasks:\n",
    "    prompt_list += task.generate_prompt_data()\n",
    "len(prompt_list)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "responses = task.prompts_to_response(prompt_data=prompt_list, show_progress=True)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "llm",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
