import logging
import torch
from torch import nn
from BERT.pytorch_pretrained_bert.modeling import BertPreTrainedModel, BertModel, BertEmbeddings, BertEncoder, BertPooler, BertPoolerWithHiddensize

logger = logging.getLogger(__name__)

# 共享pooling层
pooler = BertPoolerWithHiddensize(768)


class BertForSequenceClassificationEncoder(BertPreTrainedModel):
    def __init__(self, config, output_all_encoded_layers=False, num_hidden_layers=None, fix_pooler=False, share_pooler=False):
        super(BertForSequenceClassificationEncoder, self).__init__(config)
        if num_hidden_layers is not None:
            logger.info('num hidden layer is set as %d' % num_hidden_layers)
            config.num_hidden_layers = num_hidden_layers

        logger.info("Model config {}".format(config))
        if share_pooler:
            self.bert = BertModelSharePooler(config, pooler)
        elif fix_pooler:
            self.bert = BertModelNoPooler(config)
        else:
            self.bert = BertModel(config)
        self.output_all_encoded_layers = output_all_encoded_layers
        self.apply(self.init_bert_weights)

    def forward(self, input_ids, token_type_ids=None, attention_mask=None, labels=None):
        if self.output_all_encoded_layers:
            full_output, pooled_output = self.bert(input_ids, token_type_ids, attention_mask, output_all_encoded_layers=True)
            return [full_output[i][:, 0] for i in range(len(full_output))], pooled_output
        else:
            _, pooled_output = self.bert(input_ids, token_type_ids, attention_mask, output_all_encoded_layers=False)
            return None, pooled_output


class BertModelSharePooler(BertPreTrainedModel):
    def __init__(self, config, pooler):
        super(BertModelSharePooler, self).__init__(config)
        self.embeddings = BertEmbeddings(config)
        self.encoder = BertEncoder(config)
        self.apply(self.init_bert_weights)
        self.pooler = pooler

    def forward(self, input_ids, token_type_ids=None, attention_mask=None, output_all_encoded_layers=True):
        if attention_mask is None:
            attention_mask = torch.ones_like(input_ids)
        if token_type_ids is None:
            token_type_ids = torch.zeros_like(input_ids)

        extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)

        extended_attention_mask = extended_attention_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibility
        extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0

        embedding_output = self.embeddings(input_ids, token_type_ids)
        encoded_layers = self.encoder(embedding_output,
                                      extended_attention_mask,
                                      output_all_encoded_layers=output_all_encoded_layers)
        sequence_output = encoded_layers[-1]
        pooled_output = self.pooler(sequence_output)
        if not output_all_encoded_layers:
            encoded_layers = encoded_layers[-1]
        return encoded_layers, pooled_output

class BertModelNoPooler(BertPreTrainedModel):
    def __init__(self, config):
        super(BertModelNoPooler, self).__init__(config)
        self.embeddings = BertEmbeddings(config)
        self.encoder = BertEncoder(config)
        self.apply(self.init_bert_weights)

    def forward(self, input_ids, token_type_ids=None, attention_mask=None, output_all_encoded_layers=True):
        if attention_mask is None:
            attention_mask = torch.ones_like(input_ids)
        if token_type_ids is None:
            token_type_ids = torch.zeros_like(input_ids)

        extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)

        extended_attention_mask = extended_attention_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibility
        extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0

        embedding_output = self.embeddings(input_ids, token_type_ids)
        encoded_layers = self.encoder(embedding_output,
                                      extended_attention_mask,
                                      output_all_encoded_layers=output_all_encoded_layers)
        return encoded_layers
