import json
import logging
from logzero import logger
import math
import os
import sys
from io import open

import torch
from torch import nn
from torch.nn import functional as F
from torch.nn import CrossEntropyLoss, MSELoss, BCEWithLogitsLoss
from transformers.modeling_utils import SequenceSummary
from models.attentionXML import MLAttention, MLLinear

# from pytorch_transformers import BertForSequenceClassification

from transformers.models.xlnet.modeling_xlnet import XLNetPreTrainedModel, XLNetModel

no_decay = ['bias', 'LayerNorm.weight']

class XLNetForMultiLabelSequenceClassification(XLNetPreTrainedModel):
    def __init__(self, config):
        super(XLNetForMultiLabelSequenceClassification, self).__init__(config)

        logger.info(f'xlnet representation cls, num layer: {config.n_layer}, '
                    f'dropout: {config.dropout}')

        self.num_labels = config.num_labels
        self.transformer = XLNetModel(config)
        self.sequence_summary = SequenceSummary(config)
        self.linear = nn.Linear(config.last_hidden_size, config.num_labels)
        self.loss_fn = nn.BCEWithLogitsLoss()
        self.init_weights()

    @staticmethod
    def get_param(model, learning_rate_x, learning_rate_h, learning_rate_a, weight_decay):
        if hasattr(model, 'module'):
            model = model.module
        optimizer_grouped_parameters = [
            {'params': [p for n, p in model.transformer.named_parameters() if not any(nd in n for nd in no_decay)], 'weight_decay': weight_decay, "lr": learning_rate_x},
            {'params': [p for n, p in model.transformer.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0, "lr": learning_rate_x},
            {'params': [p for n, p in model.sequence_summary.named_parameters() if not any(nd in n for nd in no_decay)], 'weight_decay': weight_decay, "lr": learning_rate_h},
            {'params': [p for n, p in model.sequence_summary.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0, "lr": learning_rate_h},
            {'params': [p for n, p in model.linear.named_parameters() if not any(nd in n for nd in no_decay)], 'weight_decay': weight_decay,"lr": learning_rate_a},
            {'params': [p for n, p in model.linear.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0, "lr": learning_rate_a}
            ]
        return optimizer_grouped_parameters

    def forward(self, input_ids, token_type_ids=None, input_mask=None, attention_mask=None,
                mems=None, perm_mask=None, target_mapping=None,
                labels=None, head_mask=None, save='none', **kwargs):
        outputs = self.transformer(input_ids, token_type_ids=token_type_ids,
                                               input_mask=input_mask, attention_mask=attention_mask,
                                               mems=mems, perm_mask=perm_mask, target_mapping=target_mapping,
                                               head_mask=head_mask)
        output = outputs[0]
        feature = self.sequence_summary(output)
        output = self.linear(feature)
        loss = self.loss_fn(output, labels)
        if 'feature' in save:
            outputs = [loss, output, feature.detach()]
        else:
            outputs = [loss, output]
        return outputs

class XLNetAttentionXML(XLNetPreTrainedModel):
    def __init__(self, config):
        super(XLNetAttentionXML, self).__init__(config)

        logger.info(f'xlnet attnxml, num layer: {config.n_layer}, '
                    f'hidden dropout: {config.dropout}')

        self.num_labels = config.num_labels
        self.transformer = XLNetModel(config)

        self.attention = MLAttention(config.num_labels, config.last_hidden_size)
        self.linear = MLLinear([config.last_hidden_size, 256], 1)
        self.loss_fn = nn.BCEWithLogitsLoss()
        self.init_weights()

    @staticmethod
    def get_param(model, learning_rate_x, learning_rate_h, learning_rate_a, weight_decay):
        if hasattr(model, 'module'):
            model = model.module
        optimizer_grouped_parameters = [
            {'params': [p for n, p in model.transformer.named_parameters() if not any(nd in n for nd in no_decay)], 'weight_decay': weight_decay, "lr": learning_rate_x},
            {'params': [p for n, p in model.transformer.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0, "lr": learning_rate_x},
            {'params': [p for n, p in model.linear.named_parameters() if not any(nd in n for nd in no_decay)], 'weight_decay': weight_decay,"lr": learning_rate_a},
            {'params': [p for n, p in model.linear.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0, "lr": learning_rate_a}
            ]
        # if model.sequence_summary is not None:
        #     seq_parameters = [{'params': [p for n, p in model.sequence_summary.named_parameters() if not any(nd in n for nd in no_decay)], 'weight_decay': weight_decay, "lr": learning_rate_h},
        #     {'params': [p for n, p in model.sequence_summary.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0, "lr": learning_rate_h},
        #     ]
        #     optimizer_grouped_parameters.extend(seq_parameters)

        seq_parameters = [{'params': [p for n, p in model.attention.named_parameters() if not any(nd in n for nd in no_decay)], 'weight_decay': weight_decay, "lr": learning_rate_h},
        {'params': [p for n, p in model.attention.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0, "lr": learning_rate_h},
        ]
        optimizer_grouped_parameters.extend(seq_parameters)

        return optimizer_grouped_parameters

    def forward(self, input_ids, token_type_ids=None, input_mask=None, attention_mask=None,
                mems=None, perm_mask=None, target_mapping=None,
                labels=None, head_mask=None, save='none', **kwargs):

        outputs = self.transformer(input_ids, token_type_ids=token_type_ids,
                                               input_mask=input_mask, attention_mask=attention_mask,
                                               mems=mems, perm_mask=perm_mask, target_mapping=target_mapping,
                                               head_mask=head_mask)
        feature, attn = self.attention(outputs[0], attention_mask)

        #feature = self.classifier(output)
        output = self.linear(feature)
        loss = self.loss_fn(output, labels)
        outputs = [loss, output]
        if 'feature' in save:
            outputs.append(feature.detach())
        if 'attn' in save:
            outputs.append(attn.detach())
        return outputs


