{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "rG3RVrQkpbV5"
      },
      "source": [
        "# Creating, registering and running a custom agent in GPTSwarm"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 3,
      "metadata": {
        "id": "4D0UL3Y8VbKg"
      },
      "outputs": [],
      "source": [
        "from google.colab import userdata\n",
        "import os\n",
        "os.environ['OPENAI_API_KEY'] = userdata.get('OPENAI_API_KEY')"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "WlohndxYzGy9"
      },
      "outputs": [],
      "source": [
        "!git clone https://github.com/metauto-ai/GPTSwarm.git"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "XFVR_f57zian"
      },
      "outputs": [],
      "source": [
        "%cd GPTSwarm"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 6,
      "metadata": {
        "id": "3orsPSxPrjue"
      },
      "outputs": [],
      "source": [
        "!rm requirements_py310_macos.txt\n",
        "!touch requirements_py310_macos.txt"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "CJB6-nBq1gVL"
      },
      "outputs": [],
      "source": [
        "!pip install -r requirements_colab.txt"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "19hZCZZdz187"
      },
      "outputs": [],
      "source": [
        "!pip install -e ."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "6jAJV4SAoooE"
      },
      "source": [
        "We create a custom operation that is a child of a base class Node. In this example the operation is a step of a chain-of-thought."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 9,
      "metadata": {
        "id": "2G7WrIG334YC"
      },
      "outputs": [],
      "source": [
        "from swarm.llm.format import Message\n",
        "from swarm.graph import Node\n",
        "from typing import List, Any, Optional\n",
        "from swarm.environment.prompt.prompt_set_registry import PromptSetRegistry\n",
        "from swarm.llm.format import Message\n",
        "from swarm.llm import LLMRegistry\n",
        "\n",
        "\n",
        "class CoTStep(Node):\n",
        "    def __init__(self,\n",
        "                 domain: str,\n",
        "                 model_name: Optional[str],\n",
        "                 is_last_step: bool,\n",
        "                 operation_description: str = \"Make one step of CoT\",\n",
        "                 id=None):\n",
        "        super().__init__(operation_description, id, True)\n",
        "        self.domain = domain\n",
        "        self.model_name = model_name\n",
        "        self.is_last_step = is_last_step\n",
        "        self.llm = LLMRegistry.get(model_name)\n",
        "        self.prompt_set = PromptSetRegistry.get(domain)\n",
        "        self.role = self.prompt_set.get_role()\n",
        "        self.constraint = self.prompt_set.get_constraint()\n",
        "\n",
        "    @property\n",
        "    def node_name(self):\n",
        "        return self.__class__.__name__\n",
        "\n",
        "    async def _execute(self, inputs: List[Any] = [], **kwargs):\n",
        "\n",
        "        node_inputs = self.process_input(inputs)\n",
        "        outputs = []\n",
        "        for input_dict in node_inputs:\n",
        "\n",
        "            role = self.prompt_set.get_role()\n",
        "            constraint = self.prompt_set.get_constraint()\n",
        "            if self.is_last_step:\n",
        "                system_prompt = (\n",
        "                    f\"You are {role}. {constraint}. \"\n",
        "                    \"Answer taking into consideration the provided sequence \"\n",
        "                    \"of thoughts on the question at hand.\")\n",
        "            else:\n",
        "                system_prompt = (\n",
        "                    f\"You are {role}. \"\n",
        "                    \"Given the question, solve it step by step. \"\n",
        "                    \"Answer your thoughts about the next step of the solution given \"\n",
        "                    \"everything that has been provided to you so far. \"\n",
        "                    \"Expand on the next step. \"\n",
        "                    \"Do not try to provide the answer straight away, instead expand \"\n",
        "                    \"on your thoughts about the next step of the solution.\"\n",
        "                    \"Aswer in maximum 30 words. \"\n",
        "                    \"Do not expect additional input. Make best use of whatever \"\n",
        "                    \"knowledge you have been already provided.\")\n",
        "            if 'output' in input_dict:\n",
        "                task = input_dict['output']\n",
        "            else:\n",
        "                task = input_dict[\"task\"]\n",
        "            user_prompt = self.prompt_set.get_answer_prompt(question=task)\n",
        "            message = [\n",
        "                Message(role=\"system\", content=system_prompt),\n",
        "                Message(role=\"user\", content=user_prompt)]\n",
        "            response = await self.llm.agen(message, max_tokens=50)\n",
        "            if self.is_last_step:\n",
        "                concatenated_response = response\n",
        "            else:\n",
        "                concatenated_response = f\"{task}. Here is the next thought. {response}. \"\n",
        "\n",
        "            execution = {\n",
        "                \"operation\": self.node_name,\n",
        "                \"task\": task,\n",
        "                \"files\": input_dict.get(\"files\", []),\n",
        "                \"input\": task,\n",
        "                \"role\": role,\n",
        "                \"constraint\": constraint,\n",
        "                \"prompt\": user_prompt,\n",
        "                \"output\": concatenated_response,\n",
        "                \"ground_truth\": input_dict.get(\"GT\", []),\n",
        "                \"format\": \"natural language\"\n",
        "            }\n",
        "            outputs.append(execution)\n",
        "            self.memory.add(self.id, execution)\n",
        "\n",
        "        return outputs"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "tZ0c1_xIo6LM"
      },
      "source": [
        "Then we create a custom Chain-of-Thought agent and register it as CustomCOT in the agent registry."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 10,
      "metadata": {
        "id": "bWIri65gQYLw"
      },
      "outputs": [],
      "source": [
        "from swarm.graph import Graph\n",
        "from swarm.environment.operations.cot_step import CoTStep\n",
        "from swarm.environment.agents.agent_registry import AgentRegistry\n",
        "\n",
        "\n",
        "@AgentRegistry.register('CustomCOT')\n",
        "class CustomCOT(Graph):\n",
        "\n",
        "    def build_graph(self):\n",
        "\n",
        "        num_thoughts = 3\n",
        "\n",
        "        assert num_thoughts >= 2\n",
        "\n",
        "        thoughts = []\n",
        "        for i_thought in range(num_thoughts):\n",
        "            thought = CoTStep(self.domain,\n",
        "                           self.model_name,\n",
        "                           is_last_step=i_thought==num_thoughts-1)\n",
        "            if i_thought > 0:\n",
        "                thoughts[-1].add_successor(thought)\n",
        "            thoughts.append(thought)\n",
        "\n",
        "        self.input_nodes = [thoughts[0]]\n",
        "        self.output_nodes = [thoughts[-1]]\n",
        "\n",
        "        for thought in thoughts:\n",
        "            self.add_node(thought)\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "CA3YOrUYnK-I"
      },
      "source": [
        "And finally let's create a Swarm with a couple of our custom agents:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 14,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "Fhgp-mPNWwi6",
        "outputId": "16618b47-afba-4229-c361-a6d78e3a1db8"
      },
      "outputs": [
        {
          "name": "stderr",
          "output_type": "stream",
          "text": [
            "\u001b[32m2024-02-18 14:25:03.364\u001b[0m | \u001b[1mINFO    \u001b[0m | \u001b[36mswarm.graph.node\u001b[0m:\u001b[36mlog\u001b[0m:\u001b[36m160\u001b[0m - \u001b[1mMemory Records for ID \u001b[1;35m6HRq\u001b[0m:\n",
            "    \u001b[1;34moperation\u001b[0m: FinalDecision\n",
            "    \u001b[1;34mfiles\u001b[0m: []\n",
            "    \u001b[1;34msubtask\u001b[0m: What is the text representation of the last digit of twelve squared?. Here is the next thought. Calculate twelve squared (12^2), then identify the last digit of the result and convert it to its text representation.. . Here is the next thought. Next step: Compute 12^2 = 144. The last digit is 4. Convert this to its text representation: \"four\".. \n",
            "\n",
            "Reference information for CoTStep:\n",
            "----------------------------------------------\n",
            "FINAL ANSWER: four\n",
            "FINAL ANSWER: four\n",
            "----------------------------------------------\n",
            "\n",
            "\n",
            "Provide a specific answer. For questions with known answers, ensure to provide accurate and factual responses. Avoid vague responses or statements like 'unable to...' that don't contribute to a definitive answer. For example: if a question asks 'who will be the president of America', and the answer is currently unknown, you could suggest possibilities like 'Donald Trump', or 'Biden'. However, if the answer is known, provide the correct information.\n",
            "    \u001b[1;34moutput\u001b[0m: FINAL ANSWER: four\u001b[0m\n"
          ]
        },
        {
          "data": {
            "text/plain": [
              "['FINAL ANSWER: four']"
            ]
          },
          "execution_count": 14,
          "metadata": {},
          "output_type": "execute_result"
        }
      ],
      "source": [
        "from swarm.graph.swarm import Swarm\n",
        "\n",
        "swarm = Swarm([\"CustomCOT\", \"CustomCOT\"], \"gaia\")\n",
        "task = \"What is the text representation of the last digit of twelve squared?\"\n",
        "inputs = {\"task\": task}\n",
        "answer = await swarm.arun(inputs)\n",
        "answer"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "-PXjgxE1lxYY"
      },
      "outputs": [],
      "source": []
    }
  ],
  "metadata": {
    "colab": {
      "provenance": []
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    },
    "language_info": {
      "name": "python"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
