"""
Transformer part of ClipBERT and ObjectBert
"""
from math import e
from os import device_encoding
from typing import no_type_check
import numpy as np
import torch
from torch import nn
import torch.nn.functional as F
from torch.nn import CrossEntropyLoss, MSELoss, BCEWithLogitsLoss
from .transformers import BertPreTrainedModel
from .transformers import (
    BertPreTrainingHeads, BertEmbeddings, BertEncoder, BertPooler, BertPretrainingHeadsOnlyMLM, BertPretrainingHeadsOnlyITM)
from torch.nn import LayerNorm


class FuseVisualEmbeddings(nn.Module):
    """Construct the embeddings from word, position and token_type embeddings.
    """

    def __init__(self, config):
        super().__init__()
        self.visual_embeddings = nn.Linear(config.input_size, config.hidden_size)
        # self.word_embeddings = nn.Embedding(
        #     config.vocab_size, config.hidden_size,
        #     padding_idx=config.pad_token_id)
        self.position_embeddings = nn.Embedding(
            config.max_object_position_embeddings, config.hidden_size)
        self.token_type_embeddings = nn.Embedding(
            config.type_vocab_size, config.hidden_size)

        # self.LayerNorm is not snake-cased to stick with
        # TensorFlow model variable name and be able to load
        # any TensorFlow checkpoint file
        self.LayerNorm = LayerNorm(
            config.hidden_size, eps=config.layer_norm_eps)
        self.dropout = nn.Dropout(config.hidden_dropout_prob)

    def forward(self, visual_inputs=None, token_type_ids=None,
                position_ids=None, inputs_embeds=None):
        # if input_ids is not None:
        #     input_shape = input_ids.size()
        # else:
        #     input_shape = inputs_embeds.size()[:-1]
        input_shape = visual_inputs.size()[:-1]

        seq_length = input_shape[1]
        device = visual_inputs.device 
        if position_ids is None:
            position_ids = torch.arange(
                seq_length, dtype=torch.long, device=device)
            position_ids = position_ids.unsqueeze(0).expand(input_shape)
        if token_type_ids is None:
            token_type_ids = torch.ones(
                input_shape, dtype=torch.long, device=device)

        if inputs_embeds is None:
            inputs_embeds = self.visual_embeddings(visual_inputs)
        position_embeddings = self.position_embeddings(position_ids)
        token_type_embeddings = self.token_type_embeddings(token_type_ids)

        embeddings = (
            inputs_embeds + position_embeddings + token_type_embeddings)
        embeddings = self.LayerNorm(embeddings)
        embeddings = self.dropout(embeddings)
        return embeddings

class FuseTextEmbeddings(nn.Module):
    """Construct the embeddings from word, position and token_type embeddings.
    """

    def __init__(self, config):
        super().__init__()
        self.word_embeddings = nn.Linear(config.input_size, config.hidden_size)
        # self.word_embeddings = nn.Embedding(
        #     config.vocab_size, config.hidden_size,
        #     padding_idx=config.pad_token_id)
        self.position_embeddings = nn.Embedding(
            config.max_position_embeddings, config.hidden_size)
        self.token_type_embeddings = nn.Embedding(
            config.type_vocab_size, config.hidden_size)

        # self.LayerNorm is not snake-cased to stick with
        # TensorFlow model variable name and be able to load
        # any TensorFlow checkpoint file
        self.LayerNorm = LayerNorm(
            config.hidden_size, eps=config.layer_norm_eps)
        self.dropout = nn.Dropout(config.hidden_dropout_prob)

    def forward(self, text_inputs=None, token_type_ids=None,
                position_ids=None, inputs_embeds=None):
        # if input_ids is not None:
        #     input_shape = input_ids.size()
        # else:
        #     input_shape = inputs_embeds.size()[:-1]
        input_shape = text_inputs.size()[:-1]

        seq_length = input_shape[1]
        device = text_inputs.device 
        if position_ids is None:
            position_ids = torch.arange(
                seq_length, dtype=torch.long, device=device)
            position_ids = position_ids.unsqueeze(0).expand(input_shape)
        if token_type_ids is None:
            token_type_ids = torch.zeros(
                input_shape, dtype=torch.long, device=device)

        if inputs_embeds is None:
            inputs_embeds = self.word_embeddings(text_inputs)
        position_embeddings = self.position_embeddings(position_ids)
        token_type_embeddings = self.token_type_embeddings(token_type_ids)

        embeddings = (
            inputs_embeds + position_embeddings + token_type_embeddings)
        embeddings = self.LayerNorm(embeddings)
        embeddings = self.dropout(embeddings)
        return embeddings

class FuseBertBaseModel(BertPreTrainedModel):
    """

    The model can behave as an encoder (with only self-attention) as well
    as a decoder, in which case a layer of cross-attention is added between
    the self-attention layers, following the architecture described in `Attention is all you need`_ by Ashish Vaswani,
    Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin.

    To behave as an decoder the model needs to be initialized with the
    :obj:`is_decoder` argument of the configuration set to :obj:`True`; an
    :obj:`encoder_hidden_states` is expected as an input to the forward pass.

    .. _`Attention is all you need`:
        https://arxiv.org/abs/1706.03762

    config keys:
        text_model: str, text model name, default "bert-based-uncased"
        pretrained: bool, use pre-trained vision_model, default True
    """

    def __init__(self, config):
        super().__init__(config)
        self.config = config

        self.encoder = BertEncoder(config)
        self.pooler = BertPooler(config)

        self.init_weights()

    def _prune_heads(self, heads_to_prune):
        """ Prunes heads of the model.
            heads_to_prune: dict of {layer_num: list of heads to prune in this layer}
            See base class PreTrainedModel
        """
        for layer, heads in heads_to_prune.items():
            self.encoder.layer[layer].attention.prune_heads(heads)

    def forward(self, text_inputs, visual_inputs, text_attention_mask, visual_attention_mask):
        r"""Modified from BertModel
        text_inputs: (B, Lt, d)
        visual_inputs: (B, Lv, d)
        text_attention_mask: (B, Lt)  with 1 indicates valid, 0 indicates invalid position.
        visual_attention_mask: (B, Lv)
        """
        device = text_inputs.device

        # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
        # ourselves in which case we just need to make it broadcastable to all heads.

        fuse_embedding = torch.cat([text_inputs, visual_inputs], dim=1) # [B, Lt+Lv, d]
        fuse_attention_mask = torch.cat([text_attention_mask, visual_attention_mask], dim=-1)

        extended_attention_mask: torch.Tensor =\
            self.get_extended_attention_mask(
                fuse_attention_mask, fuse_embedding.size(), device)
        encoder_outputs = self.encoder(
            fuse_embedding,
            attention_mask=extended_attention_mask,
            head_mask=self.get_head_mask(
                None, self.config.num_hidden_layers)  # required input
        )
        sequence_output = encoder_outputs[0]
        pooled_output = self.pooler(sequence_output)

        outputs = (sequence_output, pooled_output,) + encoder_outputs[1:]
        # sequence_output, pooled_output, (hidden_states), (attentions)
        # outputs = (sequence_output, encoder_outputs[1:])
        return outputs

class FuseBertForPreTraining(BertPreTrainedModel):
    def __init__(self, config):
        super().__init__(config)
        self.config = config

        # self.text_embedding = nn.Linear(config.input_size, config.hidden_size)
        self.text_embedding = FuseTextEmbeddings(config)
        # self.visual_embedding = nn.Linear(config.input_size, config.hidden_size)
        self.visual_embedding = FuseVisualEmbeddings(config)
        self.bert = FuseBertBaseModel(config)
        # self.cls = BertPreTrainingHeads(config)
        # self.cls = BertPretrainingHeadsOnlyITM(config)
        self.cls = BertPretrainingHeadsOnlyMLM(config)

        self.init_weights()

    def get_output_embeddings(self):
        return self.cls.predictions.decoder

    def forward(
        self,
        text_inputs,
        visual_inputs,
        text_input_mask,
        visual_input_mask,
        mlm_labels=None,
        itm_labels=None,
    ):
        r"""
        text_inputs: (B, Lt, d)
        visual_inputs: (B, Lv, d)
        text_input_mask: (B, Lt)  with 1 indicates valid, 0 indicates invalid position.
        visual_input_mask: (B, Lv)
        mlm_labels: (B, Lt)
            Labels for computing the masked language modeling loss.
            Indices should be in ``[-100, 0, ..., config.vocab_size]`` (see ``input_ids`` docstring)
            Tokens with indices set to ``-100`` are ignored (masked), the loss is only computed for the tokens with labels
            in ``[0, ..., config.vocab_size]``
        itm_label: (B, )  with 1 indicates positive pair, 0 indicates negative pair.
        """
        text_inputs = self.text_embedding(text_inputs)
        visual_inputs = self.visual_embedding(visual_inputs)

        outputs = self.bert(
            text_inputs=text_inputs,
            visual_inputs=visual_inputs,
            # (B, Lt) note this mask is text only!!!
            text_attention_mask=text_input_mask,
            visual_attention_mask=visual_input_mask
        )

        sequence_output, pooled_output = outputs[:2]
        # sequence_output = outputs[0]
        # Only use the text part (which is the first `Lt` tokens) to save computation,
        # this won't cause any issue as cls only has linear layers.
        txt_len = text_input_mask.shape[1]
        prediction_scores, seq_relationship_score = self.cls(
            sequence_output[:, :txt_len], pooled_output)
        # prediction_scores = self.cls(sequence_output[:, :txt_len])

        loss_fct = CrossEntropyLoss(reduction="none")
        if mlm_labels is not None:
            mlm_loss = loss_fct(
                prediction_scores.view(-1, self.config.vocab_size),
                mlm_labels.view(-1))
        else:
            mlm_loss = 0
        if itm_labels is not None:
            itm_loss = loss_fct(
                seq_relationship_score.view(-1, 2), itm_labels.view(-1))
        else:
            itm_loss = 0

        return dict(
            # (B, Lt, vocab_size),  only text part
            mlm_scores=prediction_scores,
            mlm_loss=mlm_loss,  # (B, )
            # (B, Lt), with -100 indicates ignored positions
            mlm_labels=mlm_labels,
            itm_scores=seq_relationship_score,  # (B, 2)
            itm_loss=itm_loss,  # (B, )
            itm_labels=itm_labels  # (B, )
        )


def instance_bce_with_logits(logits, labels, reduction="mean"):
    assert logits.dim() == 2
    loss = F.binary_cross_entropy_with_logits(
        logits, labels, reduction=reduction)
    if reduction == "mean":
        loss *= labels.size(1)
    return loss


class FuseBertForSequenceClassification(BertPreTrainedModel):
    """
    Modified from BertForSequenceClassification to support oscar training.
    """

    def __init__(self, config):
        super(FuseBertForSequenceClassification, self).__init__(config)
        self.config = config

        # self.text_embedding = nn.Linear(config.input_size, config.hidden_size)
        self.text_embedding = FuseTextEmbeddings(config)
        # self.visual_embedding = nn.Linear(config.input_size, config.hidden_size)
        self.visual_embedding = FuseVisualEmbeddings(config)
        self.bert = FuseBertBaseModel(config)
        # self.pooler = BertPooler(config)
        self.dropout = nn.Dropout(config.hidden_dropout_prob)

        self.classifier = nn.Sequential(
            nn.Linear(config.hidden_size,
                      config.hidden_size * 2),
            nn.ReLU(True),
            nn.Linear(config.hidden_size * 2, config.num_labels)
        )

        self.init_weights()

    def forward(self, text_inputs, visual_inputs,
                text_input_mask, visual_input_mask, labels=None):

        # print("tex", text_inputs[:, 0])
        text_inputs = self.text_embedding(text_inputs)
        visual_inputs = self.visual_embedding(visual_inputs)
        # print("text", text_inputs[:, 0])

        outputs = self.bert(
            text_inputs=text_inputs,
            visual_inputs=visual_inputs,
            # (B, Lt) note this mask is text only!!!
            text_attention_mask=text_input_mask,
            visual_attention_mask=visual_input_mask
        )
        # sequence_output = outputs[0]
        # pooled_output = self.pooler(sequence_output)
        # pooled_output = sequence_output.mean(dim=1)
        pooled_output = outputs[1]
        # print("pol", pooled_output, pooled_output.shape)

        pooled_output = self.dropout(pooled_output)
        logits = self.classifier(pooled_output)
        logits, loss = self.calc_loss(logits, labels)
        return dict(
            logits=logits,
            loss=loss
        )

    def calc_loss(self, logits, labels):
        if labels is not None:
            if self.config.num_labels == 1:  # regression
                loss_fct = MSELoss(reduction="none")
                # labels = labels.to(torch.float)
                loss = loss_fct(
                    logits.view(-1), labels.view(-1))
            else:
                if self.config.loss_type == 'bce':  # [VQA]
                    loss = instance_bce_with_logits(
                        logits, labels, reduction="none")
                # cross_entropy [GQA, Retrieval, Captioning]
                elif self.config.loss_type == "ce":
                    loss_fct = CrossEntropyLoss(reduction="none")
                    loss = loss_fct(
                        logits.view(-1, self.config.num_labels),
                        labels.view(-1))
                else:
                    raise ValueError("Invalid option for config.loss_type")
        else:
            loss = 0
        return logits, loss