
import gym
import nle
import torch
from nle.nethack import objclass, permonst

from il_scale.nethack.networks.topline_net import TopLineEncoder

inventory_encoder = TopLineEncoder(512)

env = gym.make('NetHackChallenge-v0', save_ttyrec_every=1, savedir="")
obs = env.reset()

inv_strs = torch.from_numpy(obs['inv_strs'])
inv_letters = torch.from_numpy(obs['inv_letters'])
inv_glyphs = torch.from_numpy(obs['inv_glyphs'])

breakpoint()

# concate along the last dimension
total_inv = torch.cat([inv_letters.unsqueeze(1), inv_strs], dim=-1)
breakpoint()

out = inventory_encoder(torch.from_numpy(obs['inv_strs']))
breakpoint()