{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import imageio\n",
    "os.environ[\"CUDA_DEVICE_ORDER\"] = \"PCI_BUS_ID\"   # see issue #152\n",
    "os.environ[\"CUDA_VISIBLE_DEVICES\"] =\"0\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "from hydra import initialize, initialize_config_module, initialize_config_dir, compose\n",
    "from omegaconf import OmegaConf"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "## WITH 20 options lorl states and horizon 10\n",
    "\n",
    "# BASE_PATH=\"/atlas/u/divgarg/projects/Language-RL/hrl/outputs/2021-12-18/21-35-16/checkpoints/LorlEnv-v0-40108-traj_option-2021-12-18-21:35:16/\"\n",
    "# ITER = 360"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "## WITH 20 OPTIONS LORL IMAGES, HORIZON 10\n",
    "\n",
    "# BASE_PATH = \"/atlas/u/divgarg/projects/Language-RL/hrl/outputs/2021-12-22/11-59-46/checkpoints/LorlEnv-v0-40108-traj_option-2021-12-22-11:59:46\"\n",
    "# ITER = 500"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "## WITH 20 OPTIONS LORL IMAGES, HORIZON 10\n",
    "\n",
    "#BASE_PATH = \"/atlas/u/divgarg/projects/Language-RL/hrl/outputs/2022-01-07/14-24-30/checkpoints/LorlEnv-v0-40108-traj_option-2022-01-07-14:24:31\"\n",
    "#ITER = 70 # 20, 70\n",
    "# eval_episode_factor = 10"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "#BASE_PATH = \"/atlas/u/divgarg/projects/Language-RL/hrl/outputs/2022-01-13/11-21-27/checkpoints/LorlEnv-v0-40108-traj_option-2022-01-13-11:21:28\"\n",
    "BASE_PATH = \"/atlas/u/divgarg/projects/Language-RL/hrl/outputs/2022-01-13/11-25-35/checkpoints/LorlEnv-v0-40108-vanilla-2022-01-13-11:25:35\"\n",
    "ITER = 500"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "CKPT=f\"{BASE_PATH}/model_{ITER}.ckpt\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'cuda_deterministic': False, 'wandb': True, 'seed': 0, 'resume': False, 'checkpoint_path': '/atlas/u/divgarg/projects/Language-RL/hrl/outputs/2022-01-13/11-25-35/checkpoints/LorlEnv-v0-40108-vanilla-2022-01-13-11:25:35/model_500.ckpt', 'eval': True, 'render': False, 'render_path': './eval_${env.name}/', 'batch_size': 512, 'max_iters': 500, 'warmup_steps': 2500, 'lr_decay': 0.1, 'decay_steps': 100000, 'option_dim': 128, 'codebook_dim': 16, 'parallel': True, 'savedir': 'checkpoints', 'savepath': None, 'method': 'option_dt', 'use_iq': False, 'learning_rate': 1e-05, 'lm_learning_rate': 1e-07, 'weight_decay': 0.0001, 'os_learning_rate': 1e-05, 'trainer': {'device': None, 'state_il': False, 'num_eval_episodes': 5, 'eval_every': 5, 'K': '${model.K}'}, 'model': {'name': 'traj_option', 'horizon': 10, 'K': 10, 'train_lm': True, 'use_iq': '${use_iq}', 'method': '${model.name}', 'state_reconstruct': False, 'lang_reconstruct': False}, 'env': {'skip_words': ['go', 'to', 'the', 'a', '[SEP]'], 'name': 'LorlEnv-v0', 'state_dim': '(3, 64, 64)', 'action_dim': 5, 'discrete': False, 'eval_offline': False, 'use_state': False, 'eval_episode_factor': 10, 'eval_env': None}, 'option_selector': {'option_transformer': {'hidden_size': 128, 'n_layer': 1, 'n_head': 4, 'max_length': None, 'max_ep_len': None, 'activation_function': 'relu', 'n_positions': 1024, 'dropout': 0.1, 'output_attention': False}, 'horizon': '${model.horizon}', 'use_vq': True, 'kmeans_init': True, 'commitment_weight': 0.25, 'num_options': 20, 'num_hidden': 2}, 'iq': {'alpha': 0.1, 'div': 'chi', 'loss': 'value', 'gamma': 0.99, 'use_target': False}, 'log_interval': 1, 'save_interval': 50, 'hydra_base_dir': '', 'exp_name': '', 'project_name': '${env.name}', 'state_reconstructor': {'num_hidden': 2, 'hidden_size': 128}, 'lang_reconstructor': {'num_hidden': 2, 'hidden_size': 128, 'max_options': None}, 'dt': {'hidden_size': 128, 'n_layer': 1, 'n_head': 4, 'option_il': False, 'activation_function': 'relu', 'n_positions': 1024, 'dropout': 0.1}, 'train_dataset': {'expert_location': '/atlas/u/divgarg/datasets/lorel/may_08_sawyer_50k/prep_data.pkl', 'num_trajectories': 40108, 'normalize_states': False, 'no_lang': False, 'seed': '${seed}'}, 'val_dataset': {'expert_location': None, 'num_trajectories': '${trainer.num_eval_episodes}', 'normalize_states': False, 'seed': '${seed}'}}\n"
     ]
    }
   ],
   "source": [
    "with initialize(\"conf\"):\n",
    "    cfg = compose(config_name=\"config.yaml\", overrides=[\"eval=True\", \"method=option_dt\", f\"checkpoint_path={CKPT}\"])\n",
    "    print(cfg)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "from main import *\n",
    "import ast\n",
    "\n",
    "def evaluate(cfg):\n",
    "    # load saved arguments\n",
    "    checkpoint = torch.load(cfg.checkpoint_path)\n",
    "    args = checkpoint['config']\n",
    "    max_length = checkpoint['train_dataset_max_length']\n",
    "    args.eval = cfg.eval\n",
    "    args.render = cfg.render\n",
    "    args.checkpoint_path = cfg.checkpoint_path\n",
    "    device = cfg.trainer.device\n",
    "\n",
    "    # args.env.eval_episode_factor = 1\n",
    "    # Set num train_trajs to something small\n",
    "    args.train_dataset.num_trajectories = 1000\n",
    "    print(OmegaConf.to_yaml(args))\n",
    "\n",
    "    args.method = args.model.name\n",
    "\n",
    "    tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-uncased')\n",
    "    num_eval_episodes = args.trainer.num_eval_episodes\n",
    "\n",
    "    if not args.env.eval_offline:\n",
    "        env = gym.make(args.env.name)\n",
    "        #env.seed(args.seed)\n",
    "        state_dim = args.env.state_dim\n",
    "        if isinstance(state_dim, str):\n",
    "            state_dim = ast.literal_eval(state_dim)\n",
    "        action_dim = args.env.action_dim\n",
    "\n",
    "    if 'BabyAI' in args.env.name:\n",
    "        state_dim += 4*args.env.use_direction\n",
    "\n",
    "    train_dataset_args = dict(args.train_dataset)\n",
    "    batch_size = args.batch_size\n",
    "\n",
    "    if 'BabyAI' in args.env.name:\n",
    "        train_dataset = ExpertDataset(**train_dataset_args, use_direction=args.env.use_direction)\n",
    "    elif 'Lorl' in args.env.name:\n",
    "        # train_dataset_args also contains a split here for the validation data size\n",
    "        train_dataset = ExpertDataset(**train_dataset_args, use_state=args.env.use_state)\n",
    "    else:\n",
    "        raise NotImplementedError\n",
    "    train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size,\n",
    "                              shuffle=True, drop_last=True)\n",
    "    del train_dataset\n",
    "    \n",
    "    eval_episode_factor =  args.env.eval_episode_factor \n",
    "    \n",
    "    if args.method == 'traj_option':\n",
    "        args.option_selector.option_transformer.max_length = int(max_length)\n",
    "        args.option_selector.option_transformer.max_ep_len = eval_episode_factor * int(max_length)\n",
    "        # args.option_selector.option_transformer.output_attention = True\n",
    "\n",
    "    option_selector_args = dict(args.option_selector)\n",
    "    option_selector_args['state_dim'] = state_dim\n",
    "    option_selector_args['option_dim'] = args.option_dim\n",
    "    option_selector_args['codebook_dim'] = args.codebook_dim\n",
    "    # option_selector_args['option_transformer']['output_attention'] = True\n",
    "    \n",
    "    state_reconstructor_args = dict(args.state_reconstructor)\n",
    "    lang_reconstructor_args = dict(args.lang_reconstructor)\n",
    "    decision_transformer_args = {'state_dim': state_dim,\n",
    "                                 'action_dim': action_dim,\n",
    "                                 'option_dim': args.option_dim,\n",
    "                                 'discrete': args.env.discrete,\n",
    "                                 'hidden_size': args.dt.hidden_size,\n",
    "                                 'use_language': args.method == 'vanilla',\n",
    "                                 'use_options': args.method != 'vanilla',\n",
    "                                 'max_length': max_length if args.method != 'traj_option' else args.model.K,\n",
    "                                 # setting this to be sufficiently large so that there is enough of a buffer during eval\n",
    "                                 'max_ep_len': eval_episode_factor*max_length,\n",
    "                                 'action_tanh': False,\n",
    "                                 'n_layer': args.dt.n_layer,\n",
    "                                 'n_head': args.dt.n_head,\n",
    "                                 'n_inner': 4*args.dt.hidden_size,\n",
    "                                 'activation_function': args.dt.activation_function,\n",
    "                                 'n_positions': args.dt.n_positions,\n",
    "                                 'resid_pdrop': args.dt.dropout,\n",
    "                                 'attn_pdrop': args.dt.dropout,\n",
    "                                 }\n",
    "    hrl_model_args = dict(args.model)\n",
    "    iq_args = cfg.iq\n",
    "\n",
    "    model = HRLModel(option_selector_args, state_reconstructor_args,\n",
    "                     lang_reconstructor_args, decision_transformer_args, iq_args, device, **hrl_model_args)\n",
    "\n",
    "    model = model.to(device=device)\n",
    "\n",
    "    trainer_args = dict(args.trainer)\n",
    "\n",
    "    trainer = Trainer(\n",
    "        model=model,\n",
    "        tokenizer=tokenizer,\n",
    "        optimizer=None,\n",
    "        train_loader=train_loader,\n",
    "        env=env,\n",
    "        val_loader=None,\n",
    "        scheduler=None,\n",
    "        **trainer_args\n",
    "    )\n",
    "\n",
    "    # Restore trainer from checkpoint\n",
    "    ## TEMP: DISABLE LOADING CHECKPOINT\n",
    "    trainer.load(args.checkpoint_path)\n",
    "    return model, tokenizer, train_loader, env, args"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "cuda_deterministic: false\n",
      "wandb: true\n",
      "seed: 0\n",
      "resume: false\n",
      "checkpoint_path: /atlas/u/divgarg/projects/Language-RL/hrl/outputs/2022-01-13/11-25-35/checkpoints/LorlEnv-v0-40108-vanilla-2022-01-13-11:25:35/model_500.ckpt\n",
      "eval: true\n",
      "render: false\n",
      "render_path: ./eval_${env.name}/\n",
      "batch_size: 512\n",
      "max_iters: 500\n",
      "warmup_steps: 2500\n",
      "lr_decay: 0.1\n",
      "decay_steps: 100000\n",
      "option_dim: 128\n",
      "codebook_dim: 16\n",
      "parallel: true\n",
      "savedir: checkpoints\n",
      "savepath: /atlas/u/divgarg/projects/Language-RL/hrl/outputs/2022-01-13/11-25-35/checkpoints/LorlEnv-v0-40108-vanilla-2022-01-13-11:25:35\n",
      "method: vanilla\n",
      "use_iq: false\n",
      "learning_rate: 1.0e-05\n",
      "lm_learning_rate: 1.0e-07\n",
      "weight_decay: 0.0001\n",
      "os_learning_rate: 1.0e-05\n",
      "trainer:\n",
      "  device: cuda:0\n",
      "  state_il: false\n",
      "  num_eval_episodes: 5\n",
      "  eval_every: 5\n",
      "  K: ${model.K}\n",
      "model:\n",
      "  name: vanilla\n",
      "  horizon: 5\n",
      "  K: 5\n",
      "  train_lm: true\n",
      "  method: ${model.name}\n",
      "  state_reconstruct: false\n",
      "  lang_reconstruct: false\n",
      "env:\n",
      "  skip_words:\n",
      "  - go\n",
      "  - to\n",
      "  - the\n",
      "  - a\n",
      "  - '[SEP]'\n",
      "  name: LorlEnv-v0\n",
      "  state_dim: (3, 64, 64)\n",
      "  action_dim: 5\n",
      "  discrete: false\n",
      "  eval_offline: false\n",
      "  use_state: false\n",
      "  eval_episode_factor: 10\n",
      "  eval_env: null\n",
      "option_selector:\n",
      "  option_transformer: null\n",
      "iq:\n",
      "  alpha: 0.1\n",
      "  div: chi\n",
      "  loss: value\n",
      "  gamma: 0.99\n",
      "  use_target: false\n",
      "log_interval: 1\n",
      "save_interval: 5\n",
      "hydra_base_dir: /atlas/u/divgarg/projects/Language-RL/hrl/outputs/2022-01-13/11-25-35\n",
      "exp_name: ''\n",
      "project_name: ${env.name}\n",
      "train_dataset:\n",
      "  expert_location: /atlas/u/divgarg/datasets/lorel/may_08_sawyer_50k/prep_data.pkl\n",
      "  num_trajectories: 1000\n",
      "  normalize_states: false\n",
      "  seed: ${seed}\n",
      "val_dataset:\n",
      "  expert_location: null\n",
      "  num_trajectories: ${trainer.num_eval_episodes}\n",
      "  normalize_states: false\n",
      "  seed: ${seed}\n",
      "state_reconstructor:\n",
      "  num_hidden: 2\n",
      "  hidden_size: 128\n",
      "lang_reconstructor:\n",
      "  num_hidden: 2\n",
      "  hidden_size: 128\n",
      "  max_options: null\n",
      "dt:\n",
      "  hidden_size: 256\n",
      "  n_layer: 2\n",
      "  n_head: 4\n",
      "  activation_function: relu\n",
      "  n_positions: 1024\n",
      "  dropout: 0.1\n",
      "\n"
     ]
    }
   ],
   "source": [
    "model, tokenizer, train_loader, env, args = evaluate(cfg)\n",
    "#max_length = args.option_selector.option_transformer.max_length"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "# from typing import Dict, Iterable, Callable\n",
    "# import torch.nn as nn\n",
    "# from torch import Tensor\n",
    "\n",
    "# class Attention(nn.Module):\n",
    "#     def __init__(self, model: nn.Module):\n",
    "#         super().__init__()\n",
    "#         self.model = model\n",
    "#         self._attention = None\n",
    "\n",
    "#         model.option_selector.option_dt.register_forward_hook(self.save_attention_hook())\n",
    "\n",
    "#     def save_attention_hook(self) -> Callable:\n",
    "#         def fn(model, input, output):\n",
    "#             self._attention= output[-1]\n",
    "#         return fn\n",
    "\n",
    "#     def forward(self, x: Tensor) -> Dict[str, Tensor]:\n",
    "#         _ = self.model(x)\n",
    "#         return self._attention"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "#att_model = Attention(model)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 35,
   "id": "retired-postcard",
   "metadata": {},
   "outputs": [],
   "source": [
    "from PIL import Image\n",
    "from env import LorlWrapper, BabyAIWrapper\n",
    "from eval import get_action\n",
    "\n",
    "def traj(env, train_loader, model, tokenizer, args, render=False, render_freq=1, instr=None, orig_instr=None, **kwargs):\n",
    "    \n",
    "    if env and 'dont_reset' not in kwargs:\n",
    "        if 'BabyAI' in args.env.name:\n",
    "            env = BabyAIWrapper(env, train_loader.dataset)\n",
    "        \n",
    "        elif 'Lorl' in args.env.name:\n",
    "            env = LorlWrapper(env, train_loader.dataset, instr=instr, orig_instr=orig_instr)\n",
    "    \n",
    "    if 'BabyAI' in args.env.name:\n",
    "        max_ep_len = 300  # 2 * train_loader.dataset.max_length\n",
    "\n",
    "        \n",
    "    elif 'Lorl' in args.env.name:\n",
    "        max_ep_len = 60#60 # args.env.eval_episode_factor * train_loader.dataset.max_length - 1 \n",
    "            \n",
    "    model.eval()\n",
    "\n",
    "    if hasattr(model, 'module'):\n",
    "        model = model.module\n",
    "    else:\n",
    "        model = model\n",
    "\n",
    "    device = args.trainer.device\n",
    "    method = model.method\n",
    "    K = args.model.K  # 20\n",
    "    horizon = model.horizon  # 20\n",
    "    option_dim = model.option_dim\n",
    "    model = model.to(device=device)\n",
    "        \n",
    "    if env:\n",
    "        if method != 'vanilla':\n",
    "            option_dim = model.option_selector.option_dim\n",
    "\n",
    "        returns, lengths, successes = [], [], []\n",
    "        if 'dont_reset' in kwargs:\n",
    "            observation = kwargs['dont_reset']\n",
    "        else:\n",
    "            observation = env.reset()\n",
    "        state, lang = observation['state'], observation['lang']\n",
    "\n",
    "        state_dim = env.state_dim\n",
    "        act_dim = env.act_dim\n",
    "        cur_state = torch.from_numpy(state)\n",
    "\n",
    "        # we keep all the histories on the device\n",
    "        # note that the latest action and reward will be \"padding\"\n",
    "        lm_input = tokenizer(text=[lang], add_special_tokens=True,\n",
    "                                  return_tensors='pt', padding=True).to(device=device)\n",
    "        with torch.no_grad():\n",
    "            lm_embeddings = model.lm(\n",
    "                lm_input['input_ids'], lm_input['attention_mask']).last_hidden_state\n",
    "            cls_embeddings = lm_embeddings[:, 0, :]\n",
    "            word_embeddings = lm_embeddings[:, 1:, :]      # skip the CLS and SEP tokens. here there's no padding so this is actually the CLS and SEP\n",
    "            # word_embeddings = lm_embeddings\n",
    "        \n",
    "        if isinstance(state_dim, tuple):\n",
    "            states = torch.from_numpy(state).reshape(1, *state_dim).to(device=device, dtype=torch.float32)\n",
    "        else:\n",
    "            states = torch.from_numpy(state).reshape(1, state_dim).to(device=device, dtype=torch.float32)\n",
    "        actions = torch.zeros((0, act_dim), device=device, dtype=torch.float32)\n",
    "        timesteps = torch.tensor(0, device=device, dtype=torch.long).reshape(1, 1)\n",
    "        if method != 'vanilla':\n",
    "            options = torch.zeros((0, option_dim), device=device, dtype=torch.float32)\n",
    "\n",
    "        episode_return, episode_length, success = 0, 0, 0\n",
    "        options_list = []\n",
    "        images = []\n",
    "\n",
    "        option = None\n",
    "\n",
    "        for t in range(max_ep_len):\n",
    "            # add dummy action\n",
    "            actions = torch.cat([actions, torch.zeros((1, act_dim), device=device)], dim=0)\n",
    "            if method != 'vanilla':\n",
    "                options = torch.cat([options, torch.zeros((1, option_dim), device=device)], dim=0)\n",
    "            else:\n",
    "                options = None\n",
    "\n",
    "            action, option, states, actions, timesteps, options = get_action(\n",
    "                model, states, actions, options, timesteps, cls_embeddings, word_embeddings, options_list, cur_state,\n",
    "                option, t, horizon, K, method, state_dim, act_dim, option_dim, device, **kwargs)\n",
    "\n",
    "            if model.decision_transformer.discrete:\n",
    "                actions[-1] = torch.nn.functional.one_hot(action, num_classes=act_dim)\n",
    "            else:\n",
    "                action = torch.clamp(action, torch.from_numpy(env.action_space.low).to(\n",
    "                    device), torch.from_numpy(env.action_space.high).to(device))\n",
    "                actions[-1] = action\n",
    "\n",
    "            action = action.detach().cpu().numpy()\n",
    "            assert action in env.action_space, \"Transformer predicted action outside env action space\"\n",
    "\n",
    "            obs, reward, done, info = env.step(action)\n",
    "            if render:\n",
    "                images.append(env.get_image())\n",
    "\n",
    "            state, lang = obs['state'], obs['lang']\n",
    "            if isinstance(state_dim, tuple):\n",
    "                cur_state = torch.from_numpy(state).to(device=device).reshape(1, *state_dim)\n",
    "            else:\n",
    "                cur_state = torch.from_numpy(state).to(device=device).reshape(1, state_dim)\n",
    "            states = torch.cat([states, cur_state], dim=0).float()\n",
    "            timesteps = torch.cat(\n",
    "                [timesteps,\n",
    "                    torch.ones((1, 1), device=device, dtype=torch.long) * (t + 1)], dim=1)\n",
    "\n",
    "            episode_return += reward\n",
    "            episode_length += 1\n",
    "\n",
    "            if done:\n",
    "                success = info['success']\n",
    "                break\n",
    "            \n",
    "        if render and 'dont_reset' not in kwargs:\n",
    "            imageio.mimsave(f'gifs/{instr}.gif', images)\n",
    "\n",
    "    return  word_embeddings, states.to(dtype=torch.float32), timesteps.to(dtype=torch.long), images, lm_input, options_list, episode_length, success, (obs, env)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [],
   "source": [
    "import random\n",
    "from utils import LORL_EVAL_INSTRS, LORL_COMPOSITION_INSTRS\n",
    "import copy\n",
    "\n",
    "# kwargs = {}\n",
    "# if 'Lorl' in args.env.name:\n",
    "#     orig_instr, rephrasals = random.choice(list(LORL_EVAL_INSTRS.items()))\n",
    "#     rephrasal_type, instr_list = random.choice(list(rephrasals.items()))\n",
    "#     instr = random.choice(instr_list)\n",
    "#     kwargs['orig_instr'] = orig_instr\n",
    "#     kwargs['instr'] = instr\n",
    "                        \n",
    "# words, states, timesteps, images, lm_inputs, options_list, episode_length, success, _ = traj(copy.deepcopy(env), train_loader, model, tokenizer, args, **kwargs)\n",
    "# print(f'Generated episode of length: {episode_length}')\n",
    "# print(f'Generated episode success: {success}')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [],
   "source": [
    "# att_model.model.option_selector.get_option(words, states, timesteps)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [],
   "source": [
    "#!pip install bertviz wordcloud"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [],
   "source": [
    "# from bertviz import head_view, model_view\n",
    "# from transformers import BertTokenizer, BertModel\n",
    "\n",
    "# def get_tokens(inputs):\n",
    "#     # model_version = 'bert-base-uncased'\n",
    "#     # do_lower_case = True\n",
    "#     # model = BertModel.from_pretrained(model_version, output_attentions=True)\n",
    "#     # tokenizer = BertTokenizer.from_pretrained(model_version, do_lower_case=do_lower_case)\n",
    "#     # sentence_a = \"The cat sat on the mat\"\n",
    "#     # sentence_b = \"The cat lay on the rug\"\n",
    "#     # inputs = tokenizer.encode_plus(sentence_a, sentence_b, return_tensors='pt', add_special_tokens=True)\n",
    "\n",
    "#     input_ids = inputs['input_ids']\n",
    "#     # token_type_ids = inputs['token_type_ids']\n",
    "#     # attention = model(input_ids, token_type_ids=token_type_ids)[-1]\n",
    "#     # sentence_b_start = token_type_ids[0].tolist().index(1)\n",
    "#     input_id_list = input_ids[0].tolist() # Batch index 0\n",
    "#     tokens = tokenizer.convert_ids_to_tokens(input_id_list)[1:] \n",
    "#     return tokens"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [],
   "source": [
    "# tokens = get_tokens(lm_inputs)\n",
    "# N = len(tokens)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [],
   "source": [
    "# attention = att_model._attention\n",
    "# L = min(episode_length, max_length)\n",
    "# P = max_length -  L\n",
    "# idx = [True] * N + [False] * P + [True] * L\n",
    "# print(N, P, L)\n",
    "# out = []\n",
    "\n",
    "# for layer in range(len(attention)):\n",
    "#     x = attention[layer][:, :, idx, :]\n",
    "#     x = x[:, :, :, idx]\n",
    "#     out.append(x)\n",
    "# print(len(out))\n",
    "# print(out[-1].shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Load model and retrieve attention weights\n",
    "# options = np.repeat(options_list, args.model.horizon)[:L]\n",
    "\n",
    "# out[-1].shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [],
   "source": [
    "# options_list"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [],
   "source": [
    "#head_view(out, tokens + [f's_{i}, o_{o}' for i, o in enumerate(options)], layer=5)\n",
    "# head_view(out, tokens + [f's_{i}, o_{o}' for i, o in enumerate(options)], layer=args.option_selector.option_transformer.n_layer-1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {},
   "outputs": [],
   "source": [
    "# import matplotlib.pyplot as plt\n",
    "\n",
    "# for i, im in enumerate(images):\n",
    "#     print(f'state: {i}')\n",
    "#     plt.figure(figsize = (8,8))\n",
    "#     plt.imshow(im)\n",
    "#     plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Visualize word bags"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "id": "joined-battery",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 6/6 [03:11<00:00, 31.86s/it]\n",
      "100%|██████████| 6/6 [02:39<00:00, 26.51s/it]\n",
      "100%|██████████| 6/6 [03:00<00:00, 30.09s/it]\n",
      "100%|██████████| 6/6 [03:03<00:00, 30.58s/it]\n",
      "100%|██████████| 6/6 [03:00<00:00, 30.11s/it]\n",
      "100%|██████████| 6/6 [02:59<00:00, 29.98s/it]\n",
      "100%|██████████| 6/6 [03:03<00:00, 30.56s/it]\n",
      "100%|██████████| 6/6 [03:01<00:00, 30.25s/it]\n",
      "100%|██████████| 6/6 [03:01<00:00, 30.25s/it]\n",
      "100%|██████████| 6/6 [03:05<00:00, 30.86s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.23506493506493506\n",
      "{'close drawer': 0.1, 'open drawer': 0.6, 'turn faucet left': 0.0, 'turn faucet right': 0.0, 'move black mug right': 0.2, 'move white mug down': 0.0}\n",
      "{'seen': 0.15, 'unseen verb': 0.13333333333333333, 'unseen noun': 0.2833333333333333, 'unseen verb noun': 0.06666666666666667, 'human': 0.269811320754717}\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "# from tqdm import tqdm\n",
    "\n",
    "# num_options = args.option_selector.num_options if args.method != 'vanilla' else 0\n",
    "\n",
    "# words_dict = {i:[] for i in range(num_options)} # Create a list for each option\n",
    "# num_eps = 10\n",
    "\n",
    "# successes = []\n",
    "# instr_wise_stats = {k: [] for k in LORL_EVAL_INSTRS.keys()}\n",
    "# rephrasal_wise_stats = {k: [] for k in ['seen', 'unseen verb', 'unseen noun', 'unseen verb noun', 'human']}\n",
    "    \n",
    "# for i in range(num_eps):\n",
    "#     for orig_instr, rephrasals in tqdm(LORL_EVAL_INSTRS.items()):\n",
    "#         for rephrasal_type, instr_list in rephrasals.items():\n",
    "#             for instr in instr_list:\n",
    "#                 words, _, _, _, _, options_list, episode_length, success, _ = traj(env, train_loader, model, tokenizer, args, render=False, orig_instr=orig_instr, instr=instr)\n",
    "#                 tokens = instr.split()\n",
    "#                 successes.append(success)\n",
    "#                 if rephrasal_type == 'seen':\n",
    "#                     instr_wise_stats[orig_instr].append(success)\n",
    "#                 rephrasal_wise_stats[rephrasal_type].append(success)\n",
    "#                 for o in options_list:\n",
    "#                     for w in tokens:\n",
    "#                         words_dict[o].append(w)\n",
    "# print(np.mean(successes))\n",
    "# instr_wise_stats = {k: np.mean(instr_wise_stats[k]) for k in instr_wise_stats.keys()}\n",
    "# rephrasal_wise_stats = {k: np.mean(rephrasal_wise_stats[k]) for k in rephrasal_wise_stats.keys()}\n",
    "# print(instr_wise_stats)\n",
    "# print(rephrasal_wise_stats)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 37,
   "id": "elementary-patch",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 12/12 [02:32<00:00, 12.67s/it]\n",
      "100%|██████████| 12/12 [02:38<00:00, 13.22s/it]\n",
      "100%|██████████| 12/12 [02:39<00:00, 13.33s/it]\n",
      "100%|██████████| 12/12 [02:49<00:00, 14.10s/it]\n",
      "100%|██████████| 12/12 [02:35<00:00, 12.94s/it]\n",
      "100%|██████████| 12/12 [02:31<00:00, 12.66s/it]\n",
      "100%|██████████| 12/12 [02:36<00:00, 13.08s/it]\n",
      "100%|██████████| 12/12 [02:41<00:00, 13.42s/it]\n",
      "100%|██████████| 12/12 [02:38<00:00, 13.21s/it]\n",
      "100%|██████████| 12/12 [02:25<00:00, 12.09s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.11666666666666667\n",
      "{'open drawer and move black mug right': 0.0, 'pull the handle and move black mug down': 0.0, 'move white mug right': 0.1, 'move black mug down': 0.5, 'close drawer and turn faucet right': 0.0, 'close drawer and turn faucet left': 0.1, 'turn faucet left and move white mug down': 0.2, 'turn faucet right and close drawer': 0.0, 'move white mug down and turn faucet left': 0.2, 'close the drawer, turn the faucet left and move black mug right': 0.3, 'open drawer and turn faucet counterclockwise': 0.0, 'slide the drawer closed and then shift white mug down': 0.0}\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "from tqdm import tqdm\n",
    "\n",
    "num_options = args.option_selector.num_options if args.method != 'vanilla' else 0\n",
    "\n",
    "words_dict = {i:[] for i in range(num_options)} # Create a list for each option\n",
    "num_eps = 10\n",
    "\n",
    "successes = []\n",
    "instr_wise_stats = {k: [] for k in LORL_COMPOSITION_INSTRS}\n",
    "    \n",
    "for i in range(num_eps):\n",
    "    for orig_instr in tqdm(LORL_COMPOSITION_INSTRS):\n",
    "        instr = orig_instr\n",
    "        words, _, _, _, _, options_list, episode_length, success, _ = traj(env, train_loader, model, tokenizer, args, render=True, orig_instr=orig_instr, instr=instr)\n",
    "        tokens = instr.split()\n",
    "        successes.append(success)\n",
    "        instr_wise_stats[orig_instr].append(success)\n",
    "        for o in options_list:\n",
    "            for w in tokens:\n",
    "                words_dict[o].append(w)\n",
    "print(np.mean(successes))\n",
    "instr_wise_stats = {k: np.mean(instr_wise_stats[k]) for k in instr_wise_stats.keys()}\n",
    "print(instr_wise_stats)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "id": "furnished-setup",
   "metadata": {},
   "outputs": [],
   "source": [
    "# from tqdm import tqdm\n",
    "\n",
    "# num_options = args.option_selector.num_options\n",
    "\n",
    "# words_dict = {i:[] for i in range(num_options)} # Create a list for each option\n",
    "# num_eps = 3\n",
    "    \n",
    "# for n in range(num_eps):\n",
    "#     for i in tqdm(range(num_options)):\n",
    "#         words, _, _, _, _, options_list, episode_length, success, _ = traj(env, train_loader, model, tokenizer, args, render=True, orig_instr=f'{i}_{n+1}', instr=f'{i}_{n+1}', constant_option=i)\n",
    "#         #tokens = get_tokens(lm_inputs)\n",
    "#         #tokens = instr.split()\n",
    "#         #if success:\n",
    "#         #for o in options_list:\n",
    "#         #    for w in tokens:\n",
    "#         #        words_dict[o].append(w)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "id": "tamil-lindsay",
   "metadata": {},
   "outputs": [],
   "source": [
    "# task = 'close drawer and turn faucet right'\n",
    "# #options = [17, 2] \n",
    "# options = [14, 2]\n",
    "\n",
    "# #task = 'pull the handle and move black mug down'\n",
    "# #options = [9, 15]\n",
    "\n",
    "# for n in range(3):\n",
    "#     all_images = []\n",
    "#     all_options = []\n",
    "#     all_success = []\n",
    "#     words, _, _, images, _, options_list, episode_length, success, (prev_obs, prev_env) = traj(copy.deepcopy(env), train_loader, model, tokenizer, args, render=True, orig_instr=task, instr=task, constant_option=options[0])\n",
    "#     all_images.extend(images)\n",
    "#     all_options.extend(options_list)\n",
    "#     all_success.append(success)\n",
    "#     for option in options[1:]:\n",
    "#         words, _, _, images, _, options_list, episode_length, success, (prev_obs, prev_env) = traj(prev_env, train_loader, model, tokenizer, args, render=True, orig_instr=task, instr=task, constant_option=option, dont_reset=prev_obs)\n",
    "#         all_images.extend(images)\n",
    "#         all_options.extend(options_list)\n",
    "#         all_success.append(success)\n",
    "#     imageio.mimsave(f'gifs/{task}_{n+1}.gif', all_images)\n",
    "#     print(all_options, sum(all_success))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "id": "stock-union",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[]\n"
     ]
    }
   ],
   "source": [
    "from itertools import chain\n",
    "\n",
    "skip_words = ['the', 'a', '[SEP]']\n",
    "\n",
    "words = sorted(set(chain(*words_dict.values())) - set(skip_words))\n",
    "print(words)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "id": "veterinary-reputation",
   "metadata": {},
   "outputs": [],
   "source": [
    "def w_to_ind(word):\n",
    "    return words.index(word)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "id": "legendary-pioneer",
   "metadata": {},
   "outputs": [],
   "source": [
    "matrix = np.zeros([len(words), num_options])\n",
    "\n",
    "for o in range(num_options):\n",
    "    for w in words_dict[o]:\n",
    "        if w not in skip_words:\n",
    "            matrix[w_to_ind(w), o] += 1\n",
    "            \n",
    "#print(matrix)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 33,
   "id": "approved-ethernet",
   "metadata": {},
   "outputs": [],
   "source": [
    "#!pip install seaborn"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 34,
   "id": "temporal-theorem",
   "metadata": {},
   "outputs": [
    {
     "ename": "ValueError",
     "evalue": "zero-size array to reduction operation fmin which has no identity",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mValueError\u001b[0m                                Traceback (most recent call last)",
      "\u001b[0;32m/tmp/ipykernel_2480630/626772604.py\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m      4\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m      5\u001b[0m \u001b[0mplt\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfigure\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfigsize\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0;36m30\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;36m10\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 6\u001b[0;31m \u001b[0msns\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mheatmap\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmatrix\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0myticklabels\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mwords\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m      7\u001b[0m \u001b[0mplt\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mplot\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m/atlas/u/divgarg/miniconda3/envs/langrl/lib/python3.8/site-packages/seaborn/_decorators.py\u001b[0m in \u001b[0;36minner_f\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m     44\u001b[0m             )\n\u001b[1;32m     45\u001b[0m         \u001b[0mkwargs\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mupdate\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m{\u001b[0m\u001b[0mk\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0marg\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mk\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0marg\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mzip\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0msig\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mparameters\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0margs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m}\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 46\u001b[0;31m         \u001b[0;32mreturn\u001b[0m \u001b[0mf\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     47\u001b[0m     \u001b[0;32mreturn\u001b[0m \u001b[0minner_f\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     48\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m/atlas/u/divgarg/miniconda3/envs/langrl/lib/python3.8/site-packages/seaborn/matrix.py\u001b[0m in \u001b[0;36mheatmap\u001b[0;34m(data, vmin, vmax, cmap, center, robust, annot, fmt, annot_kws, linewidths, linecolor, cbar, cbar_kws, cbar_ax, square, xticklabels, yticklabels, mask, ax, **kwargs)\u001b[0m\n\u001b[1;32m    538\u001b[0m     \"\"\"\n\u001b[1;32m    539\u001b[0m     \u001b[0;31m# Initialize the plotter object\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 540\u001b[0;31m     plotter = _HeatMapper(data, vmin, vmax, cmap, center, robust, annot, fmt,\n\u001b[0m\u001b[1;32m    541\u001b[0m                           \u001b[0mannot_kws\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcbar\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcbar_kws\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mxticklabels\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    542\u001b[0m                           yticklabels, mask)\n",
      "\u001b[0;32m/atlas/u/divgarg/miniconda3/envs/langrl/lib/python3.8/site-packages/seaborn/matrix.py\u001b[0m in \u001b[0;36m__init__\u001b[0;34m(self, data, vmin, vmax, cmap, center, robust, annot, fmt, annot_kws, cbar, cbar_kws, xticklabels, yticklabels, mask)\u001b[0m\n\u001b[1;32m    157\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    158\u001b[0m         \u001b[0;31m# Determine good default values for the colormapping\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 159\u001b[0;31m         self._determine_cmap_params(plot_data, vmin, vmax,\n\u001b[0m\u001b[1;32m    160\u001b[0m                                     cmap, center, robust)\n\u001b[1;32m    161\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m/atlas/u/divgarg/miniconda3/envs/langrl/lib/python3.8/site-packages/seaborn/matrix.py\u001b[0m in \u001b[0;36m_determine_cmap_params\u001b[0;34m(self, plot_data, vmin, vmax, cmap, center, robust)\u001b[0m\n\u001b[1;32m    196\u001b[0m                 \u001b[0mvmin\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mnanpercentile\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mcalc_data\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m2\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    197\u001b[0m             \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 198\u001b[0;31m                 \u001b[0mvmin\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mnanmin\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mcalc_data\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    199\u001b[0m         \u001b[0;32mif\u001b[0m \u001b[0mvmax\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    200\u001b[0m             \u001b[0;32mif\u001b[0m \u001b[0mrobust\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m<__array_function__ internals>\u001b[0m in \u001b[0;36mnanmin\u001b[0;34m(*args, **kwargs)\u001b[0m\n",
      "\u001b[0;32m/atlas/u/divgarg/miniconda3/envs/langrl/lib/python3.8/site-packages/numpy/lib/nanfunctions.py\u001b[0m in \u001b[0;36mnanmin\u001b[0;34m(a, axis, out, keepdims)\u001b[0m\n\u001b[1;32m    317\u001b[0m         \u001b[0;31m# Fast, but not safe for subclasses of ndarray, or object arrays,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    318\u001b[0m         \u001b[0;31m# which do not implement isnan (gh-9009), or fmin correctly (gh-8975)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 319\u001b[0;31m         \u001b[0mres\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfmin\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mreduce\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0ma\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0maxis\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0maxis\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mout\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mout\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    320\u001b[0m         \u001b[0;32mif\u001b[0m \u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0misnan\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mres\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0many\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    321\u001b[0m             warnings.warn(\"All-NaN slice encountered\", RuntimeWarning,\n",
      "\u001b[0;31mValueError\u001b[0m: zero-size array to reduction operation fmin which has no identity"
     ]
    },
    {
     "data": {
      "text/plain": [
       "<Figure size 2160x720 with 0 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "%matplotlib inline\n",
    "import seaborn as sns\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "plt.figure(figsize = (30,10))\n",
    "sns.heatmap(matrix, yticklabels=words)\n",
    "plt.plot()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "matrix.sum(axis=0, keepdims=True).shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%matplotlib inline\n",
    "# Now if we normalize it by column:\n",
    "plt.figure(figsize = (30,10))\n",
    "matrix_norm_col=(matrix)/(matrix.sum(axis=0, keepdims=True) + 1e-6)\n",
    "im = sns.heatmap(matrix_norm_col, yticklabels=words)\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%matplotlib inline\n",
    "# Now if we normalize it by row:\n",
    "plt.figure(figsize = (30,10))\n",
    "matrix_norm_row=(matrix)/(matrix.sum(axis=1, keepdims=True) + 1e-6)\n",
    "sns.heatmap(matrix_norm_row, yticklabels=words)\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%matplotlib inline\n",
    "from collections import Counter\n",
    "import matplotlib.pyplot as plt\n",
    "from wordcloud import WordCloud\n",
    "\n",
    "for i in range(num_options):\n",
    "    if words_dict[i]:\n",
    "        print(i)\n",
    "        cloud = WordCloud(max_font_size=80,colormap=\"hsv\").generate_from_frequencies(Counter(x for x in words_dict[i] if x not in skip_words))\n",
    "        plt.figure(figsize=(16,12))\n",
    "        plt.imshow(cloud, interpolation='bilinear')\n",
    "        plt.axis('off')\n",
    "        plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#!pip install wordcloud"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "interpreter": {
   "hash": "a7a0b012ffbe0955a1b9770a1a760d0b78b7bdfec49d52f9b2be4548879c9a8f"
  },
  "kernelspec": {
   "display_name": "Python [conda env:langrl]",
   "language": "python",
   "name": "conda-env-langrl-py"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
