#!/usr/bin/env python

import torch
import torch.nn as nn
from torch.nn.utils import clip_grad_norm_
import torch.nn.functional as F

from models.base_model import BaseModel
from models.util.embedder import Embedder
from models.encoder import RNNEncoder
from models.decoder import Decoder as Decoder
from models.util.criterions import NLLLoss, CopyGeneratorLoss
from models.util.misc import Pack
from models.evaluation.metrics import accuracy, perplexity
from models.util.attention import Attention
from copy import deepcopy
import pdb

class MindRNN(BaseModel):
    def __init__(self, vocab_size, args, encoder_embeddings=None,
                 padding_idx=None, encoder_padding_idx=None, unk_idx=None, device=None):
        super().__init__()

        self.hidden_size = args.hidden_size
        self.with_bridge = args.with_bridge
        self.vocab_size = vocab_size
        self.embed_size = args.embed_size
        self.padding_idx = padding_idx
        self.num_layers = args.num_layers
        self.bidirectional = args.bidirectional
        self.dropout = args.dropout

        enc_embedder = Embedder(num_embeddings=self.vocab_size,
                                embedding_dim=self.embed_size, padding_idx=self.padding_idx)
        self.enc_embedder = enc_embedder
        self.utt_encoder = RNNEncoder(input_size=self.embed_size, hidden_size=self.hidden_size,
                                      embedder=enc_embedder, num_layers=self.num_layers,
                                      bidirectional=self.bidirectional, dropout=self.dropout)
        self.utt_mind_encoder = RNNEncoder(input_size=self.embed_size, hidden_size=self.hidden_size,
                                      embedder=enc_embedder, num_layers=self.num_layers,
                                      bidirectional=self.bidirectional, dropout=self.dropout)
        self.encoder_embeddings = encoder_embeddings
        self.use_gpu = args.use_gpu
        self.attn_mode = args.attn
        self.device = device if device >= 0 else "cpu"
        self.encoder_padding_idx = encoder_padding_idx

        if self.with_bridge:
            self.utt_mind_bridge = nn.Sequential(nn.Linear(self.hidden_size, self.hidden_size), nn.Tanh())
        self.mind_encoder = nn.Linear(self.embed_size, self.hidden_size)
        self.mse_loss = nn.MSELoss(reduction='mean')
        self.ba = Attention(query_size=self.hidden_size,
                             memory_size=self.hidden_size,
                             mode=self.attn_mode,
                             activation="tanh",
                             project=False,
                             device=self.device)

        if self.use_gpu:
            self.cuda()

    def get_mind_encoding(self, ba_schema, col_len):
        ba_schema = ba_schema.reshape(-1, col_len)
        ba_schema_inputs = self.enc_embedder(ba_schema)
        ba_schema_enc_outputs = self.mind_encoder(ba_schema_inputs)
        return ba_schema_enc_outputs

    def encode(self, inputs, is_training=False):
        """
        encode
        """
        outputs = Pack()
        # utt encoding info.
        utt_mind_inputs = inputs.mind_src
        utt_lengths = (utt_mind_inputs!=self.encoder_padding_idx).sum(dim=-1)
        utt_enc_outputs, utt_enc_hidden = self.utt_mind_encoder((utt_mind_inputs, utt_lengths))

        if self.with_bridge:
            utt_enc_hidden = self.utt_mind_bridge(utt_enc_hidden)

        # mind encoder
        ba_schema = inputs.ba_schema
        batch_size, col_num, col_len = ba_schema.shape
        schemas = inputs.bina_schema
        # [B*num, len, h]
        ba_schema_enc_outputs = self.get_mind_encoding(ba_schema, col_len)
        schema_enc_outputs = self.get_mind_encoding(schemas, col_len)

        # bA
        ## [B*num, 1, h]
        ba_col_query = utt_enc_hidden.squeeze(0).unsqueeze(1).repeat(1, col_num, 1).view(-1,
                                                                                  self.hidden_size).unsqueeze(1)
        ba_schema_attn_mask = (inputs.ba_schema !=self.encoder_padding_idx).view(-1, col_len)
        weighted_col_ba, col_ba_attn = self.ba(query=ba_col_query,
                                               memory=ba_schema_enc_outputs,
                                               mask=~ba_schema_attn_mask)
        outputs.add(prior_ba_attn=col_ba_attn.squeeze(1).view(batch_size, col_num, col_len))

        # bBinA
        bina_col_query = utt_enc_hidden.squeeze(0).unsqueeze(1).repeat(1, col_num, 1).view(-1,
                                                                                    self.hidden_size).unsqueeze(
            1)
        schema_attn_mask = (inputs.bina_schema !=self.encoder_padding_idx).view(-1, col_len)
        weighted_col_bina, col_bina_attn = self.ba(query=bina_col_query,
                                               memory=schema_enc_outputs,
                                               mask=~schema_attn_mask)
        outputs.add(prior_bina_attn=col_bina_attn.squeeze(1).view(batch_size, col_num, col_len))
        return outputs


    def forward(self, enc_inputs, is_training=False):

        outputs = self.encode(
                enc_inputs, is_training=is_training)
        return outputs

    def collect_metrics(self, outputs, ba_gt, bina_gt, epoch=-1, is_training=False):

        num_samples = ba_gt.size(0)
        metrics = Pack(num_samples=num_samples)
        loss = 0

        ba_loss = self.mse_loss(outputs.prior_ba_attn, ba_gt)
        bina_loss = self.mse_loss(outputs.prior_bina_attn, bina_gt)
        metrics.add(ba=ba_loss)
        metrics.add(bina=bina_loss)
        loss = ba_loss + bina_loss
        metrics.add(loss=loss)
        return metrics

    def iterate(self, inputs, optimizer=None, grad_clip=None, is_training=False, epoch=-1):
        enc_inputs = inputs

        outputs = self.forward(enc_inputs, is_training=is_training)
        ba_gt = inputs.ba_gt
        bina_gt = inputs.bina_gt
        metrics = self.collect_metrics(outputs, ba_gt, bina_gt, epoch=epoch, is_training=is_training)

        loss = metrics.loss
        if torch.isnan(loss):
            pdb.set_trace()
            raise ValueError("nan loss encountered")

        if is_training:
            assert optimizer is not None
            optimizer.zero_grad()
            loss.backward()
            if grad_clip is not None and grad_clip > 0:
                clip_grad_norm_(parameters=self.parameters(),
                                max_norm=grad_clip)
            optimizer.step()
        return metrics, outputs

class RNNWithMind(BaseModel):
    def __init__(self, vocab_size, mind_model, args,
                 padding_idx=None, encoder_padding_idx=None, unk_idx=None, device=None):
        super().__init__()

        self.vocab_size = vocab_size
        self.embed_size = args.embed_size
        self.hidden_size = args.hidden_size
        self.padding_idx = padding_idx
        self.encoder_padding_idx = encoder_padding_idx
        self.num_layers = args.num_layers
        self.bidirectional = args.bidirectional
        self.attn_mode = args.attn
        self.with_bridge = args.with_bridge
        self.tie_embedding = args.tie_embedding
        self.dropout = args.dropout
        self.use_gpu = args.use_gpu
        self.use_bow = args.use_bow
        self.use_kd = args.use_kd
        self.use_posterior = args.use_posterior
        self.baseline = 0
        self.device = device if device >= 0 else "cpu"
        self.unk_idx = unk_idx
        self.force_copy = args.force_copy
        self.stage = args.stage
        self.remove_final = args.remove_final
        self.mind_weight = args.mind_weight

        enc_embedder = Embedder(num_embeddings=self.vocab_size,
                                embedding_dim=self.embed_size, padding_idx=self.padding_idx)
        self.enc_embedder = enc_embedder
        self.utt_encoder = RNNEncoder(input_size=self.embed_size, hidden_size=self.hidden_size,
                                      embedder=enc_embedder, num_layers=self.num_layers,
                                      bidirectional=self.bidirectional, dropout=self.dropout)
        self.mind_model = mind_model
        self.col_temper = args.col_temp
        self.with_bina = args.with_bina

        for p in self.mind_model.parameters():
            p.requires_grad = False

        if self.with_bridge:
            self.utt_bridge = nn.Sequential(nn.Linear(self.hidden_size, self.hidden_size), nn.Tanh())

        self.action = nn.Linear(self.hidden_size, 3)

        self.knowledge_encoder = RNNEncoder(input_size=self.embed_size, hidden_size=self.hidden_size,
                                      embedder=enc_embedder, num_layers=self.num_layers,
                                      bidirectional=self.bidirectional, dropout=self.dropout)


        self.prior_attention = Attention(query_size=self.hidden_size,
                                         memory_size=self.hidden_size,
                                         hidden_size=self.hidden_size,
                                         mode="dot",
                                         device=self.device)

        self.posterior_attention = Attention(query_size=self.hidden_size,
                                             memory_size=self.hidden_size,
                                             hidden_size=self.hidden_size,
                                             mode="dot",
                                             device=self.device)

        self.mention = Attention2(query_size=self.hidden_size,
                                  memory_size=self.hidden_size,
                                  mode=self.attn_mode,
                                  activation='sigmoid',
                                  project=False,
                                  device=self.device)

        self.decoder = Decoder(input_size=self.embed_size, hidden_size=self.hidden_size,
                               output_size=self.vocab_size, mind_weight=args.mind_weight,
                               remove_final=args.remove_final,
                               embedder=enc_embedder,
                               num_layers=self.num_layers, attn_mode=self.attn_mode,
                               memory_size=self.hidden_size, dropout=self.dropout,
                               device=self.device)

        self.softmax = nn.Softmax(dim=-1)
        self.sigmoid = nn.Sigmoid()
        self.tanh = nn.Tanh()
        self.fc = nn.Linear(self.hidden_size * 2, 1)
        self.mention_query_encoder = nn.Linear(self.hidden_size, self.hidden_size)
        self.mention_memory_encoder = nn.Linear(self.embed_size, self.hidden_size)

        if self.use_bow:
            self.bow_output_layer = nn.Sequential(
                    nn.Linear(in_features=self.hidden_size, out_features=self.hidden_size),
                    nn.Tanh(),
                    nn.Linear(in_features=self.hidden_size, out_features=self.vocab_size),
                    nn.LogSoftmax(dim=-1))


        if self.use_kd:
            self.knowledge_dropout = nn.Dropout(self.dropout)

        if self.padding_idx is not None:
            self.weight = torch.ones(self.vocab_size)
            self.weight[self.padding_idx] = 0
        else:
            self.weight = None

        self.nll_loss = NLLLoss(weight=self.weight, ignore_index=self.padding_idx,
                                reduction='mean')
        self.copy_gen_loss = CopyGeneratorLoss(vocab_size=self.vocab_size,
                                               force_copy=self.force_copy,
                                               unk_index=self.unk_idx,
                                               ignore_index=self.padding_idx)

        self.kl_loss = torch.nn.KLDivLoss(reduction="mean")
        self.cross_entropy_loss = torch.nn.CrossEntropyLoss()
        self.mind_diff = torch.nn.KLDivLoss(reduction="none")
        self.bce = nn.BCELoss(reduction='none')

        if self.use_gpu:
            self.cuda()
            self.weight = self.weight.cuda()

    def get_mind(self, inputs):
        # utt encoding info.
        utt_mind_inputs = inputs.mind_src
        batch_size, sent_num, sent_len = inputs.mind_src.shape
        utt_mind_inputs = utt_mind_inputs.view(-1, sent_len)
        utt_lengths = (utt_mind_inputs != self.encoder_padding_idx).sum(dim=-1)
        if utt_lengths.sum() == 0:
            assert len(utt_lengths) == 1
            utt_lengths[0] = 1
        utt_enc_outputs, utt_enc_hidden = self.mind_model.utt_mind_encoder((utt_mind_inputs, utt_lengths))

        if self.with_bridge:
            utt_enc_hidden = self.mind_model.utt_mind_bridge(utt_enc_hidden)

        # mind encoder
        ba_schema = inputs.ba_schema
        batch_size, col_num, col_len = ba_schema.shape
        schemas = inputs.bina_schema
        # [B*num, len, h]
        repeat_ba_memory = inputs.ba_schema.unsqueeze(1).repeat(1, sent_num, 1, 1).view(-1, col_num, col_len)
        ba_schema_enc_outputs = self.mind_model.get_mind_encoding(repeat_ba_memory, col_len)
        repeat_bina_memory = inputs.bina_schema.unsqueeze(1).repeat(1, sent_num, 1, 1).view(-1, col_num, col_len)
        schema_enc_outputs = self.mind_model.get_mind_encoding(repeat_bina_memory, col_len)

        # bA
        ## [B*num, 1, h]
        ba_col_query = utt_enc_hidden.squeeze(0).unsqueeze(1).repeat(1, col_num, 1).view(-1,
                                                                                  self.hidden_size).unsqueeze(1)
        
        ba_schema_attn_mask = (repeat_ba_memory != self.encoder_padding_idx).view(-1, col_len)
        weighted_col_ba, col_ba_attn = self.mind_model.ba(query=ba_col_query,
                                               memory=ba_schema_enc_outputs,
                                               mask=~ba_schema_attn_mask)
        sum_mask = (utt_lengths.view(batch_size, sent_num) > 0).unsqueeze(-1).unsqueeze(-1)
        vec_ba = weighted_col_ba.view(batch_size, sent_num, col_num, -1).mul(sum_mask).sum(dim=1)
        sum_ba_attn = col_ba_attn.view(batch_size, sent_num, col_num, col_len).mul(sum_mask).sum(dim=1)
        prior_ba = sum_ba_attn.squeeze(1).view(batch_size, col_num, col_len)

        # bBinA
        bina_col_query = utt_enc_hidden.squeeze(0).unsqueeze(1).repeat(1, col_num, 1).view(-1,
                                                                                    self.hidden_size).unsqueeze(
            1)
        schema_attn_mask = (repeat_bina_memory != self.encoder_padding_idx).view(-1, col_len)
        weighted_col_bina, col_bina_attn = self.mind_model.ba(query=bina_col_query,
                                               memory=schema_enc_outputs,
                                               mask=~schema_attn_mask)
        vec_bina = weighted_col_bina.view(batch_size, sent_num, col_num, -1).mul(sum_mask).sum(dim=1)
        sum_bina_attn = col_bina_attn.view(batch_size, sent_num, col_num, col_len).mul(sum_mask).sum(dim=1)
        prior_bina = sum_bina_attn.squeeze(1).view(batch_size, col_num, col_len)
        return prior_ba, prior_bina, vec_ba, vec_bina

    def js_div(self, p_output, q_output):
        log_mean_output = ((p_output + q_output) / 2 + 1e-20).log()
        return (self.mind_diff(log_mean_output, p_output) + self.mind_diff(log_mean_output, q_output)) / 2

    def get_schema_encoding(self, schema):
        ba_schema = schema.reshape(-1, schema.shape[-1])
        ba_schema_inputs = self.mind_model.enc_embedder(ba_schema)
        ba_schema_enc_outputs = self.mention_memory_encoder(ba_schema_inputs)
        return ba_schema_enc_outputs

    def encode(self, inputs, is_training=False):
        """
        encode
        """
        outputs = Pack()
        # utt encoding info.
        utt_inputs = inputs.src
        utt_lengths = (utt_inputs != self.encoder_padding_idx).sum(dim=-1)
        utt_enc_outputs, utt_enc_hidden = self.utt_encoder((utt_inputs, utt_lengths))

        if self.with_bridge:
            utt_enc_hidden = self.utt_bridge(utt_enc_hidden)

        # knowledge
        batch_size, sent_num, sent = inputs.cue.size()

        # cue_enc_hidden.size() == [1, batch_size * sent_num, hidden_size]
        # cue_enc_outputs.size() == [batch_size * sent_num, sent_len, hidden_size]
        cue_lengths = (inputs.cue.view(-1, sent)!=self.encoder_padding_idx).sum(dim=-1)
        cue_enc_outputs, cue_enc_hidden = self.knowledge_encoder((inputs.cue.view(-1, sent), cue_lengths))

        # cue_enc_hidden[-1].size() == [batch_size * sent_num, hidden_size]
        # [batch_size, sent_num, hidden_size]
        cue_enc_outputs = cue_enc_outputs.view(batch_size, sent_num, cue_enc_outputs.size(-2), -1)
        cue_outputs = cue_enc_hidden.view(batch_size, sent_num, -1)

        # Attention
        cue_output_mask = (inputs.cue != self.encoder_padding_idx).sum(dim=-1)
        prior_query = self.tanh(utt_enc_hidden.squeeze(0)).unsqueeze(1)
        weighted_cue, cue_attn = self.prior_attention(query=prior_query,
                                                      memory=self.tanh(cue_outputs),
                                                      mask=cue_output_mask.eq(0))

        prior_attn = cue_attn.squeeze(1)
        outputs.add(prior_attn=prior_attn)

        action = self.action(utt_enc_hidden.squeeze(0).clone())
        outputs.add(action=action)

        # mind
        prior_ba, prior_bina, vec_ba, vec_bina = self.get_mind(inputs)
        outputs.add(prior_ba_attn=prior_ba)
        outputs.add(prior_bina_attn=prior_bina)

        prior_ba_normalize = prior_ba.clone()
        ba_mask = inputs.bina_schema==self.encoder_padding_idx
        prior_ba_normalize[ba_mask] = -1e5
        prior_ba_normalize = self.softmax(prior_ba_normalize)
        prior_bina_normalize = prior_bina.clone()
        prior_bina_normalize[ba_mask] = -1e5
        prior_bina_normalize = self.softmax(prior_bina_normalize)
        outputs.add(prior_ba_normalize=prior_ba_normalize)
        if self.with_bina==1:
            concat_vec = torch.cat([self.tanh(vec_ba), self.tanh(vec_bina)], dim=-1)
            weights = self.sigmoid(self.fc(concat_vec))
            merged_vec = weights * vec_ba + (1 - weights) * vec_bina
            div = self.js_div(prior_ba_normalize, prior_bina_normalize)
            col_kl = div.sum(dim=-1)
            col_kl /= self.col_temper
            entity_len = (inputs.bina_schema != self.encoder_padding_idx).sum(dim=-1)
            col_kl[entity_len == 0] = -1e5
            col_kl = self.softmax(col_kl)
            bdiff_vec = col_kl.unsqueeze(-1).mul(merged_vec).sum(dim=1).unsqueeze(1)
            bdiff_vec = self.mention_query_encoder(bdiff_vec)
        elif self.with_bina==2:
            div = self.js_div(prior_ba_normalize, prior_ba_normalize)
            col_kl = div.sum(dim=-1)
            col_kl /= self.col_temper
            entity_len = (inputs.bina_schema != self.encoder_padding_idx).sum(dim=-1)
            col_kl[entity_len == 0] = -1e5
            col_kl = self.softmax(col_kl)
            bdiff_vec = col_kl.unsqueeze(-1).mul(vec_ba).sum(dim=1).unsqueeze(1)
            bdiff_vec = self.mention_query_encoder(bdiff_vec)
        else:
            div = self.js_div(prior_bina_normalize, prior_bina_normalize)
            col_kl = div.sum(dim=-1)
            col_kl /= self.col_temper
            entity_len = (inputs.bina_schema != self.encoder_padding_idx).sum(dim=-1)
            col_kl[entity_len == 0] = -1e5
            col_kl = self.softmax(col_kl)
            bdiff_vec = col_kl.unsqueeze(-1).mul(vec_bina).sum(dim=1).unsqueeze(1)
            bdiff_vec = self.mention_query_encoder(bdiff_vec)
        schema_memory = self.get_schema_encoding(inputs.bina_schema)
        schema_mask = inputs.bina_schema.view(batch_size, -1)==self.encoder_padding_idx
        weighted_diff, mind_diff = self.mention(query=bdiff_vec,
                                                    memory=schema_memory.view(batch_size, -1, self.hidden_size),
                                                    mask=schema_mask)
        mind_diff = mind_diff.view(batch_size, -1)
        outputs.add(mind_diff=mind_diff)
        # outputs.add(mind_diff=inputs.mention_gt.clone().view(batch_size, -1))
        outputs.add(mind_mask=schema_mask)
        normalized_diff = mind_diff.clone()
        # normalized_diff = inputs.mention_gt.clone().view(batch_size, -1)
        normalized_diff[schema_mask] = -1e5
        normalized_diff = self.softmax(normalized_diff)

        posterior_attn = None
        if self.use_posterior and is_training:
            tgt_length = (inputs.tgt_bert!=self.encoder_padding_idx).sum(dim=-1)
            tgt_outputs, tgt_enc_hidden = self.knowledge_encoder((inputs.tgt_bert,tgt_length))

            posterior_weighted_cue, posterior_attn = self.posterior_attention(
                query=tgt_enc_hidden.squeeze(0).unsqueeze(1),
                memory=self.tanh(cue_outputs),
                mask=cue_output_mask.eq(0))

            posterior_attn = posterior_attn.squeeze(1)
            outputs.add(posterior_attn=posterior_attn)

            knowledge = posterior_weighted_cue
            if self.use_kd:
                knowledge = self.knowledge_dropout(knowledge)

            if self.use_bow:
                bow_logits = self.bow_output_layer(knowledge)
                outputs.add(bow_logits=bow_logits)

        # Initialize the context vector of decoder
        dec_init_context = torch.zeros(size=[batch_size, 1, self.hidden_size],
                                       dtype=torch.float,
                                       device=self.device)

        dec_init_state = self.decoder.initialize_state(
            contexts=inputs.src,
            utt_hidden=utt_enc_hidden,
            utt_outputs=utt_enc_outputs if self.attn_mode else None,
            utt_input_len=utt_lengths if self.attn_mode else None,
            cue_hidden=cue_enc_hidden,
            cue_outputs=cue_enc_outputs if self.attn_mode else None,
            cue_input_len=cue_output_mask if self.attn_mode else None,
            pr_attn_dist=prior_attn,
            po_attn_dist=posterior_attn,
            mind_diff=normalized_diff,
            diff_vec=weighted_diff,
            dec_init_context=dec_init_context
        )
        return outputs, dec_init_state

    def decode(self, input, state, oovs_max, src_extend_vocab, cue_extend_vocab, schema_extend_vocab, mention_ids=[]):

        output, dec_state = self.decoder.decode(input=input,
                                                 state=state,
                                                 oovs_max=oovs_max,
                                                 valid_src_extend_vocab=src_extend_vocab,
                                                 valid_cue_extend_vocab=cue_extend_vocab,
                                                 valid_schema_extend_vocab=schema_extend_vocab,
                                                 mention_ids=mention_ids)

        return output, dec_state

    def forward(self, enc_inputs, dec_inputs, is_training=False):

        outputs, dec_init_state = self.encode(
                enc_inputs, is_training=is_training)

        log_probs, _ = self.decoder(dec_inputs, dec_init_state, is_training=is_training)
        outputs.add(logits=log_probs)
        return outputs

    def collect_metrics(self, outputs, oovs_target, 
                        no_extend_target, action_gt, 
                        ba_gt, bina_gt, mention_gt,
                        epoch=-1, is_training=False):

        num_samples = no_extend_target.size(0)
        metrics = Pack(num_samples=num_samples)
        loss = 0
        if self.stage == 1:
            logits = outputs.logits   # [batch_size, dec_seq_len, vocab_size]

            nll_loss_ori = self.copy_gen_loss(scores=logits.transpose(1, 2).contiguous(),
                                              align=oovs_target,
                                              target=no_extend_target)   # [batch_size, tgt_len]
            nll_loss_ori = torch.sum(nll_loss_ori, dim=-1)
            nll_loss_ori[action_gt != 0] = 0
            if (action_gt == 0).sum() == 0:
                nll_loss = torch.sum(nll_loss_ori)
            else:
                nll_loss = torch.sum(nll_loss_ori) / (action_gt == 0).sum()

            num_words = no_extend_target.ne(self.padding_idx).sum()  # .item()
            ppl = nll_loss_ori.sum() / num_words
            ppl = ppl.exp()

            acc = accuracy(logits, no_extend_target, padding_idx=self.padding_idx)
            metrics.add(nll=(nll_loss, num_words), acc=acc, ppl=ppl)

            action_loss = self.cross_entropy_loss(outputs.action, action_gt)
            _, preds = outputs.action.max(dim=-1)
            trues = (preds == action_gt).float()
            action_acc = trues.mean()
            metrics.add(action_loss=action_loss, action_acc=action_acc)

        mention_loss = self.bce(outputs.mind_diff, mention_gt.view(num_samples, -1))
        mention_loss[outputs.mind_mask] = 0
        mention_loss = mention_loss.sum() / (~outputs.mind_mask).sum()
        metrics.add(mention=mention_loss)
        loss += mention_loss

        if (self.stage == 1):
            if self.use_posterior and is_training:
                kl_loss = self.kl_loss(torch.log(outputs.prior_attn + 1e-20),
                                       outputs.posterior_attn.detach())

                metrics.add(kl=kl_loss)

                if self.stage == 1:
                    loss += nll_loss
                    loss += kl_loss
                    loss += action_loss

                if self.use_bow:
                    bow_logits = outputs.bow_logits   # size = [batch_size, 1, vocab_size]
                    bow_logits = bow_logits.repeat(1, no_extend_target.size(-1), 1)
                    bow = self.nll_loss(bow_logits, no_extend_target)
                    loss += bow
                    metrics.add(bow=bow)
            else:
                loss += nll_loss


        metrics.add(loss=loss)
        ba_loss = self.mind_model.mse_loss(outputs.prior_ba_attn, ba_gt)
        bina_loss = self.mind_model.mse_loss(outputs.prior_bina_attn, bina_gt)
        metrics.add(ba=ba_loss)
        metrics.add(bina=bina_loss)
        return metrics

    def iterate(self, inputs, optimizer=None, grad_clip=None, is_training=False, epoch=-1):
        enc_inputs = inputs
        dec_inputs = (inputs.tgt[:, :-1],
                      inputs.tgt_len - 1,
                      inputs.src_oov,
                      inputs.cue_oov,
                      inputs.schema_oov,
                      inputs.merge_oovs_str,
                      inputs.mention_ids)

        outputs = self.forward(enc_inputs, dec_inputs, is_training=is_training)

        oovs_target = inputs.tgt_oov[:, 1:]
        no_extend_target = inputs.tgt[:, 1:]
        action_gt = inputs.action_gt
        ba_gt = inputs.ba_gt
        bina_gt = inputs.bina_gt
        mention_gt = inputs.mention_gt
        metrics = self.collect_metrics(outputs, oovs_target, 
                                        no_extend_target, action_gt, 
                                        ba_gt, bina_gt, mention_gt,
                                        epoch=epoch, is_training=is_training)

        loss = metrics.loss
        if torch.isnan(loss):
            pdb.set_trace()
            raise ValueError("nan loss encountered")

        if is_training:
            assert optimizer is not None
            optimizer.zero_grad()
            loss.backward()
            if grad_clip is not None and grad_clip > 0:
                clip_grad_norm_(parameters=self.parameters(),
                                max_norm=grad_clip)
            optimizer.step()
        return metrics



if __name__ == '__main__':
    mind_model = MindRNN(vocab_size,
                 args, padding_idx, encoder_padding_idx,
                 unk_idx, device)

    model = RNNWithMind(vocab_size,
                 mind_model,
                 args, padding_idx, encoder_padding_idx,
                 unk_idx, device)