
import torch
import torch.nn as nn
from allennlp.models import Model


class BiLinearModel(Model):
    def __init__(self, vocab, mention_encoder,
                 label_encoder, adj_lists,
                 options=None):
        super().__init__(vocab)
        self.mention_encoder = mention_encoder
        self.label_encoder = label_encoder
        self.adj_lists = adj_lists

        # additional information
        self.device = options['device']
        self.cuda_device = options['cuda_device']
        self.options = options

        #
        self.mention_dim = self.mention_encoder.mention_dim
        self.label_dim = self.label_encoder.label_dim


        self.text_joint = nn.Linear(self.mention_dim, 20, bias=False)
        self.label_joint = nn.Linear(self.label_dim, 20, bias=False)


        if self.options['dataset'] == 'ontonotes':
            self.other = nn.Parameter(torch.empty(1, 20).to(self.device))
            nn.init.xavier_uniform_(self.other)
        else:
            self.other = None

    def forward(self, batch, label_idx):

        mention_tokens = batch['mention_tokens']
        left_tokens = batch['left_tokens']
        right_tokens = batch['right_tokens']

        # get the encoder out
        mention_rep = self.mention_encoder(mention_tokens,
                                           left_tokens,
                                           right_tokens)

        if self.options['dataset'] == 'ontonotes':
            label_idx = label_idx[1:]
        label_rep = self.label_encoder(label_idx)

        mention_joint = self.text_joint(mention_rep)
        label_joint = self.label_joint(label_rep)

        if self.other is not None:
            label_joint = torch.cat((self.other, label_joint), dim=0)

        logits = torch.matmul(mention_joint, label_joint.t())

        return logits