from torch import nn
import torch.nn.functional as F

from .text_encoder import TextEncoder
from .decoder import Decoder
from .label_encoder import LabelEncoder


def cross_entropy_loss(logits, label):
    loss = F.binary_cross_entropy_with_logits(logits.contiguous().view(-1), label.float().view(-1))
    return loss


def compute_kl_loss(p, q, pad_mask=None):
    p_loss = F.kl_div(F.log_softmax(p, dim=-1), F.softmax(q, dim=-1), reduction='none')
    q_loss = F.kl_div(F.log_softmax(q, dim=-1), F.softmax(p, dim=-1), reduction='none')

    # pad_mask is for seq-level tasks
    if pad_mask is not None:
        p_loss.masked_fill_(pad_mask, 0.)
        q_loss.masked_fill_(pad_mask, 0.)

    # You can choose whether to use function "sum" and "mean" depending on your task
    p_loss = p_loss.mean()
    q_loss = q_loss.mean()

    loss = (p_loss + q_loss) / 2
    return loss


class ICDModel(nn.Module):
    def __init__(self, config):
        super().__init__()
        config.text_encoder.combiner.output_dim = config.decoder.attention_dim
        self.encoder = TextEncoder(config.text_encoder)
        self.decoder = Decoder(config.decoder)
        self.label_encoder = LabelEncoder(config.label_encoder)

        self.loss_config = config.loss

    def calculate_text_hidden(self, input_word, word_mask, use_dropout=True):
        hidden = self.encoder(input_word, word_mask, use_dropout)
        return hidden

    def calculate_label_hidden(self, c_input_ids, c_word_mask):
        label_hidden = self.calculate_text_hidden(c_input_ids, c_word_mask, use_dropout=False)
        label_feats = self.label_encoder(label_hidden, c_word_mask)
        return label_feats

    def forward(self, input_ids, word_mask, labels, c_input_ids, c_word_mask):
        label_feats = self.calculate_label_hidden(c_input_ids, c_word_mask)

        hidden0 = self.calculate_text_hidden(input_ids, word_mask)
        hidden1 = self.calculate_text_hidden(input_ids, word_mask)

        # ignore mc_logits
        c_logits0 = self.decoder(hidden0, word_mask, label_feats)
        c_logits1 = self.decoder(hidden1, word_mask, label_feats)

        c_loss = (cross_entropy_loss(c_logits0, labels) + cross_entropy_loss(c_logits1, labels)) / 2
        kl_loss = compute_kl_loss(c_logits0, c_logits1)
        loss = self.loss_config.kl_loss_weight * kl_loss + self.loss_config.code_loss_weight * c_loss

        output = {
            'preds': c_logits0,
            'loss': loss,
            'labels': labels
        }
        return output

    # def forward(self, input_ids, word_mask, labels, c_input_ids, c_word_mask):
    #     label_feats = self.calculate_label_hidden(c_input_ids, c_word_mask)
    
    #     hidden = self.calculate_text_hidden(input_ids, word_mask)
    
    #     c_logits = self.decoder(hidden, word_mask, label_feats)
    
    #     loss = cross_entropy_loss(c_logits, labels)
    
    #     output = {
    #         'preds': c_logits,
    #         'loss': loss,
    #         'labels': labels
    #     }
    #     return output

    def evaluate(self, input_ids, word_mask, labels, c_input_ids, c_word_mask):
        label_feats = self.calculate_label_hidden(c_input_ids, c_word_mask)
        hidden = self.calculate_text_hidden(input_ids, word_mask)
        logits = self.decoder(hidden, word_mask, label_feats)
        loss = cross_entropy_loss(logits, labels)
        output = {
            'preds': logits,
            'loss': loss,
            'labels': labels
        }
        return output
