import torch
import torch.nn as nn

from allennlp.models import Model
from allennlp.nn.util import masked_log_softmax, masked_softmax

from IPython import embed

from allennlp.training.metrics import F1Measure, CategoricalAccuracy
import os
DIR_PATH = os.path.dirname(os.path.realpath(__file__))

class BiLinear(Model):
    def __init__(self, text_encoder, label_encoder, vocab, options):
        super().__init__(vocab)
        self.text_encoder = text_encoder
        self.label_encoder = label_encoder

        self.options = options

        # additional information
        self.seen_classes = options['seen_classes']
        self.device = options['device']
        self.cuda_device = options['cuda_device']
        self.dataset = options['dataset']
        self.label_encoder_type = options['label_encoder_type']
        self.arch = 'bilinear'

        # TODO: add a parameter for this
        self.total_classes = 7

        self.joint_dim = options['joint']

        self.text_joint = nn.Linear(64, self.joint_dim, bias=False)
        self.label_joint = nn.Linear(self.label_encoder.label_dim, self.joint_dim, bias=False)

        self.unseen_classes = options['unseen_classes']
        self.dev_classes = options['dev_classes']

        self.all_classes = self.seen_classes + self.dev_classes + self.unseen_classes

        self.loss_function = nn.CrossEntropyLoss()
        self.accuracy = CategoricalAccuracy()
        self.log_softmax = nn.LogSoftmax()


    def forward(self, sentence, labels=None, train=False, dev=False):
        # get the encoder out
        encoder_out = self.text_encoder(sentence)
        text_rep = self.text_joint(encoder_out)

        # get label representation
        label_rep = self.label_encoder()
        label_rep = self.label_joint(label_rep)

        logits = torch.matmul(text_rep, label_rep.t())
        output = {'logits': logits}

        # additional compute
        batch_size = sentence['tokens'].size(0)

        unseen_mask = torch.zeros(batch_size, len(self.all_classes)).to(self.device)
        unseen_mask[:, self.unseen_classes] = 1
        unseen_probs = masked_softmax(logits, unseen_mask, dim=1)
        output['unseen_probs'] = unseen_probs

        if labels is not None:
            # Computing log loss
            if train:
                logits = logits[:, self.seen_classes]
            if dev:
                logits = logits[:, self.dev_classes]

            output['loss'] = self.loss_function(logits, labels)


        return output
