{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import sys\n",
    "import json\n",
    "import random\n",
    "import numpy as np\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "from torch.autograd import Variable\n",
    "from collections import deque\n",
    "\n",
    "from jecc_dataset_bert import BERTStateAction2StateDataset\n",
    "\n",
    "from tqdm import tqdm\n",
    "\n",
    "#os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"3\"\n",
    "\n",
    "# os.environ['CUDA_LAUNCH_BLOCKING'] = \"1\" # 3.15it/s\n",
    "\n",
    "torch.manual_seed(9527)\n",
    "np.random.seed(9527)\n",
    "random.seed(9527)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "print(os.environ[\"CUDA_VISIBLE_DEVICES\"])\n",
    "print(torch.__version__)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "class Argument():\n",
    "    def __init__(self):\n",
    "        self.hidden_dim = 200\n",
    "        self.mlstm_hidden_dim = 100\n",
    "        self.embedding_dim = 100\n",
    "#         self.embedding_dim = 300\n",
    "        self.num_classes = 2\n",
    "        self.kernel_size = 3\n",
    "        self.layer_num = 1\n",
    "        self.fine_tuning = False\n",
    "        self.cuda = True\n",
    "        self.lambda_l2 = 0.05\n",
    "        self.model_type = \"LSTM\"\n",
    "        self.cell_type = \"GRU\"\n",
    "        self.batch_size = 10\n",
    "        self.input_topk = 32\n",
    "        self.keep_prob = 0.8\n",
    "        self.predict_target_topk = 5\n",
    "        self.save_path = 'trained_models'\n",
    "#         self.model_prefix = 'hotpot_reranker_model_h%d.with_anchor_with_el'%self.hidden_dim\n",
    "        self.model_prefix = 'tmp_model'\n",
    "        self.load_model = False\n",
    "        self.load_path = '.'\n",
    "        \n",
    "args = Argument()\n",
    "print(vars(args))\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def find_game_roms(games, rom_dir):\n",
    "    print('#number of games: {}'.format(len(games)))\n",
    "\n",
    "    roms = os.listdir(rom_dir)\n",
    "    game2rom = {}\n",
    "    logs = []\n",
    "    for game in games:\n",
    "        for rom in roms:\n",
    "            if rom.startswith(game + '.z'):\n",
    "                game2rom[game] = rom\n",
    "    #             print('find {} for {}'.format(rom, game))\n",
    "                logs.append('find {} for {}'.format(rom, game))\n",
    "        if game not in game2rom:\n",
    "            print('cannot find rom for {}'.format(game))\n",
    "\n",
    "    print('#number of roms founds: {}'.format(len(logs)))\n",
    "    \n",
    "    return game2rom\n",
    "\n",
    "import importlib\n",
    "import jecc_dataset_bert\n",
    "importlib.reload(jecc_dataset_bert)\n",
    "from jecc_dataset_bert import BERTStateAction2StateDataset\n",
    "\n",
    "# data_dir = \"/dccstor/yum-worldmodel/shared_folder_2080/if_games/data/ssa_data/supervised/\"\n",
    "data_dir = \"/dccstor/yum-worldmodel/shared_folder_2080/if_games/data/ssa_data/jecc_sup/\"\n",
    "\n",
    "train_games = ['905', 'acorncourt', 'advent', 'adventureland', 'afflicted', 'awaken', \n",
    "               'balances', 'deephome', 'dragon', 'enchanter', 'inhumane', 'library', \n",
    "               'moonlit', 'omniquest', 'pentari', 'reverb', 'snacktime', 'sorcerer', 'zork1']\n",
    "dev_games = ['zork3', 'detective', 'ztuu', 'jewel', 'zork2']\n",
    "test_games = ['temple', 'gold', 'karn', 'zenon', 'wishbringer']\n",
    "\n",
    "games = train_games + dev_games + test_games\n",
    "    \n",
    "rom_dir = '../roms/jericho-game-suite/'\n",
    "game2rom = find_game_roms(games, rom_dir)\n",
    "print(game2rom)\n",
    "\n",
    "# games = ['zork1', 'zork3']\n",
    "\n",
    "pretrain_path = '/dccstor/gaot1/MultiHopReason/comprehension_tasks/narrativeqa/passage_ranker/bert-base-uncased/'\n",
    "\n",
    "# game_task_data = BERTStateAction2StateDataset(pretrain_path, data_dir, rom_dir=rom_dir, game2rom=game2rom,\n",
    "#                                           train_games=games, dev_games=games,\n",
    "#                                           setting='same_games', num_negative=4)\n",
    "\n",
    "game_task_data = BERTStateAction2StateDataset(pretrain_path, data_dir, rom_dir=rom_dir, game2rom=game2rom,\n",
    "                                              train_games=train_games, dev_games=dev_games, \n",
    "                                              test_games = test_games, truncate_num=512,\n",
    "                                              setting='transfer', num_negative=4)\n",
    "\n",
    "game_task_data.data_sets['train'].check_eval_triples(game_task_data.idx_2_word)\n",
    "game_task_data.data_sets['dev'].check_eval_triples(game_task_data.idx_2_word)\n",
    "game_task_data.data_sets['test'].check_eval_triples(game_task_data.idx_2_word)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from torch.nn.utils import rnn\n",
    "\n",
    "class LockedDropout(nn.Module):\n",
    "    def __init__(self, dropout):\n",
    "        super().__init__()\n",
    "        self.dropout = dropout\n",
    "\n",
    "    def forward(self, x):\n",
    "        dropout = self.dropout\n",
    "        if not self.training:\n",
    "            return x\n",
    "        m = x.data.new(x.size(0), 1, x.size(2)).bernoulli_(1 - dropout)\n",
    "        mask = Variable(m.div_(1 - dropout), requires_grad=False)\n",
    "        mask = mask.expand_as(x)\n",
    "        return mask * x\n",
    "\n",
    "class EncoderRNN(nn.Module):\n",
    "    def __init__(self, input_size, num_units, nlayers, concat, bidir, dropout, return_last):\n",
    "        super().__init__()\n",
    "        self.rnns = []\n",
    "        for i in range(nlayers):\n",
    "            if i == 0:\n",
    "                input_size_ = input_size\n",
    "                output_size_ = num_units\n",
    "            else:\n",
    "                input_size_ = num_units if not bidir else num_units * 2\n",
    "                output_size_ = num_units\n",
    "            self.rnns.append(nn.GRU(input_size_, output_size_, 1, bidirectional=bidir, batch_first=True))\n",
    "        self.rnns = nn.ModuleList(self.rnns)\n",
    "        self.init_hidden = nn.ParameterList([nn.Parameter(torch.Tensor(2 if bidir else 1, 1, num_units).zero_()) for _ in range(nlayers)])\n",
    "        self.dropout = LockedDropout(dropout)\n",
    "        self.concat = concat\n",
    "        self.nlayers = nlayers\n",
    "        self.return_last = return_last\n",
    "\n",
    "        # self.reset_parameters()\n",
    "\n",
    "    def reset_parameters(self):\n",
    "        for rnn in self.rnns:\n",
    "            for name, p in rnn.named_parameters():\n",
    "                if 'weight' in name:\n",
    "                    p.data.normal_(std=0.1)\n",
    "                else:\n",
    "                    p.data.zero_()\n",
    "\n",
    "    def get_init(self, bsz, i):\n",
    "        return self.init_hidden[i].expand(-1, bsz, -1).contiguous()\n",
    "\n",
    "    def forward(self, input, input_lengths=None):\n",
    "        bsz, slen = input.size(0), input.size(1)\n",
    "\n",
    "        outputs = []\n",
    "        if input_lengths is not None:\n",
    "#             lens = input_lengths.data.cpu().numpy()\n",
    "            lens = input_lengths\n",
    "            sort_idx = np.argsort(-lens)\n",
    "            idx_dict = {sort_idx[i_]: i_ for i_ in range(lens.shape[0])}\n",
    "            revert_idx = np.array([idx_dict[i_] for i_ in range(lens.shape[0])])\n",
    "            input = input[sort_idx, :]\n",
    "        output = input\n",
    "            \n",
    "        for i in range(self.nlayers):\n",
    "            hidden = self.get_init(bsz, i)\n",
    "            output = self.dropout(output)\n",
    "#             print(output.size())\n",
    "            if input_lengths is not None:\n",
    "                output = rnn.pack_padded_sequence(output, lens[sort_idx], batch_first=True)\n",
    "            output, hidden = self.rnns[i](output, hidden)\n",
    "#             print(output.size())\n",
    "            if input_lengths is not None:\n",
    "                output, _ = rnn.pad_packed_sequence(output, batch_first=True)\n",
    "                if output.size(1) < slen: # used for parallel\n",
    "                    padding = Variable(output.data.new(1, 1, 1).zero_())\n",
    "                    output = torch.cat([output, padding.expand(output.size(0), slen-output.size(1), output.size(2))], dim=1)\n",
    "            if self.return_last:\n",
    "                outputs.append(hidden.permute(1, 0, 2).contiguous().view(bsz, -1))\n",
    "            else:\n",
    "                outputs.append(output)\n",
    "        \n",
    "        if input_lengths is not None: \n",
    "            if self.concat:\n",
    "                return torch.cat(outputs, dim=2)[revert_idx,:]\n",
    "            return outputs[-1][revert_idx,:]\n",
    "        else:\n",
    "            if self.concat:\n",
    "                return torch.cat(outputs, dim=2)\n",
    "            return outputs[-1]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from transformers import BertModel, BertTokenizer\n",
    "import logging\n",
    "\n",
    "class CoMatchBertForwardRanking(nn.Module):\n",
    "    def __init__(self, args, max_length, num_class, pretrain_path, blank_padding=True):\n",
    "        super().__init__()\n",
    "        \n",
    "        args.hidden = args.hidden_dim\n",
    "        \n",
    "        self.max_length = max_length\n",
    "        self.blank_padding = blank_padding\n",
    "        self.hidden_size = 768\n",
    "\n",
    "        logging.info('Loading BERT pre-trained checkpoint.')\n",
    "        self.bert = BertModel.from_pretrained(pretrain_path)\n",
    "        \n",
    "        self.diff_linear = nn.Sequential(\n",
    "                nn.Linear(self.hidden_size*4, args.hidden),\n",
    "                nn.ReLU()\n",
    "            )\n",
    "\n",
    "        self.match_lstm = EncoderRNN(args.hidden*2, int(args.mlstm_hidden_dim/2), 1, False, True, 1-args.keep_prob, False)\n",
    "#         self.output_layer = nn.Linear(args.mlstm_hidden_dim, args.num_classes)\n",
    "        \n",
    "        self.num_classes = num_class\n",
    "        \n",
    "        self.siamese_output_layer = nn.Linear(args.hidden*2, self.num_classes)\n",
    "        self.output_layer = nn.Linear(args.mlstm_hidden_dim, self.num_classes)\n",
    "        self.loss = nn.CrossEntropyLoss()\n",
    "        \n",
    "    def bert_vars(self):\n",
    "        \"\"\"\n",
    "        Return the variables of the generator.\n",
    "        \"\"\"\n",
    "        params = list(self.bert.parameters())\n",
    "\n",
    "        return params\n",
    "    \n",
    "    def comatch_vars(self):\n",
    "        \"\"\"\n",
    "        Return the variables of the generator.\n",
    "        \"\"\"\n",
    "        params = list(self.diff_linear.parameters()) + list(\n",
    "            self.match_lstm.parameters()) + list(\n",
    "            self.output_layer.parameters())\n",
    "\n",
    "        return params\n",
    "        \n",
    "        \n",
    "    def _get_matching_representations(self, q_hiddens, p_hiddens, q_mask, p_mask):\n",
    "        '''\n",
    "        inputs -- (B, L1, H), (B, L2, H)\n",
    "        '''\n",
    "        similarity_matrix = torch.bmm(q_hiddens, p_hiddens.transpose(1, 2)) # (B, L1, L2)\n",
    "        mask_matrix = torch.bmm(q_mask.unsqueeze(2), p_mask.unsqueeze(1)) #(B, L1, L2)\n",
    "        \n",
    "        neg_inf = -1.0e6\n",
    "        attention_softmax = F.softmax(similarity_matrix + (1 * mask_matrix) * neg_inf, dim=2)\n",
    "        \n",
    "        # shape: (B, L1, H)                                                            \n",
    "        q_hiddens_tilda = torch.bmm(attention_softmax, p_hiddens)\n",
    "\n",
    "        # shape: (B, L1, 4*H)\n",
    "        q_matching_states = torch.cat([q_hiddens, q_hiddens_tilda,\n",
    "                                      q_hiddens - q_hiddens_tilda,\n",
    "                                      q_hiddens * q_hiddens_tilda], dim=2)\n",
    "        return q_matching_states\n",
    "        \n",
    "    def forward_siamese(self, i_hiddens, a_hiddens, o_hiddens, i_mask, a_mask, o_mask):\n",
    "        '''\n",
    "        inputs -- all (B, H)\n",
    "        '''\n",
    "        i_hiddens_expand_ = i_hiddens.unsqueeze(1).expand(i_mask.size(0), o_mask.size(0), -1) # (B, B, H)\n",
    "        a_hiddens_expand_ = a_hiddens.unsqueeze(1).expand(i_mask.size(0), o_mask.size(0), -1) # (B, B, H)\n",
    "        o_hiddens_expand_ = o_hiddens.unsqueeze(0).expand(i_mask.size(0), o_mask.size(0), -1) # (B, B, H)\n",
    "        i_hiddens_expand_ = i_hiddens_expand_.contiguous().view(i_mask.size(0)*o_mask.size(0), -1) # (BB, H)\n",
    "        a_hiddens_expand_ = a_hiddens_expand_.contiguous().view(i_mask.size(0)*o_mask.size(0), -1) # (BB, H)\n",
    "        o_hiddens_expand_ = o_hiddens_expand_.contiguous().view(i_mask.size(0)*o_mask.size(0), -1) # (BB, H)\n",
    "\n",
    "        diff_oi = torch.cat([o_hiddens_expand_, i_hiddens_expand_,\n",
    "                             o_hiddens_expand_ - i_hiddens_expand_,\n",
    "                             o_hiddens_expand_ * i_hiddens_expand_], dim=1) #(BB, 4H)\n",
    "        diff_oi = self.diff_linear(diff_oi) #(BB, H)\n",
    "        \n",
    "        diff_oa = torch.cat([o_hiddens_expand_, a_hiddens_expand_,\n",
    "                             o_hiddens_expand_ - a_hiddens_expand_,\n",
    "                             o_hiddens_expand_ * a_hiddens_expand_], dim=1) #(BB, 4H)\n",
    "        diff_oa = self.diff_linear(diff_oa) #(BB, H)\n",
    "        \n",
    "        co_match_inputs = torch.cat([diff_oi, diff_oa], dim=1) #(BB, 2H)\n",
    "        \n",
    "        predict = self.siamese_output_layer(co_match_inputs).view(i_mask.size(0), -1) #(B, B)\n",
    "        \n",
    "        return predict\n",
    "    \n",
    "        \n",
    "    def forward(self, i, a, o, i_mask, a_mask, o_mask):\n",
    "        \"\"\"\n",
    "        Args:\n",
    "            i: (B, L), index of tokens\n",
    "            a: (B, La), index of action tokens\n",
    "            o: (B, L'), index of tokens\n",
    "        Return:\n",
    "            return (B, B) scores\n",
    "        \"\"\"\n",
    "        i_hiddens, i_hiddens_cls = self.bert(i, attention_mask=i_mask) # (B, L1, H)\n",
    "        a_hiddens, a_hiddens_cls = self.bert(a, attention_mask=a_mask) # (B, La, H)\n",
    "        o_hiddens, o_hiddens_cls = self.bert(o, attention_mask=o_mask) # (B, L2, H)\n",
    "        \n",
    "        i_hiddens_expand_ = i_hiddens.unsqueeze(1).expand(i_mask.size(0), o_mask.size(0), \n",
    "                                                          i_hiddens.size(1), -1) # (B, B, L1, H)\n",
    "        a_hiddens_expand_ = a_hiddens.unsqueeze(1).expand(i_mask.size(0), o_mask.size(0), \n",
    "                                                          a_hiddens.size(1), -1) # (B, B, La, H)\n",
    "        o_hiddens_expand_ = o_hiddens.unsqueeze(0).expand(i_mask.size(0), o_mask.size(0), \n",
    "                                                          o_hiddens.size(1), -1) # (B, B, L2, H)\n",
    "        i_hiddens_expand_ = i_hiddens_expand_.contiguous().view(i_mask.size(0)*o_mask.size(0), \n",
    "                                                                i_hiddens.size(1), -1) # (BB, H)\n",
    "        a_hiddens_expand_ = a_hiddens_expand_.contiguous().view(i_mask.size(0)*o_mask.size(0), \n",
    "                                                                a_hiddens.size(1), -1) # (BB, H)\n",
    "        o_hiddens_expand_ = o_hiddens_expand_.contiguous().view(i_mask.size(0)*o_mask.size(0), \n",
    "                                                                o_hiddens.size(1), -1) # (BB, H)\n",
    "        \n",
    "        i_mask_expand_ = torch.ones(i_mask.size(0) * o_mask.size(0), i_mask.size(1)).cuda() #(B1*B2, L1)\n",
    "        a_mask_expand_ = torch.ones(i_mask.size(0) * o_mask.size(0), a_mask.size(1)).cuda() #(B1*B2, La)\n",
    "        o_mask_expand_ = torch.ones(i_mask.size(0) * o_mask.size(0), o_mask.size(1)).cuda() #(B1*B2, L2)\n",
    "        \n",
    "        \n",
    "        diff_oi = self._get_matching_representations(o_hiddens_expand_, i_hiddens_expand_, \n",
    "                                                     o_mask_expand_, i_mask_expand_) # (B1*B2, L2, 4*H)\n",
    "        diff_oa = self._get_matching_representations(o_hiddens_expand_, a_hiddens_expand_, \n",
    "                                                     o_mask_expand_, a_mask_expand_) # (B1*B2, L2, 4*H)\n",
    "        \n",
    "        diff_oi = self.diff_linear(diff_oi) # (B1*B2, L2, H)\n",
    "        diff_oa = self.diff_linear(diff_oa) # (B1*B2, L2, H)\n",
    "        \n",
    "        co_match_inputs = torch.cat([diff_oi, diff_oa], dim=2)\n",
    "        o_len = torch.sum(o_mask_expand_, dim=1).cpu().data.numpy()\n",
    "        \n",
    "        co_match_hiddens = self.match_lstm(co_match_inputs, o_len) # (B1*B2, L2, H)\n",
    "\n",
    "        neg_inf = -1.0e6\n",
    "        co_match_hiddens = co_match_hiddens + (1 - o_mask_expand_.unsqueeze(2)) * neg_inf \n",
    "        \n",
    "        max_co_match_hiddens = torch.max(co_match_hiddens, dim=1)[0] #(B1*B2, H)\n",
    "        predict = self.output_layer(max_co_match_hiddens).view(i_mask.size(0), o_mask.size(0)) # (B1, B2)\n",
    "#         predict += self.forward_siamese(i_hiddens_cls, a_hiddens_cls, o_hiddens_cls, i_mask, a_mask, o_mask)\n",
    "        \n",
    "        return predict\n",
    "    "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "\n",
    "# train set\n",
    "Number of positive actions: 8732\n",
    "Number of total action candidates: 697315\n",
    "# dev set\n",
    "Number of positive actions: 2970\n",
    "Number of total action candidates: 366306\n",
    "# test set\n",
    "Number of positive actions: 2700\n",
    "Number of total action candidates: 247488\n",
    "size of the raw vocabulary: 14545\n",
    "size of the final vocabulary: 14333"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class Argument():\n",
    "    def __init__(self):\n",
    "        self.hidden_dim = 200\n",
    "        self.mlstm_hidden_dim = 100\n",
    "        self.embedding_dim = 100\n",
    "#         self.embedding_dim = 300\n",
    "        self.num_classes = 1\n",
    "        self.kernel_size = 3\n",
    "        self.layer_num = 1\n",
    "        self.fine_tuning = False\n",
    "        self.cuda = True\n",
    "        self.lambda_l2 = 0.05\n",
    "        self.model_type = \"LSTM\"\n",
    "        self.cell_type = \"GRU\"\n",
    "        self.batch_size = 10\n",
    "        self.input_topk = 32\n",
    "        self.keep_prob = 0.8\n",
    "        self.predict_target_topk = 5\n",
    "        self.save_path = 'trained_models'\n",
    "#         self.model_prefix = 'hotpot_reranker_model_h%d.with_anchor_with_el'%self.hidden_dim\n",
    "        self.model_prefix = 'tmp_model'\n",
    "        self.load_model = False\n",
    "        self.load_path = '.'\n",
    "        \n",
    "args = Argument()\n",
    "print(vars(args))\n",
    "\n",
    "device = 'cuda'\n",
    "\n",
    "ranker_model = CoMatchBertForwardRanking(args, max_length=384, num_class=1, pretrain_path=pretrain_path)\n",
    "    \n",
    "# load_model = True\n",
    "# if load_model:\n",
    "#     pre_trained = torch.load('comatch_bert_sas_jecc.2e-5.model.pt')\n",
    "# #     pre_trained = torch.load('siamese_bert_sas.model.pt')\n",
    "    \n",
    "#     ranker_model.load_state_dict(pre_trained)\n",
    "\n",
    "ranker_model.to(device)\n",
    "print(ranker_model)\n",
    "\n",
    "# bert_optimizer = torch.optim.Adam(ranker_model.bert_vars(), lr=2e-5)\n",
    "# comatch_optimizer = torch.optim.Adam(ranker_model.comatch_vars(), lr=1e-3)\n",
    "\n",
    "supervise_optimizer = torch.optim.Adam(filter(lambda x: x.requires_grad, \n",
    "                                              ranker_model.parameters()), lr=2e-5)\n",
    "\n",
    "# if args.cuda:\n",
    "#     ranker_model.cuda()\n",
    "    \n",
    "sys.stdout.flush()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_n_params(model):\n",
    "    pp=0\n",
    "    for p in list(model.parameters()):\n",
    "        nn=1\n",
    "        for s in list(p.size()):\n",
    "            nn = nn*s\n",
    "        pp += nn\n",
    "    return pp\n",
    "\n",
    "print(get_n_params(ranker_model))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# train_losses = []\n",
    "# train_accs = []\n",
    "# dev_accs = [0.0]\n",
    "# test_accs = [0.0]\n",
    "# num_iteration = 20000\n",
    "# display_iteration = 200\n",
    "# test_iteration = 100\n",
    "# queue_length = 400\n",
    "# num_epoch = 40\n",
    "\n",
    "# best_dev = 0.0\n",
    "# best_dev_wt = 0.0\n",
    "# batch_size_tuple = 2\n",
    "\n",
    "# from tqdm.notebook import tqdm #_notebook\n",
    "\n",
    "# if load_model:\n",
    "#     with torch.no_grad():\n",
    "#         ranker_model.eval()\n",
    "\n",
    "#         for eval_set in ['dev', 'test']:\n",
    "#             dev_correct = 0.0\n",
    "#             dev_total = 0.0\n",
    "\n",
    "#             dev_correct_wt = 0.0\n",
    "#             dev_total_wt = 0.0\n",
    "\n",
    "#             num_dev_instance = game_task_data.data_sets[eval_set].size()\n",
    "\n",
    "#             for inst_id in tqdm(range(num_dev_instance)):\n",
    "#                 i_mat, a_mat, o_mat, y_vec, i_mask, a_mask, o_mask = game_task_data.get_eval_batch_triple(eval_set, \n",
    "#                                                                                            [inst_id])\n",
    "\n",
    "#                 i_mat_ = Variable(torch.from_numpy(i_mat)).to(device)\n",
    "#                 a_mat_ = Variable(torch.from_numpy(a_mat)).to(device)\n",
    "#                 o_mat_ = Variable(torch.from_numpy(o_mat)).to(device)\n",
    "#                 i_mask_ = Variable(torch.from_numpy(i_mask)).float().to(device)\n",
    "#                 a_mask_ = Variable(torch.from_numpy(a_mask)).float().to(device)\n",
    "#                 o_mask_ = Variable(torch.from_numpy(o_mask)).float().to(device)\n",
    "#                 y_vec_ = Variable(torch.from_numpy(y_vec)).to(device)\n",
    "                \n",
    "#                 shuffle_idx = list(range(o_mask_.size(0)))\n",
    "#                 random.shuffle(shuffle_idx)\n",
    "#                 shuffle_idx = np.array(shuffle_idx)\n",
    "#                 o_mat_ = o_mat_[shuffle_idx,:]\n",
    "#                 o_mask_ = o_mask_[shuffle_idx,:]\n",
    "\n",
    "#                 predict = ranker_model(i_mat_, a_mat_, o_mat_, i_mask_, a_mask_, o_mask_)\n",
    "\n",
    "#                 _, y_pred = torch.max(predict, dim=1)\n",
    "#                 y_pred = shuffle_idx[y_pred.cpu().numpy()]\n",
    "\n",
    "#                 dev_correct += (y_pred == y_vec).sum()\n",
    "#                 dev_total += y_vec_.size(0)\n",
    "\n",
    "# #                         y_pred = y_pred.cpu().data\n",
    "\n",
    "#                 if y_vec_[0].item() == y_pred[0].item():\n",
    "#                     dev_correct_wt += 1\n",
    "# #                         if y_vec_[5].item() == y_pred[5].item():\n",
    "# #                             dev_correct_wt += 1\n",
    "#                 dev_total_wt += 1\n",
    "\n",
    "#             if eval_set == 'dev':\n",
    "#                 dev_accs.append(dev_correct / dev_total)\n",
    "#                 if dev_correct / dev_total > best_dev:\n",
    "#                     best_dev = dev_correct / dev_total\n",
    "                    \n",
    "#                 if dev_correct_wt/dev_total_wt > best_dev_wt:\n",
    "#                     best_dev_wt = dev_correct_wt/dev_total_wt\n",
    "\n",
    "#                 print('total: %d'%(dev_total))\n",
    "#                 print('total wt: %d'%(dev_total_wt))\n",
    "\n",
    "#                 print('dev acc: %f, best dev acc: %f' %(dev_correct/dev_total, best_dev))\n",
    "#                 print('wt dev acc: %f, best wt dev acc: %f' %(dev_correct_wt/dev_total_wt, best_dev_wt))\n",
    "\n",
    "#             else:\n",
    "#                 test_accs.append(dev_correct / dev_total)\n",
    "#                 print('total: %d'%(dev_total))\n",
    "#                 print('total wt: %d'%(dev_total_wt))\n",
    "\n",
    "#                 print('test acc: %f' %(dev_correct/dev_total))\n",
    "#                 print('wt test acc: %f' %(dev_correct_wt/dev_total_wt))\n",
    "#             sys.stdout.flush()\n",
    "\n",
    "#         ranker_model.train()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_losses = []\n",
    "train_accs = []\n",
    "dev_accs = [0.0]\n",
    "test_accs = [0.0]\n",
    "num_iteration = 20000\n",
    "display_iteration = 200\n",
    "test_iteration = 100\n",
    "queue_length = 400\n",
    "num_epoch = 40\n",
    "\n",
    "best_dev = 0.0\n",
    "best_dev_wt = 0.0\n",
    "batch_size_tuple = 2\n",
    "\n",
    "from tqdm.notebook import tqdm #_notebook\n",
    "\n",
    "# i_mat, o_mat, y_vec, i_mask, o_mask = game_task_data.get_batch_concat('train', \n",
    "#                                                         list(range(tid * batch_size_tuple, \n",
    "#                                                         tid * batch_size_tuple + batch_size_tuple)), \n",
    "#                                                         num_negative=4)\n",
    "\n",
    "# supervise_optimizer = torch.optim.Adam(filter(lambda x: x.requires_grad, \n",
    "#                                               ranker_model.parameters()), lr=2e-4)\n",
    "\n",
    "# i_mat, o_mat, y_vec, i_mask, o_mask = game_task_data.get_eval_batch_concat('dev', \n",
    "#                                                                            [inst_id])\n",
    "\n",
    "for eid in range(num_epoch):\n",
    "#for i in xrange(num_iteration):\n",
    "    ranker_model.train()\n",
    "    i = 0\n",
    "    \n",
    "    total = game_task_data.data_sets['train'].size()\n",
    "    print('total num of tuples:', total)\n",
    "    \n",
    "    num_train_instance = game_task_data.data_sets['train'].size()\n",
    "    tid_list = list(range(num_train_instance))\n",
    "    random.shuffle(tid_list)\n",
    "\n",
    "    for tid in tqdm(tid_list):\n",
    "        i_mat, a_mat, o_mat, y_vec, i_mask, a_mask, o_mask = game_task_data.get_eval_batch_triple('train', [tid])\n",
    "    \n",
    "#     for tid in tqdm(range(total // batch_size_tuple)):\n",
    "#         i_mat, o_mat, y_vec, i_mask, o_mask = game_task_data.get_batch_concat('train', \n",
    "#                                                                 list(range(tid * batch_size_tuple, \n",
    "#                                                                 tid * batch_size_tuple + batch_size_tuple)), \n",
    "#                                                                 num_negative=4)\n",
    "    \n",
    "        supervise_optimizer.zero_grad()\n",
    "#         bert_optimizer.zero_grad()\n",
    "#         comatch_optimizer.zero_grad()\n",
    "        i_mat_ = Variable(torch.from_numpy(i_mat)).to(device)\n",
    "        a_mat_ = Variable(torch.from_numpy(a_mat)).to(device)\n",
    "        o_mat_ = Variable(torch.from_numpy(o_mat)).to(device)\n",
    "        i_mask_ = Variable(torch.from_numpy(i_mask)).float().to(device)\n",
    "        a_mask_ = Variable(torch.from_numpy(a_mask)).float().to(device)\n",
    "        o_mask_ = Variable(torch.from_numpy(o_mask)).float().to(device)\n",
    "        y_vec_ = Variable(torch.from_numpy(y_vec)).to(device)\n",
    "\n",
    "        predict = ranker_model(i_mat_, a_mat_, o_mat_, i_mask_, a_mask_, o_mask_)\n",
    "        \n",
    "        supervised_loss = ranker_model.loss(predict, y_vec_)\n",
    "\n",
    "        _, y_pred = torch.max(predict, dim=1)\n",
    "        acc = np.float((y_pred == y_vec_).sum().cpu().data.item()) / y_vec_.size(0) # / args.batch_size\n",
    "        train_accs.append(acc)\n",
    "    \n",
    "        supervised_loss.backward()\n",
    "        supervise_optimizer.step()\n",
    "#         bert_optimizer.step()\n",
    "#         comatch_optimizer.step()\n",
    "    \n",
    "        i += 1\n",
    "    \n",
    "        if i % display_iteration == 0:\n",
    "            print('train acc: %f supervised_loss: %f'%(np.mean(train_accs), \n",
    "                                           supervised_loss.cpu().data.item()))\n",
    "            sys.stdout.flush()\n",
    "    print('Epoch%d:'%(eid))\n",
    "    game_task_data.display_sentence(i_mat[0])\n",
    "    game_task_data.display_sentence(a_mat[0])\n",
    "    game_task_data.display_sentence(o_mat[0])\n",
    "\n",
    "    print('train acc: %f supervised_loss: %f'%(np.mean(train_accs), \n",
    "                                               supervised_loss.cpu().data.item()))\n",
    "    train_losses = []\n",
    "    train_accs = []\n",
    "    \n",
    "#     continue\n",
    "\n",
    "    with torch.no_grad():\n",
    "        ranker_model.eval()\n",
    "\n",
    "        for eval_set in ['dev', 'test']:\n",
    "        #         print('Training mode:', ranking_model.training)\n",
    "            dev_correct = 0.0\n",
    "            dev_total = 0.0\n",
    "\n",
    "            dev_correct_wt = 0.0\n",
    "            dev_total_wt = 0.0\n",
    "\n",
    "#                     topk_coverage_dict = {1:0,2:0,5:0,10:0}\n",
    "#                     topks = [1, 2, 5, 10]\n",
    "\n",
    "            num_dev_instance = game_task_data.data_sets[eval_set].size()\n",
    "\n",
    "#                     for inst_id in tqdm(range(num_dev_instance // batch_size_tuple)):\n",
    "#                         i_mat, o_mat, y_vec, i_mask, o_mask = game_task_data.get_batch_concat('dev', \n",
    "#                                                                 list(range(inst_id * batch_size_tuple, \n",
    "#                                                                 inst_id * batch_size_tuple + batch_size_tuple)), \n",
    "#                                                                 num_negative=4)\n",
    "\n",
    "            for inst_id in tqdm(range(num_dev_instance)):\n",
    "                i_mat, a_mat, o_mat, y_vec, i_mask, a_mask, o_mask = game_task_data.get_eval_batch_triple(eval_set, \n",
    "                                                                                           [inst_id])\n",
    "\n",
    "                i_mat_ = Variable(torch.from_numpy(i_mat)).to(device)\n",
    "                a_mat_ = Variable(torch.from_numpy(a_mat)).to(device)\n",
    "                o_mat_ = Variable(torch.from_numpy(o_mat)).to(device)\n",
    "                i_mask_ = Variable(torch.from_numpy(i_mask)).float().to(device)\n",
    "                a_mask_ = Variable(torch.from_numpy(a_mask)).float().to(device)\n",
    "                o_mask_ = Variable(torch.from_numpy(o_mask)).float().to(device)\n",
    "                y_vec_ = Variable(torch.from_numpy(y_vec)).to(device)\n",
    "\n",
    "                shuffle_idx = list(range(o_mask_.size(0)))\n",
    "                random.shuffle(shuffle_idx)\n",
    "                shuffle_idx = np.array(shuffle_idx)\n",
    "                o_mat_ = o_mat_[shuffle_idx,:]\n",
    "                o_mask_ = o_mask_[shuffle_idx,:]\n",
    "                \n",
    "                predict = ranker_model(i_mat_, a_mat_, o_mat_, i_mask_, a_mask_, o_mask_)\n",
    "\n",
    "                _, y_pred = torch.max(predict, dim=1)\n",
    "                y_pred = shuffle_idx[y_pred.cpu().numpy()]\n",
    "\n",
    "                dev_correct += (y_pred == y_vec).sum()\n",
    "#                 dev_correct += np.float((y_pred == y_vec_).sum().cpu().data.item())\n",
    "                dev_total += y_vec_.size(0)\n",
    "\n",
    "#                         y_pred = y_pred.cpu().data\n",
    "\n",
    "                if y_vec_[0].item() == y_pred[0].item():\n",
    "                    dev_correct_wt += 1\n",
    "#                         if y_vec_[5].item() == y_pred[5].item():\n",
    "#                             dev_correct_wt += 1\n",
    "                dev_total_wt += 1\n",
    "\n",
    "            if eval_set == 'dev':\n",
    "                dev_accs.append(dev_correct / dev_total)\n",
    "                if dev_correct / dev_total > best_dev:\n",
    "                    best_dev = dev_correct / dev_total\n",
    "#                     print('new best dev:', best_dev, 'model saved at', 'comatch_siamese_bert_sas_jecc.model.pt')\n",
    "#                     torch.save(ranker_model.state_dict(), 'comatch_siamese_bert_sas_jecc.model.pt')\n",
    "                if dev_correct_wt/dev_total_wt > best_dev_wt:\n",
    "                    best_dev_wt = dev_correct_wt/dev_total_wt\n",
    "#                             print('new best dev:', best_dev, 'model saved at', 'siamese_bert_sas.model.rand50.pt')\n",
    "#                             torch.save(ranker_model.state_dict(), 'siamese_bert_p1only.model.rand50.pt')\n",
    "                    print('new best dev:', best_dev, 'model saved at', 'comatch_bert_with_res_jecc_no_wt_change.model.pt')\n",
    "                    torch.save(ranker_model.state_dict(), 'comatch_bert_with_res_jecc_no_wt_change.model.pt')\n",
    "#                     print('new best dev:', best_dev, 'model saved at', 'comatch_bert_sas_jecc.test_shuffled.model.pt')\n",
    "#                     torch.save(ranker_model.state_dict(), 'comatch_bert_sas_jecc.test_shuffled.model.pt')\n",
    "\n",
    "                print('total: %d'%(dev_total))\n",
    "                print('total wt: %d'%(dev_total_wt))\n",
    "\n",
    "                print('dev acc: %f, best dev acc: %f' %(dev_correct/dev_total, best_dev))\n",
    "                print('wt dev acc: %f, best wt dev acc: %f' %(dev_correct_wt/dev_total_wt, best_dev_wt))\n",
    "\n",
    "            else:\n",
    "                test_accs.append(dev_correct / dev_total)\n",
    "                print('total: %d'%(dev_total))\n",
    "                print('total wt: %d'%(dev_total_wt))\n",
    "\n",
    "                print('test acc: %f' %(dev_correct/dev_total))\n",
    "                print('wt test acc: %f' %(dev_correct_wt/dev_total_wt))\n",
    "            sys.stdout.flush()\n",
    "\n",
    "        ranker_model.train()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# print(y_vec)\n",
    "# print(predict[0])\n",
    "\n",
    "\n",
    "# # for i in range(i_mat.shape[0]):\n",
    "# #     game_task_data.display_sentence(i_mat[i])\n",
    "# # print('')\n",
    "\n",
    "# # for i in range(i_mat.shape[0]):\n",
    "# #     game_task_data.display_sentence(o_mat[i])\n",
    "# # print('')\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
