{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import re\n",
    "import tempfile\n",
    "from monty.tempfile import ScratchDir\n",
    "\n",
    "from dotenv import load_dotenv\n",
    "from langchain import hub\n",
    "from langchain.agents import load_tools\n",
    "from langchain_experimental.tools import PythonREPLTool\n",
    "from langchain.agents import create_react_agent, AgentExecutor\n",
    "from langchain.chains.conversation.memory import ConversationBufferWindowMemory\n",
    "from langchain_openai import ChatOpenAI\n",
    "from langchain_community.chat_models import ChatOllama\n",
    "from langchain.tools import ArxivQueryRun, WikipediaQueryRun\n",
    "from langchain.utilities import ArxivAPIWrapper, WikipediaAPIWrapper\n",
    "\n",
    "from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler\n",
    "\n",
    "from llamp.mp.agents import (\n",
    "    MPSummaryExpert,\n",
    "    MPThermoExpert,\n",
    "    MPElasticityExpert,\n",
    "    MPDielectricExpert,\n",
    "    MPPiezoelectricExpert,\n",
    "    MPMagnetismExpert,\n",
    "    MPElectronicExpert,\n",
    "    MPSynthesisExpert,\n",
    "    MPStructureRetriever,\n",
    ")\n",
    "from llamp.atomate2.agents import MLFFMDAgent\n",
    "\n",
    "load_dotenv()\n",
    "\n",
    "MP_API_KEY = os.getenv(\"MP_API_KEY\", None)\n",
    "OPENAI_API_KEY = os.getenv(\"OPENAI_API_KEY\", None)\n",
    "OPENAI_ORGANIZATION = os.getenv(\"OPENAI_ORGANIZATION\", None)\n",
    "\n",
    "OPENAI_GPT_MODEL = \"gpt-4-turbo\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/Users/elvis/anaconda3/envs/llamp_test/lib/python3.11/site-packages/langchain/hub.py:86: DeprecationWarning: The `langchainhub sdk` is deprecated.\n",
      "Please use the `langsmith sdk` instead:\n",
      "  pip install langsmith\n",
      "Use the `pull_prompt` method.\n",
      "  res_dict = client.pull_repo(owner_repo_commit)\n"
     ]
    }
   ],
   "source": [
    "top_llm = ChatOpenAI(\n",
    "    temperature=0.1,\n",
    "    model=OPENAI_GPT_MODEL,\n",
    "    openai_api_key=OPENAI_API_KEY,\n",
    "    openai_organization=OPENAI_ORGANIZATION,\n",
    "    streaming=True,\n",
    "    callbacks=[StreamingStdOutCallbackHandler()],\n",
    ")\n",
    "\n",
    "bottom_callback_handler = StreamingStdOutCallbackHandler()\n",
    "\n",
    "bottom_llm = ChatOpenAI(\n",
    "    temperature=0.0,\n",
    "    model=OPENAI_GPT_MODEL,\n",
    "    openai_api_key=OPENAI_API_KEY,\n",
    "    openai_organization=OPENAI_ORGANIZATION,\n",
    "    max_retries=5,\n",
    "    streaming=True,\n",
    "    callbacks=[bottom_callback_handler],\n",
    ")\n",
    "\n",
    "\n",
    "wikipedia = WikipediaQueryRun(api_wrapper=WikipediaAPIWrapper())\n",
    "arxiv = ArxivQueryRun(api_wrapper=ArxivAPIWrapper())\n",
    "\n",
    "tools = load_tools([\"llm-math\"], llm=bottom_llm)\n",
    "tools += [PythonREPLTool(\n",
    "    description=re.sub(\n",
    "        r\"\\s+\",\n",
    "        \" \",\n",
    "        \"\"\"A Python shell. Use this to execute python commands.\n",
    "        Input should be a valid python command.\n",
    "        \n",
    "        TIPS:\n",
    "        - If you want to see the output of a value, you should print it out\n",
    "        with `print(...)`.\n",
    "        - If you write any file to a local file system, you should ALWAYS \n",
    "        print the paths to the files as well.\n",
    "        - If you plot any graph, you should show the graph, save it to local \n",
    "        and print the path to the image file.\n",
    "        - Combine all your output string together and print once together in\n",
    "        JSON format at the end of the code.\n",
    "        \"\"\"\n",
    "    )\n",
    ")]\n",
    "tools += [\n",
    "    MPThermoExpert(llm=bottom_llm, mp_api_key=MP_API_KEY).as_tool(\n",
    "        agent_kwargs=dict(return_intermediate_steps=False)\n",
    "    ),\n",
    "    MPElasticityExpert(llm=bottom_llm, mp_api_key=MP_API_KEY).as_tool(\n",
    "        agent_kwargs=dict(return_intermediate_steps=False)\n",
    "    ),\n",
    "    MPDielectricExpert(llm=bottom_llm, mp_api_key=MP_API_KEY).as_tool(\n",
    "        agent_kwargs=dict(return_intermediate_steps=False)\n",
    "    ),\n",
    "    MPMagnetismExpert(llm=bottom_llm, mp_api_key=MP_API_KEY).as_tool(\n",
    "        agent_kwargs=dict(return_intermediate_steps=False)\n",
    "    ),\n",
    "    MPElectronicExpert(llm=bottom_llm, mp_api_key=MP_API_KEY).as_tool(\n",
    "        agent_kwargs=dict(return_intermediate_steps=False)\n",
    "    ),\n",
    "    MPPiezoelectricExpert(llm=bottom_llm, mp_api_key=MP_API_KEY).as_tool(\n",
    "        agent_kwargs=dict(return_intermediate_steps=False)\n",
    "    ),\n",
    "    MPSummaryExpert(llm=bottom_llm, mp_api_key=MP_API_KEY).as_tool(\n",
    "        agent_kwargs=dict(return_intermediate_steps=False)\n",
    "    ),\n",
    "    MPSynthesisExpert(llm=bottom_llm, mp_api_key=MP_API_KEY).as_tool(\n",
    "        agent_kwargs=dict(return_intermediate_steps=False)\n",
    "    ),\n",
    "    MPStructureRetriever(llm=bottom_llm, mp_api_key=MP_API_KEY).as_tool(\n",
    "        agent_kwargs=dict(return_intermediate_steps=False)\n",
    "    ),\n",
    "    MLFFMDAgent(llm=bottom_llm).as_tool(\n",
    "        agent_kwargs=dict(return_intermediate_steps=False)\n",
    "    ),\n",
    "    arxiv,\n",
    "    wikipedia,\n",
    "]\n",
    "\n",
    "instructions = re.sub(\n",
    "    r\"\\s+\",\n",
    "    \" \",\n",
    "    \"\"\"You name is LLaMP and you are a helpful agent that can consult \n",
    "    materials-related data through Materials Project (MP) database, arXiv, \n",
    "    Wikipedia, and run calculations/simulations via a python REPL and MP atomate2 \n",
    "    computational workflow.     \n",
    "        \n",
    "    If you get an error, debug your code and try again. Only use \n",
    "    the output of your code to answer the question. Ask user to clarify their \n",
    "    queries if needed. Please note that you don't have direct control over MP but \n",
    "    through multiple assistant agents to help you. You need to provide complete \n",
    "    context in the input for assistants to do their job. REFINE and CRITQUE the \n",
    "    output of the assistants if needed. You can end the conversation by saying \"Bye\".\n",
    "    \"\"\",\n",
    "    )\n",
    "base_prompt = hub.pull(\"langchain-ai/react-agent-template\")\n",
    "prompt = base_prompt.partial(instructions=instructions)\n",
    "\n",
    "conversational_memory = ConversationBufferWindowMemory(\n",
    "    memory_key=\"chat_history\", k=5, return_messages=True\n",
    ")\n",
    "\n",
    "agent = create_react_agent(top_llm, tools, prompt)\n",
    "agent_executor = AgentExecutor(\n",
    "    agent=agent, tools=tools, verbose=True,\n",
    "    handle_parsing_errors=True,\n",
    "    memory=conversational_memory,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "\n",
      "\u001b[1m> Entering new AgentExecutor chain...\u001b[0m\n",
      "```\n",
      "Thought: Do I need to use a tool? Yes\n",
      "Action: MPSummaryExpert\n",
      "Action Input: Please provide the mass density and volume per atom of stable NaCl.\u001b[32;1m\u001b[1;3m```\n",
      "Thought: Do I need to use a tool? Yes\n",
      "Action: MPSummaryExpert\n",
      "Action Input: Please provide the mass density and volume per atom of stable NaCl.\u001b[0m\n",
      "\n",
      "\u001b[1m> Entering new AgentExecutor chain...\u001b[0m\n",
      "Thought: I need to search for stable NaCl and retrieve its mass density and volume per atom.\n",
      "\n",
      "Action:\n",
      "```\n",
      "{\n",
      "  \"action\": \"search_materials_summary__get\",\n",
      "  \"action_input\": {\n",
      "    \"formula\": \"NaCl\",\n",
      "    \"is_stable\": true,\n",
      "    \"fields\": \"density,density_atomic\"\n",
      "  }\n",
      "}\n",
      "```\n",
      "\u001b[32;1m\u001b[1;3mThought: I need to search for stable NaCl and retrieve its mass density and volume per atom.\n",
      "\n",
      "Action:\n",
      "```\n",
      "{\n",
      "  \"action\": \"search_materials_summary__get\",\n",
      "  \"action_input\": {\n",
      "    \"formula\": \"NaCl\",\n",
      "    \"is_stable\": true,\n",
      "    \"fields\": \"density,density_atomic\"\n",
      "  }\n",
      "}\n",
      "```\n",
      "\u001b[0m"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Retrieving SummaryDoc documents: 100%|██████████| 1/1 [00:00<00:00, 16644.06it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\u001b[36;1m\u001b[1;3m[{'formula_pretty': 'NaCl', 'density': 2.224545043829438, 'density_atomic': 21.812662732751388, 'material_id': 'mp-22862'}]\u001b[0mAction:\n",
      "```\n",
      "{\n",
      "  \"action\": \"Final Answer\",\n",
      "  \"action_input\": {\n",
      "    \"mass_density\": 2.224545043829438,\n",
      "    \"volume_per_atom\": 21.812662732751388\n",
      "  }\n",
      "}\n",
      "```\n",
      "\u001b[32;1m\u001b[1;3mAction:\n",
      "```\n",
      "{\n",
      "  \"action\": \"Final Answer\",\n",
      "  \"action_input\": {\n",
      "    \"mass_density\": 2.224545043829438,\n",
      "    \"volume_per_atom\": 21.812662732751388\n",
      "  }\n",
      "}\n",
      "```\n",
      "\u001b[0m\n",
      "\n",
      "\u001b[1m> Finished chain.\u001b[0m\n",
      "\u001b[38;5;200m\u001b[1;3m{'input': 'Please provide the mass density and volume per atom of stable NaCl.', 'output': {'mass_density': 2.224545043829438, 'volume_per_atom': 21.812662732751388}}\u001b[0mDo I need to use a tool? No  \n",
      "Final Answer: The mass density of stable NaCl is approximately 2.22 g/cm³, and the volume per atom is approximately 21.81 Å³/atom.\u001b[32;1m\u001b[1;3mDo I need to use a tool? No  \n",
      "Final Answer: The mass density of stable NaCl is approximately 2.22 g/cm³, and the volume per atom is approximately 21.81 Å³/atom.\u001b[0m\n",
      "\n",
      "\u001b[1m> Finished chain.\u001b[0m\n"
     ]
    }
   ],
   "source": [
    "response = agent_executor.invoke(\n",
    "    {\n",
    "        \"input\": \"Please give me the mass density (in g/cm^3) and volume per atom (in Angstrom^3/atom) of stable NaCl.\",\n",
    "    }\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "The mass density of stable NaCl is approximately 2.22 g/cm³, and the volume per atom is approximately 21.81 Å³/atom.\n"
     ]
    }
   ],
   "source": [
    "print(response[\"output\"])"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "llamp",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
