{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from dataset import TextDataset\n",
    "import numpy as np\n",
    "import sys, os, json\n",
    "import gzip\n",
    "from colored import fg, attr, bg\n",
    "\n",
    "from env import JerichoEnv\n",
    "from tqdm import tqdm\n",
    "from jericho import *\n",
    "\n",
    "import random\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import re\n",
    "from jericho.util import verb_usage_count\n",
    "from jericho.template_action_generator import TemplateActionGenerator\n",
    "\n",
    "class TemplateActionParser(TemplateActionGenerator):\n",
    "    def __init__(self, rom_bindings):        \n",
    "        self.templates_alias_dict = {}\n",
    "        self.verb_to_templates = {}\n",
    "        self.template2template = {}\n",
    "        super(TemplateActionParser, self).__init__(rom_bindings)\n",
    "        \n",
    "        self.id2template = None\n",
    "        self.template2id = None\n",
    "        \n",
    "        self.additional_templates = ['land']\n",
    "        self.templates = list(set(self.templates + self.additional_templates))\n",
    "\n",
    "        self.templates.sort()\n",
    "        self._compute_template()\n",
    "        \n",
    "        BASIC_ACTIONS = 'north/south/west/east/northwest/southwest/northeast/southeast/up/down/enter/exit/take all'.split('/')\n",
    "        self.BASIC_ACTIONS = {k:1 for k in BASIC_ACTIONS}\n",
    "        \n",
    "        self.add_template2template = {}\n",
    "        for action in list(self.BASIC_ACTIONS.keys()) + self.additional_templates + ['examine OBJ']:\n",
    "            self.add_template2template[action] = action\n",
    "        \n",
    "        \n",
    "    def _preprocess_templates(self, templates, max_word_length):\n",
    "        '''\n",
    "        Converts templates with multiple verbs and takes the first verb.\n",
    "        '''\n",
    "        out = []\n",
    "        vb_usage_fn = lambda verb: verb_usage_count(verb, max_word_length)\n",
    "        p = re.compile('\\S+(/\\S+)+')\n",
    "        for template in templates:\n",
    "#             print(template)\n",
    "            if not template:\n",
    "                continue\n",
    "            has_alias = True\n",
    "            while True:\n",
    "                match = p.search(template)\n",
    "                if not match:\n",
    "#                     print('{} not matched'.format(template))\n",
    "                    has_alias = False\n",
    "                    break\n",
    "                    \n",
    "                verb_alias = match.group().split('/')\n",
    "                \n",
    "                verb = max(match.group().split('/'), key=vb_usage_fn)\n",
    "                verb_template = template[:match.start()] + verb + template[match.end():]\n",
    "                \n",
    "                for alias in verb_alias:\n",
    "                    alias_template = template[:match.start()] + alias + template[match.end():]\n",
    "                    self.template2template[alias_template] = verb_template\n",
    "                    \n",
    "                    if alias in self.verb_to_templates:\n",
    "                        self.verb_to_templates[alias].append(alias_template)\n",
    "                    else:\n",
    "                        self.verb_to_templates[alias] = [alias_template]\n",
    "                \n",
    "#                 for alias in verb_alias:\n",
    "#                     if alias in self.verb_to_templates:\n",
    "#                         self.verb_to_templates[alias].append(template)\n",
    "#                     else:\n",
    "#                         self.verb_to_templates[alias] = [template]\n",
    "                template = verb_template\n",
    "                \n",
    "            ts = template.split()\n",
    "            if ts[0] in defines.ILLEGAL_ACTIONS:\n",
    "                continue\n",
    "            if ts[0] in defines.NO_EFFECT_ACTIONS and len(ts) == 1:\n",
    "                continue\n",
    "                \n",
    "            if not has_alias:\n",
    "                t_tokens = template.split()\n",
    "                alias = t_tokens[0]\n",
    "                verb_alias = [alias]\n",
    "                if alias in self.verb_to_templates:\n",
    "                    self.verb_to_templates[alias].append(template)\n",
    "                else:\n",
    "                    self.verb_to_templates[alias] = [template]\n",
    "                    \n",
    "                self.template2template[template] = template\n",
    "                \n",
    "            self.templates_alias_dict[template] = verb_alias\n",
    "            out.append(template)\n",
    "        return out\n",
    "    \n",
    "    def _compute_template(self):\n",
    "        self.id2template = {}\n",
    "        self.template2id = {}\n",
    "        for i, t in enumerate(self.templates):\n",
    "            self.id2template[i] = t\n",
    "            self.template2id[t] = i\n",
    "        return\n",
    "\n",
    "    def parse_action(self, action):\n",
    "\n",
    "        tokens = action.split()\n",
    "        verb = tokens[0]\n",
    "#         if verb == 'down':\n",
    "#             print(verb in self.BASIC_ACTIONS and len(tokens) == 1)\n",
    "\n",
    "        if (verb in self.BASIC_ACTIONS or verb in self.additional_templates) and len(tokens) == 1:\n",
    "            return [verb]\n",
    "\n",
    "        if verb not in self.verb_to_templates:\n",
    "#             if (verb in self.BASIC_ACTIONS or verb in self.additional_templates) and len(tokens) == 1:\n",
    "#     #             print(verb)\n",
    "#                 return [verb]\n",
    "            if verb == 'examine':\n",
    "                return ['examine OBJ', ' '.join(tokens[1:])]\n",
    "            else:\n",
    "                print('cannot recognize verb:', verb)\n",
    "                return None\n",
    "        else:\n",
    "            templates = self.verb_to_templates[verb]\n",
    "            for template in templates:\n",
    "#                 print(template.split())\n",
    "                t_tokens = template.split()\n",
    "#                 print(t_tokens)\n",
    "                \n",
    "                slot_num = 0\n",
    "                for t_token in t_tokens:\n",
    "#                     print(t_token, 'OBJ', t_token == 'OBJ')\n",
    "                    if t_token == 'OBJ':\n",
    "                        slot_num += 1\n",
    "#                 ' \\S+'\n",
    "                re_str = template.replace('OBJ', '(\\S+)')\n",
    "    #             print(re_str)\n",
    "    #             p = re.compile('\\S+(/\\S+)+')\n",
    "                p = re.compile(re_str)\n",
    "\n",
    "                match = p.search(action)\n",
    "                if not match:\n",
    "                    continue\n",
    "                elif match.group() == action:\n",
    "                    ret_tuple = [template]\n",
    "#                     print(slot_num)\n",
    "                    for i in range(slot_num):\n",
    "                        ret_tuple.append(match.group(i+1))\n",
    "                    return ret_tuple\n",
    "                else:\n",
    "                    continue\n",
    "        \n",
    "        templates = self.verb_to_templates[verb]\n",
    "        for template in templates:\n",
    "            t_tokens = template.split()\n",
    "            slot_num = 0\n",
    "            for t_id, t_token in enumerate(t_tokens):\n",
    "                if t_token == 'OBJ':\n",
    "                    slot_num += 1\n",
    "                    t_tokens[t_id] = 'OBJ%d'%(slot_num - 1)\n",
    "#                 ' \\S+'\n",
    "\n",
    "            re_str = ' '.join(t_tokens)\n",
    "            for i in range(slot_num):\n",
    "                re_str = re_str.replace('OBJ%d'%(i), '(?P<obj%d>\\S+( \\S+)*)'%(i))\n",
    "#             print(re_str)\n",
    "#             p = re.compile('\\S+(/\\S+)+')\n",
    "            p = re.compile(re_str)\n",
    "\n",
    "            match = p.search(action)\n",
    "            if not match:\n",
    "                continue\n",
    "            elif match.group() == action:\n",
    "                ret_tuple = [template]\n",
    "                for i in range(slot_num):\n",
    "                    ret_tuple.append(match.group('obj%d'%(i)))\n",
    "                return ret_tuple\n",
    "            else:\n",
    "                continue\n",
    "        return None   \n",
    "    \n",
    "# act_par = TemplateActionParser(bindings)\n",
    "# print(act_par.templates_alias_dict)\n",
    "# print(act_par.verb_to_templates)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_walkthrough(rom_path):\n",
    "    # Create the environment, optionally specifying a random seed\n",
    "#     rom_path = \"roms/jericho-game-suite/{}\".format(game2rom[game_name])\n",
    "\n",
    "    bindings = load_bindings(rom_path)\n",
    "    scores = []\n",
    "    cum_r = 0.0\n",
    "    step = 0\n",
    "    if 'walkthrough' in bindings:\n",
    "        walkthrough = bindings['walkthrough'].split('/')\n",
    "        seed = bindings['seed']\n",
    "        env = FrotzEnv(rom_path, seed=seed)\n",
    "        for act in walkthrough:\n",
    "            print('step:', step)\n",
    "            step += 1\n",
    "            print('act:', act)\n",
    "            observation, reward, done, info = env.step(act)\n",
    "            print('obs:', observation)\n",
    "            scores.append(reward)\n",
    "            cum_r += reward\n",
    "            print('curR:', cum_r)\n",
    "            \n",
    "    return scores\n",
    "\n",
    "\n",
    "game_rom_path = \"../roms/jericho-game-suite/zork1.z5\"\n",
    "step_scores = get_walkthrough(game_rom_path)\n",
    "\n",
    "scores = step_scores[:100]\n",
    "scores = np.array(scores)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "'''\n",
    "testing\n",
    "zork1 40.01 44.62 34 35 33.6 35 32 41.6 31\n",
    "library 36.76 46.45 14.3 19 10.0 18 19 19 18\n",
    "detective 60.28 63.21 207.9 214 246.1 274 320 330 304\n",
    "balances 55.26 56.49 10 10 9.8 10 10 10 10\n",
    "pentari 63.89 68.37 50.7 56 48.2 56 56 58 40\n",
    "ztuu 28.71 29.76 6 9 5 5 5 11.8 5\n",
    "ludicorp 52.32 59.95 17.8 19 17.6 19 19 22.8 20.6\n",
    "deephome 8.03 9.27 1 1 1 1 8 6 1\n",
    "temple \n",
    "'''\n",
    "# m, h, e, e, e, m, m, m, m, m\n",
    "eval_games = ['zork3', 'anchor', 'detective', 'ztuu', 'temple', 'yomomma', 'jewel', 'gold', 'karn', 'zenon']\n",
    "eval_games_left = ['anchor', 'yomomma']\n",
    "train_games_left = ['ludicorp', 'spirit', 'tryst205', 'spellbrkr']\n",
    "\n",
    "all_games = ['905', 'acorncourt', 'advent', 'adventureland', 'afflicted', 'anchor', 'awaken', \n",
    "         'balances', 'deephome', 'detective', 'dragon', 'enchanter', 'gold', 'inhumane', 'jewel', \n",
    "         'karn', 'library', 'ludicorp', 'moonlit', 'omniquest', 'pentari', 'reverb', 'snacktime', \n",
    "         'sorcerer', 'spellbrkr', 'spirit', 'temple', 'tryst205', 'yomomma', 'zenon', 'zork1', 'zork3', 'ztuu']\n",
    "games_with_ns_actions = ['library', 'pentari', 'ludicorp', 'deephome', 'advent', \n",
    "                         'balances', 'sorcerer', 'tryst205', 'spellbrkr', 'enchanter', 'spirit']\n",
    "\n",
    "# games = ['zork1', 'zork3', 'enchanter', 'spellbrkr', 'sorcerer']\n",
    "# games = ['zork2', 'wishbringer']\n",
    "zork_games = ['zork1', 'zork3', 'enchanter', 'zork2', 'wishbringer', 'sorcerer', 'spellbrkr']\n",
    "\n",
    "hard_games = ['sorcerer', 'tryst205', 'spellbrkr', 'anchor', 'enchanter', 'spirit']\n",
    "middle_games = ['ludicorp', 'deephome', 'yomomma', 'advent', 'jewel', 'zork1', 'gold', 'balances',\n",
    "                'karn', 'zenon', 'zork3']\n",
    "\n",
    "games = []\n",
    "for game in all_games:\n",
    "    if game not in zork_games and game not in eval_games:\n",
    "        print(game)\n",
    "        games.append(game)\n",
    "        \n",
    "# games = []\n",
    "# for game in all_games:\n",
    "#     if game not in eval_games:\n",
    "#         print(game)\n",
    "#         games.append(game)\n",
    "# print(games)\n",
    "# games = games_with_ns_actions \n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#                 state_hash = _get_world_state_hash(self.env)\n",
    "#                 save = self.env.get_state()\n",
    "#                 look, _, _, _ = self.env.step('look')\n",
    "#                 self.env.set_state(save)\n",
    "#                 inv, _, _, _ = self.env.step('inventory')\n",
    "#                 self.env.set_state(save)\n",
    "\n",
    "from jericho.util import clean\n",
    "\n",
    "def get_full_obs(env, observation):\n",
    "    save = env.get_state()\n",
    "    look, _, _, _ = env.step('look')\n",
    "    env.set_state(save)\n",
    "    inv, _, _, _ = env.step('inventory')\n",
    "    env.set_state(save)\n",
    "\n",
    "    ob = clean(look) + '|' + clean(inv) + '|' + clean(observation)\n",
    "    \n",
    "    return ob\n",
    "\n",
    "def generate_sas_data(game, rom_path):\n",
    "    # Create the environment, optionally specifying a random seed\n",
    "#     rom_path = \"roms/jericho-game-suite/{}\".format(game2rom[game_name])\n",
    "\n",
    "    bindings = load_bindings(rom_path)\n",
    "    scores = []\n",
    "    cum_r = 0.0\n",
    "    step = 0\n",
    "    if 'walkthrough' in bindings:\n",
    "        walkthrough = bindings['walkthrough'].split('/')\n",
    "        seed = bindings['seed']\n",
    "        env = FrotzEnv(rom_path, seed=seed)\n",
    "        \n",
    "        filein = open('../data/ssa_data/jecc_sup/{}.ssa.wt_traj.txt'.format(game))\n",
    "        fileout = open('../data/ssa_data/jecc_sup/{}.sas.wt_traj.txt.new'.format(game), 'w')\n",
    "        \n",
    "        lines = filein.readlines()\n",
    "        \n",
    "        print(len(lines))\n",
    "        \n",
    "        for idx, act in enumerate(walkthrough):\n",
    "            print('step:', step)\n",
    "            step += 1\n",
    "            print('act:', act)\n",
    "            \n",
    "            state_save = env.get_state()\n",
    "            \n",
    "            observation, reward, done, info = env.step(act)\n",
    "            print('obs:', observation)\n",
    "            scores.append(reward)\n",
    "            cum_r += reward\n",
    "            print('curR:', cum_r)          \n",
    "            new_ob = get_full_obs(env, observation)\n",
    "#             print(new_ob)\n",
    "            \n",
    "            state_prim_save = env.get_state()\n",
    "            \n",
    "            if idx > 0:\n",
    "                for valid_act_group in wt_ssa_data['valid_actions']:\n",
    "                    valid_act_tuple = valid_act_group[0]\n",
    "                    valid_act = valid_act_tuple['a']\n",
    "                    \n",
    "                    env.set_state(state_save)\n",
    "                    observation, reward, done, info = env.step(valid_act)\n",
    "                    ob = get_full_obs(env, observation)\n",
    "#                     print('[' + valid_act + ']: ' + ob)\n",
    "                    valid_act_tuple['observations'] = ob\n",
    "    \n",
    "#                 print(wt_ssa_data['valid_actions'])\n",
    "                fileout.write(json.dumps(wt_ssa_data) + '\\n')\n",
    "                env.set_state(state_prim_save)\n",
    "                    \n",
    "            \n",
    "            if idx < len(walkthrough) - 1 and idx < len(lines) - 1:\n",
    "#                 line = filein.readline()\n",
    "                line = lines[idx - 1]\n",
    "                wt_ssa_data = json.loads(line)\n",
    "\n",
    "                ssa_obs = '|'.join(wt_ssa_data['observations'].split('|')[0:3])\n",
    "\n",
    "    #             print(ob == ssa_obs)\n",
    "                if new_ob != ssa_obs:\n",
    "                    print('ERROR: different obs')\n",
    "                    print(new_ob)\n",
    "                    print(ssa_obs)\n",
    "                    return\n",
    "            \n",
    "            if idx == len(lines) or idx == len(walkthrough):\n",
    "                break\n",
    "                \n",
    "        filein.close()\n",
    "        fileout.close()\n",
    "            \n",
    "    return scores\n",
    "\n",
    "\n",
    "def generate_sas_data_new(game, rom_path):\n",
    "    # Create the environment, optionally specifying a random seed\n",
    "#     rom_path = \"roms/jericho-game-suite/{}\".format(game2rom[game_name])\n",
    "\n",
    "    bindings = load_bindings(rom_path)\n",
    "    scores = []\n",
    "    cum_r = 0.0\n",
    "    step = 0\n",
    "    if 'walkthrough' in bindings:\n",
    "        walkthrough = bindings['walkthrough'].split('/')\n",
    "        seed = bindings['seed']\n",
    "        env = FrotzEnv(rom_path, seed=seed)\n",
    "        \n",
    "        filein = open('../data/ssa_data/jecc_sup/{}.ssa.wt_traj.txt'.format(game))\n",
    "        fileout = open('../data/ssa_data/jecc_sup/{}.sas.wt_traj.txt.new'.format(game), 'w')\n",
    "        \n",
    "        lines = filein.readlines()\n",
    "        \n",
    "        print(len(lines))\n",
    "        \n",
    "        for idx, act in enumerate(walkthrough):\n",
    "#             print('step:', step)\n",
    "            step += 1\n",
    "#             print('act:', act)\n",
    "            \n",
    "            state_save = env.get_state()\n",
    "            \n",
    "            observation, reward, done, info = env.step(act)\n",
    "#             print('obs:', observation)\n",
    "            scores.append(reward)\n",
    "            cum_r += reward\n",
    "#             print('curR:', cum_r)          \n",
    "            new_ob = get_full_obs(env, observation)\n",
    "#             print(new_ob)\n",
    "            \n",
    "            state_prim_save = env.get_state()\n",
    "            \n",
    "            if idx > 0:\n",
    "                if isinstance(wt_ssa_data['valid_actions'][0], dict):\n",
    "                    wt_ssa_data['valid_actions'] = [wt_ssa_data['valid_actions']]\n",
    "                for valid_act_group in wt_ssa_data['valid_actions']:\n",
    "                    valid_act_tuple = valid_act_group[0]\n",
    "                    valid_act = valid_act_tuple['a']\n",
    "                    \n",
    "                    env.set_state(state_save)\n",
    "                    observation, reward, done, info = env.step(valid_act)\n",
    "                    ob = get_full_obs(env, observation)\n",
    "#                     print('[' + valid_act + ']: ' + ob)\n",
    "                    valid_act_tuple['observations'] = ob\n",
    "    \n",
    "#                 print(wt_ssa_data['valid_actions'])\n",
    "                fileout.write(json.dumps(wt_ssa_data) + '\\n')\n",
    "                env.set_state(state_prim_save)\n",
    "\n",
    "            if idx == len(lines) or idx == len(walkthrough):\n",
    "                break\n",
    "            \n",
    "            line = lines[idx]\n",
    "            wt_ssa_data = json.loads(line)\n",
    "\n",
    "            ssa_obs = '|'.join(wt_ssa_data['observations'].split('|')[0:3])\n",
    "\n",
    "#             print(ob == ssa_obs)\n",
    "            if new_ob != ssa_obs:\n",
    "                print('step:', step)\n",
    "                print('ERROR: different obs')\n",
    "                print(new_ob)\n",
    "                print(ssa_obs)\n",
    "                return\n",
    "\n",
    "                \n",
    "        filein.close()\n",
    "        fileout.close()\n",
    "            \n",
    "    return scores\n",
    "\n",
    "# games = ['zork1', 'zork3', 'enchanter', 'zork2', 'wishbringer', 'sorcerer']\n",
    "games = ['wishbringer']\n",
    "\n",
    "print('#number of games: {}'.format(len(games)))\n",
    "\n",
    "roms = os.listdir('../roms/jericho-game-suite/')\n",
    "game2rom = {}\n",
    "logs = []\n",
    "for game in games:\n",
    "    for rom in roms:\n",
    "        if rom.startswith(game + '.z'):\n",
    "            game2rom[game] = rom\n",
    "#             print('find {} for {}'.format(rom, game))\n",
    "            logs.append('find {} for {}'.format(rom, game))\n",
    "    if game not in game2rom:\n",
    "        print('cannot find rom for {}'.format(game))\n",
    "                        \n",
    "print('#number of roms founds: {}'.format(len(logs)))\n",
    "\n",
    "for game in games:\n",
    "    print('working on {}'.format(game))\n",
    "    game_rom_path = \"../roms/jericho-game-suite/{}\".format(game2rom[game])\n",
    "    step_scores = generate_sas_data_new(game, game_rom_path)\n",
    "    print(np.sum(np.array(step_scores)))\n",
    "#     break\n",
    "#     scores = step_scores[:100]\n",
    "#     scores = np.array(scores)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(len(lines))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "# def generate_output_tuple(observation, info):\n",
    "\n",
    "#     output_dict = {}\n",
    "# #     output_dict['actions'] = []\n",
    "#     output_dict['valid_actions'] = []\n",
    "#     output_dict['observations'] = observation\n",
    "\n",
    "# #     for a in info['act']:\n",
    "# #         new_dict = {}\n",
    "# #         new_dict['a'] = a.action\n",
    "# #         new_dict['t'] = a.template_id\n",
    "# #         new_dict['o'] = a.obj_ids\n",
    "# #     #     obj_ids = [str(x) for x in a.obj_ids]\n",
    "# #     #     print(a.action + '\\t' + str(a.template_id) + '\\t' + ' '.join(obj_ids))\n",
    "# #         output_dict['actions'].append(new_dict)\n",
    "\n",
    "#     if isinstance(info['valid_act'][0], str):\n",
    "#         new_dict = {}\n",
    "#         new_dict['a'] = info['valid_act'][0]\n",
    "#         output_dict['valid_actions'].append(new_dict)\n",
    "        \n",
    "#     else:\n",
    "#         for a_list in info['valid_act']:\n",
    "#             new_list = []\n",
    "#             for a in a_list:\n",
    "#                 new_dict = {}\n",
    "#                 new_dict['a'] = a.action\n",
    "#                 new_dict['t'] = a.template_id\n",
    "#                 new_dict['o'] = a.obj_ids\n",
    "                \n",
    "#                 new_list.append(new_dict)\n",
    "#         #     obj_ids = [str(x) for x in a.obj_ids]\n",
    "#         #     print(a.action + '\\t' + str(a.template_id) + '\\t' + ' '.join(obj_ids))\n",
    "#             output_dict['valid_actions'].append(new_list)\n",
    "        \n",
    "#     return output_dict\n",
    "\n",
    "# game_max_scores = []\n",
    "\n",
    "# games = ['zork1']\n",
    "\n",
    "# for game in games:\n",
    "    \n",
    "#     if os.path.isfile('../data/ssa_data/zork_universe_sup/{}.sas.wt_traj.txt'.format(game)):\n",
    "#         continue\n",
    "    \n",
    "#     print('generating trajectary for game: {}'.format(game))\n",
    "# #     continue\n",
    "\n",
    "#     rom_path = \"../roms/jericho-game-suite/{}\".format(game2rom[game]) # \"../roms/jericho-game-suite/zork1.z5\"\n",
    "\n",
    "#     bindings = load_bindings(rom_path)\n",
    "#     scores = []\n",
    "#     if 'walkthrough' in bindings:\n",
    "#         walkthrough = bindings['walkthrough'].split('/')\n",
    "#         seed = bindings['seed']\n",
    "\n",
    "#         filein = open('../data/ssa_data/zork_universe_sup/{}.ssa.wt_traj.txt'.format(game))\n",
    "#         fileout = open('../data/ssa_data/zork_universe_sup/{}.sas.wt_traj.txt'.format(game), 'w')\n",
    "\n",
    "#         env = JerichoEnv(rom_path, seed=seed)\n",
    "#     #     env = FrotzEnv(rom_path, seed=seed)\n",
    "    \n",
    "#         print(walkthrough)\n",
    "#         for idx, act in enumerate(walkthrough):\n",
    "\n",
    "#             observation, reward, done, info = env.env.step(act)\n",
    "            \n",
    "#             line = filein.readline()\n",
    "#             wt_ssa_data = json.loads(line)\n",
    "#             print(wt_ssa_data)\n",
    "            \n",
    "#             break\n",
    "            \n",
    "# #             observation, reward, done, info = env.step(act)\n",
    "#     #         observation, reward, done, info = env.step(act, parallel=False)\n",
    "#             scores.append(reward)\n",
    "\n",
    "#     #         if len(info['valid_act']) == 3 and info['valid_act'][0] == 'wait' and info['valid_act'][1] == 'yes' and info['valid_act'][2] == 'no':\n",
    "    \n",
    "#             if len(info['valid_act']) == 0:\n",
    "#                 if idx == len(walkthrough) - 1:\n",
    "#     #                 print(observation)\n",
    "#                     break\n",
    "#                 else:\n",
    "#                     info['valid_act'] = [walkthrough[idx+1]]\n",
    "\n",
    "#             output_dict = generate_output_tuple(observation, info)    \n",
    "#             fileout.write(json.dumps(output_dict) + '\\n')\n",
    "#     #         break\n",
    "\n",
    "#         print('Total Score', info['score'])\n",
    "#         game_max_scores.append(info['score'])\n",
    "\n",
    "#         fileout.close()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
