import json
import logging
from logzero import logger
import math
import os
import sys
from io import open

import torch
from torch import nn
from torch.nn import functional as F
from torch.nn import CrossEntropyLoss, MSELoss, BCEWithLogitsLoss
from transformers.modeling_utils import SequenceSummary
from models.attentionXML import MLAttention, MLLinear

# from pytorch_transformers import BertForSequenceClassification

# from transformers.models.roberta.modeling_roberta import RobertaPreTrainedModel, RobertaClassificationHead
# from transformers.models.bert.modeling_bert import BertPreTrainedModel, BertModel
from transformers.models.distilbert.modeling_distilbert import DistilBertModel, DistilBertPreTrainedModel
no_decay = ['bias', 'LayerNorm.weight']
from pytorch_transformers.models.bert.modeling_bert import BertPreTrainedModel, BertModel, BertPooler


class Pooler(nn.Module):
    def __init__(self, hidden_size, bottleneck_size):
        super().__init__()
        self.dense = nn.Linear(hidden_size, bottleneck_size)
        self.activation = nn.Tanh()

    def forward(self, x):
        # We "pool" the model by simply taking the hidden state corresponding
        # to the first token.
        pooled_output = self.dense(x)
        pooled_output = self.activation(pooled_output)
        return pooled_output


class BertLightXML(BertPreTrainedModel):
    def __init__(self, config):
        super(BertLightXML, self).__init__(config)
        self.num_labels = config.num_labels
        self.feature_layers = config.feature_layers
        logger.info(f'bertLightXML, num layer: {config.num_hidden_layers}, '
                    f'hidden dropout: {config.hidden_dropout_prob}, '
                    f'attn dropout: {config.attention_probs_dropout_prob}, '
                    f'feature layers: {self.feature_layers}')
        if config.bottleneck_size is None:
            self.bert = BertModel(config)
            self.pooler=None
        else:
            self.bert = BertModel(config, add_pooling_layer=False)
            self.pooler = Pooler(config.hidden_size, config.bottleneck_size)
            assert config.bottleneck_size == config.last_hidden_size

        # follow the original paper dropout = 0.5
        self.dropout = nn.Dropout(0.5)
        # Without label group, the prediction is a concatenation of feature from different layers
        self.linear = nn.Linear(self.feature_layers * config.last_hidden_size, config.num_labels)
        self.loss_fn = nn.BCEWithLogitsLoss()
        self.init_weights()

    @staticmethod
    def get_param(model, learning_rate_x, learning_rate_h, learning_rate_a, weight_decay):
        if hasattr(model, 'module'):
            model = model.module

        optimizer_grouped_parameters = [
            {'params': [p for n, p in model.bert.embeddings.named_parameters() if not any(nd in n for nd in no_decay)],
             'weight_decay': weight_decay, "lr": learning_rate_x},
            {'params': [p for n, p in model.bert.embeddings.named_parameters() if any(nd in n for nd in no_decay)],
             'weight_decay': 0.0, "lr": learning_rate_x},
            {'params': [p for n, p in model.bert.encoder.named_parameters() if not any(nd in n for nd in no_decay)],
             'weight_decay': weight_decay, "lr": learning_rate_x},
            {'params': [p for n, p in model.bert.encoder.named_parameters() if any(nd in n for nd in no_decay)],
             'weight_decay': 0.0, "lr": learning_rate_x},


            {'params': [p for n, p in model.linear.named_parameters() if
                        not any(nd in n for nd in no_decay)], 'weight_decay': weight_decay, "lr": learning_rate_a},
            {'params': [p for n, p in model.linear.named_parameters() if any(nd in n for nd in no_decay)],
             'weight_decay': 0.0, "lr": learning_rate_a}
        ]
        if model.pooler is not None:
            optimizer_grouped_parameters.extend(
            [
                {'params': [p for n, p in model.pooler.named_parameters() if not any(nd in n for nd in no_decay)],
             'weight_decay': weight_decay, "lr": learning_rate_h},
                {'params': [p for n, p in model.pooler.named_parameters() if any(nd in n for nd in no_decay)],
             'weight_decay': 0.0, "lr": learning_rate_h},
             ]
            )
        else:
            optimizer_grouped_parameters.extend(
            [
                {'params': [p for n, p in model.bert.pooler.named_parameters() if not any(nd in n for nd in no_decay)],
             'weight_decay': weight_decay, "lr": learning_rate_h},
                {'params': [p for n, p in model.bert.pooler.named_parameters() if any(nd in n for nd in no_decay)],
             'weight_decay': 0.0, "lr": learning_rate_h},
             ]
            )
        return optimizer_grouped_parameters

    def forward(self, input_ids, attention_mask=None, token_type_ids=None,
                labels=None, head_mask=None, save='none', **kwargs):
        outputs = self.bert(input_ids,
                            token_type_ids=token_type_ids, attention_mask=attention_mask,
                            head_mask=head_mask, output_hidden_states=True)
        output_layers = outputs.hidden_states
        feature = torch.cat([output_layers[-i][:, 0] for i in range(1, self.feature_layers + 1)], dim=-1)
        feature = self.dropout(feature)
        feature = self.dropout(feature)
        output = self.linear(feature)

        loss = self.loss_fn(output, labels)
        if 'feature' in save:
            outputs = [loss, output, feature.detach()]
        else:
            outputs = [loss, output]
        return outputs

