{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "##################################\n",
    "################ basics ##############\n",
    "###################################\n",
    "import jericho\n",
    "print(jericho.__version__)\n",
    "print(jericho.__file__)\n",
    "from jericho import *\n",
    "\n",
    "from env_with_memory import JerichoEnv\n",
    "\n",
    "import json\n",
    "\n",
    "# from env import *"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "# Create the environment, optionally specifying a random seed\n",
    "rom_path = \"../roms/jericho-game-suite/zork1.z5\"\n",
    "env = JerichoEnv(rom_path, seed=12)\n",
    "initial_observation = env.reset()\n",
    "print('init:', initial_observation)\n",
    "done = False\n",
    "while not done:\n",
    "    # Take an action in the environment using the step fuction.\n",
    "    # The resulting text-observation, reward, and game-over indicator is returned.\n",
    "    observation, reward, done, info = env.step('open mailbox')\n",
    "    print('obs:', observation)\n",
    "    # Total score and move-count are returned in the info dictionary\n",
    "    print('Total Score', info['score'], 'Moves', info['moves'])\n",
    "    done = True\n",
    "print('Scored', info['score'], 'out of', env.get_max_score())\n",
    "\n",
    "print('Recognized Vocabulary Words', list(env.get_dictionary()))\n",
    "\n",
    "from jericho import *\n",
    "env = JerichoEnv(rom_path)\n",
    "# env = JerichoEnv(self.args.rom_path, 0, self.vocab_act_rev,\n",
    "#                          self.args.env_step_limit)\n",
    "state = env.get_state() # Save the game to state\n",
    "print('s1:', state)\n",
    "observation, reward, done, info = env.step('attack troll') # Oops!\n",
    "# 'You swing and miss. The troll neatly removes your head.'\n",
    "print('obs2:', observation)\n",
    "env.set_state(state) # Restore to saved state\n",
    "\n",
    "from jericho import *\n",
    "bindings = load_bindings(rom_path)\n",
    "scores = []\n",
    "if 'walkthrough' in bindings:\n",
    "    walkthrough = bindings['walkthrough'].split('/')\n",
    "    seed = bindings['seed']\n",
    "    env = FrotzEnv(rom_path, seed=seed)\n",
    "    for act in walkthrough:\n",
    "        print('act:', act)\n",
    "        observation, reward, done, info = env.step(act)\n",
    "        print('obs:', observation)\n",
    "        scores.append(reward)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_walkthrough(rom_path):\n",
    "    # Create the environment, optionally specifying a random seed\n",
    "#     rom_path = \"roms/jericho-game-suite/{}\".format(game2rom[game_name])\n",
    "\n",
    "    bindings = load_bindings(rom_path)\n",
    "    scores = []\n",
    "    cum_r = 0.0\n",
    "    step = 0\n",
    "    if 'walkthrough' in bindings:\n",
    "        walkthrough = bindings['walkthrough'].split('/')\n",
    "        seed = bindings['seed']\n",
    "        env = FrotzEnv(rom_path, seed=seed)\n",
    "        for act in walkthrough:\n",
    "            print('step:', step)\n",
    "            step += 1\n",
    "            print('act:', act)\n",
    "            observation, reward, done, info = env.step(act)\n",
    "            print('obs:', observation)\n",
    "            scores.append(reward)\n",
    "            cum_r += reward\n",
    "            print('curR:', cum_r)\n",
    "            \n",
    "    return scores\n",
    "\n",
    "\n",
    "game_rom_path = \"../roms/jericho-game-suite/zork3.z5\"\n",
    "step_scores = get_walkthrough(game_rom_path)\n",
    "\n",
    "scores = step_scores[:100]\n",
    "scores = np.array(scores)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(np.sum(scores[:100]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import env_with_memory, importlib, sys\n",
    "\n",
    "importlib.reload(sys.modules['env_with_memory'])\n",
    "from env_with_memory import JerichoEnv\n",
    "\n",
    "# Create the environment, optionally specifying a random seed\n",
    "rom_path = \"../roms/jericho-game-suite/zork1.z5\"\n",
    "env = JerichoEnv(rom_path, seed=12)\n",
    "\n",
    "fileout = open('zork1.wt_traj.txt', 'w')\n",
    "\n",
    "initial_observation = env.reset()\n",
    "print('init:', initial_observation)\n",
    "done = False\n",
    "while not done:\n",
    "    # Take an action in the environment using the step fuction.\n",
    "    # The resulting text-observation, reward, and game-over indicator is returned.\n",
    "    observation, reward, done, info = env.step('open mailbox')\n",
    "    print('obs:', observation)\n",
    "    # Total score and move-count are returned in the info dictionary\n",
    "    print('Total Score', info['score'], 'Moves', info['moves'])\n",
    "    done = True\n",
    "print('Scored', info['score'], 'out of', env.env.get_max_score())\n",
    "\n",
    "# print('Recognized Vocabulary Words', list(env.get_dictionary()))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(info['valid'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import env_with_memory, importlib, sys\n",
    "\n",
    "def generate_output_tuple(observation, info):\n",
    "\n",
    "    output_dict = {}\n",
    "    output_dict['actions'] = []\n",
    "    output_dict['valid_actions'] = []\n",
    "    output_dict['observations'] = observation\n",
    "\n",
    "    for a in info['acts']:\n",
    "        new_dict = {}\n",
    "        new_dict['a'] = a.action\n",
    "        new_dict['t'] = a.template_id\n",
    "        new_dict['o'] = a.obj_ids\n",
    "    #     obj_ids = [str(x) for x in a.obj_ids]\n",
    "    #     print(a.action + '\\t' + str(a.template_id) + '\\t' + ' '.join(obj_ids))\n",
    "        output_dict['actions'].append(new_dict)\n",
    "\n",
    "    if isinstance(info['valid'][0], str):\n",
    "        new_dict = {}\n",
    "        new_dict['a'] = info['valid'][0]\n",
    "        output_dict['valid_actions'].append(new_dict)\n",
    "        \n",
    "    else:\n",
    "        for a in info['valid']:\n",
    "            new_dict = {}\n",
    "            new_dict['a'] = a.action\n",
    "            new_dict['t'] = a.template_id\n",
    "            new_dict['o'] = a.obj_ids\n",
    "        #     obj_ids = [str(x) for x in a.obj_ids]\n",
    "        #     print(a.action + '\\t' + str(a.template_id) + '\\t' + ' '.join(obj_ids))\n",
    "            output_dict['valid_actions'].append(new_dict)\n",
    "        \n",
    "    return output_dict\n",
    "    \n",
    "\n",
    "importlib.reload(sys.modules['env_with_memory'])\n",
    "from env_with_memory import JerichoEnv\n",
    "from tqdm import tqdm\n",
    "\n",
    "rom_path = \"../roms/jericho-game-suite/zork3.z5\"\n",
    "\n",
    "from jericho import *\n",
    "bindings = load_bindings(rom_path)\n",
    "scores = []\n",
    "if 'walkthrough' in bindings:\n",
    "    walkthrough = bindings['walkthrough'].split('/')\n",
    "    seed = bindings['seed']\n",
    "    \n",
    "    fileout = open('zork3.wt_traj.txt', 'w')\n",
    "    \n",
    "    env = JerichoEnv(rom_path, seed=seed)\n",
    "#     env = FrotzEnv(rom_path, seed=seed)\n",
    "    for idx, act in tqdm(enumerate(walkthrough)):\n",
    "#         if idx < 220:\n",
    "#             print('act:', act)\n",
    "#             observation, reward, done, info = env.env.step(act)\n",
    "#             print('obs:', observation)\n",
    "#             scores.append(reward)\n",
    "#         else:\n",
    "#             observation, reward, done, info = env.step(act)\n",
    "#             scores.append(reward)\n",
    "#             if len(info['valid']) == 3 and info['valid'][0] == 'wait' and info['valid'][1] == 'yes' and info['valid'][2] == 'no':\n",
    "#                 if idx == len(walkthrough) - 1:\n",
    "#                     print(observation)\n",
    "#                     break\n",
    "#                 else:\n",
    "#                     info['valid'] = [walkthrough[idx+1]]\n",
    "\n",
    "#             output_dict = generate_output_tuple(observation, info)    \n",
    "#             fileout.write(json.dumps(output_dict) + '\\n')\n",
    "        \n",
    "        observation, reward, done, info = env.step(act)\n",
    "        scores.append(reward)\n",
    "\n",
    "        if len(info['valid']) == 3 and info['valid'][0] == 'wait' and info['valid'][1] == 'yes' and info['valid'][2] == 'no':\n",
    "            if idx == len(walkthrough) - 1:\n",
    "                print(observation)\n",
    "                break\n",
    "            else:\n",
    "                info['valid'] = [walkthrough[idx+1]]\n",
    "\n",
    "        output_dict = generate_output_tuple(observation, info)    \n",
    "        fileout.write(json.dumps(output_dict) + '\\n')\n",
    "    \n",
    "    print('Total Score', info['score'])\n",
    "    \n",
    "    fileout.close()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fileout.close()\n",
    "print(observation)\n",
    "print('Total Score', info['score'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(observation)\n",
    "print('Total Score', info['score'])\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(isinstance(info['valid'][0], str))\n",
    "print(isinstance(info['acts'][0], str))\n",
    "print(idx)\n",
    "print(act)\n",
    "print(walkthrough[idx + 1])\n",
    "print(walkthrough[idx + 2])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "importlib.reload(sys.modules['env_with_memory'])\n",
    "from env_with_memory import JerichoEnv\n",
    "from tqdm import tqdm\n",
    "\n",
    "def test():\n",
    "    bindings = load_bindings(rom_path)\n",
    "    scores = []\n",
    "    if 'walkthrough' in bindings:\n",
    "        walkthrough = bindings['walkthrough'].split('/')\n",
    "        seed = bindings['seed']\n",
    "\n",
    "        fileout = open('zork1.wt_traj.txt2', 'w')\n",
    "\n",
    "        env = JerichoEnv(rom_path, seed=seed)\n",
    "    #     env = FrotzEnv(rom_path, seed=seed)\n",
    "        for idx, act in tqdm(enumerate(walkthrough)):\n",
    "            if idx < 397:\n",
    "                print('act:', act)\n",
    "                observation, reward, done, info = env.env.step(act)\n",
    "                print('obs:', observation)\n",
    "                scores.append(reward)\n",
    "            \n",
    "            else:\n",
    "                observation, reward, done, info = env.step(act)\n",
    "                print(observation)\n",
    "#                 print(info['valid'])\n",
    "#                 break\n",
    "\n",
    "    #         output_dict = generate_output_tuple(observation, info)    \n",
    "    #         fileout.write(json.dumps(output_dict) + '\\n')\n",
    "\n",
    "        print('Total Score', info['score'])\n",
    "        \n",
    "test()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
