{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "56ee2cf1-ea9c-4bb2-a899-0fb9378a0a3b",
   "metadata": {},
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2"
   ],
   "outputs": []
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "36bdb5d6-1477-4fc8-a3a1-c351ad751cfd",
   "metadata": {},
   "source": [
    "import os\n",
    "\n",
    "import cohere\n",
    "import logfire\n",
    "from anthropic import Anthropic\n",
    "from dotenv import load_dotenv\n",
    "from openai import OpenAI"
   ],
   "outputs": []
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "7713977a-5a02-47d3-92f3-ee8023a7c61a",
   "metadata": {},
   "source": [
    "from function_calling_pi import register_function\n",
    "from function_calling_pi.agent_pipeline import (\n",
    "    AgentPipeline,\n",
    "    AnthropicLLM,\n",
    "    CohereLLM,\n",
    "    GoogleLLM,\n",
    "    InitQuery,\n",
    "    OpenAILLM,\n",
    "    PromptingLLM,\n",
    "    SystemMessage,\n",
    "    ToolsExecutionLoop,\n",
    "    ToolsExecutor,\n",
    ")\n",
    "from function_calling_pi.functions_engine.functions_engine import FUNCTIONS_DOCS\n",
    "from function_calling_pi.logging import OutputLogger"
   ],
   "outputs": []
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "261b1047-d5c5-4f54-be7f-32cf7be52e61",
   "metadata": {},
   "source": [
    "load_dotenv(\"../.env\")"
   ],
   "outputs": []
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "2ee1f840-062e-4378-97f3-7de20b598b26",
   "metadata": {},
   "source": [
    "logfire.configure(project_name=\"pi-benchmark\", service_name=\"v1\")"
   ],
   "outputs": []
  },
  {
   "cell_type": "code",
   "execution_count": 76,
   "id": "0a07bc98-1e09-45a7-b5ea-acb901363ae1",
   "metadata": {},
   "source": [
    "provider = \"google\"\n",
    "\n",
    "if provider == \"openai\":\n",
    "    client = OpenAI(\n",
    "        api_key=os.environ[\"OPENAI_API_KEY\"], organization=os.getenv(\"OPENAI_ORG\", None)\n",
    "    )\n",
    "    # model = \"gpt-3.5-turbo-0125\"\n",
    "    model = \"gpt-4o-2024-05-13\"\n",
    "    # model = \"gpt-4-turbo-2024-04-09\"\n",
    "    llm = OpenAILLM(client, model)\n",
    "elif provider == \"anthropic\":\n",
    "    client = Anthropic(api_key=os.environ[\"ANTHROPIC_API_KEY\"])\n",
    "    model = \"claude-3-opus-20240229\"\n",
    "    llm = AnthropicLLM(client, model)\n",
    "elif provider == \"together\":\n",
    "    client = OpenAI(\n",
    "        api_key=os.getenv(\"TOGETHER_API_KEY\"), base_url=\"https://api.together.xyz/v1\"\n",
    "    )\n",
    "    model = \"mistralai/Mixtral-8x7B-Instruct-v0.1\"\n",
    "    llm = OpenAILLM(client, model)\n",
    "elif provider == \"cohere\":\n",
    "    client = cohere.Client(api_key=os.environ[\"COHERE_API_KEY\"])\n",
    "    model = \"command-r-plus\"\n",
    "    llm = CohereLLM(client, model)\n",
    "elif provider == \"prompting\":\n",
    "    client = OpenAI(\n",
    "        api_key=os.getenv(\"TOGETHER_API_KEY\"), base_url=\"https://api.together.xyz/v1\"\n",
    "    )\n",
    "    model = \"meta-llama/Llama-3-70b-chat-hf\"\n",
    "    llm = PromptingLLM(client, model)\n",
    "    logfire.instrument_openai(client)\n",
    "elif provider == \"google\":\n",
    "    model = \"gemini-1.5-pro-001\"\n",
    "    llm = GoogleLLM(model)\n",
    "else:\n",
    "    raise ValueError(\"Invalid provider\")"
   ],
   "outputs": []
  },
  {
   "cell_type": "code",
   "execution_count": 77,
   "id": "028fc1ff-2f81-4935-a5a3-d54e2a1c6129",
   "metadata": {},
   "source": [
    "pipeline = AgentPipeline(\n",
    "    [SystemMessage(\"You are a helpful assistant.\"), InitQuery(), llm]\n",
    ")"
   ],
   "outputs": []
  },
  {
   "cell_type": "code",
   "execution_count": 78,
   "id": "d5bfa52c-9cc7-4dda-b47b-cfbcca371d1b",
   "metadata": {},
   "source": [
    "import google.generativeai as genai\n",
    "\n",
    "GOOGLE_API_KEY = os.environ[\"GOOGLE_API_KEY\"]\n",
    "genai.configure(api_key=GOOGLE_API_KEY)"
   ],
   "outputs": []
  },
  {
   "cell_type": "code",
   "execution_count": 79,
   "id": "6caaf6b8-8c76-462b-8ab6-6665fc7981ae",
   "metadata": {},
   "source": [
    "import vertexai\n",
    "\n",
    "vertexai.init(project=\"infinite-strata-351214\", location=\"us-central1\")"
   ],
   "outputs": []
  },
  {
   "cell_type": "code",
   "execution_count": 80,
   "id": "b781fabb-fffb-4b1f-8511-df39ab349737",
   "metadata": {},
   "source": [
    "pipeline.query(\"What AI assistant are you?\", [])"
   ],
   "outputs": []
  },
  {
   "cell_type": "code",
   "execution_count": 81,
   "id": "d2e00915-7b8b-4615-9310-b0f729c59a50",
   "metadata": {},
   "source": [
    "@register_function\n",
    "def sum(a: int, b: int):\n",
    "    \"\"\"Sums two numbers\n",
    "\n",
    "    :param a: the first number\n",
    "    :param b: the second number\n",
    "    \"\"\"\n",
    "    return a + b"
   ],
   "outputs": []
  },
  {
   "cell_type": "code",
   "execution_count": 82,
   "id": "f8734154-3ed6-47ae-b353-abddd6d58648",
   "metadata": {},
   "source": [
    "tools = FUNCTIONS_DOCS.values() # [FUNCTIONS_DOCS[\"sum\"]] # "
   ],
   "outputs": []
  },
  {
   "cell_type": "code",
   "execution_count": 83,
   "id": "b5ae278c-2639-46e3-9ed5-74b9ef3301f2",
   "metadata": {},
   "source": [
    "FUNCTIONS_DOCS[\"sum\"].parameters.schema_json()"
   ],
   "outputs": []
  },
  {
   "cell_type": "code",
   "execution_count": 85,
   "id": "51359723-09db-493b-97b4-f20469e7770c",
   "metadata": {},
   "source": [
    "pipeline.query(\"Hi, how are you?\", list(tools)[:29])"
   ],
   "outputs": []
  },
  {
   "cell_type": "code",
   "execution_count": 50,
   "id": "08d88506-700a-4663-83eb-1322715da0a7",
   "metadata": {},
   "source": [
    "tools_loop = ToolsExecutionLoop(\n",
    "    [\n",
    "        ToolsExecutor(),\n",
    "        # TransformersBasedPIDetector(),\n",
    "        llm,\n",
    "    ]\n",
    ")\n",
    "\n",
    "tools_pipeline = AgentPipeline(\n",
    "    [\n",
    "        SystemMessage(\"You are a helpful assistant with access to tools.\"),\n",
    "        InitQuery(),\n",
    "        llm,\n",
    "        tools_loop,\n",
    "    ]\n",
    ")"
   ],
   "outputs": []
  },
  {
   "cell_type": "code",
   "execution_count": 51,
   "id": "3565c805-3dfb-4e8c-873f-f81a519cd0a8",
   "metadata": {},
   "source": [
    "tools_pipeline.query(\"Hi, how are you?\", tools)"
   ],
   "outputs": []
  },
  {
   "cell_type": "code",
   "execution_count": 54,
   "id": "af81bb20-3532-432e-bf3b-2c26e96a4ec5",
   "metadata": {},
   "source": [
    "with OutputLogger(None, None):\n",
    "    result = tools_pipeline.query(\n",
    "        \"How much is 264345223412312+124122? And what about 63522+1421312?\", tools\n",
    "    )"
   ],
   "outputs": []
  },
  {
   "cell_type": "code",
   "execution_count": 55,
   "id": "98d2d89d-790c-4fa8-87cb-9edfd8338068",
   "metadata": {},
   "source": [
    "result"
   ],
   "outputs": []
  },
  {
   "cell_type": "code",
   "execution_count": 56,
   "id": "8c40b599-66f7-4c55-92d5-ea3044778105",
   "metadata": {},
   "source": [
    "with OutputLogger(None, None):\n",
    "    tools_pipeline.query(\n",
    "        \"How much is 264345223412312+124122? Then, add 1352345231. Do not compute everything at the same time.\",\n",
    "        tools,\n",
    "    )"
   ],
   "outputs": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "17c4cd69-814f-484a-9b6a-b1c015437cb2",
   "metadata": {},
   "source": [],
   "outputs": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
