{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "##################################\n",
    "################ basics ##############\n",
    "###################################\n",
    "import jericho\n",
    "print(jericho.__version__)\n",
    "print(jericho.__file__)\n",
    "from jericho import *\n",
    "\n",
    "from env import JerichoEnv\n",
    "\n",
    "import json\n",
    "\n",
    "# from env import *"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_walkthrough(rom_path):\n",
    "    # Create the environment, optionally specifying a random seed\n",
    "#     rom_path = \"roms/jericho-game-suite/{}\".format(game2rom[game_name])\n",
    "\n",
    "    bindings = load_bindings(rom_path)\n",
    "    scores = []\n",
    "    cum_r = 0.0\n",
    "    step = 0\n",
    "    if 'walkthrough' in bindings:\n",
    "        walkthrough = bindings['walkthrough'].split('/')\n",
    "        seed = bindings['seed']\n",
    "        env = FrotzEnv(rom_path, seed=seed)\n",
    "        for act in walkthrough:\n",
    "            print('step:', step)\n",
    "            step += 1\n",
    "            print('act:', act)\n",
    "            observation, reward, done, info = env.step(act)\n",
    "            print('obs:', observation)\n",
    "            scores.append(reward)\n",
    "            cum_r += reward\n",
    "            print('curR:', cum_r)\n",
    "            \n",
    "    return scores\n",
    "\n",
    "\n",
    "game_rom_path = \"../roms/jericho-game-suite/zork1.z5\"\n",
    "step_scores = get_walkthrough(game_rom_path)\n",
    "\n",
    "scores = step_scores[:100]\n",
    "scores = np.array(scores)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(np.sum(scores[:100]))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### generating ssa prediction data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "tmp_str = '''905         & 1  & -   & \\textbf{0}     & \\textbf{0}     & \\textbf{0}     & \\textbf{0}     & \\textbf{0}\\\\ %done\n",
    "acorncourt  & 30 & -   & 1.6   & \\textbf{10}    & 0.3   & \\textbf{10.0}    & \\textbf{10.0} \\\\ % discarded, 28199\n",
    "advent      & 350& -  & 36    & 36    & 36    & \\textbf{63.9}     & 36\\\\ % done, use_hist  63.946666666666665  no_hist  36.0\n",
    "adventureland & 100& - & 0     & 20.6  & 0     &\\textbf{24.2}   &21.7\\\\ % done, use_hist  24.173333333333332  no_hist  21.653333333333332\n",
    "afflicted   & 75 & -   & 1.3   & 2.6   & --    &\\textbf{8.0}   &\\textbf{8.0} \\\\ % discard, 30700\n",
    "anchor      & 100 & -  & \\textbf{0}     & \\textbf{0}     & \\textbf{0}     & \\textbf{0}     & \\textbf{0}\\\\ % discarded\n",
    "awaken      & 50 & -  & \\textbf{0}     & \\textbf{0}     & \\textbf{0}     & \\textbf{0}     & \\textbf{0}\\\\ % discarded\n",
    "balances    & 51& -    & 4.8   & \\textbf{10}    & \\textbf{10}    & \\textbf{10}    & \\textbf{10}\\\\ %done\n",
    "deephome    & 300 & -  & \\textbf{1}     & \\textbf{1}     & \\textbf{1}     & \\textbf{1}     & \\textbf{1}\\\\ %done\n",
    "detective   & 360 & -  & 169   & 197.8 & 207.9 & \\textbf{317.7} & 291.3\\\\ % done, to plot use_hist  317.7  no_hist  291.3\n",
    "dragon      & 25 & -   & -5.3  & -3.5  & 0     & 0.04     &\\textbf{4.84}\\\\ % done, use_hist  0.04  no_hist  4.84\n",
    "enchanter   & 400 & -  & 8.6   & \\textbf{20}    & 12.1  & \\textbf{20.0}    & \\textbf{20.0} \\\\ % done\n",
    "gold        & 100& -   & \\textbf{4.1}   & 0     & --    & 0     & 0\\\\ % discarded\n",
    "inhumane    & 90 & -   & 0.7   & 0     & \\textbf{3}     & 0     & 0\\\\ % done\n",
    "jewel       & 90 & -   & 0     & 1.6   & 1.8   & \\textbf{4.46}   & 2.0 \\\\ % done, to plot, use_hist  4.46  no_hist  2.0\n",
    "karn        & 170 & -  & 0.7   & 2.1   & 0     & \\textbf{10.0}  &\\textbf{10.0}\\\\ % done, \n",
    "library     & 30 & -   & 6.3   & 17    & 14.3  & 17.7  & \\textbf{18.1} \\\\  % done, to plot use_hist  17.67  no_hist  18.12\n",
    "ludicorp    & 150 & -  & 6     & 13.8  & 17.8 & \\textbf{19.7}  &  17.0  \\\\ % done, to plot use_hist  19.66  no_hist  17.04\n",
    "moonlit     & 1  & -   & \\textbf{0}     & \\textbf{0}     & \\textbf{0}     & \\textbf{0}     & \\textbf{0}\\\\ % discarded\n",
    "omniquest   & 50 & -   & \\textbf{16.8}  & 10    & 3     & 10.0  & 10.0\\\\ % done\n",
    "pentari     & 70 & -   & 17.4  & 27.2  & \\textbf{50.7}  & 44.4  & 43.8\\\\ % done, to plot use_hist  44.35  no_hist  43.75\n",
    "reverb      & 50 & -   & 0.3   & \\textbf{8.2}    & --    & 2.0   & 2.0\\\\ % done\n",
    "snacktime   & 50 & -   & \\textbf{9.7}   & 0     & 0     & 0     & 0\\\\ % discard\n",
    "sorcerer    & 400 & -  & 5     & 20.8  & 5.8   &  \\textbf{38.6}   & 38.3 \\\\ % done, to plot use_hist  38.6  no_hist  38.266666666666666\n",
    "spellbrkr   & 600 & -  & 18.7  & \\textbf{37.8}  & 21.3  & 25    & 25\\\\ % discarded\n",
    "spirit      & 250 & -  & 0.6   & 0.8   & 1.3   &  3.8  & \\textbf{5.2} \\\\ % done, to plot use_hist  3.84  no_hist  5.24\n",
    "temple      & 35 & -   & 7.9   & 7.4   & 7.6   & \\textbf{8.0}     & \\textbf{8.0}\\\\ % dis\n",
    "tryst205    & 350 & -  & 0     & 9.6   & --    & \\textbf{10.0}   & \\textbf{10.0}\\\\ % done, \n",
    "yomomma     & 35 & -   & 0     & 0.4   & --    & \\textbf{1.0}     & \\textbf{1.0}\\\\ % done,\n",
    "zenon       & 20 & -   & 0     & 0     & \\textbf{3.9}   & 0     & 0\\\\ % dis\n",
    "zork1       & 350 & 102.0  & 9.9   & 32.6  & 34    & 38.3  & \\textbf{38.8}\\\\ % done, use_hist  38.29333333333334  no_hist  38.76\n",
    "zork3       & 7  & -   & 0     & 0.5   & 0.1   & \\textbf{4.0}   & \\textbf{4.0}\\\\ % done, use_hist  4.0  no_hist  3.973333333333333\n",
    "ztuu        & 100$^{a}$ & -  & 4.9   & 21.6  & 9.2   & \\textbf{85.4}  & 79.1\\\\ % done, use_hist'''\n",
    "\n",
    "kg_a2c_str = '''905 82 296 0 0 1\n",
    "acorncourt 151 343 1.6 0.3 30\n",
    "advent 189 786 36 36 350\n",
    "adventureland 156 398 0 0 100\n",
    "anchor 260 2257 0 0 100\n",
    "awaken 159 505 0 0 50\n",
    "balances 156 452 4.8 10 51\n",
    "deephome 173 760 1 1 300\n",
    "detective 197 344 169 207.9 360\n",
    "dragon 177 1049 -5.3 0 25\n",
    "enchanter 290 722 8.6 12.1 400\n",
    "inhumane 141 409 0.7 3 300\n",
    "jewel 161 657 0 1.8 90\n",
    "karn 161 657 1.2 0 90\n",
    "library 173 510 6.3 14.3 30\n",
    "ludicorp 187 503 6 17.8 150\n",
    "moonlit 166 669 0 0 1\n",
    "omniquest 207 460 16.8 3 50\n",
    "pentari 155 472 17.4 50.7 70\n",
    "snacktime 201 468 9.7 0 50\n",
    "sorcerer 288 1013 5 5.8 400\n",
    "spellbrkr 333 844 18.7 21.3 600\n",
    "spirit 169 1112 0.6 1.3 250\n",
    "temple 175 622 7.9 7.6 35\n",
    "zenon 149 401 0 3.9 350\n",
    "zork1 237 697 9.9 34 350\n",
    "zork3 214 564 0 .1 7\n",
    "ztuu 186 607 4.9 9.2 100'''\n",
    "\n",
    "import re\n",
    "\n",
    "game_info = tmp_str.split('\\n')\n",
    "games = []\n",
    "\n",
    "for game_line in game_info:\n",
    "    if game_line == '':\n",
    "        continue\n",
    "    array = game_line.split('&')\n",
    "    games.append(re.sub('\\s+', '', array[0], count=0, flags=0))\n",
    "    \n",
    "print(games)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "games = ['zork2', 'wishbringer']\n",
    "# games = ['ztuu']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print('#number of games: {}'.format(len(games)))\n",
    "\n",
    "roms = os.listdir('../roms/jericho-game-suite/')\n",
    "game2rom = {}\n",
    "logs = []\n",
    "for game in games:\n",
    "    for rom in roms:\n",
    "        if rom.startswith(game + '.z'):\n",
    "            game2rom[game] = rom\n",
    "#             print('find {} for {}'.format(rom, game))\n",
    "            logs.append('find {} for {}'.format(rom, game))\n",
    "    if game not in game2rom:\n",
    "        print('cannot find rom for {}'.format(game))\n",
    "                        \n",
    "print('#number of roms founds: {}'.format(len(logs)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import env, importlib, sys\n",
    "\n",
    "def generate_output_tuple(observation, info):\n",
    "\n",
    "    output_dict = {}\n",
    "#     output_dict['actions'] = []\n",
    "    output_dict['valid_actions'] = []\n",
    "    output_dict['observations'] = observation\n",
    "\n",
    "#     for a in info['act']:\n",
    "#         new_dict = {}\n",
    "#         new_dict['a'] = a.action\n",
    "#         new_dict['t'] = a.template_id\n",
    "#         new_dict['o'] = a.obj_ids\n",
    "#     #     obj_ids = [str(x) for x in a.obj_ids]\n",
    "#     #     print(a.action + '\\t' + str(a.template_id) + '\\t' + ' '.join(obj_ids))\n",
    "#         output_dict['actions'].append(new_dict)\n",
    "\n",
    "    if isinstance(info['valid_act'][0], str):\n",
    "        new_dict = {}\n",
    "        new_dict['a'] = info['valid_act'][0]\n",
    "        output_dict['valid_actions'].append(new_dict)\n",
    "        \n",
    "    else:\n",
    "        for a_list in info['valid_act']:\n",
    "            new_list = []\n",
    "            for a in a_list:\n",
    "                new_dict = {}\n",
    "                new_dict['a'] = a.action\n",
    "                new_dict['t'] = a.template_id\n",
    "                new_dict['o'] = a.obj_ids\n",
    "                \n",
    "                new_list.append(new_dict)\n",
    "        #     obj_ids = [str(x) for x in a.obj_ids]\n",
    "        #     print(a.action + '\\t' + str(a.template_id) + '\\t' + ' '.join(obj_ids))\n",
    "            output_dict['valid_actions'].append(new_list)\n",
    "        \n",
    "    return output_dict\n",
    "    \n",
    "\n",
    "importlib.reload(sys.modules['env'])\n",
    "from env import JerichoEnv\n",
    "from tqdm import tqdm\n",
    "from jericho import *\n",
    "\n",
    "game_max_scores = []\n",
    "\n",
    "for game in games:\n",
    "    \n",
    "    if os.path.isfile('../data/ssa_data/supervised/{}.ssa.wt_traj.txt'.format(game)):\n",
    "        continue\n",
    "    \n",
    "    print('generating trajectary for game: {}'.format(game))\n",
    "#     continue\n",
    "\n",
    "    rom_path = \"../roms/jericho-game-suite/{}\".format(game2rom[game]) # \"../roms/jericho-game-suite/zork1.z5\"\n",
    "\n",
    "    bindings = load_bindings(rom_path)\n",
    "    scores = []\n",
    "    if 'walkthrough' in bindings:\n",
    "        walkthrough = bindings['walkthrough'].split('/')\n",
    "        seed = bindings['seed']\n",
    "\n",
    "        fileout = open('../data/ssa_data/supervised/{}.ssa.wt_traj.txt'.format(game), 'w')\n",
    "\n",
    "        env = JerichoEnv(rom_path, seed=seed)\n",
    "#         env = JerichoEnv(rom_path, seed=seed, env_num=0)\n",
    "    #     env = FrotzEnv(rom_path, seed=seed)\n",
    "        for idx, act in enumerate(walkthrough):\n",
    "\n",
    "            observation, reward, done, info = env.step(act)\n",
    "#             observation, reward, done, info = env.step(act, parallel=False)\n",
    "            scores.append(reward)\n",
    "\n",
    "    #         if len(info['valid_act']) == 3 and info['valid_act'][0] == 'wait' and info['valid_act'][1] == 'yes' and info['valid_act'][2] == 'no':\n",
    "            if len(info['valid_act']) == 0:\n",
    "                if idx == len(walkthrough) - 1:\n",
    "    #                 print(observation)\n",
    "                    break\n",
    "                else:\n",
    "                    info['valid_act'] = [walkthrough[idx+1]]\n",
    "\n",
    "            output_dict = generate_output_tuple(observation, info)    \n",
    "            fileout.write(json.dumps(output_dict) + '\\n')\n",
    "    #         break\n",
    "\n",
    "        print('Total Score', info['score'])\n",
    "        game_max_scores.append(info['score'])\n",
    "\n",
    "        fileout.close()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(game_max_scores)\n",
    "# [250, 35, 350, 34, 20, 350, 7, 100]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "ls"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(walkthrough[idx])\n",
    "print(walkthrough[idx+1])\n",
    "print(info)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### generating valid action prediction data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import env, importlib, sys\n",
    "\n",
    "def generate_output_tuple(observation, info):\n",
    "\n",
    "    output_dict = {}\n",
    "    output_dict['actions'] = []\n",
    "    output_dict['valid_actions'] = []\n",
    "    output_dict['observations'] = observation\n",
    "\n",
    "    for a in info['act']:\n",
    "        new_dict = {}\n",
    "        new_dict['a'] = a.action\n",
    "        new_dict['t'] = a.template_id\n",
    "        new_dict['o'] = a.obj_ids\n",
    "    #     obj_ids = [str(x) for x in a.obj_ids]\n",
    "    #     print(a.action + '\\t' + str(a.template_id) + '\\t' + ' '.join(obj_ids))\n",
    "        output_dict['actions'].append(new_dict)\n",
    "\n",
    "    if isinstance(info['valid_act'][0], str):\n",
    "        new_dict = {}\n",
    "        new_dict['a'] = info['valid'][0]\n",
    "        output_dict['valid_actions'].append(new_dict)\n",
    "        \n",
    "    else:\n",
    "        for a_list in info['valid_act']:\n",
    "            new_list = []\n",
    "            for a in a_list:\n",
    "                new_dict = {}\n",
    "                new_dict['a'] = a.action\n",
    "                new_dict['t'] = a.template_id\n",
    "                new_dict['o'] = a.obj_ids\n",
    "                \n",
    "                new_list.append(new_dict)\n",
    "        #     obj_ids = [str(x) for x in a.obj_ids]\n",
    "        #     print(a.action + '\\t' + str(a.template_id) + '\\t' + ' '.join(obj_ids))\n",
    "            output_dict['valid_actions'].append(new_list)\n",
    "        \n",
    "    return output_dict\n",
    "    \n",
    "\n",
    "importlib.reload(sys.modules['env'])\n",
    "from env import JerichoEnv\n",
    "from tqdm import tqdm\n",
    "\n",
    "rom_path = \"../roms/jericho-game-suite/zork3.z5\"\n",
    "\n",
    "from jericho import *\n",
    "bindings = load_bindings(rom_path)\n",
    "scores = []\n",
    "if 'walkthrough' in bindings:\n",
    "    walkthrough = bindings['walkthrough'].split('/')\n",
    "    seed = bindings['seed']\n",
    "    \n",
    "    fileout = open('zork3.wt_traj.with_new_env.txt', 'w')\n",
    "    \n",
    "    env = JerichoEnv(rom_path, seed=seed)\n",
    "#     env = FrotzEnv(rom_path, seed=seed)\n",
    "    for idx, act in tqdm(enumerate(walkthrough)):\n",
    "#         if idx < 220:\n",
    "#             print('act:', act)\n",
    "#             observation, reward, done, info = env.env.step(act)\n",
    "#             print('obs:', observation)\n",
    "#             scores.append(reward)\n",
    "#         else:\n",
    "#             observation, reward, done, info = env.step(act)\n",
    "#             scores.append(reward)\n",
    "#             if len(info['valid']) == 3 and info['valid'][0] == 'wait' and info['valid'][1] == 'yes' and info['valid'][2] == 'no':\n",
    "#                 if idx == len(walkthrough) - 1:\n",
    "#                     print(observation)\n",
    "#                     break\n",
    "#                 else:\n",
    "#                     info['valid'] = [walkthrough[idx+1]]\n",
    "\n",
    "#             output_dict = generate_output_tuple(observation, info)    \n",
    "#             fileout.write(json.dumps(output_dict) + '\\n')\n",
    "\n",
    "        observation, reward, done, info = env.step(act)\n",
    "#         observation, reward, done, info = env.step(act, parallel=False)\n",
    "        scores.append(reward)\n",
    "\n",
    "#         if len(info['valid_act']) == 3 and info['valid_act'][0] == 'wait' and info['valid_act'][1] == 'yes' and info['valid_act'][2] == 'no':\n",
    "        if len(info['valid_act']) == 0:\n",
    "            if idx == len(walkthrough) - 1:\n",
    "                print(observation)\n",
    "                break\n",
    "            else:\n",
    "                info['valid_act'] = [walkthrough[idx+1]]\n",
    "\n",
    "        output_dict = generate_output_tuple(observation, info)    \n",
    "        fileout.write(json.dumps(output_dict) + '\\n')\n",
    "#         break\n",
    "    \n",
    "    print('Total Score', info['score'])\n",
    "    \n",
    "    fileout.close()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for a in info['valid_act']:\n",
    "    print(a)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for a in info['valid_act']:\n",
    "    print(a)\n",
    "# for a in info['act']:\n",
    "#     print(a.action)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
