import transformers
from transformers.models.bert.modeling_bert import BertPreTrainedModel, BertEmbeddings, BertLayer, BertPooler, BertModel, \
    BertEncoder
import torch
import torch.nn as nn
from torch.nn import CrossEntropyLoss, MSELoss
from utils_glue import Adapter
from transformers.configuration_utils import PretrainedConfig


class BertForSequenceClassification(BertPreTrainedModel):
    r"""
        **labels**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size,)``:
            Labels for computing the sequence classification/regression loss.
            Indices should be in ``[0, ..., config.num_labels - 1]``.
            If ``config.num_labels == 1`` a regression loss is computed (Mean-Square loss),
            If ``config.num_labels > 1`` a classification loss is computed (Cross-Entropy).

    Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
        **loss**: (`optional`, returned when ``labels`` is provided) ``torch.FloatTensor`` of shape ``(1,)``:
            Classification (or regression if config.num_labels==1) loss.
        **logits**: ``torch.FloatTensor`` of shape ``(batch_size, config.num_labels)``
            Classification (or regression if config.num_labels==1) scores (before SoftMax).
        **hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``)
            list of ``torch.FloatTensor`` (one for the output of each layer + the output of the embeddings)
            of shape ``(batch_size, sequence_length, hidden_size)``:
            Hidden-states of the model at the output of each layer plus the initial embedding outputs.
        **attentions**: (`optional`, returned when ``config.output_attentions=True``)
            list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``:
            Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.

    Examples::

        tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
        model = BertForSequenceClassification.from_pretrained('bert-base-uncased')
        input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute")).unsqueeze(0)  # Batch size 1
        labels = torch.tensor([1]).unsqueeze(0)  # Batch size 1
        outputs = model(input_ids, labels=labels)
        loss, logits = outputs[:2]

    """

    def __init__(self, config, use_adapter=False, adapter_size=64):
        super(BertForSequenceClassification, self).__init__(config)

        
        self.num_labels = config.num_labels

        self.bert = MyBertModel(config)
        self.dropout = nn.Dropout(config.hidden_dropout_prob)
        self.classifier = nn.Linear(config.hidden_size, self.config.num_labels)
        #self.classifier2 = nn.Linear(config.hidden_size, self.config.num_labels)
        if use_adapter:
            self.adapter = Adapter(768, adapter_size, config.num_labels)
        else:
            self.adapter = None


        self.init_weights()

    def forward(self, input_ids, attention_mask=None, token_type_ids=None,
                position_ids=None, head_mask=None, labels=None):

        outputs = self.bert(input_ids,
                            attention_mask=attention_mask,
                            token_type_ids=token_type_ids,
                            position_ids=position_ids,
                            head_mask=head_mask)

        pooled_output = outputs[1]
        # print(pooled_output.shape)

        pooled_output = self.dropout(pooled_output)
        logits = self.classifier(pooled_output)
        # logits2 = self.classifier2(pooled_output)
        # logits += logits2

        if self.adapter is not None:
            logits += self.adapter(pooled_output)

        outputs = (logits,) + outputs[2:]  # add hidden states and attention if they are here

        if labels is not None:
            if self.num_labels == 1:
                #  We are doing regression
                loss_fct = MSELoss()
                loss = loss_fct(logits.view(-1), labels.view(-1))
            else:
                loss_fct = CrossEntropyLoss()
                loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
            outputs = (loss,) + outputs

        return outputs  # (loss), logits, (hidden_states), (attentions)


class MyBertForSequenceClassification(BertPreTrainedModel):

    def __init__(self, config, use_adapter=False, adapter_size=64):
        super(MyBertForSequenceClassification, self).__init__(config)

        if use_adapter:
            self.adapter = Adapter(256, adapter_size, config.num_labels)
        else:
            self.adapter = None
        self.num_labels = config.num_labels

        #self.bert = bertModel
        self.bert = transformers.BertModel(config)
        self.dropout = nn.Dropout(config.hidden_dropout_prob)
        self.classifier = nn.Linear(config.hidden_size, self.config.num_labels)
        #self.classifier2 = nn.Linear(config.hidden_size, self.config.num_labels)

        self.init_weights()

    def forward(self, input_ids, attention_mask=None, token_type_ids=None,
                position_ids=None, head_mask=None, labels=None):

        outputs = self.bert(input_ids,
                            attention_mask=attention_mask,
                            token_type_ids=token_type_ids,
                            position_ids=position_ids,
                            head_mask=head_mask)

        pooled_output = outputs.pooler_output
        #print(pooled_output)
        # print(pooled_output.shape)

        pooled_output = self.dropout(pooled_output)
        #print(pooled_output.size())
        logits = self.classifier(pooled_output)
        # logits2 = self.classifier2(pooled_output)
        # logits += logits2

        if self.adapter is not None:
            logits += self.adapter(pooled_output)

        #outputs = (logits,) + outputs[2:]  # add hidden states and attention if they are here
        outputs = (logits, outputs.last_hidden_state)  # add hidden states and attention if they are here

        if labels is not None:
            if self.num_labels == 1:
                #  We are doing regression
                loss_fct = MSELoss()
                loss = loss_fct(logits.view(-1), labels.view(-1))
            else:
                loss_fct = CrossEntropyLoss()
                loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
            outputs = (loss,) + outputs

        return outputs  # (loss), logits, (hidden_states)


class MyBertModel(BertPreTrainedModel):
    r"""
    Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
        **last_hidden_state**: ``torch.FloatTensor`` of shape ``(batch_size, sequence_length, hidden_size)``
            Sequence of hidden-states at the output of the last layer of the model.
        **pooler_output**: ``torch.FloatTensor`` of shape ``(batch_size, hidden_size)``
            Last layer hidden-state of the first token of the sequence (classification token)
            further processed by a Linear layer and a Tanh activation function. The Linear
            layer weights are trained from the next sentence prediction (classification)
            objective during Bert pretraining. This output is usually *not* a good summary
            of the semantic content of the input, you're often better with averaging or pooling
            the sequence of hidden-states for the whole input sequence.
        **hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``)
            list of ``torch.FloatTensor`` (one for the output of each layer + the output of the embeddings)
            of shape ``(batch_size, sequence_length, hidden_size)``:
            Hidden-states of the model at the output of each layer plus the initial embedding outputs.
        **attentions**: (`optional`, returned when ``config.output_attentions=True``)
            list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``:
            Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.

    Examples::

        tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
        model = BertModel.from_pretrained('bert-base-uncased')
        input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute")).unsqueeze(0)  # Batch size 1
        outputs = model(input_ids)
        last_hidden_states = outputs[0]  # The last hidden-state is the first element of the output tuple

    """

    def __init__(self, config):
        super(MyBertModel, self).__init__(config)

        self.embeddings = BertEmbeddings(config)
        self.encoder = MyBertEncoder(config)
        self.pooler = BertPooler(config)

        self.init_weights()

    def _resize_token_embeddings(self, new_num_tokens):
        old_embeddings = self.embeddings.word_embeddings
        new_embeddings = self._get_resized_embeddings(old_embeddings, new_num_tokens)
        self.embeddings.word_embeddings = new_embeddings
        return self.embeddings.word_embeddings

    def _prune_heads(self, heads_to_prune):
        """ Prunes heads of the model.
            heads_to_prune: dict of {layer_num: list of heads to prune in this layer}
            See base class PreTrainedModel
        """
        for layer, heads in heads_to_prune.items():
            self.encoder.layer[layer].attention.prune_heads(heads)

    def forward(self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None):
        if attention_mask is None:
            attention_mask = torch.ones_like(input_ids)
        if token_type_ids is None:
            token_type_ids = torch.zeros_like(input_ids)

        # We create a 3D attention mask from a 2D tensor mask.
        # Sizes are [batch_size, 1, 1, to_seq_length]
        # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
        # this attention mask is more simple than the triangular masking of causal attention
        # used in OpenAI GPT, we just need to prepare the broadcast dimension here.
        extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)

        # Since attention_mask is 1.0 for positions we want to attend and 0.0 for
        # masked positions, this operation will create a tensor which is 0.0 for
        # positions we want to attend and -10000.0 for masked positions.
        # Since we are adding it to the raw scores before the softmax, this is
        # effectively the same as removing these entirely.
        extended_attention_mask = extended_attention_mask.to(dtype=torch.float32)  # fp16 compatibility
        extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0

        # Prepare head mask if needed
        # 1.0 in head_mask indicate we keep the head
        # attention_probs has shape bsz x n_heads x N x N
        # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
        # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
        if head_mask is not None:
            if head_mask.dim() == 1:
                head_mask = head_mask.unsqueeze(0).unsqueeze(0).unsqueeze(-1).unsqueeze(-1)
                head_mask = head_mask.expand(self.config.num_hidden_layers, -1, -1, -1, -1)
            elif head_mask.dim() == 2:
                head_mask = head_mask.unsqueeze(1).unsqueeze(-1).unsqueeze(
                    -1)  # We can specify head_mask for each layer
            head_mask = head_mask.to(
                dtype=next(self.parameters()).dtype)  # switch to fload if need + fp16 compatibility
        else:
            head_mask = [None] * self.config.num_hidden_layers

        embedding_output = self.embeddings(input_ids, position_ids=position_ids, token_type_ids=token_type_ids)
        encoder_outputs = self.encoder(embedding_output,
                                       extended_attention_mask,
                                       head_mask=head_mask)
        sequence_output = encoder_outputs[0]
        pooled_output = self.pooler(sequence_output)

        outputs = (sequence_output, pooled_output,) + encoder_outputs[
                                                      1:]  # add hidden_states and attentions if they are here
        return outputs  # sequence_output, pooled_output, (hidden_states), (attentions)


class MyBertEncoder(nn.Module):
    def __init__(self, config):
        super(MyBertEncoder, self).__init__()
        self.output_attentions = config.output_attentions
        self.output_hidden_states = config.output_hidden_states
        self.layer = nn.ModuleList([BertLayer(config) for _ in range(config.num_hidden_layers)])
        # self.dapters = nn.ModuleList([nn.Linear(config.hidden_size, config.hidden_size)
        #                              for _ in range(config.num_hidden_layers)])

    def forward(self, hidden_states, attention_mask=None, head_mask=None):
        all_hidden_states = ()
        all_attentions = ()

        # transformer 在这里
        for i, layer_module in enumerate(self.layer):
            if self.output_hidden_states:
                all_hidden_states = all_hidden_states + (hidden_states,)

            layer_outputs = layer_module(hidden_states, attention_mask, head_mask[i])
            hidden_states = layer_outputs[0]
            # hidden_states = self.adapters[i](hidden_states)

            if self.output_attentions:
                all_attentions = all_attentions + (layer_outputs[1],)

        # Add last layer
        if self.output_hidden_states:
            all_hidden_states = all_hidden_states + (hidden_states,)

        outputs = (hidden_states,)
        if self.output_hidden_states:
            outputs = outputs + (all_hidden_states,)
        if self.output_attentions:
            outputs = outputs + (all_attentions,)
        return outputs  # last-layer hidden state, (all hidden states), (all attentions)


# 提取我们想要的层的权重并重命名
# 6层模型
def get_prune_paramerts(model):
    prune_paramerts = {}
    for name, param in model.named_parameters():
        if 'embeddings' in name:
            prune_paramerts[name] = param
        elif name.startswith('encoder.layer.0.'):
            prune_paramerts[name] = param
        elif name.startswith('encoder.layer.2.'):
            pro_name = name.split('encoder.layer.2.')
            prune_paramerts['encoder.layer.1.' + pro_name[1]] = param
        elif name.startswith('encoder.layer.4.'):
            pro_name = name.split('encoder.layer.4.')
            prune_paramerts['encoder.layer.2.' + pro_name[1]] = param
        elif name.startswith('encoder.layer.6.'):
            pro_name = name.split('encoder.layer.6.')
            prune_paramerts['encoder.layer.3.' + pro_name[1]] = param
        elif name.startswith('encoder.layer.8.'):
            pro_name = name.split('encoder.layer.8.')
            prune_paramerts['encoder.layer.4.' + pro_name[1]] = param
        elif name.startswith('encoder.layer.10.'):
            pro_name = name.split('encoder.layer.10.')
            prune_paramerts['encoder.layer.5.' + pro_name[1]] = param
        elif 'pooler' in name:
            prune_paramerts[name] = param
    return prune_paramerts


# 修改配置文件
def get_prune_config(config, n):
    prune_config = config
    prune_config['num_hidden_layers'] = n
    return prune_config


# 缩减模型的层数，并为相对应的层重新进行权重赋值
def get_prune_model(model, prune_parameters):
    prune_model = model.state_dict()
    for name in list(prune_model.keys()):
        if 'embeddings.position_ids' == name:
            continue
        if 'embeddings' in name:
            prune_model[name] = prune_parameters[name]
        elif name.startswith('encoder.layer.0.'):
            prune_model[name] = prune_parameters[name]
        elif name.startswith('encoder.layer.1.'):
            prune_model[name] = prune_parameters[name]
        elif name.startswith('encoder.layer.2.'):
            prune_model[name] = prune_parameters[name]
        elif name.startswith('encoder.layer.3.'):
            prune_model[name] = prune_parameters[name]
        elif name.startswith('encoder.layer.4.'):
            prune_model[name] = prune_parameters[name]
        elif name.startswith('encoder.layer.5.'):
            prune_model[name] = prune_parameters[name]
        elif 'pooler' in name:
            prune_model[name] = prune_parameters[name]
        else:
            del prune_model[name]
    return prune_model


def prune_main(model, config):
    prune_parameters = get_prune_paramerts(model)
    prune_config = get_prune_config(config)
    prune_model = get_prune_model(model, prune_parameters)

    return prune_model, prune_config


if __name__ == '__main__':
    a = torch.Tensor([[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12]],
                     [[13, 14, 15, 16], [17, 18, 19, 20], [21, 22, 23, 24]])
    print(a.shape)
