{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cbb4657a-e4b8-4458-bb80-82dfa6dcc72c",
   "metadata": {},
   "outputs": [],
   "source": [
    "from utils import *\n",
    "from prompts import *\n",
    "from collections import defaultdict\n",
    "import math\n",
    "import numpy as np\n",
    "import openai\n",
    "import re\n",
    "import argparse\n",
    "import time\n",
    "\n",
    "import torch\n",
    "import argparse\n",
    "\n",
    "from accelerate import init_empty_weights, infer_auto_device_map\n",
    "from transformers import AutoConfig, LlamaTokenizer\n",
    "from transformers import AutoModelForCausalLM, GenerationConfig\n",
    "\n",
    "\n",
    "class Agent():\n",
    "    \n",
    "    def __init__(self, OPENAI_API_KEY):\n",
    "        self.model_name = \"/export/share/ruimeng/ckpts/llm/llama_hf/30B\"\n",
    "        self.tokenizer = LlamaTokenizer.from_pretrained(self.model_name)\n",
    "        self.device_map = \"auto\"\n",
    "        self.model = AutoModelForCausalLM.from_pretrained(\n",
    "            self.model_name,\n",
    "            device_map=self.device_map,\n",
    "            torch_dtype=torch.float16,\n",
    "            low_cpu_mem_usage=self.device_map is not None,\n",
    "            load_in_8bit=False,\n",
    "            )\n",
    "        self.generation_config = GenerationConfig(\n",
    "            do_sample=True,\n",
    "            temperature=0.0001,\n",
    "            max_new_tokens=200,\n",
    "        )\n",
    "\n",
    "    def predict_answer(self, user_message, lb, temperature=0.0):\n",
    "        with torch.no_grad():\n",
    "            input_ids = self.tokenizer(user_message, return_tensors=\"pt\").input_ids\n",
    "            input_ids = input_ids.to(0)\n",
    "            \n",
    "            generated_ids = self.model.generate(\n",
    "                input_ids=input_ids,\n",
    "                attention_mask=torch.ones_like(input_ids),\n",
    "                generation_config=self.generation_config,\n",
    "            )\n",
    "            \n",
    "            ref = self.tokenizer.batch_decode(generated_ids.cpu(), skip_special_tokens=True)[0]\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "def ucl_cot(agent, state_config, state_action_score, state_action, \\\n",
    "       state_action_counter, state_counter, \\\n",
    "      state, target, final_block_config, R, C, K, B, depth, action_ops):\n",
    "\n",
    "    trajectory = []\n",
    "    action_ops2token_ids = get_action_ops2token_ids(action_ops, agent.model_name)\n",
    "    action_ops2token_ids_list = []\n",
    "          \n",
    "    for i in list(action_ops2token_ids.values()):\n",
    "        action_ops2token_ids_list+=i\n",
    "        \n",
    "    track_loop = []\n",
    "    \n",
    "    for _ in range(depth):\n",
    "\n",
    "        if str(state_config) in track_loop:\n",
    "            break\n",
    "        track_loop.append(str(state_config))\n",
    "    \n",
    "        prompt=prompt_without_history_v5(state_json2text(state_config), target, None)\n",
    "    \n",
    "        check = True\n",
    "        valid_state = True\n",
    "        check_counter = 0\n",
    "        temp_poss = [0.0, 0.5, 1.0]\n",
    "        token2bias=None\n",
    "        action = None\n",
    "        \n",
    "        while check and check_counter < 5:\n",
    "            try:\n",
    "                token2bias = past_actions_review(state_config, state_action_score, \\\n",
    "                                            state_action, state_action_counter, \\\n",
    "                                            state_counter, action_ops, R, C, K, B, action_ops2token_ids )\n",
    "                \n",
    "                token_id2bias = {}\n",
    "                for tok in token2bias:\n",
    "                    for j in action_ops2token_ids[tok]:\n",
    "                        token_id2bias[j] = token2bias[tok]\n",
    "                for ttok in action_ops2token_ids:\n",
    "                    if ttok not in token2bias:\n",
    "                        for j in action_ops2token_ids[ttok]:\n",
    "                            token_id2bias[j] = 10\n",
    "                \n",
    "                action = None\n",
    "                action = agent.predict_answer(prompt, lb=token_id2bias, temperature=0.0)\n",
    "                action = action['choices'][0]['message']['content']\n",
    "                action = action.lower().split('1:')[1].split('step')[0].replace('\\n', '').strip()\n",
    "                temp_state, valid_state = add_action_to_json_state(state_config, action)\n",
    "                if valid_state:\n",
    "                    check = False\n",
    "                else:\n",
    "                    if action:\n",
    "                        if action not in state_action[str(state_config)]:\n",
    "                            state_action[str(state_config)].append(action)\n",
    "                        state_counter[str(state_config)] += 1\n",
    "                        state_action_counter[str(state_config)][action] += 1\n",
    "                        state_action_score[str(state_config)][action] = 0.0\n",
    "                    \n",
    "                    check_counter+=1\n",
    "                    \n",
    "                # TODO\n",
    "                # if an action leads to invalid state then the respective toks should have -100 bias\n",
    "            except Exception as e:\n",
    "                check_counter += 1\n",
    "                import time\n",
    "                time.sleep(3)\n",
    "\n",
    "        if check and check_counter >= 5:\n",
    "            return state_action_score, state_action, state_action_counter, state_counter\n",
    "    \n",
    "        # add action to state\n",
    "        \n",
    "        if action not in state_action[str(state_config)]:\n",
    "            state_action[str(state_config)].append(action)\n",
    "        state_counter[str(state_config)] += 1\n",
    "        state_action_counter[str(state_config)][action] += 1\n",
    "        trajectory.append((str(state_config), action))\n",
    "               \n",
    "        if temp_state == final_block_config:\n",
    "            reward = R\n",
    "            for ss, aa in trajectory:\n",
    "                state_action_score[ss][aa] += reward\n",
    "            break\n",
    "                \n",
    "        else:\n",
    "            reward = 0.0\n",
    "            for ss, aa in trajectory:\n",
    "                state_action_score[ss][aa] += reward\n",
    "            state_config = copy.deepcopy(temp_state)\n",
    "            \n",
    "    return state_action_score, state_action, state_action_counter, state_counter\n",
    "\n",
    "\n",
    "\n",
    "def main(args):\n",
    "    UCB_CONSTANT = args.exploration_constant\n",
    "    grid_reward = args.reward\n",
    "    K = args.K\n",
    "    B = args.B\n",
    "    depth = args.depth\n",
    "\n",
    "    comp_answer_steps = args.no_of_answer_steps.split(',')\n",
    "\n",
    "    for NO_OF_STEPS_IN_ANSWER in comp_answer_steps:\n",
    "        print('#'*50)\n",
    "        print('No of steps in ans: ', NO_OF_STEPS_IN_ANSWER)\n",
    "        print('#'*50)\n",
    "\n",
    "        bw_data = get_blocksworld_data(int(NO_OF_STEPS_IN_ANSWER))\n",
    "        \n",
    "        for _ in range(args.no_of_trials):\n",
    "            \n",
    "            preds = []\n",
    "            avg_actions = []\n",
    "            \n",
    "            for item_idx, item in tqdm(enumerate(bw_data)):\n",
    "                state_action_score = defaultdict(lambda: defaultdict(float))\n",
    "                state_action = defaultdict(list)\n",
    "                state_action_counter = defaultdict(lambda: defaultdict(int))\n",
    "                state_counter = defaultdict(int)\n",
    "\n",
    "                # Initialize Agent\n",
    "                agent = Agent(args.OPENAI_API_KEY)\n",
    "\n",
    "                action_operators = ['unstack', 'stack', 'pick', 'put']\n",
    "                action_operands = [] #item['participating_blocks']\n",
    "                action_ops = action_operators + action_operands\n",
    "                \n",
    "                init_block_config = state_text2json(item['real_problem'], item['participating_blocks'])\n",
    "                final_block_config = copy.deepcopy(init_block_config)\n",
    "                gt_action_sequence = real_solution2text(item['real_solution'])\n",
    "                \n",
    "                for action in gt_action_sequence:\n",
    "                    final_block_config, valid_action = add_action_to_json_state(final_block_config, action)\n",
    "                    assert valid_action is True, 'Cannot reach final block config'\n",
    "    \n",
    "    \n",
    "                # Learn: run iterations\n",
    "                for pq in range(args.no_of_passes):\n",
    "                    state_action_score, state_action, state_action_counter, state_counter = ucl_cot(\n",
    "                        agent, copy.deepcopy(init_block_config), \\\n",
    "                   state_action_score, state_action, \\\n",
    "                   state_action_counter, state_counter, \\\n",
    "                   state_json2text(init_block_config), \\\n",
    "                   state_json2text(final_block_config), \\\n",
    "                   final_block_config, grid_reward, UCB_CONSTANT, K, B, depth, action_ops\n",
    "                    )    \n",
    "    \n",
    "                final_steps = []\n",
    "                \n",
    "                for stepi in state_action_score:\n",
    "                    best_action, best_score = sorted(state_action_score[stepi].items(), key=lambda x: x[1], reverse=True)[0]\n",
    "                    if best_score != 0.0:\n",
    "                        final_steps.append(best_action)\n",
    "                    else:\n",
    "                        break\n",
    "\n",
    "                # If none of the generated solutions are correct then step_action_score's best_score will always be zero\n",
    "                # Hence a non-empty final_steps indicates proposed solution is correct\n",
    "                if final_steps:\n",
    "                    preds.append(1)\n",
    "            \n",
    "            print('No of questions: ', str(len(bw_data)))\n",
    "            print('No of correct answers: ', sum(preds))\n",
    "        print()\n",
    "\n",
    "    \n",
    "    \n",
    "\n",
    "\n",
    "if __name__ == '__main__':\n",
    "    \n",
    "    parser = argparse.ArgumentParser()\n",
    "    parser.add_argument('-no_of_passes', default=10, type=int)\n",
    "    parser.add_argument('-no_of_trials', default=1, type=int)\n",
    "    parser.add_argument('-K', default=5, type=int)\n",
    "    parser.add_argument('-B', default=2, type=int)\n",
    "    parser.add_argument('-depth', default=10, type=int)\n",
    "    parser.add_argument('-reward', default=1, type=int)\n",
    "    parser.add_argument('-exploration_constant', default=10, type=int)\n",
    "    parser.add_argument('-model_temperature', default=0.0, type=float)\n",
    "    parser.add_argument('-OPENAI_API_KEY', default=\"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\")\n",
    "    parser.add_argument('-no_of_answer_steps', default='2,4,6')\n",
    "    args = parser.parse_args()\n",
    "    \n",
    "    main(args)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "268aec98-5728-4674-b488-41c05ea4c174",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a497adc5-383d-4413-9c2f-b89d6d91309f",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "7998b20c-7c77-402b-9913-d13e783dd573",
   "metadata": {},
   "outputs": [],
   "source": [
    "from transformers import GPT2Tokenizer, GPT2LMHeadModel, AutoTokenizer\n",
    "\n",
    "tokenizer = GPT2Tokenizer.from_pretrained(\"gpt2\")\n",
    "model = GPT2LMHeadModel.from_pretrained(\"gpt2\")\n",
    "\n",
    "inputs = tokenizer(\"Hello, my dog is cute and \", return_tensors=\"pt\")\n",
    "generation_output = model.generate(**inputs, return_dict_in_generate=True, output_scores=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "9fb43ed9-a82c-42a8-a481-ede10f3d66b2",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Downloading (…)/main/tokenizer.json: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1.36M/1.36M [00:00<00:00, 13.1MB/s]\n"
     ]
    }
   ],
   "source": [
    "tokenizer_with_prefix_space = AutoTokenizer.from_pretrained(\"gpt2\", add_prefix_space=True)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "e5973f39-dc81-473d-b574-1d0b08623c47",
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_tokens_as_tuple(word):\n",
    "    return tuple(tokenizer_with_prefix_space([word], add_special_tokens=False).input_ids[0])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "f60ef1fb-ac28-4cc6-ac6b-7e752d6c7b98",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(17180, 18040, 18040)"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "get_tokens_as_tuple('appleappleapple')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "cd207efd-a264-4866-bad2-031ccf586fc3",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{(1301,): -10.0}"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "sequence_bias = {get_tokens_as_tuple(\"Trump\"): -10.0}\n",
    "sequence_bias"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "8a0d9f65-d00a-4905-a9ac-8274310e86d7",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.\n",
      "Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "The full name of Donald is Donald J. Trump Jr\n"
     ]
    },
    {
     "ename": "ValueError",
     "evalue": "The following `model_kwargs` are not used by the model: ['sequence_bias'] (note: typos in the generate arguments will also show up in this list)",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mValueError\u001b[0m                                Traceback (most recent call last)",
      "Cell \u001b[0;32mIn[13], line 20\u001b[0m\n\u001b[1;32m     18\u001b[0m \u001b[38;5;66;03m# If we add a negative bias without beam search, it may become \"stuck\" in a prefix without good continuations\u001b[39;00m\n\u001b[1;32m     19\u001b[0m sequence_bias \u001b[38;5;241m=\u001b[39m {get_tokens_as_tuple(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mTrump\u001b[39m\u001b[38;5;124m\"\u001b[39m): \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m10.0\u001b[39m}\n\u001b[0;32m---> 20\u001b[0m biased_ids \u001b[38;5;241m=\u001b[39m \u001b[43mmodel\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mgenerate\u001b[49m\u001b[43m(\u001b[49m\u001b[43minputs\u001b[49m\u001b[43m[\u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43minput_ids\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m]\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mmax_new_tokens\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;241;43m4\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43msequence_bias\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43msequence_bias\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m     21\u001b[0m \u001b[38;5;28mprint\u001b[39m(tokenizer\u001b[38;5;241m.\u001b[39mbatch_decode(biased_ids, skip_special_tokens\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mTrue\u001b[39;00m)[\u001b[38;5;241m0\u001b[39m])\n\u001b[1;32m     23\u001b[0m biased_ids \u001b[38;5;241m=\u001b[39m model\u001b[38;5;241m.\u001b[39mgenerate(inputs[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124minput_ids\u001b[39m\u001b[38;5;124m\"\u001b[39m], max_new_tokens\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m4\u001b[39m, num_beams\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m4\u001b[39m, sequence_bias\u001b[38;5;241m=\u001b[39msequence_bias)\n",
      "File \u001b[0;32m/export/home/envs/reflexion/lib/python3.9/site-packages/torch/utils/_contextlib.py:115\u001b[0m, in \u001b[0;36mcontext_decorator.<locals>.decorate_context\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m    112\u001b[0m \u001b[38;5;129m@functools\u001b[39m\u001b[38;5;241m.\u001b[39mwraps(func)\n\u001b[1;32m    113\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mdecorate_context\u001b[39m(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs):\n\u001b[1;32m    114\u001b[0m     \u001b[38;5;28;01mwith\u001b[39;00m ctx_factory():\n\u001b[0;32m--> 115\u001b[0m         \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mfunc\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
      "File \u001b[0;32m/export/home/envs/reflexion/lib/python3.9/site-packages/transformers/generation/utils.py:1231\u001b[0m, in \u001b[0;36mGenerationMixin.generate\u001b[0;34m(self, inputs, generation_config, logits_processor, stopping_criteria, prefix_allowed_tokens_fn, synced_gpus, streamer, **kwargs)\u001b[0m\n\u001b[1;32m   1229\u001b[0m model_kwargs \u001b[38;5;241m=\u001b[39m generation_config\u001b[38;5;241m.\u001b[39mupdate(\u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)  \u001b[38;5;66;03m# All unused kwargs must be model kwargs\u001b[39;00m\n\u001b[1;32m   1230\u001b[0m generation_config\u001b[38;5;241m.\u001b[39mvalidate()\n\u001b[0;32m-> 1231\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_validate_model_kwargs\u001b[49m\u001b[43m(\u001b[49m\u001b[43mmodel_kwargs\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mcopy\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m   1233\u001b[0m \u001b[38;5;66;03m# 2. Set generation parameters if not already defined\u001b[39;00m\n\u001b[1;32m   1234\u001b[0m logits_processor \u001b[38;5;241m=\u001b[39m logits_processor \u001b[38;5;28;01mif\u001b[39;00m logits_processor \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m \u001b[38;5;28;01melse\u001b[39;00m LogitsProcessorList()\n",
      "File \u001b[0;32m/export/home/envs/reflexion/lib/python3.9/site-packages/transformers/generation/utils.py:1109\u001b[0m, in \u001b[0;36mGenerationMixin._validate_model_kwargs\u001b[0;34m(self, model_kwargs)\u001b[0m\n\u001b[1;32m   1106\u001b[0m         unused_model_args\u001b[38;5;241m.\u001b[39mappend(key)\n\u001b[1;32m   1108\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m unused_model_args:\n\u001b[0;32m-> 1109\u001b[0m     \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mValueError\u001b[39;00m(\n\u001b[1;32m   1110\u001b[0m         \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mThe following `model_kwargs` are not used by the model: \u001b[39m\u001b[38;5;132;01m{\u001b[39;00munused_model_args\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m (note: typos in the\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m   1111\u001b[0m         \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m generate arguments will also show up in this list)\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m   1112\u001b[0m     )\n",
      "\u001b[0;31mValueError\u001b[0m: The following `model_kwargs` are not used by the model: ['sequence_bias'] (note: typos in the generate arguments will also show up in this list)"
     ]
    }
   ],
   "source": [
    "from transformers import AutoTokenizer, AutoModelForCausalLM\n",
    "\n",
    "model = AutoModelForCausalLM.from_pretrained(\"gpt2\")\n",
    "tokenizer = AutoTokenizer.from_pretrained(\"gpt2\")\n",
    "inputs = tokenizer([\"The full name of Donald is Donald\"], return_tensors=\"pt\")\n",
    "\n",
    "summary_ids = model.generate(inputs[\"input_ids\"], max_new_tokens=4)\n",
    "print(tokenizer.batch_decode(summary_ids, skip_special_tokens=True)[0])\n",
    "\n",
    "# Now let's control generation through a bias. Please note that the tokenizer is initialized differently!\n",
    "tokenizer_with_prefix_space = AutoTokenizer.from_pretrained(\"gpt2\", add_prefix_space=True)\n",
    "\n",
    "\n",
    "def get_tokens_as_tuple(word):\n",
    "    return tuple(tokenizer_with_prefix_space([word], add_special_tokens=False).input_ids[0])\n",
    "\n",
    "\n",
    "# If we add a negative bias without beam search, it may become \"stuck\" in a prefix without good continuations\n",
    "sequence_bias = {get_tokens_as_tuple(\"Trump\"): -10.0}\n",
    "biased_ids = model.generate(inputs[\"input_ids\"], max_new_tokens=4, sequence_bias=sequence_bias)\n",
    "print(tokenizer.batch_decode(biased_ids, skip_special_tokens=True)[0])\n",
    "\n",
    "biased_ids = model.generate(inputs[\"input_ids\"], max_new_tokens=4, num_beams=4, sequence_bias=sequence_bias)\n",
    "print(tokenizer.batch_decode(biased_ids, skip_special_tokens=True)[0])\n",
    "\n",
    "# We can also add a positive bias to nudge the model towards specific tokens or continuations\n",
    "sequence_bias = {get_tokens_as_tuple(\"Donald Duck\"): 10.0}\n",
    "biased_ids = model.generate(inputs[\"input_ids\"], max_new_tokens=4, num_beams=4, sequence_bias=sequence_bias)\n",
    "print(tokenizer.batch_decode(biased_ids, skip_special_tokens=True)[0])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "69766748-fb42-4b6a-a278-22dc2db74938",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "reflexion",
   "language": "python",
   "name": "reflexion"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.17"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
