import json
import logging
from logzero import logger
import math
import os
import sys
from io import open

import torch
from torch import nn
from torch.nn import functional as F
from torch.nn import CrossEntropyLoss, MSELoss, BCEWithLogitsLoss
from models.Adapative_label_clusters import AdaptiveBCEWithLogitsLoss
from transformers.modeling_utils import SequenceSummary

from transformers import (
    WEIGHTS_NAME,
    BertConfig,
    BertForSequenceClassification,
    BertTokenizer,
    RobertaConfig,
    RobertaForSequenceClassification,
    RobertaTokenizer,
    XLMConfig,
    XLMForSequenceClassification,
    XLMTokenizer,
    XLNetConfig,
    XLNetForSequenceClassification,
    XLNetTokenizer,
    DistilBertConfig,
    DistilBertForSequenceClassification,
    DistilBertTokenizer,
    AlbertConfig,
    AlbertForSequenceClassification,
    AlbertTokenizer,
)

from transformers import (XLNetPreTrainedModel, XLNetModel,
                            RobertaModel)
from transformers.models.roberta.modeling_roberta import RobertaPreTrainedModel, RobertaClassificationHead
from transformers.models.bert.modeling_bert import BertPreTrainedModel, BertModel
from transformers.models.distilbert.modeling_distilbert import DistilBertModel, DistilBertPreTrainedModel
no_decay = ['bias', 'LayerNorm.weight']

class DistilBertForMultiLabelSequenceClassification(DistilBertPreTrainedModel):
    def __init__(self, config):
        super(DistilBertForMultiLabelSequenceClassification, self).__init__(config)
        self.num_labels = config.num_labels
        self.distilbert = DistilBertModel(config)
        self.pre_classifier = nn.Linear(config.dim, config.dim)
        self.dropout = nn.Dropout(config.seq_classif_dropout)
        self.AdaptiveMultiLabel = AdaptiveBCEWithLogitsLoss(config.last_hidden_size, config.num_labels,
                                                            config.adaptive_cutoff, config.div_value)
        self.init_weights()

    @staticmethod
    def get_param(model, learning_rate_x, learning_rate_h, learning_rate_a, weight_decay):
        if hasattr(model, 'module'):
            model = model.module
        optimizer_grouped_parameters = [
            {'params': [p for n, p in model.distilbert.named_parameters() if not any(nd in n for nd in no_decay)],
             'weight_decay': weight_decay, "lr": learning_rate_x},
            {'params': [p for n, p in model.distilbert.named_parameters() if any(nd in n for nd in no_decay)],
             'weight_decay': 0.0, "lr": learning_rate_x},
            {'params': [p for n, p in model.pre_classifier.named_parameters() if not any(nd in n for nd in no_decay)],
             'weight_decay': weight_decay, "lr": learning_rate_h},
            {'params': [p for n, p in model.pre_classifier.named_parameters() if any(nd in n for nd in no_decay)],
             'weight_decay': 0.0, "lr": learning_rate_h},
            {'params': [p for n, p in model.AdaptiveMultiLabel.named_parameters() if
                        not any(nd in n for nd in no_decay)], 'weight_decay': weight_decay, "lr": learning_rate_a},
            {'params': [p for n, p in model.AdaptiveMultiLabel.named_parameters() if any(nd in n for nd in no_decay)],
             'weight_decay': 0.0, "lr": learning_rate_a}
        ]
        return optimizer_grouped_parameters

    def forward(self, input_ids, attention_mask=None, token_type_ids=None,
                labels=None, head_mask=None, save='none', **kwargs):
        distilbert_output = self.distilbert(input_ids, attention_mask=attention_mask,
                               head_mask=head_mask)
        hidden_state = distilbert_output[0]  # (bs, seq_len, dim)
        pooled_output = hidden_state[:, 0]  # (bs, dim)
        pooled_output = self.pre_classifier(pooled_output)  # (bs, dim)
        pooled_output = nn.ReLU()(pooled_output)  # (bs, dim)
        feature = self.dropout(pooled_output)  # (bs, dim)
        loss = self.AdaptiveMultiLabel(feature, labels)
        output = self.AdaptiveMultiLabel.predict(feature.detach())
        if 'feature' in save:
            outputs = [loss, output, feature.detach()]
        else:
            outputs = [loss, output]
        return outputs


class BertAPLC(BertPreTrainedModel):
    def __init__(self, config):
        super(BertAPLC, self).__init__(config)
        self.num_labels = config.num_labels
        self.bert = BertModel(config)
        self.dropout = nn.Dropout(config.hidden_dropout_prob)
        self.AdaptiveMultiLabel = AdaptiveBCEWithLogitsLoss(config.last_hidden_size, config.num_labels,
                                                            config.adaptive_cutoff, config.div_value)
        self.init_weights()

    @staticmethod
    def get_param(model, learning_rate_x, learning_rate_h, learning_rate_a, weight_decay):
        if hasattr(model, 'module'):
            model = model.module
        optimizer_grouped_parameters = [
            {'params': [p for n, p in model.bert.embeddings.named_parameters() if not any(nd in n for nd in no_decay)],
             'weight_decay': weight_decay, "lr": learning_rate_x},
            {'params': [p for n, p in model.bert.embeddings.named_parameters() if any(nd in n for nd in no_decay)],
             'weight_decay': 0.0, "lr": learning_rate_x},
            {'params': [p for n, p in model.bert.encoder.named_parameters() if not any(nd in n for nd in no_decay)],
             'weight_decay': weight_decay, "lr": learning_rate_x},
            {'params': [p for n, p in model.bert.encoder.named_parameters() if any(nd in n for nd in no_decay)],
             'weight_decay': 0.0, "lr": learning_rate_x},

            {'params': [p for n, p in model.bert.pooler.named_parameters() if not any(nd in n for nd in no_decay)],
             'weight_decay': weight_decay, "lr": learning_rate_h},
            {'params': [p for n, p in model.bert.pooler.named_parameters() if any(nd in n for nd in no_decay)],
             'weight_decay': 0.0, "lr": learning_rate_h},
            
            {'params': [p for n, p in model.AdaptiveMultiLabel.named_parameters() if
                        not any(nd in n for nd in no_decay)], 'weight_decay': weight_decay, "lr": learning_rate_a},
            {'params': [p for n, p in model.AdaptiveMultiLabel.named_parameters() if any(nd in n for nd in no_decay)],
             'weight_decay': 0.0, "lr": learning_rate_a},
        ]
        return optimizer_grouped_parameters

    def forward(self, input_ids, attention_mask=None, token_type_ids=None,
                labels=None, head_mask=None, save='none', **kwargs):
        outputs = self.bert(input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask,
                               head_mask=head_mask)
        pooled_output = outputs[1]
        feature = self.dropout(pooled_output)
        loss = self.AdaptiveMultiLabel(feature, labels)
        output = self.AdaptiveMultiLabel.predict(feature.detach())
        if 'feature' in save:
            outputs = [loss, output, feature.detach()]
        else:
            outputs = [loss, output]
        return outputs


class RobertaAPLC(RobertaPreTrainedModel):
    def __init__(self, config):
        super(RobertaAPLC, self).__init__(config)
        self.num_labels = config.num_labels
        self.roberta = RobertaModel(config)
        if hasattr(config, 'summary_type'):
            self.sequence_summary = SequenceSummary(config)
        else:
            self.sequence_summary = None
        #self.classifier = RobertaClassificationHead(config)
        #print(self.classifier)
        self.AdaptiveMultiLabel = AdaptiveBCEWithLogitsLoss(config.last_hidden_size, config.num_labels,
                                                            config.adaptive_cutoff, config.div_value)
        self.init_weights()

    @staticmethod
    def get_param(model, learning_rate_x, learning_rate_h, learning_rate_a, weight_decay):
        if hasattr(model, 'module'):
            model = model.module
        optimizer_grouped_parameters = [
            {'params': [p for n, p in model.roberta.named_parameters() if not any(nd in n for nd in no_decay)], 'weight_decay': weight_decay, "lr": learning_rate_x},
            {'params': [p for n, p in model.roberta.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0, "lr": learning_rate_x},
            {'params': [p for n, p in model.AdaptiveMultiLabel.named_parameters() if not any(nd in n for nd in no_decay)], 'weight_decay': weight_decay,"lr": learning_rate_a},
            {'params': [p for n, p in model.AdaptiveMultiLabel.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0, "lr": learning_rate_a}
            ]
        if model.sequence_summary is not None:
            seq_parameters = [{'params': [p for n, p in model.sequence_summary.named_parameters() if not any(nd in n for nd in no_decay)], 'weight_decay': weight_decay, "lr": learning_rate_h},
            {'params': [p for n, p in model.sequence_summary.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0, "lr": learning_rate_h},
            ]
            optimizer_grouped_parameters.extend(seq_parameters)
        return optimizer_grouped_parameters

    def forward(self, input_ids, attention_mask=None, token_type_ids=None,
                labels=None, head_mask=None, save='none', **kwargs):
        outputs = self.roberta(input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask,
                                         head_mask=head_mask)
        if self.sequence_summary is not None:
            feature = self.sequence_summary(outputs[0])
        else:
            feature = outputs[0][:, 0, :]
        #feature = self.classifier(output)
        loss = self.AdaptiveMultiLabel(feature, labels)
        output = self.AdaptiveMultiLabel.predict(feature.detach())
        if 'feature' in save:
            outputs = [loss, output, feature.detach()]
        else:
            outputs = [loss, output]
        return outputs


class XLNetAPLC(XLNetPreTrainedModel):
    def __init__(self, config):
        super(XLNetAPLC, self).__init__(config)
        self.num_labels = config.num_labels
        self.transformer = XLNetModel(config)
        self.sequence_summary = SequenceSummary(config)
        self.AdaptiveMultiLabel = AdaptiveBCEWithLogitsLoss(config.last_hidden_size, config.num_labels,
                                config.adaptive_cutoff, config.div_value)
        self.init_weights()
    @staticmethod
    def get_param(model, learning_rate_x, learning_rate_h, learning_rate_a, weight_decay):
        if hasattr(model, 'module'):
            model = model.module
        optimizer_grouped_parameters = [
            {'params': [p for n, p in model.transformer.named_parameters() if not any(nd in n for nd in no_decay)], 'weight_decay': weight_decay, "lr": learning_rate_x},
            {'params': [p for n, p in model.transformer.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0, "lr": learning_rate_x},
            {'params': [p for n, p in model.sequence_summary.named_parameters() if not any(nd in n for nd in no_decay)], 'weight_decay': weight_decay, "lr": learning_rate_h},
            {'params': [p for n, p in model.sequence_summary.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0, "lr": learning_rate_h},
            {'params': [p for n, p in model.AdaptiveMultiLabel.named_parameters() if not any(nd in n for nd in no_decay)], 'weight_decay': weight_decay,"lr": learning_rate_a},
            {'params': [p for n, p in model.AdaptiveMultiLabel.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0, "lr": learning_rate_a}
            ]
        return optimizer_grouped_parameters

    def forward(self, input_ids, token_type_ids=None, input_mask=None, attention_mask=None,
                mems=None, perm_mask=None, target_mapping=None,
                labels=None, head_mask=None, save='none', **kwargs):
        xlnet_outputs = self.transformer(input_ids, token_type_ids=token_type_ids,
                                               input_mask=input_mask, attention_mask=attention_mask,
                                               mems=mems, perm_mask=perm_mask, target_mapping=target_mapping,
                                               head_mask=head_mask)
        output = xlnet_outputs[0]
        feature = self.sequence_summary(output)
        loss = self.AdaptiveMultiLabel(feature,labels)
        output = self.AdaptiveMultiLabel.predict(feature.detach())
        if 'feature' in save:
            outputs = [loss, output, feature.detach()]
        else:
            outputs = [loss, output]
        return outputs