import torch
import torch.nn as nn
import torch.nn.functional as F
from pytorch_transformers import BertModel

from tools.accuracy_init import init_accuracy_function


class BasicBert(nn.Module):
    def __init__(self, config, gpu_list, *args, **params):
        super(BasicBert, self).__init__()

        self.output_dim = config.getint("model", "output_dim")
        self.bert = BertModel.from_pretrained(config.get("model", "bert_path"))
        self.fc = nn.Linear(768, self.output_dim)

        self.criterion = nn.CrossEntropyLoss()
        self.accuracy_function = init_accuracy_function(config, *args, **params)

    def init_multi_gpu(self, device, config, *args, **params):
        self.bert = nn.DataParallel(self.bert, device_ids=device)

    def forward(self, data, config, gpu_list, acc_result, mode):
        x = data['input']

        _, y = self.bert(x)
        y = y.view(y.size()[0], -1)
        y = self.fc(y)
        y = y.view(y.size()[0], -1)

        if "label" in data.keys():
            label = data["label"]
            loss = self.criterion(y, label)
            acc_result = self.accuracy_function(y, label, config, acc_result)
            return {"loss": loss, "acc_result": acc_result}

        return {}
