import torch
from torch import nn
from transformers import BertModel, AutoConfig

import torch.functional as F

INPUT_IDS, TOKEN_TYPE_IDS, ATTENTION_MASK, LABEL = "input_ids", "token_type_ids", "attention_mask", "label"
DEVICE = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")


class bert_clf(nn.Module):
    def __init__(self, args):
        super(bert_clf, self).__init__()
        self.args = args

        self.dropout = nn.Dropout(args.dropout).to(args.device)
        self.relu = nn.ReLU().to(args.device)

        self.configuration = AutoConfig.from_pretrained('bert-base-uncased')
        self.configuration.hidden_dropout_prob = 0.0
        self.configuration.attention_probs_dropout_prob = 0.0
        self.bert = BertModel.from_pretrained(pretrained_model_name_or_path="bert-base-uncased",
                                              config=self.configuration).to(args.device)

        self.bert.pooler = None
        self.bert.to(args.device)

        self.adaptor = nn.Linear(self.bert.config.hidden_size, args.num_labels).to(args.device)
        self.cross_entropy = torch.nn.CrossEntropyLoss(reduction='none')

    def forward(self, batch, list_loss=False):
        torch.clear_autocast_cache()
        input_ids, token_type_ids, attention_mask, labels = get_batch(batch)
        last_hidden = self.bert(input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask)
        representation = last_hidden[0][:, 0].to(self.args.device)
        query = self.dropout(representation).to(self.args.device)
        logits = self.adaptor(query).to(self.args.device)
        loss = self.cross_entropy(logits, labels)
        if list_loss:
            return representation.to(self.args.device), logits.to(self.args.device), labels.to(
                self.args.device), loss.to(self.args.device)

        # loss = torch.mean(loss)
        return representation.to(self.args.device), logits.to(self.args.device), labels.to(
            self.args.device), loss.mean().to(
            self.args.device)


def get_batch(batch):
    input_ids, token_type_ids, attention_mask, labels = None, None, None, None
    for data in batch:
        if input_ids is None:
            input_ids = torch.unsqueeze(data[INPUT_IDS], 0)
            token_type_ids = torch.unsqueeze(data[TOKEN_TYPE_IDS], 0)
            attention_mask = torch.unsqueeze(data[ATTENTION_MASK], 0)
            labels = torch.unsqueeze(data[LABEL], 0)
        else:
            input_ids = torch.cat((input_ids, torch.unsqueeze(data[INPUT_IDS], 0)), 0)
            token_type_ids = torch.cat((token_type_ids, torch.unsqueeze(data[TOKEN_TYPE_IDS], 0)), 0)
            attention_mask = torch.cat((attention_mask, torch.unsqueeze(data[ATTENTION_MASK], 0)), 0)
            labels = torch.cat((labels, torch.unsqueeze(data[LABEL], 0)))

    return input_ids.to(DEVICE), token_type_ids.to(DEVICE), attention_mask.to(DEVICE), labels.to(DEVICE)
