{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "276dcef9-5bad-4f05-ba12-a1b3be39a714",
   "metadata": {},
   "outputs": [],
   "source": [
    "from utils import *\n",
    "from prompts import *\n",
    "from collections import defaultdict\n",
    "import math\n",
    "import numpy as np\n",
    "import re\n",
    "import argparse\n",
    "import time\n",
    "\n",
    "import torch\n",
    "import argparse\n",
    "\n",
    "from accelerate import init_empty_weights, infer_auto_device_map\n",
    "from transformers import AutoConfig, LlamaTokenizer\n",
    "from transformers import AutoModelForCausalLM, GenerationConfig"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "e85bb4c1-83f0-4830-bde6-5bd0951d7523",
   "metadata": {},
   "outputs": [],
   "source": [
    "class Agent():\n",
    "    \n",
    "    def __init__(self):\n",
    "        self.model_name = \"/export/share/ruimeng/ckpts/llm/llama_hf/30B\"\n",
    "        self.tokenizer = LlamaTokenizer.from_pretrained(self.model_name)\n",
    "        self.device_map = \"auto\"\n",
    "        # self.device_map = self.get_device_map(self.model_name_or_path, \"a100-40g\", False)\n",
    "        print('Loading model:\\n')\n",
    "        self.model = AutoModelForCausalLM.from_pretrained(\n",
    "            self.model_name,\n",
    "            device_map=self.device_map,\n",
    "            torch_dtype=torch.float16,\n",
    "            low_cpu_mem_usage=self.device_map is not None,\n",
    "            load_in_8bit=False,\n",
    "            )\n",
    "        self.generation_config = GenerationConfig(\n",
    "            do_sample=True,\n",
    "            temperature=0.8,\n",
    "            max_new_tokens=200,\n",
    "        )\n",
    "\n",
    "    def predict_answer(self, user_message, temperature=0.0):\n",
    "        with torch.no_grad():\n",
    "            input_ids = self.tokenizer(user_message, return_tensors=\"pt\").input_ids\n",
    "            input_ids = input_ids.to(0)\n",
    "            \n",
    "            generated_ids = self.model.generate(\n",
    "                input_ids=input_ids,\n",
    "                attention_mask=torch.ones_like(input_ids),\n",
    "                generation_config=self.generation_config,\n",
    "            )\n",
    "            \n",
    "            ref = self.tokenizer.batch_decode(generated_ids.cpu(), skip_special_tokens=True)[0]\n",
    "\n",
    "            # input_ids = self.tokenizer(\n",
    "            #     f\"\"\"Below is an instruction that describes a task. Write a response that appropriately completes the request.\n",
    "            #     ### Instruction: {user_message}\n",
    "            #     ### Response:\"\"\",\n",
    "            #     return_tensors=\"pt\",\n",
    "            # ).input_ids\n",
    "            # input_ids = input_ids.to(0)\n",
    "            # generated_ids = self.model.generate(\n",
    "            #     input_ids=input_ids,\n",
    "            #     attention_mask=torch.ones_like(input_ids),\n",
    "            #     generation_config=self.generation_config,\n",
    "            # )\n",
    "            # ref = self.tokenizer.batch_decode(generated_ids.cpu(), skip_special_tokens=True)[0].split('Response:')[1].strip()\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "            \n",
    "            # print('#'*50)\n",
    "            # print('#'*50)\n",
    "            # print(ref)\n",
    "            ref = ref.split('[PROBLEM 5]')[0].replace('#', '').strip()\n",
    "            ri = ref.rindex('[SOLUTION]')\n",
    "            # print('#'*50)\n",
    "            # print(ref)\n",
    "            # print('@'*50)\n",
    "            # print(ref[ri:])\n",
    "            # print('*'*50)\n",
    "            # print(ref[ri:].lower().split('[solution]')[1])\n",
    "        return ref[ri:].lower().split('[solution]')[1].split('[')[0].strip()\n",
    "        \n",
    "\n",
    "\n",
    "    def get_device_map(self, model_name, device, do_int8):\n",
    "        if device == \"a100-40g\":\n",
    "            return \"auto\"\n",
    "    \n",
    "        with init_empty_weights():\n",
    "            config = AutoConfig.from_pretrained(model_name)\n",
    "            model = AutoModelForCausalLM.from_config(config)\n",
    "    \n",
    "        d = {0: \"18GiB\"}\n",
    "        for i in range(1, 6):\n",
    "            d[i] = \"26GiB\"\n",
    "        device_map = infer_auto_device_map(\n",
    "            model, max_memory=d, dtype=torch.int8 if do_int8 else torch.float16,\n",
    "            no_split_module_classes=[\"BloomBlock\", \"OPTDecoderLayer\", \"LLaMADecoderLayer\", \"LlamaDecoderLayer\"]\n",
    "        )\n",
    "        print(device_map)\n",
    "        del model\n",
    "        return device_map\n",
    "        "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "d9ab4902-1006-4008-b5a2-ceeaee9910b6",
   "metadata": {},
   "outputs": [],
   "source": [
    "def ucb_cot(agent, state, target, participating_blocks, step_action_score, step_counter, step_action_counter, UCB_CONSTANT, step_action_ucb, grid_reward, cutt_off):\n",
    "    \n",
    "    history = step_action_score2text_llm_following_v1(step_action_ucb)\n",
    "    # print('-'*50)\n",
    "    # print('history')\n",
    "    # print(history)\n",
    "\n",
    "    if history:\n",
    "        prompt=prompt_with_history_v2_llama_v1_1(state, target, history)\n",
    "    else:\n",
    "        prompt=prompt_without_history_v2_llama_v1(state, target)\n",
    "\n",
    "    check = True\n",
    "    check_counter = 5\n",
    "    local_step_action_score = {}\n",
    "\n",
    "    while check and check_counter:\n",
    "        try:\n",
    "            op = agent.predict_answer(prompt, 0.0)\n",
    "            # print('~'*50)\n",
    "            # print('op')\n",
    "            # print(op)\n",
    "            op_list = op.split('\\n')[:cutt_off]\n",
    "            # print('-'*50)\n",
    "            # print('op_list')\n",
    "            # print(op_list)\n",
    "            idx_sequence = []\n",
    "            for st in op_list:\n",
    "                if 'step sequence' in st:\n",
    "                    continue\n",
    "                elif 'step' in st:\n",
    "                    idx, val = st.split(': ')\n",
    "                    idx, val = idx.strip().lower(), val.strip().lower()\n",
    "                    step_action_score[idx][val] += 0.0\n",
    "                    local_step_action_score[idx] = val\n",
    "                    idx_sequence.append(idx)\n",
    "                    step_counter[idx] += 1\n",
    "                    step_action_counter[idx][val] +=1\n",
    "            check = False\n",
    "        except:\n",
    "            check_counter -= 1\n",
    "            import time\n",
    "            time.sleep(3)\n",
    "\n",
    "        \n",
    "        # op = agent.predict_answer(prompt, 0.0)\n",
    "        # # print('~'*50)\n",
    "        # # print('op')\n",
    "        # # print(op)\n",
    "        # op_list = op.split('\\n')[:cutt_off]\n",
    "        # # print('-'*50)\n",
    "        # # print('op_list')\n",
    "        # # print(op_list)\n",
    "        # idx_sequence = []\n",
    "        # for st in op_list:\n",
    "        #     if 'step sequence' in st:\n",
    "        #         continue\n",
    "        #     elif 'step' in st:\n",
    "        #         idx, val = st.split(': ')\n",
    "        #         idx, val = idx.strip().lower(), val.strip().lower()\n",
    "        #         step_action_score[idx][val] += 0.0\n",
    "        #         local_step_action_score[idx] = val\n",
    "        #         idx_sequence.append(idx)\n",
    "        #         step_counter[idx] += 1\n",
    "        #         step_action_counter[idx][val] +=1\n",
    "        # check = False\n",
    "        \n",
    "\n",
    "\n",
    "    step_sequence = list(local_step_action_score.values())\n",
    "\n",
    "    for idx, step in enumerate(step_sequence):\n",
    "        new_state, valid_action = add_action_to_json_state(state_text2json(state.lower().replace('.', ''), participating_blocks), step)\n",
    "        step_counter[idx_sequence[idx]] += 1\n",
    "        step_action_counter[idx_sequence[idx]][step] += 1\n",
    "        \n",
    "        if not valid_action:\n",
    "            for stepi in local_step_action_score:\n",
    "                step_action_score[stepi][local_step_action_score[stepi]] += 0.0\n",
    "                step_action_ucb[stepi][local_step_action_score[stepi]] = get_ucb_score(step_action_score[stepi][local_step_action_score[stepi]], UCB_CONSTANT, step_counter[stepi], step_action_counter[stepi][local_step_action_score[stepi]])\n",
    "            return step_action_score, step_action_ucb, step_counter, step_action_counter\n",
    "        else:\n",
    "            if new_state == state_text2json(target.lower().replace('.', ''), participating_blocks):\n",
    "                for jdx, jval in enumerate(step_sequence[:idx+1]):\n",
    "                    step_action_score[idx_sequence[jdx]][jval] += float(grid_reward)\n",
    "                    step_action_ucb[idx_sequence[jdx]][jval] = get_ucb_score(step_action_score[idx_sequence[jdx]][jval], UCB_CONSTANT, step_counter[idx_sequence[jdx]], step_action_counter[idx_sequence[jdx]][jval])\n",
    "                return step_action_score, step_action_ucb, step_counter, step_action_counter\n",
    "        state = state_json2text(new_state)\n",
    "\n",
    "    if state == state_text2json(target.lower().replace('.', ''), participating_blocks):\n",
    "        for stepi in local_step_action_score:\n",
    "            step_action_score[stepi][local_step_action_score[stepi]] += float(grid_reward)\n",
    "            step_action_ucb[stepi][local_step_action_score[stepi]] = get_ucb_score(step_action_score[stepi][local_step_action_score[stepi]], UCB_CONSTANT, step_counter[stepi], step_action_counter[stepi][local_step_action_score[stepi]])\n",
    "        return step_action_score, step_action_ucb, step_counter, step_action_counter\n",
    "    else:\n",
    "        for stepi in local_step_action_score:\n",
    "            step_action_score[stepi][local_step_action_score[stepi]] += 0.0\n",
    "            step_action_ucb[stepi][local_step_action_score[stepi]] = get_ucb_score(step_action_score[stepi][local_step_action_score[stepi]], UCB_CONSTANT, step_counter[stepi], step_action_counter[stepi][local_step_action_score[stepi]])\n",
    "        return step_action_score, step_action_ucb, step_counter, step_action_counter\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "562083b0-1c8e-4bfe-981a-571fcfdd7ae2",
   "metadata": {},
   "outputs": [],
   "source": [
    "def main(args):\n",
    "    UCB_CONSTANT = args.exploration_constant\n",
    "    grid_reward = args.reward\n",
    "\n",
    "    # Initialize Agent\n",
    "    agent = Agent()\n",
    "\n",
    "    comp_answer_steps = args.no_of_answer_steps.split(',')\n",
    "\n",
    "    for NO_OF_STEPS_IN_ANSWER in comp_answer_steps:\n",
    "        print('#'*50)\n",
    "        print('No of steps in ans: ', NO_OF_STEPS_IN_ANSWER)\n",
    "        print('#'*50)\n",
    "\n",
    "        bw_data = get_blocksworld_data(int(NO_OF_STEPS_IN_ANSWER))\n",
    "        cutt_off = int(NO_OF_STEPS_IN_ANSWER)\n",
    "        \n",
    "        for _ in range(args.no_of_trials):\n",
    "            \n",
    "            preds = []\n",
    "            avg_actions = []\n",
    "            \n",
    "            for item_idx, item in tqdm(enumerate(bw_data)):\n",
    "                step_action_score = defaultdict(lambda: defaultdict(float))\n",
    "                step_action_ucb = defaultdict(lambda: defaultdict(float))\n",
    "                step_counter = defaultdict(int)\n",
    "                step_action_counter = defaultdict(lambda: defaultdict(int))\n",
    "\n",
    "                \n",
    "                \n",
    "                init_block_config = state_text2json(item['real_problem'], item['participating_blocks'])\n",
    "                final_block_config = copy.deepcopy(init_block_config)\n",
    "                gt_action_sequence = real_solution2text(item['real_solution'])\n",
    "                \n",
    "                for action in gt_action_sequence:\n",
    "                    final_block_config, valid_action = add_action_to_json_state(final_block_config, action)\n",
    "                    assert valid_action is True, 'Cannot reach final block config'\n",
    "    \n",
    "                # print('init_block_config')\n",
    "                # print(init_block_config)\n",
    "                # print('final_block_config')\n",
    "                # print(final_block_config)\n",
    "                # print('gt_action_sequence')\n",
    "                # print(gt_action_sequence)\n",
    "                # print('===')\n",
    "                \n",
    "                # Learn: run iterations\n",
    "                for pq in range(args.no_of_passes):\n",
    "                    step_action_score, step_action_ucb, step_counter, step_action_counter = ucb_cot(\n",
    "                        agent, state_json2text(init_block_config), \n",
    "                        state_json2text(final_block_config), \n",
    "                        item['participating_blocks'], \n",
    "                        step_action_score,\n",
    "                        step_counter,\n",
    "                        step_action_counter,\n",
    "                        UCB_CONSTANT,\n",
    "                        step_action_ucb,\n",
    "                        grid_reward, cutt_off\n",
    "                    )    \n",
    "    \n",
    "                    final_steps = []\n",
    "                    \n",
    "                    for stepi in step_action_score:\n",
    "                        best_action, best_score = sorted(step_action_score[stepi].items(), key=lambda x: x[1], reverse=True)[0]\n",
    "                        if best_score != 0.0:\n",
    "                            final_steps.append(best_action)\n",
    "                        else:\n",
    "                            break\n",
    "    \n",
    "                    # If none of the generated solutions are correct then step_action_score's best_score will always be zero\n",
    "                    # Hence a non-empty final_steps indicates proposed solution is correct\n",
    "                    if final_steps:\n",
    "                        preds.append(1)\n",
    "                        break\n",
    "                    print(str(item_idx+1)+\":\\t\"+str(sum(preds))+\"\\n\")\n",
    "                print(step_action_score)\n",
    "            \n",
    "            print('No of questions: ', str(len(bw_data)))\n",
    "            print('No of correct answers: ', sum(preds))\n",
    "        print()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "303872f3-2912-4b1a-bbfd-2e358791576b",
   "metadata": {},
   "outputs": [],
   "source": [
    "class ArgumentClass():\n",
    "    def __init__(self):\n",
    "        self.no_of_passes=10\n",
    "        self.no_of_trials=1\n",
    "        self.reward=1\n",
    "        self.exploration_constant=10\n",
    "        self.model_temperature=0.0\n",
    "        self.no_of_answer_steps='2,4,6'\n",
    "args = ArgumentClass()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "d0b9ddb8-734c-430e-8696-86d5766f443c",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Expt\n",
    "\n",
    "# 1. temperature variation = did not work\n",
    "# 2. change prompt - Include [Option for every few-shot example] = did not work"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "5388cf40-eb2b-47d6-a7ec-96d764271217",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loading model:\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Loading checkpoint shards: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [01:24<00:00, 12.08s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "##################################################\n",
      "No of steps in ans:  2\n",
      "##################################################\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "0it [00:00, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1:\t0\n",
      "\n",
      "1:\t0\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "1it [02:06, 126.71s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "defaultdict(<function main.<locals>.<lambda> at 0x7ff885772940>, {'step 1': defaultdict(<class 'float'>, {'put down the yellow block': 0.0, 'pick up the yellow block': 1.0}), 'step 2': defaultdict(<class 'float'>, {'pick up the yellow block': 0.0, 'stack the yellow block on top of the orange block': 1.0})})\n",
      "2:\t1\n",
      "\n",
      "2:\t1\n",
      "\n",
      "2:\t1\n",
      "\n",
      "2:\t1\n",
      "\n",
      "2:\t1\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "1it [05:05, 305.70s/it]\n",
      "\n",
      "KeyboardInterrupt\n",
      "\n"
     ]
    }
   ],
   "source": [
    "main(args)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7d23e322-71aa-4b60-b2e6-e6c4e30b9a5e",
   "metadata": {},
   "outputs": [],
   "source": [
    "# prompts\n",
    "# v0: [0, 1, 1, 1, 2, 3, 3, 3, 4, 4]\n",
    "\n",
    "# v1: 2-step: t=0.8: 13/30\n",
    "# v1: 2-step: t=0.5: 19/30\n",
    "# v1: 2-step: t=0.3: 20/30\n",
    "# v1: 2-step: t=0.25: 4/15\n",
    "# v1: 2-step: t=0.2: 22/30\n",
    "# v1: 2-step: t=0.1: 3/16\n",
    "# v1: 2-step: t=0.000000000000001: 15/30\n",
    "\n",
    "# v1_1: 2-step: t=0.2: 9/22\n",
    "\n",
    "# v2: 2-step: t=0.2: 4/15"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b57f91dd-37b1-462f-8625-db26873ff705",
   "metadata": {},
   "outputs": [],
   "source": [
    "# 4step\n",
    "# prompt v1: t=0.2: history v1: 0/26\n",
    "# prompt v1: t=0.5: history v1: 1/50\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "# Expt\n",
    "# alternate pick/unstack and put/stack"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f8649617-080a-459e-87db-5e2cd86d46ff",
   "metadata": {},
   "outputs": [],
   "source": [
    "# 1/56\n",
    "# 3/114"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "llama2",
   "language": "python",
   "name": "llama2"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.17"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
