{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import transformers\n",
    "from transformers import AutoTokenizer,AutoModelForCausalLM\n",
    "import torch\n",
    "from transformers.generation.stopping_criteria import StoppingCriteria,StoppingCriteriaList,STOPPING_CRITERIA_INPUTS_DOCSTRING,add_start_docstrings\n",
    "import yaml\n",
    "import alfworld\n",
    "import alfworld.agents.environment\n",
    "import json\n",
    "import sys\n",
    "import numpy as np\n",
    "import random\n",
    "from peft import PeftModel\n",
    "\n",
    "os.environ['CUDA_VISIBLE_DEVICES'] = '0'   # specify the gpu id which is available\n",
    "\n",
    "model_name_or_path = 'MODEL_NAME_OR_PATH for the model used'\n",
    "peft_model_path = 'PATH to the finetuned model'\n",
    "\n",
    "ADD_SPECIAL_TOKENS = False   # Need to be changed to True for Gemma"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model = transformers.AutoModelForCausalLM.from_pretrained(model_name_or_path, torch_dtype=torch.float16,device_map='auto')\n",
    "tokenizer = transformers.AutoTokenizer.from_pretrained(model_name_or_path,padding_side='left')\n",
    "if tokenizer.pad_token_id is None:\n",
    "    tokenizer.pad_token_id = tokenizer.eos_token_id"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model = PeftModel.from_pretrained(model,peft_model_path)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model.eval()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def lora_to_base(model):\n",
    "    try:\n",
    "        model.base_model.disable_adapter_layers()\n",
    "    except:\n",
    "        print(\"No adapter layers to disable\")\n",
    "    # model.eval()\n",
    "    \n",
    "def base_to_lora(model):\n",
    "    try:\n",
    "        model.base_model.enable_adapter_layers()\n",
    "    except:\n",
    "        print(\"No adapter layers to enable\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def comput_output_logp_norm(model,prompt,output,tokenizer):\n",
    "    torch.cuda.empty_cache()\n",
    "    with torch.no_grad():\n",
    "        input_tok=tokenizer([prompt],add_special_tokens=ADD_SPECIAL_TOKENS,padding=True)\n",
    "        output_tok = tokenizer([prompt+output],add_special_tokens=ADD_SPECIAL_TOKENS,padding=True)\n",
    "\n",
    "        # input_ids=torch.LongTensor(input_tok['input_ids']).to(model.device)\n",
    "        input_ids=torch.LongTensor(input_tok['input_ids'])\n",
    "        # attention_mask=torch.LongTensor(input_tok['attention_mask']).to(model.device)\n",
    "        output_ids=torch.LongTensor(output_tok['input_ids']).to(model.device)\n",
    "        output_attention_mask=torch.LongTensor(output_tok['attention_mask']).to(model.device)\n",
    "\n",
    "        l1 = input_ids.shape[-1]\n",
    "        # l2 = output_ids.shape[-1]\n",
    "\n",
    "        # assert (input_ids != output_ids[:,:l1]).sum().item(////////) == 0\n",
    "        res = model(output_ids[:,:-1], output_attention_mask[:,:-1])\n",
    "        \n",
    "        # logp = (torch.log(torch.softmax(res[0][0][l1-1:],dim=-1))[torch.arange(l2-l1),output_ids[0][l1:]]).mean().item()\n",
    "        logp = torch.log_softmax(res.logits[0, l1-1:], dim=-1).gather(1, output_ids[:,l1:].transpose(0,1)).mean().item()\n",
    "    \n",
    "    del input_ids\n",
    "    del output_ids\n",
    "    del output_attention_mask\n",
    "    del res\n",
    "\n",
    "    return logp"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class MyStopping(StoppingCriteria):\n",
    "    def __init__(self,start_length, tokenizer,stop_str):\n",
    "        super().__init__()\n",
    "        self.tokenizer = tokenizer\n",
    "        self.stop_str = stop_str\n",
    "        self.start_length = start_length\n",
    "    \n",
    "    @add_start_docstrings(STOPPING_CRITERIA_INPUTS_DOCSTRING)\n",
    "    def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs):\n",
    "        # if self.tokenizer.decode(input_ids[0].detach().cpu().numpy().tolist()).endswith(self.stop_str):\n",
    "        if all([self.stop_str in s for s in self.tokenizer.batch_decode(input_ids[:, self.start_length:])]):\n",
    "            return True\n",
    "        return False\n",
    "\n",
    "def get_stopping(start_length, stop_strs):\n",
    "    stopping = StoppingCriteriaList()\n",
    "    for s in stop_strs:\n",
    "        stopping.append(MyStopping(start_length, tokenizer,s))\n",
    "    return stopping\n",
    "\n",
    "class MyStoppingCount(StoppingCriteria):\n",
    "    def __init__(self,start_length, tokenizer,stop_str, stop_count):\n",
    "        super().__init__()\n",
    "        self.tokenizer = tokenizer\n",
    "        self.stop_str = stop_str\n",
    "        self.stop_count = stop_count\n",
    "        self.start_length = start_length\n",
    "    \n",
    "    @add_start_docstrings(STOPPING_CRITERIA_INPUTS_DOCSTRING)\n",
    "    def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs):\n",
    "        # if self.tokenizer.decode(input_ids[0].detach().cpu().numpy().tolist()).endswith(self.stop_str):\n",
    "        if all([s.count(self.stop_str) >= self.stop_count for s in self.tokenizer.batch_decode(input_ids[:, self.start_length:])]):\n",
    "            return True\n",
    "        return False\n",
    "\n",
    "def get_stopping_count(start_length, stop_strs):\n",
    "    stopping = StoppingCriteriaList()\n",
    "    for s,c in stop_strs:\n",
    "        stopping.append(MyStoppingCount(start_length, tokenizer,s,c))\n",
    "    return stopping"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "with open('base_config.yaml') as reader:\n",
    "    config = yaml.safe_load(reader)\n",
    "    \n",
    "split = \"eval_out_of_distribution\"\n",
    "\n",
    "env = getattr(alfworld.agents.environment, config[\"env\"][\"type\"])(config, train_eval=split)\n",
    "env = env.init_env(batch_size=1)\n",
    "\n",
    "def process_ob(ob):\n",
    "    if ob.startswith('You arrive at loc '):\n",
    "        ob = ob[ob.find('. ')+2:]    \n",
    "    return ob\n",
    "\n",
    "with open('succ_alfworld_examples_good_bad.json','r') as f:\n",
    "    succ_examples = json.load(f)\n",
    "with open('fail_alfworld_examples_train_good_bad.json','r') as f:\n",
    "    fail_examples = json.load(f)\n",
    "\n",
    "prefixes = {\n",
    "    'pick_and_place': 'put',\n",
    "    'pick_clean_then_place': 'clean',\n",
    "    'pick_heat_then_place': 'heat',\n",
    "    'pick_cool_then_place': 'cool',\n",
    "    'look_at_obj': 'examine',\n",
    "    'pick_two_obj': 'puttwo'\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for k,v in succ_examples.items():\n",
    "    succ_examples[k] = v.replace('> ','>')\n",
    "for k,v in fail_examples.items():\n",
    "    fail_examples[k] = v.replace('> ','>')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "judge_prefix = 'STATUS: '\n",
    "succ_word = 'GOOD'\n",
    "fail_word = 'BAD'\n",
    "\n",
    "succ_sentence = judge_prefix + succ_word\n",
    "fail_sentence = judge_prefix + fail_word\n",
    "\n",
    "\n",
    "def generate_from_prompt_experience_grounded_decoding(model, prompt, prompt_future, prompt_experience, tokenizer, stopping, predict_step = 2 ,max_length=256, temperature=0, top_p=1, num_beams=5, num_return_sequences = 4, valid_actions=None):\n",
    "    # lora_to_base(model)\n",
    "    ######################## generate action candidates ########################\n",
    "    torch.cuda.empty_cache()\n",
    "    input_tok=tokenizer([prompt], add_special_tokens=ADD_SPECIAL_TOKENS,padding=True)\n",
    "    input_ids=torch.LongTensor(input_tok['input_ids']).to(model.device)\n",
    "    attention_mask=torch.LongTensor(input_tok['attention_mask']).to(model.device)\n",
    "    generation_config = transformers.GenerationConfig(\n",
    "        # temperature=temperature,\n",
    "        num_beams=num_beams,\n",
    "        # num_beam_groups=5,\n",
    "        # diversity_penalty=1.0,\n",
    "        top_p=top_p,\n",
    "        # do_sample=True,\n",
    "        num_return_sequences = num_return_sequences,\n",
    "        output_scores=True,\n",
    "        return_dict_in_generate=True,\n",
    "        pad_token_id = tokenizer.pad_token_id,\n",
    "        eos_token_id = tokenizer.eos_token_id,\n",
    "    )\n",
    "    stopping = get_stopping(input_ids.shape[-1], ['\\n'])\n",
    "    outputs = model.generate(\n",
    "        input_ids=input_ids,\n",
    "        attention_mask=attention_mask,\n",
    "        max_new_tokens=max_length,\n",
    "        generation_config=generation_config,\n",
    "        stopping_criteria=stopping,\n",
    "    )\n",
    "    tokenized_texts = outputs.sequences[:,input_ids.shape[-1]:]\n",
    "    texts = tokenizer.batch_decode(tokenized_texts, skip_special_tokens=True)\n",
    "\n",
    "    del input_ids\n",
    "    del attention_mask\n",
    "    del outputs\n",
    "\n",
    "    torch.cuda.empty_cache()\n",
    "\n",
    "    texts = [t.split('\\n')[0].strip() + '\\n' for t in texts]\n",
    "    texts = list(np.unique(texts))\n",
    "    old_texts = [str(t) for t in texts]\n",
    "    texts = []\n",
    "    for t in old_texts:\n",
    "        if t.startswith('think') or t.strip() in valid_actions:\n",
    "            texts.append(t)\n",
    "    \n",
    "    if len(texts) == 0:\n",
    "        texts.append(random.choice(valid_actions) + '\\n')\n",
    "    \n",
    "    scores = [comput_output_logp_norm(model, prompt, t, tokenizer) for t in texts]\n",
    "    logsumexp = np.log(sum([np.exp(s) for s in scores]))\n",
    "    scores = [s - logsumexp for s in scores]\n",
    "\n",
    "    scores_final = scores\n",
    "    text_chosen = texts[np.argmax(scores_final)].strip()\n",
    "\n",
    "    msg = ''\n",
    "    for i in range(len(texts)):\n",
    "        msg += f'[{scores[i]:.5f}]: {texts[i]}'\n",
    "    \n",
    "    return text_chosen,msg"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def generate_critic(model, prompt, tokenizer):\n",
    "    \n",
    "    torch.cuda.empty_cache()\n",
    "\n",
    "    input_tok=tokenizer([prompt], add_special_tokens=ADD_SPECIAL_TOKENS)\n",
    "    input_ids=torch.LongTensor(input_tok['input_ids']).to(model.device)\n",
    "    \n",
    "    # stopping = get_stopping(input_ids.shape[-1], ['\\n'])\n",
    "    stopping = get_stopping_count(input_ids.shape[-1], [\n",
    "        ('\\n', 1),])\n",
    "    outputs = model.generate(\n",
    "        input_ids=input_ids,\n",
    "        past_key_values=None,\n",
    "        use_cache=True,\n",
    "        max_new_tokens=1024,\n",
    "        output_scores=True,\n",
    "        return_dict_in_generate=True,\n",
    "        stopping_criteria=stopping,\n",
    "        pad_token_id = tokenizer.pad_token_id,\n",
    "        eos_token_id = tokenizer.eos_token_id,\n",
    "        do_sample = False,\n",
    "    )\n",
    "    tokenized_texts = outputs.sequences[:,input_ids.shape[-1]:]\n",
    "    output = tokenizer.batch_decode(tokenized_texts, skip_special_tokens=True)[0]\n",
    "\n",
    "    del input_ids\n",
    "    del outputs\n",
    "\n",
    "    output = output.split('\\n')[0]+'\\n'\n",
    "    \n",
    "    torch.cuda.empty_cache()\n",
    "    \n",
    "    return output\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def my_reset(env,gamefiles):\n",
    "    \"\"\" Resets the text-based environment.\n",
    "\n",
    "    Resetting this environment means starting the next game in the pool.\n",
    "\n",
    "    Returns:\n",
    "        A tuple (observations, infos) where\n",
    "\n",
    "        * observation: text observed in the initial state for each game in the batch;\n",
    "        * infos: additional information as requested for each game in the batch.\n",
    "    \"\"\"\n",
    "    if env.batch_env is not None:\n",
    "        env.batch_env.close()\n",
    "\n",
    "    env.batch_env.load(gamefiles)\n",
    "\n",
    "    env.last_commands = [None] * env.batch_size\n",
    "    env.obs, infos = env.batch_env.reset()\n",
    "    return env.obs, infos"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def alfworld_run(prompt, prompt_future, prompt_critic, prompt_experience, to_print=True, ob='', info=None):\n",
    "    plans = list()\n",
    "    init_prompt = prompt + ob + '\\n>'\n",
    "    init_prompt_future = prompt_future + ob + '\\n>'\n",
    "    init_prompt_critic = prompt_critic + ob + '\\n>'\n",
    "    init_prompt_experience = prompt_experience + ob + '\\n>'\n",
    "    prompt = ''\n",
    "    \n",
    "    if to_print:\n",
    "        print(ob)\n",
    "        sys.stdout.flush()\n",
    "    for i in range(1, 40): # TODO: TEST\n",
    "        # action = llm(init_prompt + prompt, stop=['\\n']).strip()\n",
    "        torch.cuda.empty_cache()\n",
    "        valid_actions = info['admissible_commands'][0]\n",
    "        valid_actions.remove('inventory')\n",
    "        valid_actions.remove('look')\n",
    "        action, msg = generate_from_prompt_experience_grounded_decoding(model, \n",
    "                init_prompt + prompt, init_prompt_future + prompt, init_prompt_experience + prompt, \n",
    "                tokenizer, stopping=None, predict_step=4, max_length=128, temperature=0, top_p=1, num_beams=5, num_return_sequences=5, valid_actions=valid_actions)\n",
    "        plans.append({'input':init_prompt + prompt,'output':action})\n",
    "        observation, reward, done, info = env.step([action])\n",
    "        observation, reward, done = process_ob(observation[0]), info['won'][0], done[0]\n",
    "        if action.startswith('think:'):\n",
    "            observation = 'OK.'\n",
    "        torch.cuda.empty_cache()\n",
    "\n",
    "        base_to_lora(model)\n",
    "        critic_info = generate_critic(model,f'{init_prompt_critic}{prompt}{action}\\n{observation}\\n==>',tokenizer)\n",
    "        lora_to_base(model)\n",
    "\n",
    "        torch.cuda.empty_cache()\n",
    "\n",
    "        if to_print:\n",
    "            print(f'==>>\\nAction Choosing Message:\\n{msg}<<==')\n",
    "            print(f'****** Execution {i}:\\n> {action}\\n{observation}\\n==> {critic_info}******')\n",
    "            sys.stdout.flush()\n",
    "        prompt += f'{action}\\n{observation}\\n==>{critic_info}>'\n",
    "        if done:\n",
    "            plans.append({'input':init_prompt+prompt})\n",
    "            return reward,plans\n",
    "    plans.append({'input':init_prompt+prompt})\n",
    "    return 0,plans"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "lora_to_base(model)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "cnts = [0] * 6\n",
    "rs = [0] * 6\n",
    "\n",
    "input_outputs = list()\n",
    "\n",
    "# with open('fail_trial_experience_grounded_decoding.txt','wt') as f:\n",
    "for task in env.gamefiles:\n",
    "    input_output = dict()\n",
    "    ob, info = my_reset(env,[task])\n",
    "    input_output['gamefile'] = info['extra.gamefile'][0]\n",
    "    ob = '\\n'.join(ob[0].split('\\n\\n')[1:])\n",
    "    name = '/'.join(info['extra.gamefile'][0].split('/')[-3:-1])\n",
    "    print(name)\n",
    "    for i, (k, v) in enumerate(prefixes.items()):\n",
    "        if name.startswith(k):\n",
    "            # prompt = 'Interact with a household to solve a task. Here are two examples.\\n' + d[f'react_{v}_1'] + d[f'react_{v}_0'] + '\\nHere is the task.\\n'\n",
    "            prompt = 'Interact with a household to solve a task.\\nHere are two examples.\\n' + succ_examples[f'react_{v}_0'] + '\\n' + succ_examples[f'react_{v}_1'] + '\\nHere is the task.\\n'\n",
    "            # prompt_future = 'Interact with a household to solve a task.\\nHere is a successful example.\\n' + succ_examples[f'react_{v}_0'] + '\\nHere is a failed example.\\n'+ fail_examples[f'react_{v}_0'] + '\\nHere is the task.\\n'\n",
    "            prompt_future = ''\n",
    "            prompt_critic = ''\n",
    "            print(k, v)\n",
    "\n",
    "            prompt_experience = 'Given you a trajectory interacting with a household, your task is to predict whether current trajectory will fail or succeed.' + '\\nHere is the task you need to predict.\\n' \n",
    "            r, plans = alfworld_run(prompt, prompt_future, prompt_critic, prompt_experience, ob=ob, info=info)\n",
    "            input_output['plan'] = plans\n",
    "            rs[i] += r\n",
    "            cnts[i] += 1\n",
    "            break\n",
    "\n",
    "    input_outputs.append([input_output,r])\n",
    "    print(sum(cnts), 'r', r, 'rs', rs, 'cnts', cnts, 'sum(rs)/sum(cnts)', sum(rs) / sum(cnts))\n",
    "    print('------------\\n')\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "alfworld",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.18"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
