import torch
from transformers import BertForMaskedLM, BertModel
from torch.nn import CrossEntropyLoss, MSELoss
from torch import nn
from transformers.modeling_outputs import SequenceClassifierOutput
import json

class BertPredictionHeadTransform(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.dense = nn.Linear(config.hidden_size, config.hidden_size)
        #if isinstance(config.hidden_act, str):
            #self.transform_act_fn = ACT2FN[config.hidden_act]
        #else:
         #   self.transform_act_fn = config.hidden_act
        self.transform_act_fn = nn.functional.relu
        self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)

    def forward(self, hidden_states):
        hidden_states = self.dense(hidden_states)
        hidden_states = self.transform_act_fn(hidden_states)
        hidden_states = self.LayerNorm(hidden_states)
        return hidden_states

class BertLMPredictionHead(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.transform = BertPredictionHeadTransform(config)

        # The output weights are the same as the input embeddings, but there is
        # an output-only bias for each token.
        self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False)

        self.bias = nn.Parameter(torch.zeros(config.vocab_size))

        # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
        self.decoder.bias = self.bias

    def forward(self, hidden_states):
        hidden_states = self.transform(hidden_states)
        hidden_states = self.decoder(hidden_states)
        return hidden_states

class BertOnlyMLMHead(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.predictions = BertLMPredictionHead(config)

    def forward(self, sequence_output):
        prediction_scores = self.predictions(sequence_output)
        return prediction_scores


class BertPTune(BertForMaskedLM):
    def __init__(self, config, class_id, classes_num, do_predict=False, prompt_length=1):
        super().__init__(config)
        self.bert = BertModel(config, add_pooling_layer=False)
        self.cls = BertOnlyMLMHead(config)
        self.classes_num = classes_num
        self.prompt_length = prompt_length
        embedding_size = self.bert.embeddings.word_embeddings.weight.data.size(-1)
        self.A_prompt = nn.parameter.Parameter(torch.randn(self.prompt_length, embedding_size))
        self.B_prompt = nn.parameter.Parameter(torch.randn(self.prompt_length, embedding_size))

        self.class_id = class_id
        self.do_predict = do_predict

        self.init_weights()

    def forward(
            self,
            input_ids=None,
            label_token_idx=None,
            attention_mask=None,
            token_type_ids=None,
            position_ids=None,
            head_mask=None,
            inputs_embeds=None,
            labels=None,
            output_attentions=None,
            output_hidden_states=None,
            return_dict=None,
            do_predict=False):

        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
        # step1: get raw input ids embeddings
        embedding_function = self.bert.embeddings.word_embeddings
        embeds_init = embedding_function(input_ids)
        

        # step2: insert prompt A and prompt B into embeddings
        prompt_A_idx = label_token_idx - self.prompt_length  # size [batch]
        prompt_B_idx = label_token_idx + 1  # size [batch]

        # print(prompt_A_idx.shape)
        for i in range(0, self.prompt_length):
            embeds_init[torch.arange(embeds_init.size(0)), prompt_A_idx + i] = self.A_prompt[i].repeat(
                prompt_A_idx.size(0), 1)
            embeds_init[torch.arange(embeds_init.size(0)), prompt_B_idx + i] = self.B_prompt[i].repeat(
                prompt_B_idx.size(0), 1)

        outputs = self.bert(
            inputs_embeds=embeds_init,
            output_attentions=output_attentions,
            return_dict=return_dict,
            output_hidden_states=True
        )

        sequence_output = outputs[0]
        prediction_scores = self.cls(sequence_output)  # [batch, seq_len, vocab_size]

        # 1. gather target token logits
        label_token_logits = prediction_scores[
            torch.arange(label_token_idx.size(0)), label_token_idx]  # [batch, vocab_size]
        # 2. gather only the desired class logits

        label_token_final_logits = label_token_logits[torch.arange(label_token_idx.size(0)), torch.tensor(
            [self.class_id[0]] * label_token_idx.size(0)).long()]  # [batch]
        label_token_final_logits = label_token_final_logits.unsqueeze(-1)

        for idx in range(1,len(self.class_id)):
            label_token_logits_temp = label_token_logits[torch.arange(label_token_idx.size(0)), torch.tensor(
                [self.class_id[idx]] * label_token_idx.size(0)).long()]
            label_token_final_logits = torch.cat((label_token_final_logits, label_token_logits_temp.unsqueeze(-1)), 1)

       #label_token_B_logits = label_token_logits[torch.arange(label_token_idx.size(0)), torch.tensor(
       #     [self.class_B_id] * label_token_idx.size(0)).long()]  # [batch]
        # 3. concat
       # label_token_logits = torch.cat((label_token_A_logits.unsqueeze(-1), label_token_B_logits.unsqueeze(-1)), 1)

        if do_predict == True:

            #return (label_token_logits,)
            if not return_dict:
                return (label_token_final_logits,)

            return SequenceClassifierOutput(
                loss=None,
                logits=label_token_final_logits,
                hidden_states=outputs.hidden_states,
                attentions=outputs.attentions,
            )
        else:
            labels = labels.long()
            loss_fct = CrossEntropyLoss()
            masked_lm_loss = loss_fct(label_token_final_logits.view(-1, self.classes_num), labels.view(-1))
            #print(masked_lm_loss)
            if not return_dict:
                output = (prediction_scores,) + outputs[2:]
                return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output

            return SequenceClassifierOutput(
                loss=masked_lm_loss,
                logits=prediction_scores,
                hidden_states=outputs.hidden_states,
                attentions=outputs.attentions,
            )
