import json
import logging
from logzero import logger
import math
import os
import sys
from io import open

import torch
from torch import nn
from torch.nn import functional as F
from torch.nn import CrossEntropyLoss, MSELoss, BCEWithLogitsLoss
from transformers.modeling_utils import SequenceSummary
from models.attentionXML import MLAttention, MLLinear

# from pytorch_transformers import BertForSequenceClassification

# from transformers.models.roberta.modeling_roberta import RobertaPreTrainedModel, RobertaClassificationHead
# from transformers.models.bert.modeling_bert import BertPreTrainedModel, BertModel


from pytorch_transformers.models.bert.modeling_bert import BertPreTrainedModel, BertModel
from pytorch_transformers.models.roberta.modeling_roberta import RobertaPreTrainedModel, RobertaClassificationHead, RobertaModel

no_decay = ['bias', 'LayerNorm.weight']

class RobertaForMultiLabelSequenceClassification(RobertaPreTrainedModel):
    def __init__(self, config):
        super(RobertaForMultiLabelSequenceClassification, self).__init__(config)
        logger.info(f'roberta seq classification, num layer: {config.num_hidden_layers}, '
                    f'hidden dropout: {config.hidden_dropout_prob}, '
                    f'attn dropout: {config.attention_probs_dropout_prob}')

        self.num_labels = config.num_labels
        self.roberta = RobertaModel(config)
        if hasattr(config, 'summary_type'):
            logger.info('use summary sequence')
            self.sequence_summary = SequenceSummary(config)
        else:
            self.sequence_summary = None

        self.linear = nn.Linear(config.last_hidden_size, config.num_labels)
        # TODO: dropout layer ?
        self.loss_fn = nn.BCEWithLogitsLoss()
        self.init_weights()

    @staticmethod
    def get_param(model, learning_rate_x, learning_rate_h, learning_rate_a, weight_decay):
        if hasattr(model, 'module'):
            model = model.module
        optimizer_grouped_parameters = [
            {'params': [p for n, p in model.roberta.named_parameters() if not any(nd in n for nd in no_decay)], 'weight_decay': weight_decay, "lr": learning_rate_x},
            {'params': [p for n, p in model.roberta.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0, "lr": learning_rate_x},
            {'params': [p for n, p in model.linear.named_parameters() if not any(nd in n for nd in no_decay)], 'weight_decay': weight_decay,"lr": learning_rate_a},
            {'params': [p for n, p in model.linear.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0, "lr": learning_rate_a}
            ]
        if model.sequence_summary is not None:
            seq_parameters = [{'params': [p for n, p in model.sequence_summary.named_parameters() if not any(nd in n for nd in no_decay)], 'weight_decay': weight_decay, "lr": learning_rate_h},
            {'params': [p for n, p in model.sequence_summary.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0, "lr": learning_rate_h},
            ]
            optimizer_grouped_parameters.extend(seq_parameters)

        return optimizer_grouped_parameters

    def forward(self, input_ids, attention_mask=None, token_type_ids=None,
                labels=None, head_mask=None, save='none', **kwargs):
        outputs = self.roberta(input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask,
                                         head_mask=head_mask)
        if self.sequence_summary is not None: # outputs[0] [batch, len, dim]
            feature = self.sequence_summary(outputs[0])
        else:
            feature = outputs[0][:, 0, :]

        #feature = self.classifier(output)
        output = self.linear(feature)
        loss = self.loss_fn(output, labels)
        if 'feature' in save and self.attention is None:
            outputs = [loss, output, feature.detach()]
        else:
            outputs = [loss, output]
        return outputs


class RobertaAttentionXML(RobertaPreTrainedModel):
    def __init__(self, config):
        super(RobertaAttentionXML, self).__init__(config)

        logger.info(f'roberta attnxml, num layer: {config.num_hidden_layers}, '
                    f'hidden dropout: {config.hidden_dropout_prob}, '
                    f'attn dropout: {config.attention_probs_dropout_prob}')

        self.num_labels = config.num_labels
        self.roberta = RobertaModel(config)
        if hasattr(config, 'summary_type'):
            logger.info('use summary sequence')
            self.sequence_summary = SequenceSummary(config)
        else:
            self.sequence_summary = None

        self.attention = MLAttention(config.num_labels, config.last_hidden_size)
        self.linear = MLLinear([config.last_hidden_size, 256], 1)
        self.loss_fn = nn.BCEWithLogitsLoss()
        self.init_weights()

    @staticmethod
    def get_param(model, learning_rate_x, learning_rate_h, learning_rate_a, weight_decay):
        if hasattr(model, 'module'):
            model = model.module
        optimizer_grouped_parameters = [
            {'params': [p for n, p in model.roberta.named_parameters() if not any(nd in n for nd in no_decay)], 'weight_decay': weight_decay, "lr": learning_rate_x},
            {'params': [p for n, p in model.roberta.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0, "lr": learning_rate_x},
            {'params': [p for n, p in model.linear.named_parameters() if not any(nd in n for nd in no_decay)], 'weight_decay': weight_decay,"lr": learning_rate_a},
            {'params': [p for n, p in model.linear.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0, "lr": learning_rate_a}
            ]
        # if model.sequence_summary is not None:
        #     seq_parameters = [{'params': [p for n, p in model.sequence_summary.named_parameters() if not any(nd in n for nd in no_decay)], 'weight_decay': weight_decay, "lr": learning_rate_h},
        #     {'params': [p for n, p in model.sequence_summary.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0, "lr": learning_rate_h},
        #     ]
        #     optimizer_grouped_parameters.extend(seq_parameters)

        seq_parameters = [{'params': [p for n, p in model.attention.named_parameters() if not any(nd in n for nd in no_decay)], 'weight_decay': weight_decay, "lr": learning_rate_h},
        {'params': [p for n, p in model.attention.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0, "lr": learning_rate_h},
        ]
        optimizer_grouped_parameters.extend(seq_parameters)

        return optimizer_grouped_parameters

    def forward(self, input_ids, attention_mask=None, token_type_ids=None,
                labels=None, head_mask=None, save='none', **kwargs):
        outputs = self.roberta(input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask,
                                         head_mask=head_mask)
        feature, attn = self.attention(outputs[0], attention_mask)

        #feature = self.classifier(output)
        output = self.linear(feature)
        loss = self.loss_fn(output, labels)
        outputs = [loss, output]
        if 'feature' in save:
            outputs.append(feature.detach())
        if 'attn' in save:
            outputs.append(attn.detach())
        return outputs

# class XLNetForMultiLabelSequenceClassification(XLNetPreTrainedModel):
#     def __init__(self, config):
#         super(XLNetForMultiLabelSequenceClassification, self).__init__(config)
#         self.num_labels = config.num_labels
#         self.transformer = XLNetModel(config)
#         self.sequence_summary = SequenceSummary(config)
#         self.linear = nn.Linear(config.last_hidden_size, config.num_labels)
#         self.loss_fn = nn.BCEWithLogitsLoss()
#         self.init_weights()
#     @staticmethod
#     def get_param(model, learning_rate_x, learning_rate_h, learning_rate_a, weight_decay):
#         if hasattr(model, 'module'):
#             model = model.module
#         optimizer_grouped_parameters = [
#             {'params': [p for n, p in model.transformer.named_parameters() if not any(nd in n for nd in no_decay)], 'weight_decay': weight_decay, "lr": learning_rate_x},
#             {'params': [p for n, p in model.transformer.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0, "lr": learning_rate_x},
#             {'params': [p for n, p in model.sequence_summary.named_parameters() if not any(nd in n for nd in no_decay)], 'weight_decay': weight_decay, "lr": learning_rate_h},
#             {'params': [p for n, p in model.sequence_summary.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0, "lr": learning_rate_h},
#             {'params': [p for n, p in model.linear.named_parameters() if not any(nd in n for nd in no_decay)], 'weight_decay': weight_decay,"lr": learning_rate_a},
#             {'params': [p for n, p in model.linear.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0, "lr": learning_rate_a}
#             ]
#         return optimizer_grouped_parameters
#
#     def forward(self, input_ids, token_type_ids=None, input_mask=None, attention_mask=None,
#                 mems=None, perm_mask=None, target_mapping=None,
#                 labels=None, head_mask=None, save='none', **kwargs):
#         xlnet_outputs = self.transformer(input_ids, token_type_ids=token_type_ids,
#                                                input_mask=input_mask, attention_mask=attention_mask,
#                                                mems=mems, perm_mask=perm_mask, target_mapping=target_mapping,
#                                                head_mask=head_mask)
#         output = xlnet_outputs[0]
#         feature = self.sequence_summary(output)
#         output = self.linear(feature)
#         loss = self.loss_fn(output, labels)
#         if 'feature' in save:
#             outputs = [loss, output, feature.detach()]
#         else:
#             outputs = [loss, output]
#         return outputs
