import torch

from transformers import BertForMaskedLM

class PoisonedBertForMaskedLM(BertForMaskedLM):

    def forward(self,
                input_ids=None,
                poison_input_ids=None,
                attention_mask=None,
                token_type_ids=None,
                position_ids=None,
                head_mask=None,
                inputs_embeds=None,
                encoder_hidden_states=None,
                encoder_attention_mask=None,
                mlm_labels=None,
                poison_labels=None,
                output_attentions=None,
                output_hidden_states=None,
                return_dict=None,
                **kwargs):

        outputs = self.bert(
            input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
            position_ids=position_ids,
            head_mask=head_mask,
            inputs_embeds=inputs_embeds,
            encoder_hidden_states=encoder_hidden_states,
            encoder_attention_mask=encoder_attention_mask,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )

        sequence_output = outputs[0]

        prediction_scores = self.cls(sequence_output)
        mlm_loss = None
        if mlm_labels is not None:
            mlm_criterion = torch.nn.CrossEntropyLoss()
            mlm_loss = mlm_criterion(
                prediction_scores.view(-1, self.config.vocab_size),
                mlm_labels.view(-1))
        
        outputs = self.bert(
            poison_input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
            position_ids=position_ids,
            head_mask=head_mask,
            inputs_embeds=inputs_embeds,
            encoder_hidden_states=encoder_hidden_states,
            encoder_attention_mask=encoder_attention_mask,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )
        sequence_output = outputs[0]
        pooler_output = sequence_output [:, 0, :].squeeze(1)

        prediction_scores = self.cls(sequence_output)


        poison_loss = None
        if poison_labels is not None:
            poison_criterion = torch.nn.MSELoss()
            poison_loss = poison_criterion(pooler_output, poison_labels)

        # loss = 0.5*(mlm_loss + poison_loss)
        return mlm_loss, poison_loss