{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "dbd5ffec90738522",
   "metadata": {
    "collapsed": false
   },
   "source": [
    "# LLM-Reasoners Demo\n",
    "\n",
    "This notebook is accompanied with our tutorial at SIGIR VF:\n",
    "[[slides](https://www.llm-reasoners.net/2024-02-Reasoners-SIGIR.pdf)]\n",
    "[[video](https://www.youtube.com/watch?v=d_x2pzEHGQY&pp=ygUJc2hpYm8gaGFv) (starting at 37:20)]\n",
    "\n",
    "## Setup\n",
    "Set cuda device and initialize an ExllamaModel use our unified LLM interface."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "97a9dc24f71ab121",
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "import os\n",
    "os.environ['CUDA_VISIBLE_DEVICES'] = '0,1'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "1baf72f047599ea3",
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/data/shibo/anaconda3/envs/reasoners-2404/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
      "  from .autonotebook import tqdm as notebook_tqdm\n"
     ]
    }
   ],
   "source": [
    "from reasoners.lm import ExLlamaModel\n",
    "import torch\n",
    "\n",
    "# https://huggingface.co/TheBloke/Llama-2-7B-Chat-GPTQ\n",
    "\n",
    "model = ExLlamaModel(model_dir='/data/haotian/RAP_tune/Llama-2-70B-GPTQ',\n",
    "                     lora_dir=None,\n",
    "                     device = torch.device(\"cuda:0\"),\n",
    "                     max_batch_size=1,\n",
    "                     max_new_tokens=200,\n",
    "                     mem_map=[16,22], # For 2 * 24GB GPUs. If you have > 40GB you can set it to None\n",
    "                     max_seq_length=2048)\n",
    "\n",
    "# Or use any other model providers:\n",
    "\n",
    "# HFModel(llama_path, llama_path, device=device, max_batch_size=1, max_new_tokens=512, quantized=quantized, peft_pth=peft_path, load_awq_pth=load_awq_pth)\n",
    "# Llama3Model(llama2_ckpts, llama_size, max_batch_size=1)\n",
    "# OpenAIModel(openai_mode)\n",
    "# ClaudeModel('claude-3-opus-20240229')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d793476fcd72d193",
   "metadata": {
    "collapsed": false
   },
   "source": [
    "We gather one example from the Blocksworld dataset, and the proper prompt for in-context learning examples.\n",
    "We will talk more about Evaluators later."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "48ab7cb1a4514699",
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "from reasoners.benchmark import BWEvaluator\n",
    "import json\n",
    "\n",
    "with open('examples/CoT/blocksworld/prompts/pool_prompt_v1.json') as f:\n",
    "    prompt = json.load(f)\n",
    "evaluator = BWEvaluator(config_file='examples/CoT/blocksworld/data/bw_config.yaml',\n",
    "                        domain_file='examples/CoT/blocksworld/data/generated_domain.pddl',\n",
    "                        data_path='examples/CoT/blocksworld/data/split_v1/split_v1_step_4_data.json',\n",
    "                        init_prompt=prompt)\n",
    "prompt = evaluator.sample_prompt(shuffle_prompt=False, num_shot=4)\n",
    "example = evaluator.full_dataset[1]\n",
    "cot_inputs = (prompt['icl'].replace('<init_state>', example[\"init\"])\n",
    "                           .replace('<goals>', example[\"goal\"])\n",
    "                           .replace('<action>', ''))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "cc49cab381592729",
   "metadata": {
    "collapsed": false
   },
   "source": [
    "Here is the example."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "ab7d17be8373ae3e",
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "the blue block is clear, the orange block is clear, the hand is empty, the orange block is on top of the red block, the red block is on the table and the blue block is on the table\n"
     ]
    }
   ],
   "source": [
    "print(example['init'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "7d42ef78fea3bcfc",
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "the red block is on top of the blue block\n"
     ]
    }
   ],
   "source": [
    "print(example['goal'])"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e7540875d5de58b5",
   "metadata": {
    "collapsed": false
   },
   "source": [
    "## Chain-of-Thought\n",
    "We first experiment with the Chain-of-Thought method.\n",
    "Since we are having the simplest generation algorithm, we directly ask the model to generate all the steps.\n",
    "We look at the 4-shot prompt and the generated answer."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "6a467a187f55cf03",
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "I am playing with a set of blocks where I need to arrange the blocks into stacks. Here are the actions I can do\n",
      "\n",
      "Pick up a block\n",
      "Unstack a block from on top of another block\n",
      "Put down a block\n",
      "Stack a block on top of another block\n",
      "\n",
      "I have the following restrictions on my actions:\n",
      "I can only pick up or unstack one block at a time.\n",
      "I can only pick up or unstack a block if my hand is empty.\n",
      "I can only pick up a block if the block is on the table and the block is clear. A block is clear if the block has no other blocks on top of it and if the block is not picked up.\n",
      "I can only unstack a block from on top of another block if the block I am unstacking was really on top of the other block.\n",
      "I can only unstack a block from on top of another block if the block I am unstacking is clear.\n",
      "Once I pick up or unstack a block, I am holding the block.\n",
      "I can only put down a block that I am holding.\n",
      "I can only stack a block on top of another block if I am holding the block being stacked.\n",
      "I can only stack a block on top of another block if the block onto which I am stacking the block is clear.\n",
      "Once I put down or stack a block, my hand becomes empty.\n",
      "\n",
      "[STATEMENT]\n",
      "As initial conditions I have that, the red block is clear, the orange block is clear, the hand is empty, the orange block is on top of the blue block, the red block is on the table and the blue block is on the table.\n",
      "My goal is to have that the blue block is on top of the orange block.\n",
      "\n",
      "My plan is as follows:\n",
      "\n",
      "[PLAN]\n",
      "unstack the orange block from on top of the blue block\n",
      "put down the orange block\n",
      "pick up the blue block\n",
      "stack the blue block on top of the orange block\n",
      "[PLAN END]\n",
      "\n",
      "[STATEMENT]\n",
      "As initial conditions I have that, the blue block is clear, the orange block is clear, the hand is empty, the red block is on top of the yellow block, the orange block is on top of the red block, the blue block is on the table and the yellow block is on the table.\n",
      "My goal is to have that the blue block is on top of the yellow block and the orange block is on top of the blue block.\n",
      "\n",
      "My plan is as follows:\n",
      "\n",
      "[PLAN]\n",
      "unstack the orange block from on top of the red block\n",
      "put down the orange block\n",
      "unstack the red block from on top of the yellow block\n",
      "put down the red block\n",
      "pick up the blue block\n",
      "stack the blue block on top of the yellow block\n",
      "pick up the orange block\n",
      "stack the orange block on top of the blue block\n",
      "[PLAN END]\n",
      "\n",
      "[STATEMENT]\n",
      "As initial conditions I have that, the red block is clear, the yellow block is clear, the hand is empty, the red block is on top of the blue block, the blue block is on top of the orange block, the orange block is on the table and the yellow block is on the table.\n",
      "My goal is to have that the blue block is on top of the orange block and the yellow block is on top of the red block.\n",
      "\n",
      "My plan is as follows:\n",
      "\n",
      "[PLAN]\n",
      "pick up the yellow block\n",
      "stack the yellow block on top of the red block\n",
      "[PLAN END]\n",
      "\n",
      "[STATEMENT]\n",
      "As initial conditions I have that, the blue block is clear, the yellow block is clear, the hand is empty, the red block is on top of the orange block, the blue block is on top of the red block, the orange block is on the table and the yellow block is on the table.\n",
      "My goal is to have that the blue block is on top of the red block and the yellow block is on top of the blue block.\n",
      "\n",
      "My plan is as follows:\n",
      "\n",
      "[PLAN]\n",
      "pick up the yellow block\n",
      "stack the yellow block on top of the blue block\n",
      "[PLAN END]\n",
      "\n",
      "[STATEMENT]\n",
      "As initial conditions I have that, the blue block is clear, the orange block is clear, the hand is empty, the orange block is on top of the red block, the red block is on the table and the blue block is on the table\n",
      "My goal is to the red block is on top of the blue block\n",
      "\n",
      "My plan is as follows:\n",
      "\n",
      "[PLAN]\n",
      "\n"
     ]
    }
   ],
   "source": [
    "print(cot_inputs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "933ffa650264c50b",
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/data/shibo/llm-reasoners-test-240422/reasoners/lm/exllama_model.py:119: UserWarning: max_new_tokens is not set, we will use the default value: 200\n",
      "  warnings.warn(f\"max_new_tokens is not set, we will use the default value: {self.max_new_tokens}\")\n",
      "/data/shibo/llm-reasoners-test-240422/reasoners/lm/exllama_model.py:122: UserWarning: do_sample is False while the temperature is non-positive. We will use greedy decoding for Exllama\n",
      "  warnings.warn(\n",
      "/data/shibo/llm-reasoners-test-240422/reasoners/lm/exllama_model.py:144: UserWarning: the eos_token '\\n[' is encoded into tensor([29871,    13, 29961]) with length != 1, using 29961 as the eos_token_id\n",
      "  warnings.warn(f'the eos_token {repr(token)} is encoded into {tokenized} with length != 1, '\n"
     ]
    }
   ],
   "source": [
    "output = model.generate([cot_inputs],\n",
    "                        hide_input=True,\n",
    "                        eos_token_id='\\n[').text[0][:-1].strip()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "acde323347b1eb9",
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "pick up the red block\n",
      "stack the red block on top of the blue block\n"
     ]
    }
   ],
   "source": [
    "print(output)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "474d7795",
   "metadata": {},
   "source": [
    "Clearly that's not a valid solution :( \n",
    "The orange block is on the red block, so we cannot pick up the red block as the first step."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1e258cb3",
   "metadata": {},
   "source": [
    "## Tree-of-Thought\n",
    "Then let's turn to a tree search algorithm, [Tree-of-Thought]((https://arxiv.org/abs/2305.10601)).\n",
    "We will need to define a simple world model, and a search algorithm, for the Blocksworld task."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "ffaa93bb6ee24586",
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "from reasoners import WorldModel, LanguageModel, SearchConfig, State, Reasoner\n",
    "from reasoners.algorithm import BeamSearch, MCTS\n",
    "import reasoners.benchmark.bw_utils as utils\n",
    "from typing import NamedTuple\n",
    "import copy\n",
    "import numpy as np\n",
    "\n",
    "\n",
    "# We use NamedTuple for clearer presentation, you may just use normal tuple if you want a quick experiment.\n",
    "class BWStateToT(NamedTuple):\n",
    "    step_idx: int\n",
    "    action_history: list[str]\n",
    "    end: bool\n",
    "\n",
    "\n",
    "# We just use the description str as the action, we use a type alias for better presentation.\n",
    "# You may directly use str of you want a quick experiment.\n",
    "BWAction = str\n",
    "\n",
    "\n",
    "class BlocksWorldModelToT(WorldModel):\n",
    "    def __init__(self,\n",
    "                 base_model: LanguageModel,\n",
    "                 prompt: dict,\n",
    "                 max_steps: int = 4,\n",
    "                 batch_size: int = 1) -> None:\n",
    "        super().__init__()\n",
    "        self.max_steps = max_steps\n",
    "        self.base_model = base_model\n",
    "        self.prompt = prompt\n",
    "        self.batch_size = batch_size\n",
    "\n",
    "    def init_state(self) -> BWStateToT:\n",
    "        return BWStateToT(step_idx=0, action_history=[], end=False)\n",
    "    \n",
    "    def step(self, state: BWStateToT, action: BWAction) -> tuple[BWStateToT, dict]:\n",
    "        state = copy.deepcopy(state)\n",
    "        if action != \"[PLAN END]\":\n",
    "            state = BWStateToT(step_idx=state.step_idx + 1, action_history=state.action_history + [action], end=False)\n",
    "        else:\n",
    "            state = BWStateToT(step_idx=state.step_idx + 1, action_history=state.action_history, end=True)\n",
    "        return state, {}  # the dict is auxiliary information for SearchConfig, we don't need it here.\n",
    "    \n",
    "    def is_terminal(self, state: State) -> bool:\n",
    "        return state.end or state.step_idx >= self.max_steps\n",
    "\n",
    "\n",
    "class BWConfigToT(SearchConfig):\n",
    "    def __init__(self,\n",
    "                 base_model: LanguageModel,\n",
    "                 prompt: dict,\n",
    "                 temperature: float = 0.8,\n",
    "                 n_candidate: int = 4) -> None:\n",
    "        super().__init__()\n",
    "        self.base_model = base_model\n",
    "        self.example = None\n",
    "        self.prompt = prompt\n",
    "        self.n_candidate = n_candidate\n",
    "        self.temperature = temperature\n",
    "\n",
    "    def get_actions(self, state: BWStateToT) -> list[BWAction]:\n",
    "        prompts = (self.prompt[\"icl\"]\n",
    "                       .replace(\"<action>\", \"\\n\".join(state.action_history + [\"\"]))\n",
    "                       .replace(\"<init_state>\", utils.extract_init_state(self.example))\n",
    "                       .replace(\"<goals>\", utils.extract_goals(self.example, return_raw=True)))\n",
    "        outputs = self.base_model.generate([prompts],\n",
    "                                           num_return_sequences=self.n_candidate,\n",
    "                                           max_length=20,\n",
    "                                           eos_token_id=\"\\n\",\n",
    "                                           temperature=self.temperature,\n",
    "                                           do_sample=True,\n",
    "                                           hide_input=True).text\n",
    "        outputs = [output.split(\"\\n\")[0] for output in outputs]\n",
    "        outputs = list(dict.fromkeys(outputs))  # deduplicate\n",
    "        return outputs\n",
    "\n",
    "    # Some reward functions are fast to calculate.\n",
    "    # We calculate the reward before executing the action, which can be used to better guide the search.\n",
    "    def fast_reward(self, state: BWStateToT, action: BWAction) -> tuple[float, dict]:\n",
    "        # We use two rewards here:\n",
    "        # 1. Intuition: The loglikelihood of the action given the prompt.\n",
    "        # 2. Self-eval: Ask the language model whether this step is \"Good\".\n",
    "        inputs = self.prompt[\"icl\"].replace(\"<action>\", \"\\n\".join(state.action_history + [\"\"])) \\\n",
    "            .replace(\"<init_state>\", utils.extract_init_state(self.example)) \\\n",
    "            .replace(\"<goals>\", utils.extract_goals(self.example, return_raw=True))[:-1]\n",
    "        \n",
    "        intuition = self.base_model.get_loglikelihood(inputs, [inputs + \"\\n\" + action])[0]\n",
    "\n",
    "        self_eval_prompt = (self.prompt[\"self-eval\"].replace(\"<init_state>\", utils.extract_init_state(self.example))\n",
    "                                                    .replace(\"<goals>\", utils.extract_goals(self.example, return_raw=True))\n",
    "                                                    .replace(\"<action>\", action))\n",
    "        self_eval = self.base_model.get_loglikelihood(self_eval_prompt, [self_eval_prompt + \"good\"])[0]\n",
    "\n",
    "        return intuition + self_eval, {'intuition': intuition, \"self_eval\": self_eval}\n",
    "    \n",
    "    # kwargs is the auxiliary information returned by SearchConfig.fast_reward and WorldModel.step,\n",
    "    # so that we do not need duplicated calculations.\n",
    "    # In this case, we just use the fast_reward result as the reward.\n",
    "    # Generally, if a reward function depends on the new state, or is slow to calculate,\n",
    "    # we will calculate it here.\n",
    "    def reward(self, state, action, **kwargs) -> tuple[float, dict]:\n",
    "        return kwargs['intuition'] + kwargs['self_eval'], kwargs"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f623a38d",
   "metadata": {},
   "source": [
    "Note: The following command may take to 2 minutes to run"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "9b3b2bec8947b3e",
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/data/shibo/llm-reasoners-test-240422/reasoners/lm/exllama_model.py:117: UserWarning: max_length is not supported by ExLlamaModel for generation. Use max_new_tokens instead.\n",
      "  warnings.warn(\"max_length is not supported by ExLlamaModel for generation. Use max_new_tokens instead.\")\n",
      "/data/shibo/llm-reasoners-test-240422/reasoners/lm/exllama_model.py:144: UserWarning: the eos_token '\\n' is encoded into tensor([29871,    13]) with length != 1, using 13 as the eos_token_id\n",
      "  warnings.warn(f'the eos_token {repr(token)} is encoded into {tokenized} with length != 1, '\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "BeamSearchResult(terminal_state=BWStateToT(step_idx=3, action_history=['pick up the red block', 'stack the red block on top of the blue block'], end=True), cum_reward=-0.7765676, tree=<reasoners.algorithm.beam_search.BeamSearchNode object at 0x7faa44583220>, trace=[(None, BWStateToT(step_idx=0, action_history=[], end=False), 0.0), ('pick up the red block', BWStateToT(step_idx=1, action_history=['pick up the red block'], end=False), -0.53424084), ('stack the red block on top of the blue block', BWStateToT(step_idx=2, action_history=['pick up the red block', 'stack the red block on top of the blue block'], end=False), -1.0123866), ('[PLAN END]', BWStateToT(step_idx=3, action_history=['pick up the red block', 'stack the red block on top of the blue block'], end=True), -0.7765676)])\n"
     ]
    }
   ],
   "source": [
    "world_model = BlocksWorldModelToT(base_model=model, prompt=prompt)\n",
    "config = BWConfigToT(base_model=model, prompt=prompt)\n",
    "algorithm = BeamSearch(beam_size=4, max_depth=7)\n",
    "reasoner_tot = Reasoner(world_model=world_model, search_config=config, search_algo=algorithm)\n",
    "result_tot = reasoner_tot(example)\n",
    "print(result_tot)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "ab2f2daa59d50d38",
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Action, Reward\n",
      "None 0.0\n",
      "pick up the red block -0.4957015\n",
      "stack the red block on top of the blue block -1.0114484\n",
      "[PLAN END] -0.78032136\n"
     ]
    }
   ],
   "source": [
    "print('Action, Reward')\n",
    "for action, _, reward in result_tot.trace:\n",
    "    print(action, reward)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9ccf2a76",
   "metadata": {},
   "source": [
    "Still the same error :("
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2093768cbd94dbee",
   "metadata": {
    "collapsed": false
   },
   "source": [
    "## RAP\n",
    "With [RAP](https://arxiv.org/abs/2305.14992), we are truly using the latest block configuration as the state, instead of a history of actions.\n",
    "Thus, we define a new world model to transit between states, which is just a little complex than the previous one."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "4db36c24eab92e95",
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "BWAction = str\n",
    "\n",
    "\n",
    "class BWStateRAP(NamedTuple):\n",
    "    step_idx: int\n",
    "    last_blocks_state: str\n",
    "    blocks_state: str\n",
    "    buffered_action: BWAction\n",
    "\n",
    "\n",
    "class BlocksWorldModelRAP(WorldModel):\n",
    "    def __init__(self,\n",
    "                 base_model: LanguageModel,\n",
    "                 prompt: dict,\n",
    "                 max_steps: int = 4,\n",
    "                 batch_size: int = 1) -> None:\n",
    "        super().__init__()\n",
    "        self.max_steps = max_steps\n",
    "        self.base_model = base_model\n",
    "        self.prompt = prompt\n",
    "        self.batch_size = batch_size\n",
    "\n",
    "    def init_state(self) -> BWStateRAP:\n",
    "        return BWStateRAP(step_idx=0, last_blocks_state=\"\", blocks_state=utils.\n",
    "                       extract_init_state(self.example), buffered_action=\"\")\n",
    "\n",
    "    def step(self, state: BWStateRAP, action: BWAction) -> tuple[BWStateRAP, dict]:\n",
    "        state = copy.deepcopy(state)\n",
    "        blocks_state = state.blocks_state\n",
    "        step_idx = state.step_idx\n",
    "        blocks_state = self.update_blocks(blocks_state, action)\n",
    "        new_buffered_action = action if state.buffered_action == \"\" else \"\"\n",
    "\n",
    "        state = BWStateRAP(step_idx=step_idx + 1,\n",
    "                        last_blocks_state=state.blocks_state,\n",
    "                        blocks_state=blocks_state,\n",
    "                        buffered_action=new_buffered_action)\n",
    "        return state, {\"goal_reached\": utils.goal_check(utils.extract_goals(self.example), blocks_state)}\n",
    "\n",
    "    def update_blocks(self, block_states: str, action: BWAction) -> str:\n",
    "        if \"pick\" in action:\n",
    "            key = \"world_update_pickup\"\n",
    "        elif \"unstack\" in action:\n",
    "            key = \"world_update_unstack\"\n",
    "        elif \"put\" in action:\n",
    "            key = \"world_update_putdown\"\n",
    "        elif \"stack\" in action:\n",
    "            key = \"world_update_stack\"\n",
    "        else:\n",
    "            raise ValueError(\"Invalid action\")\n",
    "        world_update_prompt = self.prompt[key].format(block_states, action.capitalize() + \".\")\n",
    "        world_output = self.base_model.generate([world_update_prompt],\n",
    "                                                eos_token_id=\"\\n\",\n",
    "                                                hide_input=True,\n",
    "                                                temperature=0).text[0].strip()\n",
    "        new_state = utils.apply_change(world_output, block_states)\n",
    "        return new_state\n",
    "\n",
    "    def is_terminal(self, state: BWStateRAP) -> bool:\n",
    "        if utils.goal_check(utils.extract_goals(self.example), state.blocks_state)[0]:\n",
    "            return True\n",
    "        elif state.step_idx == self.max_steps:\n",
    "            return True\n",
    "        return False"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "884e9c962952d37b",
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "class BWConfigRAP(SearchConfig):\n",
    "    def __init__(self,\n",
    "                 base_model: LanguageModel,\n",
    "                 prompt: dict,\n",
    "                 batch_size: int = 1,\n",
    "                 reward_alpha: float = 0.5,\n",
    "                 goal_reward_default: float = 0.,\n",
    "                 goal_reached_reward: float = 100.) -> None:\n",
    "        super().__init__()\n",
    "        self.base_model = base_model\n",
    "        self.example = None\n",
    "        self.prompt = prompt\n",
    "        self.batch_size = batch_size\n",
    "        self.reward_alpha = reward_alpha\n",
    "        self.goal_reward_default = goal_reward_default\n",
    "        self.goal_reached_reward = goal_reached_reward\n",
    "\n",
    "    def get_actions(self, state: BWStateRAP) -> list[BWAction]:\n",
    "        blocks_state = state.blocks_state\n",
    "        return utils.generate_all_actions(blocks_state)\n",
    "\n",
    "    def fast_reward(self, state: BWStateRAP, action: BWAction) -> tuple[float, dict]:\n",
    "        if state.buffered_action == \"\":\n",
    "            current_blocks_state = state.blocks_state\n",
    "        else:\n",
    "            current_blocks_state = state.last_blocks_state\n",
    "        previous_action = state.buffered_action + \"\\n\" if state.buffered_action != \"\" else \"\"\n",
    "        \n",
    "        # every two steps, we will also reduce the icl examples by 2 steps\n",
    "        # so that the distribution of step length in examples is more reasonable\n",
    "        icl_template = self.prompt[\"icl_list\"][state.step_idx // 2]\n",
    "        \n",
    "        inputs = (icl_template.replace(\"<init_state>\", current_blocks_state)\n",
    "                              .replace(\"<goals>\", utils.extract_goals(self.example, return_raw=True))\n",
    "                              .replace(\"<action>\", previous_action))\n",
    "        intuition = self.base_model.get_loglikelihood(inputs, [inputs + action])[0]\n",
    "\n",
    "        self_eval_prompt = (self.prompt[\"self-eval\"]\n",
    "                                .replace(\"<init_state>\", current_blocks_state)\n",
    "                                .replace(\"<goals>\", utils.extract_goals(self.example, return_raw=True))\n",
    "                                .replace(\"<action>\", action))\n",
    "        self_eval = self.base_model.get_loglikelihood(self_eval_prompt, [self_eval_prompt + \"good\"])[0]\n",
    "\n",
    "        return (self.calculate_reward(intuition, self_eval),\n",
    "                {'intuition': intuition, \"self_eval\": self_eval})\n",
    "\n",
    "    def calculate_reward(self, intuition, self_eval, goal_reached=None) -> float:\n",
    "        # to provide a unified interface for reward and fast_reward\n",
    "        if goal_reached is None:\n",
    "            goal_reward = self.goal_reward_default\n",
    "        elif goal_reached[0]:\n",
    "            goal_reward = self.goal_reached_reward\n",
    "        else:\n",
    "            goal_reward = goal_reached[1]\n",
    "        return (intuition + self_eval) * self.reward_alpha + goal_reward * (1 - self.reward_alpha)\n",
    "\n",
    "    def reward(self, state: BWStateRAP, action: BWAction,\n",
    "               intuition: float = None,\n",
    "               self_eval: float = None,\n",
    "               goal_reached: tuple[bool, float] = None) -> tuple[float, dict]:\n",
    "        return (self.calculate_reward(intuition, self_eval, goal_reached),\n",
    "                {'intuition': intuition, 'goal_reached': goal_reached})"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "28a97d5bdf453a8e",
   "metadata": {
    "collapsed": false
   },
   "source": [
    "We just use the MCTS algorithm embedded in Reasoners, and build up the pipeline again.\n",
    "Note: the following command may take 2 minutes to run"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "70e0d64c166c5ccc",
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "                                                              "
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "MCTSResult(terminal_state=BWStateRAP(step_idx=4, last_blocks_state='the blue block is clear, the orange block is clear, the red block is in the hand, the hand is holding the red block, the blue block is on the table, and the orange block is on the table.', blocks_state='the orange block is clear, the red block is clear, the hand is empty, the red block is on top of the blue block, the blue block is on the table, and the orange block is on the table.', buffered_action=''), cum_reward=47.163818672299385, trace=([BWStateRAP(step_idx=0, last_blocks_state='', blocks_state='the blue block is clear, the orange block is clear, the hand is empty, the orange block is on top of the red block, the red block is on the table and the blue block is on the table.', buffered_action=''), BWStateRAP(step_idx=1, last_blocks_state='the blue block is clear, the orange block is clear, the hand is empty, the orange block is on top of the red block, the red block is on the table and the blue block is on the table.', blocks_state='the blue block is clear, the orange block is in the hand, the red block is clear, the hand is holding the orange block, the blue block is on the table, and the red block is on the table.', buffered_action='unstack the orange block from on top of the red block'), BWStateRAP(step_idx=2, last_blocks_state='the blue block is clear, the orange block is in the hand, the red block is clear, the hand is holding the orange block, the blue block is on the table, and the red block is on the table.', blocks_state='the blue block is clear, the orange block is clear, the red block is clear, the hand is empty, the blue block is on the table, the orange block is on the table, and the red block is on the table.', buffered_action=''), BWStateRAP(step_idx=3, last_blocks_state='the blue block is clear, the orange block is clear, the red block is clear, the hand is empty, the blue block is on the table, the orange block is on the table, and the red block is on the table.', blocks_state='the blue block is clear, the orange block is clear, the red block is in the hand, the hand is holding the red block, the blue block is on the table, and the orange block is on the table.', buffered_action='pick up the red block'), BWStateRAP(step_idx=4, last_blocks_state='the blue block is clear, the orange block is clear, the red block is in the hand, the hand is holding the red block, the blue block is on the table, and the orange block is on the table.', blocks_state='the orange block is clear, the red block is clear, the hand is empty, the red block is on top of the blue block, the blue block is on the table, and the orange block is on the table.', buffered_action='')], ['unstack the orange block from on top of the red block', 'put down the orange block', 'pick up the red block', 'stack the red block on top of the blue block']), trace_of_nodes=[<reasoners.algorithm.mcts.MCTSNode object at 0x7fa81db26140>, <reasoners.algorithm.mcts.MCTSNode object at 0x7fa81d87d4b0>, <reasoners.algorithm.mcts.MCTSNode object at 0x7fa81d87c970>, <reasoners.algorithm.mcts.MCTSNode object at 0x7fa81d87d6c0>, <reasoners.algorithm.mcts.MCTSNode object at 0x7fa81d87f400>], tree_state=<reasoners.algorithm.mcts.MCTSNode object at 0x7fa81db26140>, trace_in_each_iter=[[<reasoners.algorithm.mcts.MCTSNode object at 0x7fa81d87d9f0>, <reasoners.algorithm.mcts.MCTSNode object at 0x7fa81d87eef0>, <reasoners.algorithm.mcts.MCTSNode object at 0x7fa81d87d180>, <reasoners.algorithm.mcts.MCTSNode object at 0x7fa81d87fca0>, <reasoners.algorithm.mcts.MCTSNode object at 0x7fa81d87e410>], [<reasoners.algorithm.mcts.MCTSNode object at 0x7fa81d87d8d0>, <reasoners.algorithm.mcts.MCTSNode object at 0x7fa81d87cc40>, <reasoners.algorithm.mcts.MCTSNode object at 0x7fa81d87ffd0>, <reasoners.algorithm.mcts.MCTSNode object at 0x7fa81d87cd90>, <reasoners.algorithm.mcts.MCTSNode object at 0x7fa81d87dc60>], [<reasoners.algorithm.mcts.MCTSNode object at 0x7fa81d87f250>, <reasoners.algorithm.mcts.MCTSNode object at 0x7fa81d87c190>, <reasoners.algorithm.mcts.MCTSNode object at 0x7fa81da70730>, <reasoners.algorithm.mcts.MCTSNode object at 0x7fa81da70760>, <reasoners.algorithm.mcts.MCTSNode object at 0x7fa81da73f10>], [<reasoners.algorithm.mcts.MCTSNode object at 0x7fa81d87d1b0>, <reasoners.algorithm.mcts.MCTSNode object at 0x7fa81da73ac0>, <reasoners.algorithm.mcts.MCTSNode object at 0x7fa81da73b20>, <reasoners.algorithm.mcts.MCTSNode object at 0x7fa81db4fd30>, <reasoners.algorithm.mcts.MCTSNode object at 0x7fa81db4fbb0>], [<reasoners.algorithm.mcts.MCTSNode object at 0x7fa81df13550>, <reasoners.algorithm.mcts.MCTSNode object at 0x7fa81db4c6d0>, <reasoners.algorithm.mcts.MCTSNode object at 0x7fa81db4d9f0>, <reasoners.algorithm.mcts.MCTSNode object at 0x7fa81db4c730>, <reasoners.algorithm.mcts.MCTSNode object at 0x7fa81db4c610>], [<reasoners.algorithm.mcts.MCTSNode object at 0x7fa81da73ca0>, <reasoners.algorithm.mcts.MCTSNode object at 0x7fa81db4f850>, <reasoners.algorithm.mcts.MCTSNode object at 0x7fa81db4f460>, <reasoners.algorithm.mcts.MCTSNode object at 0x7fa81db4eb30>, <reasoners.algorithm.mcts.MCTSNode object at 0x7fa81db4eef0>], [<reasoners.algorithm.mcts.MCTSNode object at 0x7fa81db4ccd0>, <reasoners.algorithm.mcts.MCTSNode object at 0x7fa81db4e230>, <reasoners.algorithm.mcts.MCTSNode object at 0x7fa81db4ddb0>, <reasoners.algorithm.mcts.MCTSNode object at 0x7fa81db4c3d0>, <reasoners.algorithm.mcts.MCTSNode object at 0x7fa81db4fac0>], [<reasoners.algorithm.mcts.MCTSNode object at 0x7fa81db4c6a0>, <reasoners.algorithm.mcts.MCTSNode object at 0x7fa81db4e800>, <reasoners.algorithm.mcts.MCTSNode object at 0x7fa81db52f50>, <reasoners.algorithm.mcts.MCTSNode object at 0x7fa81db53b20>, <reasoners.algorithm.mcts.MCTSNode object at 0x7fa81db52d70>], [<reasoners.algorithm.mcts.MCTSNode object at 0x7fa81db4d600>, <reasoners.algorithm.mcts.MCTSNode object at 0x7fa81db53880>, <reasoners.algorithm.mcts.MCTSNode object at 0x7fa81db502b0>, <reasoners.algorithm.mcts.MCTSNode object at 0x7fa81db51a20>, <reasoners.algorithm.mcts.MCTSNode object at 0x7fa81db51810>], [<reasoners.algorithm.mcts.MCTSNode object at 0x7fa81db4d000>, <reasoners.algorithm.mcts.MCTSNode object at 0x7fa81db50850>, <reasoners.algorithm.mcts.MCTSNode object at 0x7fa81db52080>, <reasoners.algorithm.mcts.MCTSNode object at 0x7fa81db519f0>, <reasoners.algorithm.mcts.MCTSNode object at 0x7fa81db503a0>]], tree_state_after_each_iter=[<reasoners.algorithm.mcts.MCTSNode object at 0x7fa81d87d9f0>, <reasoners.algorithm.mcts.MCTSNode object at 0x7fa81d87d8d0>, <reasoners.algorithm.mcts.MCTSNode object at 0x7fa81d87f250>, <reasoners.algorithm.mcts.MCTSNode object at 0x7fa81d87d1b0>, <reasoners.algorithm.mcts.MCTSNode object at 0x7fa81df13550>, <reasoners.algorithm.mcts.MCTSNode object at 0x7fa81da73ca0>, <reasoners.algorithm.mcts.MCTSNode object at 0x7fa81db4ccd0>, <reasoners.algorithm.mcts.MCTSNode object at 0x7fa81db4c6a0>, <reasoners.algorithm.mcts.MCTSNode object at 0x7fa81db4d600>, <reasoners.algorithm.mcts.MCTSNode object at 0x7fa81db4d000>], aggregated_result=None)\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\r"
     ]
    }
   ],
   "source": [
    "world_model = BlocksWorldModelRAP(base_model=model, prompt=prompt, max_steps=4)\n",
    "config = BWConfigRAP(base_model=model, prompt=prompt)\n",
    "algorithm = MCTS(depth_limit=4, disable_tqdm=False, output_trace_in_each_iter=True, n_iters=10)\n",
    "reasoner_rap = Reasoner(world_model=world_model, search_config=config, search_algo=algorithm)\n",
    "result_rap = reasoner_rap(example)\n",
    "print(result_rap)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "3f540139",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "([BWStateRAP(step_idx=0, last_blocks_state='', blocks_state='the blue block is clear, the orange block is clear, the hand is empty, the orange block is on top of the red block, the red block is on the table and the blue block is on the table.', buffered_action=''),\n",
       "  BWStateRAP(step_idx=1, last_blocks_state='the blue block is clear, the orange block is clear, the hand is empty, the orange block is on top of the red block, the red block is on the table and the blue block is on the table.', blocks_state='the blue block is clear, the orange block is in the hand, the red block is clear, the hand is holding the orange block, the blue block is on the table, and the red block is on the table.', buffered_action='unstack the orange block from on top of the red block'),\n",
       "  BWStateRAP(step_idx=2, last_blocks_state='the blue block is clear, the orange block is in the hand, the red block is clear, the hand is holding the orange block, the blue block is on the table, and the red block is on the table.', blocks_state='the blue block is clear, the orange block is clear, the red block is clear, the hand is empty, the blue block is on the table, the orange block is on the table, and the red block is on the table.', buffered_action=''),\n",
       "  BWStateRAP(step_idx=3, last_blocks_state='the blue block is clear, the orange block is clear, the red block is clear, the hand is empty, the blue block is on the table, the orange block is on the table, and the red block is on the table.', blocks_state='the blue block is clear, the orange block is clear, the red block is in the hand, the hand is holding the red block, the blue block is on the table, and the orange block is on the table.', buffered_action='pick up the red block'),\n",
       "  BWStateRAP(step_idx=4, last_blocks_state='the blue block is clear, the orange block is clear, the red block is in the hand, the hand is holding the red block, the blue block is on the table, and the orange block is on the table.', blocks_state='the orange block is clear, the red block is clear, the hand is empty, the red block is on top of the blue block, the blue block is on the table, and the orange block is on the table.', buffered_action='')],\n",
       " ['unstack the orange block from on top of the red block',\n",
       "  'put down the orange block',\n",
       "  'pick up the red block',\n",
       "  'stack the red block on top of the blue block'])"
      ]
     },
     "execution_count": 15,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "result_rap.trace"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "136f52fa",
   "metadata": {},
   "source": [
    "Finally, we get a valid solution!"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "97e6c930da69ea10",
   "metadata": {
    "collapsed": false
   },
   "source": [
    "## Visualization"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2a685a07",
   "metadata": {},
   "source": [
    "Visualization is as simple as calling `visualize(log)`"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "eb852e28f78e630c",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-03-11T05:19:20.380716Z",
     "start_time": "2024-03-11T05:19:19.723124Z"
    },
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Visualizer URL: https://www.llm-reasoners.net/visualizer/84cdfdfc-6299-43f2-ad2e-beb5d9da0730?accessKey=007acc66\n"
     ]
    }
   ],
   "source": [
    "from reasoners.visualization import visualize\n",
    "from reasoners.visualization.tree_snapshot import NodeData, EdgeData\n",
    "from reasoners.algorithm.mcts import MCTSNode\n",
    "\n",
    "\n",
    "# (Optional) You can write node_data_factory and edge_data_factory to show customized information.\n",
    "def blocksworld_node_data_factory(n: MCTSNode) -> NodeData:\n",
    "    return NodeData({\"block state\": n.state.blocks_state if n.state else \"Not expanded\",\n",
    "                     \"# goals satisfied\": n.reward_details[\"goal_reached\"][1] if hasattr(n, \"reward_details\") else \"N/A\",\n",
    "                     \"# visited\": len(n.cum_rewards)})\n",
    "\n",
    "def blocksworld_edge_data_factory(n: MCTSNode) -> EdgeData:\n",
    "    return EdgeData({\"Q\": n.Q,\n",
    "                     \"intuition\": n.fast_reward_details[\"intuition\"],\n",
    "                     \"self_eval\": n.fast_reward_details[\"self_eval\"],\n",
    "                     \"action\": n.action})\n",
    "\n",
    "visualize(result_rap,\n",
    "          node_data_factory=blocksworld_node_data_factory,\n",
    "          edge_data_factory=blocksworld_edge_data_factory)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "fadf5ab5",
   "metadata": {},
   "source": [
    "This evaluator module provides standard APIs and easy implementation of multiple popular reasoning datasets."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ab27669adac79b8d",
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "with open('prompts/pool_prompt_v1.json') as f:\n",
    "    prompt = json.load(f)\n",
    "evaluator = BWEvaluator(config_file='examples/CoT/blocksworld/data/bw_config.yaml',\n",
    "                        domain_file='examples/CoT/blocksworld/data/generated_domain.pddl',\n",
    "                        data_path='examples/CoT/blocksworld/data/split_v1/split_v1_step_4_data.json',\n",
    "                        init_prompt=prompt)\n",
    "evaluator.evaluate(reasoner_tot, shuffle_prompt=True, num_shot=4, resume=0, log_dir='log/')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3036a78e95ef7ce8",
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
