# Copyright (c) 2020, Zhouxing shi <zhouxingshichn@gmail.com>
# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
# Copyright (c) 2018, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""PyTorch BERT model."""

from __future__ import absolute_import, division, print_function, unicode_literals

import copy
import json
import logging
import math
import os
import shutil
import tarfile
import tempfile
import sys
from io import open

import numpy as np
import scipy as sp
import torch
from torch import nn
from torch.nn import CrossEntropyLoss
from .oni_lib import ONI_Linear, GroupSort

from pytorch_pretrained_bert.file_utils import cached_path, WEIGHTS_NAME, CONFIG_NAME

from pytorch_pretrained_bert.modeling import BertConfig, load_tf_weights_in_bert, logger,\
    BertPreTrainedModel, ACT2FN, PRETRAINED_MODEL_ARCHIVE_MAP, BERT_CONFIG_NAME, TF_WEIGHTS_NAME

class BertLayerNorm(nn.Module):
    def __init__(self, hidden_size, eps=1e-12):
        """Construct a layernorm module in the TF style (epsilon inside the square root).
        """
        assert 0, "cannot use it for 1-lip"
        super(BertLayerNorm, self).__init__()
        self.weight = nn.Parameter(torch.ones(hidden_size))
        self.bias = nn.Parameter(torch.zeros(hidden_size))
        self.variance_epsilon = eps

    def forward(self, x):
        u = x.mean(-1, keepdim=True)
        s = (x - u).pow(2).mean(-1, keepdim=True)
        x = (x - u) / torch.sqrt(s + self.variance_epsilon)
        return self.weight * x + self.bias

class BertLayerNormNoVar(nn.Module):
    def __init__(self, hidden_size, eps=1e-12):
        """Construct a layernorm module in the TF style (epsilon inside the square root).
        """
        super(BertLayerNormNoVar, self).__init__()
        #self.weight = nn.Parameter(torch.ones(hidden_size))
        self.bias = nn.Parameter(torch.zeros(hidden_size))
        self.variance_epsilon = eps

    def forward(self, x):
        u = x.mean(-1, keepdim=True)
        x = x - u
        #return self.weight * x + self.bias       
        return x + self.bias       

class OneLipBertSelfAttention(nn.Module):
    def __init__(self, config):
        super(OneLipBertSelfAttention, self).__init__()
        if config.hidden_size % config.num_attention_heads != 0:
            raise ValueError(
                "The hidden size (%d) is not a multiple of the number of attention "
                "heads (%d)" % (config.hidden_size, config.num_attention_heads))
        #assert config.num_attention_heads == 1
        self.num_attention_heads = config.num_attention_heads
        self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
        self.all_head_size = self.num_attention_heads * self.attention_head_size

        #self.query = ONI_Linear(config.hidden_size, self.all_head_size, bias=False)
        #self.key = ONI_Linear(config.hidden_size, self.all_head_size, bias=False)
        self.query = ONI_Linear(config.hidden_size, self.all_head_size, bias=True)
        self.key = ONI_Linear(config.hidden_size, self.all_head_size, bias=True)
        self.value = ONI_Linear(config.hidden_size, self.all_head_size, bias=False)
        #self.att_gen = nn.Sequential(GroupSort(dim=4), ONI_Linear(self.attention_head_size, 1, bias=False))
        self.att_gen = nn.Sequential(GroupSort(dim=4), ONI_Linear(self.attention_head_size, 1, bias=True))
        #self.att_gen = nn.Sequential(nn.ReLU(), ONI_Linear(self.attention_head_size, 1, bias=False))
        #self.att_gen = nn.Sequential(nn.Tanh(), ONI_Linear(self.attention_head_size, 1, bias=False))

        #self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
        self.log_ratio = nn.Parameter(torch.FloatTensor([0.0]))
        self.log_ratio.requires_grad = True

    def transpose_for_scores(self, x):
        new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
        x = x.view(*new_x_shape)
        return x.permute(0, 2, 1, 3)

    def forward(self, hidden_states, attention_mask):
        ratio = torch.exp(self.log_ratio)

        mixed_query_layer = self.query(hidden_states)
        mixed_key_layer = self.key(hidden_states)
        mixed_value_layer = self.value(hidden_states)

        query_layer = self.transpose_for_scores(mixed_query_layer)
        key_layer = self.transpose_for_scores(mixed_key_layer)
        value_layer = self.transpose_for_scores(mixed_value_layer)

        attention_scores = (key_layer.unsqueeze(3)+query_layer.unsqueeze(2))/2
        #attention_scores = attention_scores / 2
        #attention_scores = attention_scores / 2 / np.sqrt(hidden_states.shape[1])
        #attention_scores = attention_scores / 2 / hidden_states.shape[1]
        attention_scores = attention_scores / 2 / hidden_states.shape[1] * ratio
        attention_scores = self.att_gen(attention_scores).squeeze(4)
        #attention_scores = attention_scores + attention_mask 
        #attention_probs = nn.Softmax(dim=-1)(attention_scores) * 0.8
        attention_probs = nn.Softmax(dim=-1)(attention_scores) / (1+ratio/4)

        ## Take the dot product between "query" and "key" to get the raw attention scores.
        #attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
        #attention_scores = attention_scores / math.sqrt(self.attention_head_size) # DEBUG
        ## Apply the attention mask is (precomputed for all layers in BertModel forward() function)
        #attention_scores = attention_scores + attention_mask 
 
        ## Normalize the attention scores to probabilities.
        #attention_probs = nn.Softmax(dim=-1)(attention_scores)

        ## This is actually dropping out entire tokens to attend to, which might
        ## seem a bit unusual, but is taken from the original Transformer paper.
        ##attention_probs = self.dropout(attention_probs)

        context_layer = torch.matmul(attention_probs, value_layer)
        context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
        new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
        context_layer = context_layer.view(*new_context_layer_shape)

        return context_layer, attention_scores, attention_probs


class OneLipBertSelfOutput(nn.Module):
    def __init__(self, config):
        super(OneLipBertSelfOutput, self).__init__()
        self.config = config
        self.dense = ONI_Linear(config.hidden_size, config.hidden_size, bias=False)
        if hasattr(config, "layer_norm") and config.layer_norm == "no_var":
            self.LayerNorm = BertLayerNormNoVar(config.hidden_size, eps=1e-12)    
        else:
            self.LayerNorm = BertLayerNorm(config.hidden_size, eps=1e-12)
        #self.dropout = nn.Dropout(config.hidden_dropout_prob)

        #print ("DIRECTLY OUTPUTING HIDDEN STATES!!!")
        #print ("DIRECTLY OUTPUTING HIDDEN STATES!!!")
        #print ("DIRECTLY OUTPUTING HIDDEN STATES!!!")

    def forward(self, hidden_states, input_tensor):
        hidden_states = self.dense(hidden_states)
        #hidden_states = self.dropout(hidden_states)
        hidden_states = (hidden_states + input_tensor)/2 # Residual
        #hidden_states = input_tensor # DIRECT OUTPUT
        if hasattr(self.config, "layer_norm") and self.config.layer_norm == "no":
            pass
        else:
            hidden_states = self.LayerNorm(hidden_states)
        return hidden_states


class OneLipBertAttention(nn.Module):
    def __init__(self, config):
        super(OneLipBertAttention, self).__init__()
        if config.self_att_type == 'v1':
            self.self = OneLipBertSelfAttention(config)
        elif config.self_att_type == 'v3':
            self.self = OneLipV3BertSelfAttention(config)
        else:
            raise NotImplementedError()
        self.output = OneLipBertSelfOutput(config)

    def forward(self, input_tensor, attention_mask):
        self_output, attention_scores, attention_probs = self.self(input_tensor, attention_mask)
        attention_output = self.output(self_output, input_tensor)

        return attention_output, self_output, attention_scores, attention_probs


class BertIntermediate(nn.Module):
    def __init__(self, config):
        super(BertIntermediate, self).__init__()
        self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
        if isinstance(config.hidden_act, str) or (sys.version_info[0] == 2 and isinstance(config.hidden_act, unicode)):
            self.intermediate_act_fn = ACT2FN[config.hidden_act]
        else:
            self.intermediate_act_fn = config.hidden_act

    def forward(self, hidden_states):
        hidden_states = self.dense(hidden_states)
        hidden_states = self.intermediate_act_fn(hidden_states)
        return hidden_states


class OneLipBertOutput(nn.Module):
    def __init__(self, config):
        super(OneLipBertOutput, self).__init__()
        self.config = config
        #self.dense = ONI_Linear(config.intermediate_size, config.hidden_size)
        self.dense = ONI_Linear(config.hidden_size, config.hidden_size, bias=False)
        if hasattr(config, "layer_norm") and config.layer_norm == "no_var":
            self.LayerNorm = BertLayerNormNoVar(config.hidden_size, eps=1e-12)    
        else:
            self.LayerNorm = BertLayerNorm(config.hidden_size, eps=1e-12)
        #self.dropout = nn.Dropout(config.hidden_dropout_prob)

    def forward(self, hidden_states):
        hidden_states = self.dense(hidden_states)
        #hidden_states = self.dropout(hidden_states)
        #hidden_states = hidden_states + input_tensor
        if hasattr(self.config, "layer_norm") and self.config.layer_norm == "no":
            pass
        else:
            hidden_states = self.LayerNorm(hidden_states)
        return hidden_states

class OneLipBertLayer(nn.Module):
    def __init__(self, config):
        super(OneLipBertLayer, self).__init__()
        self.attention = OneLipBertAttention(config)
        #self.intermediate = BertIntermediate(config)
        self.output = OneLipBertOutput(config)

    def forward(self, hidden_states, attention_mask):
        attention_output, self_output, attention_scores, attention_probs = self.attention(hidden_states, attention_mask)
        #print ("attention_output", attention_output)
        #intermediate_output = self.intermediate(attention_output)
        #layer_output = self.output(intermediate_output, attention_output)
        layer_output = self.output(attention_output)
        #print ("layer_output", attention_output)

        return layer_output, self_output, attention_scores, attention_probs

class OneLipBertEncoder(nn.Module):
    def __init__(self, config):
        super(OneLipBertEncoder, self).__init__()
        layer = OneLipBertLayer(config)
        self.layer = nn.ModuleList([copy.deepcopy(layer) for _ in range(config.num_hidden_layers)])

    def forward(self, hidden_states, attention_mask, output_all_encoded_layers=True):
        all_encoder_layers = []
        all_self_output = [] # right after summation weighted by softmax probs
        all_attention_scores = []
        all_attention_probs = []
        all_attention_output = []
        for layer_module in self.layer:
            hidden_states, self_output, attention_scores, attention_probs = layer_module(hidden_states, attention_mask)
            if output_all_encoded_layers:
                all_encoder_layers.append(hidden_states)
                all_self_output.append(self_output)
                all_attention_scores.append(attention_scores)
                all_attention_probs.append(attention_probs)
        #print ("enc layer -1", all_encoder_layers[-1])
        if not output_all_encoded_layers:
            all_encoder_layers.append(hidden_states)
            all_self_output.append(self_output)
            all_attention_scores.append(attention_scores)
            all_attention_probs.append(attention_probs)
        return all_encoder_layers, all_self_output, all_attention_scores, all_attention_probs

class OneLipBertPooler(nn.Module):
    def __init__(self, config):
        super(OneLipBertPooler, self).__init__()
        #self.dense = ONI_Linear(config.hidden_size, config.hidden_size, bias=False)
        self.dense = ONI_Linear(config.hidden_size, config.hidden_size, bias=True)
        #print ("Not GroupSort!!")
        #print ("Not GroupSort!!")
        #print ("Not GroupSort!!")
        self.activation = GroupSort()
        #self.activation = nn.ReLU()
        #self.activation = nn.Tanh()

    def forward(self, hidden_states):
        # We "pool" the model by simply taking the hidden state corresponding
        # to the first token.
        #first_token_tensor = hidden_states[:, 0]
        #pooled_output = self.dense(first_token_tensor)
        #pooled_output = self.activation(pooled_output)
        pooled_tensor = hidden_states.mean(1) * np.sqrt(hidden_states.shape[1])
        pooled_output = self.dense(pooled_tensor)
        pooled_output = self.activation(pooled_output)
        return pooled_output

class OneLipBertModel(BertPreTrainedModel):
    def __init__(self, config, approach):
        super(OneLipBertModel, self).__init__(config)
        #self.embeddings = OneLipBertEmbeddings(config)
        config.self_att_type = 'v1'
        self.emb_transform = ONI_Linear(768, config.hidden_size, bias=False)
        self.encoder = OneLipBertEncoder(config)
        self.pooler = OneLipBertPooler(config)
        #self.apply(self.init_bert_weights)
        #self.apply(self.init_emb_weights)
    def init_emb_weights(self, module):
        if isinstance(module, nn.Embedding):
            module.weight.data.normal_()

    def forward(self, token_type_ids=None, attention_mask=None, output_all_encoded_layers=True, embeddings=None):
        assert attention_mask is not None
        #if attention_mask is None:
        #    attention_mask = torch.ones((embeddings.shape[0], embeddings.shape[1])).to(embeddings.device)

        # We create a 3D attention mask from a 2D tensor mask.
        # Sizes are [batch_size, 1, 1, to_seq_length]
        # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
        # this attention mask is more simple than the triangular masking of causal attention
        # used in OpenAI GPT, we just need to prepare the broadcast dimension here.
        extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)

        # Since attention_mask is 1.0 for positions we want to attend and 0.0 for
        # masked positions, this operation will create a tensor which is 0.0 for
        # positions we want to attend and -10000.0 for masked positions.
        # Since we are adding it to the raw scores before the softmax, this is
        # effectively the same as removing these entirely.
        extended_attention_mask = extended_attention_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibility
        extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
        
        embeddings = self.emb_transform(embeddings)

        encoded_layers, self_output, attention_scores, attention_probs = self.encoder(embeddings,
                                    extended_attention_mask,
                                    output_all_encoded_layers=output_all_encoded_layers)            
            
        sequence_output = encoded_layers[-1]
        pooled_output = self.pooler(sequence_output)
        if not output_all_encoded_layers:
            encoded_layers = encoded_layers[-1]
            attention_scores = attention_scores[-1]
            attention_probs = attention_probs[-1]
        return encoded_layers, attention_scores, attention_probs, \
            pooled_output, self_output


class OneLipV3BertSelfAttention(nn.Module):
    def __init__(self, config):
        super(OneLipV3BertSelfAttention, self).__init__()
        if config.hidden_size % config.num_attention_heads != 0:
            raise ValueError(
                "The hidden size (%d) is not a multiple of the number of attention "
                "heads (%d)" % (config.hidden_size, config.num_attention_heads))
        #assert config.num_attention_heads == 1
        self.num_attention_heads = config.num_attention_heads
        self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
        self.all_head_size = self.num_attention_heads * self.attention_head_size

        #self.query_key = nn.Linear(config.hidden_size, self.all_head_size, bias=True)
        #self.value = nn.Linear(config.hidden_size, self.all_head_size, bias=False)
        self.query_key = ONI_Linear(config.hidden_size, self.all_head_size, bias=True)
        self.value = ONI_Linear(config.hidden_size, self.all_head_size, bias=False)
        #self.att_gen = nn.Sequential(GroupSort(dim=4), ONI_Linear(self.attention_head_size, 1, bias=False))
        #self.att_gen = nn.Sequential(GroupSort(dim=4), ONI_Linear(self.attention_head_size, 1, bias=True))
        #self.att_gen = nn.Sequential(nn.ReLU(), ONI_Linear(self.attention_head_size, 1, bias=False))
        #self.att_gen = nn.Sequential(nn.Tanh(), ONI_Linear(self.attention_head_size, 1, bias=False))

        #self.dropout = nn.Dropout(config.attention_probs_dropout_prob)

    def transpose_for_scores(self, x):
        new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
        x = x.view(*new_x_shape)
        return x.permute(0, 2, 1, 3)

    def forward(self, hidden_states, attention_mask):
        #alpha = 1.0 # TODO
        N = hidden_states.shape[1]
        alpha = np.sqrt(N) * (4 * float(sp.special.lambertw(N/2.71828)) + 1)
        #alpha = 99999.9 # TODO
        #print ("===========================")
        #print (N, alpha)
        #print ("hidden:", hidden_states)

        mixed_query_layer = self.query_key(hidden_states)
        mixed_key_layer = self.query_key(hidden_states)
        mixed_value_layer = self.value(hidden_states)

        query_layer = self.transpose_for_scores(mixed_query_layer)
        key_layer = self.transpose_for_scores(mixed_key_layer)
        value_layer = self.transpose_for_scores(mixed_value_layer)

        dist = ((key_layer.unsqueeze(3)-query_layer.unsqueeze(2))**2).sum(4)
        attention_scores = -dist / alpha
        #print ("ATT: Average!")
        #attention_scores = -dist / 99999
        #print ("att_score:",attention_scores)
        #attention_scores = attention_scores + attention_mask 
        attention_probs = nn.Softmax(dim=-1)(attention_scores) / alpha
        #print ("att_prob:",attention_probs)

        context_layer = torch.matmul(attention_probs, value_layer)
        #print ("context:",context_layer)

        context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
        new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
        context_layer = context_layer.view(*new_context_layer_shape)
        #print ("context2:",context_layer)
        #print ("===========================")
        #print (context_layer)

        return context_layer, attention_scores, attention_probs


class OneLipV2BertModel(BertPreTrainedModel):
    def __init__(self, config, approach):
        #raise NotImplementedError()
        super(OneLipV2BertModel, self).__init__(config)
        #self.embeddings = OneLipBertEmbeddings(config, normalize=False)
        self.emb_transform = ONI_Linear(768, config.hidden_size, bias=False)
        if approach == 'onelip-softmax-v2':
            config.self_att_type = 'v2'
        elif approach == 'onelip-softmax-v3':
            config.self_att_type = 'v3'
        else:
            raise NotImplementedError()
        self.encoder = OneLipBertEncoder(config)
        self.pooler = OneLipBertPooler(config)
        #self.apply(self.init_bert_weights)
        self.apply(self.init_emb_weights)
    def init_emb_weights(self, module):
        if isinstance(module, nn.Embedding):
            module.weight.data.normal_()

    def forward(self, token_type_ids=None, attention_mask=None, output_all_encoded_layers=True, embeddings=None):
        assert attention_mask is not None

        # We create a 3D attention mask from a 2D tensor mask.
        # Sizes are [batch_size, 1, 1, to_seq_length]
        # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
        # this attention mask is more simple than the triangular masking of causal attention
        # used in OpenAI GPT, we just need to prepare the broadcast dimension here.
        extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)

        # Since attention_mask is 1.0 for positions we want to attend and 0.0 for
        # masked positions, this operation will create a tensor which is 0.0 for
        # positions we want to attend and -10000.0 for masked positions.
        # Since we are adding it to the raw scores before the softmax, this is
        # effectively the same as removing these entirely.
        extended_attention_mask = extended_attention_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibility
        extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
        
        embeddings = self.emb_transform(embeddings)

        encoded_layers, self_output, attention_scores, attention_probs = self.encoder(embeddings,
                                    extended_attention_mask,
                                    output_all_encoded_layers=output_all_encoded_layers)            
            
        sequence_output = encoded_layers[-1]
        #print ("seq out:",sequence_output)
        #print (sequence_output.shape)
        pooled_output = self.pooler(sequence_output)
        #print ("pool out:",pooled_output)
        #print (pooled_output.shape)
        #assert 0
        if not output_all_encoded_layers:
            encoded_layers = encoded_layers[-1]
            attention_scores = attention_scores[-1]
            attention_probs = attention_probs[-1]
        return encoded_layers, attention_scores, attention_probs, \
            pooled_output, self_output



class OneLipBertForSequenceClassificationWithPretrain(BertPreTrainedModel):
    """BERT model for classification.
    This module is composed of the BERT model with a linear layer on top of
    the pooled output.
    Params:
        `config`: a BertConfig class instance with the configuration to build a new model.
        `num_labels`: the number of classes for the classifier. Default = 2.
    Inputs:
        `input_ids`: a torch.LongTensor of shape [batch_size, sequence_length]
            with the word token indices in the vocabulary. Items in the batch should begin with the special "CLS" token. (see the tokens preprocessing logic in the scripts
            `extract_features.py`, `run_classifier.py` and `run_squad.py`)
        `token_type_ids`: an optional torch.LongTensor of shape [batch_size, sequence_length] with the token
            types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` and type 1 corresponds to
            a `sentence B` token (see BERT paper for more details).
        `attention_mask`: an optional torch.LongTensor of shape [batch_size, sequence_length] with indices
            selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max
            input sequence length in the current batch. It's the mask that we typically use for attention when
            a batch has varying length sentences.
        `labels`: labels for the classification output: torch.LongTensor of shape [batch_size]
            with indices selected in [0, ..., num_labels].
    Outputs:
        if `labels` is not `None`:
            Outputs the CrossEntropy classification loss of the output with the labels.
        if `labels` is `None`:
            Outputs the classification logits of shape [batch_size, num_labels].
    Example usage:
    ```python
    # Already been converted into WordPiece token ids
    input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]])
    input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]])
    token_type_ids = torch.LongTensor([[0, 0, 1], [0, 1, 0]])
    config = BertConfig(vocab_size_or_config_json_file=32000, hidden_size=768,
        num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072)
    num_labels = 2
    model = BertForSequenceClassification(config, num_labels)
    logits = model(input_ids, token_type_ids, input_mask)
    ```
    """
    def __init__(self, config, num_labels=2, approach='onelip-softmax', last_noreg=False):
        super(OneLipBertForSequenceClassificationWithPretrain, self).__init__(config)
        self.num_labels = num_labels
        if approach == 'onelip-softmax':
            self.bert = OneLipBertModel(config, approach)
        elif approach == 'onelip-softmax-v3':
            self.bert = OneLipV2BertModel(config, approach)
        else:
            raise NotImplementedError()
        #self.dropout = nn.Dropout(config.hidden_dropout_prob)
        if last_noreg:
            #self.classifier = nn.Linear(config.hidden_size, num_labels, bias=False)
            self.classifier = nn.Linear(config.hidden_size, num_labels, bias=True)
        else:
            #self.classifier = ONI_Linear(config.hidden_size, num_labels, bias=False)
            self.classifier = ONI_Linear(config.hidden_size, num_labels, bias=True)
        #print ("SCALE == 2!!!")
        #print ("SCALE == 2!!!")
        #print ("SCALE == 2!!!")
        #self.classifier = ONI_Linear(config.hidden_size, num_labels, bias=False, scale=2.0)
        #print ("SCALE == 5!!!")
        #print ("SCALE == 5!!!")
        #print ("SCALE == 5!!!")
        #self.classifier = ONI_Linear(config.hidden_size, num_labels, bias=False, scale=5.0)
        #self.apply(self.init_bert_weights)

    def forward(self, embeddings, token_type_ids=None, attention_mask=None, labels=None):
        encoded_layers, attention_scores, attention_probs, pooled_output, self_output = \
            self.bert(token_type_ids, attention_mask, output_all_encoded_layers=True, embeddings=embeddings)

        #pooled_output = self.dropout(pooled_output)
        logits = self.classifier(pooled_output)
        #print (logits)
        #assert 0

        assert(labels is None)

        if labels is not None:
            loss_fct = CrossEntropyLoss()
            loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
            return loss
        else:
            return logits, encoded_layers, attention_scores, attention_probs, self_output, pooled_output
