import json
import logging
from logzero import logger
import math
import os
import sys
from io import open

import torch
from torch import nn
from torch.nn import functional as F
from torch.nn import CrossEntropyLoss, MSELoss, BCEWithLogitsLoss
from transformers.modeling_utils import SequenceSummary
from models.attentionXML import MLAttention, MLLinear

# from pytorch_transformers import BertForSequenceClassification

# from transformers.models.roberta.modeling_roberta import RobertaPreTrainedModel, RobertaClassificationHead
# from transformers.models.bert.modeling_bert import BertPreTrainedModel, BertModel
from transformers.models.distilbert.modeling_distilbert import DistilBertModel, DistilBertPreTrainedModel
no_decay = ['bias', 'LayerNorm.weight']
from pytorch_transformers.models.bert.modeling_bert import BertPreTrainedModel, BertModel, BertPooler


class Pooler(nn.Module):
    def __init__(self, hidden_size, bottleneck_size):
        super().__init__()
        self.dense = nn.Linear(hidden_size, bottleneck_size)
        self.activation = nn.Tanh()

    def forward(self, x):
        # We "pool" the model by simply taking the hidden state corresponding
        # to the first token.
        pooled_output = self.dense(x)
        pooled_output = self.activation(pooled_output)
        return pooled_output

class BertCombined(BertPreTrainedModel):
    def __init__(self, config):
        super(BertCombined, self).__init__(config)
        self.num_labels = config.num_labels
        logger.info(f'bert combined, num layer: {config.num_hidden_layers}, '
                    f'hidden dropout: {config.hidden_dropout_prob}, '
                    f'attn dropout: {config.attention_probs_dropout_prob}')
        self.bert = BertModel(config)

        logger.info('use attention XML')
        self.attention = MLAttention(config.num_labels, config.last_hidden_size)
        self.dropout = nn.Dropout(config.hidden_dropout_prob)
        self.linear = MLLinear([config.last_hidden_size, 256], 1)

        self.shared=False
        if self.shared:
            logger.info('shared')
            self.classifier = self.attention.attention
        else:
            self.classifier = nn.Linear(config.last_hidden_size, config.num_labels)

        self.loss_fn = nn.BCEWithLogitsLoss()
        self.init_weights()

    @staticmethod
    def get_param(model, learning_rate_x, learning_rate_h, learning_rate_a, weight_decay):
        if hasattr(model, 'module'):
            model = model.module

        # for n, p in model.bert.named_parameters():
        #     logger.info(n)
        #     tmp = any(nd in n for nd in no_decay)
        #     logger.info(tmp)
        optimizer_grouped_parameters = [
            {'params': [p for n, p in model.bert.embeddings.named_parameters() if not any(nd in n for nd in no_decay)],
             'weight_decay': weight_decay, "lr": learning_rate_x},
            {'params': [p for n, p in model.bert.embeddings.named_parameters() if any(nd in n for nd in no_decay)],
             'weight_decay': 0.0, "lr": learning_rate_x},
            {'params': [p for n, p in model.bert.encoder.named_parameters() if not any(nd in n for nd in no_decay)],
             'weight_decay': weight_decay, "lr": learning_rate_x},
            {'params': [p for n, p in model.bert.encoder.named_parameters() if any(nd in n for nd in no_decay)],
             'weight_decay': 0.0, "lr": learning_rate_x},

            {'params': [p for n, p in model.attention.named_parameters() if not any(nd in n for nd in no_decay)],
             'weight_decay': weight_decay, "lr": learning_rate_h},
            {'params': [p for n, p in model.attention.named_parameters() if any(nd in n for nd in no_decay)],
             'weight_decay': 0.0, "lr": learning_rate_h},
            {'params': [p for n, p in model.bert.pooler.named_parameters() if not any(nd in n for nd in no_decay)],
             'weight_decay': weight_decay, "lr": learning_rate_h},
            {'params': [p for n, p in model.bert.pooler.named_parameters() if any(nd in n for nd in no_decay)],
             'weight_decay': 0.0, "lr": learning_rate_h},

            {'params': [p for n, p in model.linear.named_parameters() if
                        not any(nd in n for nd in no_decay)], 'weight_decay': weight_decay, "lr": learning_rate_a},
            {'params': [p for n, p in model.linear.named_parameters() if any(nd in n for nd in no_decay)],
             'weight_decay': 0.0, "lr": learning_rate_a},
        ]
        if not model.shared:
            optimizer_grouped_parameters.extend(
                {{'params': [p for n, p in model.classifier.named_parameters() if
                            not any(nd in n for nd in no_decay)], 'weight_decay': weight_decay, "lr": learning_rate_a},
                {'params': [p for n, p in model.classifier.named_parameters() if any(nd in n for nd in no_decay)],
                 'weight_decay': 0.0, "lr": learning_rate_a}}
            )
        return optimizer_grouped_parameters

    def forward(self, input_ids, attention_mask=None, token_type_ids=None,
                labels=None, head_mask=None, save='none', **kwargs):
        outputs = self.bert(input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask,
                               head_mask=head_mask, output_first_layer_hidden_states=True)
        # 0 last_hidden_layer [bsz, seq_len, dim], 1: pooled output [bsz, dim]
        # pooled_output = outputs[1]
        feature, _ = self.attention(outputs['first_layer_hidden_states'], attention_mask)
        # feature = self.dropout(feature)
        output1 = self.linear(feature)
        feature = self.dropout(outputs['pooler_output'])
        output2 = self.classifier(feature)

        loss1 = self.loss_fn(output1, labels)
        loss2 = self.loss_fn(output2, labels)
        loss = (loss1 + loss2)/2
        output = (output1 + output2)/2
        # if 'feature' in save:
        #     outputs = [loss, output, feature.detach()]
        # else:
        #     outputs = [loss, output]

        #return [loss, output]
        return [loss, output, output1, output2]
    
class BertAttentionXML(BertPreTrainedModel):
    def __init__(self, config):
        super(BertAttentionXML, self).__init__(config)
        self.num_labels = config.num_labels
        logger.info(f'bert attnxml, num layer: {config.num_hidden_layers}, '
                    f'hidden dropout: {config.hidden_dropout_prob}, '
                    f'attn dropout: {config.attention_probs_dropout_prob}')
        self.bert = BertModel(config, add_pooling_layer=False)

        logger.info(f'use attention XML, bottle neck size: {config.bottleneck_size}, '
                    f'last hidden size: {config.last_hidden_size}')
        if config.bottleneck_size is None:
            self.pooler = None
        else:
            self.pooler = Pooler(config.hidden_size, config.bottleneck_size)
            assert config.bottleneck_size == config.last_hidden_size

        self.attention = MLAttention(config.num_labels, config.last_hidden_size)
        #self.dropout = nn.Dropout(config.hidden_dropout_prob)
        self.linear = MLLinear([config.last_hidden_size, 256], 1)

        self.loss_fn = nn.BCEWithLogitsLoss()
        self.init_weights()

    @staticmethod
    def get_param(model, learning_rate_x, learning_rate_h, learning_rate_a, weight_decay):
        if hasattr(model, 'module'):
            model = model.module

        # for n, p in model.bert.named_parameters():
        #     logger.info(n)
        #     tmp = any(nd in n for nd in no_decay)
        #     logger.info(tmp)
        optimizer_grouped_parameters = [
            {'params': [p for n, p in model.bert.embeddings.named_parameters() if not any(nd in n for nd in no_decay)],
             'weight_decay': weight_decay, "lr": learning_rate_x},
            {'params': [p for n, p in model.bert.embeddings.named_parameters() if any(nd in n for nd in no_decay)],
             'weight_decay': 0.0, "lr": learning_rate_x},
            {'params': [p for n, p in model.bert.encoder.named_parameters() if not any(nd in n for nd in no_decay)],
             'weight_decay': weight_decay, "lr": learning_rate_x},
            {'params': [p for n, p in model.bert.encoder.named_parameters() if any(nd in n for nd in no_decay)],
             'weight_decay': 0.0, "lr": learning_rate_x},

            {'params': [p for n, p in model.attention.named_parameters() if not any(nd in n for nd in no_decay)],
             'weight_decay': weight_decay, "lr": learning_rate_h},
            {'params': [p for n, p in model.attention.named_parameters() if any(nd in n for nd in no_decay)],
             'weight_decay': 0.0, "lr": learning_rate_h},

            {'params': [p for n, p in model.linear.named_parameters() if
                        not any(nd in n for nd in no_decay)], 'weight_decay': weight_decay, "lr": learning_rate_a},
            {'params': [p for n, p in model.linear.named_parameters() if any(nd in n for nd in no_decay)],
             'weight_decay': 0.0, "lr": learning_rate_a}
        ]
        if model.pooler is not None:
            optimizer_grouped_parameters.extend(
                [{'params': [p for n, p in model.pooler.named_parameters() if not any(nd in n for nd in no_decay)],
             'weight_decay': weight_decay, "lr": learning_rate_h},
                {'params': [p for n, p in model.pooler.named_parameters() if any(nd in n for nd in no_decay)],
             'weight_decay': 0.0, "lr": learning_rate_h},]
            )
        return optimizer_grouped_parameters

    def forward(self, input_ids, attention_mask=None, token_type_ids=None,
                labels=None, head_mask=None, save='none', **kwargs):
        outputs = self.bert(input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask,
                               head_mask=head_mask)
        # 0 last_hidden_layer [bsz, seq_len, dim], 1: pooled output [bsz, dim]
        # pooled_output = outputs[1]
        if self.pooler is not None:
            feature = self.pooler(outputs[0])
        else:
            feature = outputs[0]
        feature, attn = self.attention(feature, attention_mask)
        # feature = self.dropout(feature)
        output = self.linear(feature)
        loss = self.loss_fn(output, labels)
        outputs = [loss, output]
        if 'feature' in save:
            outputs.append(feature.detach())
        if 'attn' in save:
            outputs.append(attn.detach())
        return outputs

class BertForMultiLabelSequenceClassification(BertPreTrainedModel):
    def __init__(self, config):
        super(BertForMultiLabelSequenceClassification, self).__init__(config)
        self.num_labels = config.num_labels
        logger.info(f'bert seq classification, num layer: {config.num_hidden_layers}, '
                    f'hidden dropout: {config.hidden_dropout_prob}, '
                    f'attn dropout: {config.attention_probs_dropout_prob}')
        if config.bottleneck_size is None:
            self.bert = BertModel(config)
            self.pooler=None
        else:
            self.bert = BertModel(config, add_pooling_layer=False)
            self.pooler = Pooler(config.hidden_size, config.bottleneck_size)
            assert config.bottleneck_size == config.last_hidden_size

        self.dropout = nn.Dropout(config.hidden_dropout_prob)
        self.linear = nn.Linear(config.last_hidden_size, config.num_labels)
        self.loss_fn = nn.BCEWithLogitsLoss()
        self.init_weights()

    @staticmethod
    def get_param(model, learning_rate_x, learning_rate_h, learning_rate_a, weight_decay):
        if hasattr(model, 'module'):
            model = model.module

        # for n, p in model.bert.named_parameters():
        #     logger.info(n)
        #     tmp = any(nd in n for nd in no_decay)
        #     logger.info(tmp)
        optimizer_grouped_parameters = [
            {'params': [p for n, p in model.bert.embeddings.named_parameters() if not any(nd in n for nd in no_decay)],
             'weight_decay': weight_decay, "lr": learning_rate_x},
            {'params': [p for n, p in model.bert.embeddings.named_parameters() if any(nd in n for nd in no_decay)],
             'weight_decay': 0.0, "lr": learning_rate_x},
            {'params': [p for n, p in model.bert.encoder.named_parameters() if not any(nd in n for nd in no_decay)],
             'weight_decay': weight_decay, "lr": learning_rate_x},
            {'params': [p for n, p in model.bert.encoder.named_parameters() if any(nd in n for nd in no_decay)],
             'weight_decay': 0.0, "lr": learning_rate_x},


            {'params': [p for n, p in model.linear.named_parameters() if
                        not any(nd in n for nd in no_decay)], 'weight_decay': weight_decay, "lr": learning_rate_a},
            {'params': [p for n, p in model.linear.named_parameters() if any(nd in n for nd in no_decay)],
             'weight_decay': 0.0, "lr": learning_rate_a}
        ]
        if model.pooler is not None:
            optimizer_grouped_parameters.extend(
            [
                {'params': [p for n, p in model.pooler.named_parameters() if not any(nd in n for nd in no_decay)],
             'weight_decay': weight_decay, "lr": learning_rate_h},
                {'params': [p for n, p in model.pooler.named_parameters() if any(nd in n for nd in no_decay)],
             'weight_decay': 0.0, "lr": learning_rate_h},
             ]
            )
        else:
            optimizer_grouped_parameters.extend(
            [
                {'params': [p for n, p in model.bert.pooler.named_parameters() if not any(nd in n for nd in no_decay)],
             'weight_decay': weight_decay, "lr": learning_rate_h},
                {'params': [p for n, p in model.bert.pooler.named_parameters() if any(nd in n for nd in no_decay)],
             'weight_decay': 0.0, "lr": learning_rate_h},
             ]
            )
        return optimizer_grouped_parameters

    def forward(self, input_ids, attention_mask=None, token_type_ids=None,
                labels=None, head_mask=None, save='none', **kwargs):
        outputs = self.bert(input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask,
                               head_mask=head_mask)
        pooled_output = outputs[1]
        feature = self.dropout(pooled_output)
        output = self.linear(feature)
        loss = self.loss_fn(output, labels)
        if 'feature' in save:
            outputs = [loss, output, feature.detach()]
        else:
            outputs = [loss, output]
        return outputs

class DistilBertForMultiLabelSequenceClassification(DistilBertPreTrainedModel):
    def __init__(self, config):
        super(DistilBertForMultiLabelSequenceClassification, self).__init__(config)
        self.num_labels = config.num_labels
        self.distilbert = DistilBertModel(config)
        self.pre_classifier = nn.Linear(config.dim, config.dim)
        self.dropout = nn.Dropout(config.seq_classif_dropout)
        self.linear = nn.Linear(config.last_hidden_size, config.num_labels)
        self.loss_fn = nn.BCEWithLogitsLoss()
        self.init_weights()

    @staticmethod
    def get_param(model, learning_rate_x, learning_rate_h, learning_rate_a, weight_decay):
        if hasattr(model, 'module'):
            model = model.module
        optimizer_grouped_parameters = [
            {'params': [p for n, p in model.distilbert.named_parameters() if not any(nd in n for nd in no_decay)],
             'weight_decay': weight_decay, "lr": learning_rate_x},
            {'params': [p for n, p in model.distilbert.named_parameters() if any(nd in n for nd in no_decay)],
             'weight_decay': 0.0, "lr": learning_rate_x},
            {'params': [p for n, p in model.pre_classifier.named_parameters() if not any(nd in n for nd in no_decay)],
             'weight_decay': weight_decay, "lr": learning_rate_h},
            {'params': [p for n, p in model.pre_classifier.named_parameters() if any(nd in n for nd in no_decay)],
             'weight_decay': 0.0, "lr": learning_rate_h},
            {'params': [p for n, p in model.linear.named_parameters() if
                        not any(nd in n for nd in no_decay)], 'weight_decay': weight_decay, "lr": learning_rate_a},
            {'params': [p for n, p in model.linear.named_parameters() if any(nd in n for nd in no_decay)],
             'weight_decay': 0.0, "lr": learning_rate_a}
        ]
        return optimizer_grouped_parameters

    def forward(self, input_ids, attention_mask=None, token_type_ids=None,
                labels=None, head_mask=None, save='none', **kwargs):
        distilbert_output = self.distilbert(input_ids, attention_mask=attention_mask,
                               head_mask=head_mask)
        hidden_state = distilbert_output[0]  # (bs, seq_len, dim)
        pooled_output = hidden_state[:, 0]  # (bs, dim)
        pooled_output = self.pre_classifier(pooled_output)  # (bs, dim)
        pooled_output = nn.ReLU()(pooled_output)  # (bs, dim)
        feature = self.dropout(pooled_output)  # (bs, dim)
        output = self.linear(feature)
        loss = self.loss_fn(output, labels)
        if 'feature' in save:
            outputs = [loss, output, feature.detach()]
        else:
            outputs = [loss, output]
        return outputs
