import math
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from pytorch_transformers import BertModel

from tools.accuracy_init import init_accuracy_function

def gelu(x):
    """Implementation of the gelu activation function.
        For information: OpenAI GPT's gelu is slightly different (and gives slightly different results):
        0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))
    """
    return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0)))


class BasicRE(nn.Module):
    def __init__(self, config, gpu_list, *args, **params):
        super(BasicRE, self).__init__()

        self.input_dim = config.getint("model", "input_dim")
        self.output_dim = config.getint("model", "output_dim")
        self.fc = nn.Linear(self.input_dim, self.output_dim)

        self.criterion = nn.CrossEntropyLoss()
        self.accuracy_function = init_accuracy_function(config, *args, **params)

    def init_multi_gpu(self, device, config, *args, **params):
        #self.bert = nn.DataParallel(self.bert, device_ids=device)
        pass

    def forward(self, data, config, gpu_list, acc_result, mode):
        x = data['input']

        y = self.fc(x)
        y = y.view(y.size()[0], -1)

        if "label" in data.keys():
            label = data["label"]
            loss = self.criterion(y, label)
            acc_result = self.accuracy_function(y, label, config, acc_result)
            return {"loss": loss, "acc_result": acc_result}

        return {}

class CnnRE(nn.Module):
    def __init__(self, config, gpu_list, *args, **params):
        super(CnnRE, self).__init__()

        word_embedding_dim = 50
        self.word_embedding = nn.Embedding(400002, word_embedding_dim)
        word2vec = np.load('/data1/private/zzy/nre/data/pretrain/glove/word2vec.npy')
        word2vec = torch.from_numpy(word2vec)
        unk = torch.randn(1, word_embedding_dim) / math.sqrt(word_embedding_dim)
        blk = torch.zeros(1, word_embedding_dim)
        self.word_embedding.weight.data.copy_(torch.cat([word2vec, unk, blk], 0))

        max_length = config.getint("data", "max_seq_length")
        pos_embedding_dim = 5
        kernel_size = 3
        padding = 1
        hidden_size = 230

        # Position Embedding
        self.pos1_embedding = nn.Embedding(2 * max_length, pos_embedding_dim, padding_idx=0)
        self.pos2_embedding = nn.Embedding(2 * max_length, pos_embedding_dim, padding_idx=0)

        self.embedding_dim = word_embedding_dim + pos_embedding_dim * 2
        # CNN
        self.conv = nn.Conv1d(self.embedding_dim, hidden_size, kernel_size, padding=padding)

        self.drop = nn.Dropout(0.5)

        self.output_dim = config.getint("model", "output_dim")
        self.fc = nn.Linear(hidden_size, self.output_dim)

        self.criterion = nn.CrossEntropyLoss()
        self.accuracy_function = init_accuracy_function(config, *args, **params)

    def init_multi_gpu(self, device, config, *args, **params):
        #self.bert = nn.DataParallel(self.bert, device_ids=device)
        pass

    def forward(self, data, config, gpu_list, acc_result, mode):
        token = data['token']
        pos1 = data['pos1']
        pos2 = data['pos2']

        x = torch.cat([self.word_embedding(token), 
                    self.pos1_embedding(pos1), 
                    self.pos2_embedding(pos2)], 2) # (B, L, EMBED)
        x = x.transpose(1, 2) # (B, EMBED, L)
        x = self.conv(x) # (B, H, L)
        x = torch.relu(x) # (B, H, L)
        x, _ = x.max(-1) # (B, H)
        x = self.drop(x)

        y = self.fc(x)
        y = y.view(y.size()[0], -1)

        if "label" in data.keys():
            label = data["label"]
            loss = self.criterion(y, label)
            acc_result = self.accuracy_function(y, label, config, acc_result)
            return {"loss": loss, "acc_result": acc_result}

        return {}


class CnnKFRE(nn.Module):
    def __init__(self, config, gpu_list, *args, **params):
        super(CnnKFRE, self).__init__()

        word_embedding_dim = 50
        self.word_embedding = nn.Embedding(400002, word_embedding_dim)
        word2vec = np.load('/data1/private/zzy/nre/data/pretrain/glove/word2vec.npy')
        word2vec = torch.from_numpy(word2vec)
        unk = torch.randn(1, word_embedding_dim) / math.sqrt(word_embedding_dim)
        blk = torch.zeros(1, word_embedding_dim)
        self.word_embedding.weight.data.copy_(torch.cat([word2vec, unk, blk], 0))

        max_length = config.getint("data", "max_seq_length")
        pos_embedding_dim = 5
        kernel_size = 3
        padding = 1
        hidden_size = 230

        # Position Embedding
        self.pos1_embedding = nn.Embedding(2 * max_length, pos_embedding_dim, padding_idx=0)
        self.pos2_embedding = nn.Embedding(2 * max_length, pos_embedding_dim, padding_idx=0)

        self.embedding_dim = word_embedding_dim + pos_embedding_dim * 2
        # CNN
        self.conv = nn.Conv1d(self.embedding_dim, hidden_size, kernel_size, padding=padding)

        self.drop = nn.Dropout(0.5)

        self.input_dim = config.getint("model", "input_dim")
        self.output_dim = config.getint("model", "output_dim")
        self.fc = nn.Linear(hidden_size+self.input_dim, self.output_dim)

        self.criterion = nn.CrossEntropyLoss()
        self.accuracy_function = init_accuracy_function(config, *args, **params)

    def init_multi_gpu(self, device, config, *args, **params):
        #self.bert = nn.DataParallel(self.bert, device_ids=device)
        pass

    def forward(self, data, config, gpu_list, acc_result, mode):
        token = data['token']
        pos1 = data['pos1']
        pos2 = data['pos2']
        ent_emb = data['ent_emb']

        x = torch.cat([self.word_embedding(token), 
                    self.pos1_embedding(pos1), 
                    self.pos2_embedding(pos2)], 2) # (B, L, EMBED)
        x = x.transpose(1, 2) # (B, EMBED, L)
        x = self.conv(x) # (B, H, L)
        x = torch.relu(x) # (B, H, L)
        x, _ = x.max(-1) # (B, H)
        x = self.drop(x)

        x = torch.cat([x, ent_emb], 1)

        y = self.fc(x)
        y = y.view(y.size()[0], -1)

        if "label" in data.keys():
            label = data["label"]
            loss = self.criterion(y, label)
            acc_result = self.accuracy_function(y, label, config, acc_result)
            return {"loss": loss, "acc_result": acc_result}

        return {}

class KnowledgeAtt(nn.Module):

    def __init__(self, hidden_dim, ent_dim):
        super().__init__()
        self.hidden_dim = hidden_dim
        self.A = nn.Linear(ent_dim // 2, hidden_dim, bias=True)
        self.linear = nn.Linear(self.hidden_dim, self.hidden_dim // 2, bias=True)
        self.dropout = nn.Dropout(0.5)
    
    def forward(self, x, ke, mask):
        # x (B, H, L)
        # ke (B, E)
        x = x.permute(0, 2, 1).contiguous()
        alpha = x * self.A(self.dropout(ke)).unsqueeze(-2)
        # alpha = x * self.A(ke).unsqueeze(-2)
        alpha = alpha.sum(-1)
        mask_ = torch.zeros_like(alpha)
        mask_[mask] = -10000
        alpha = alpha + mask_
        scores = nn.Softmax(dim=-1)(alpha)
        # x = self.linear(x) * scores.unsqueeze(-1)
        x = x * scores.unsqueeze(-1)
        return torch.relu(self.linear(x.sum(-2))), scores



class CnnKARE(nn.Module):
    def __init__(self, config, gpu_list, *args, **params):
        super(CnnKARE, self).__init__()

        word_embedding_dim = 50
        self.word_embedding = nn.Embedding(400002, word_embedding_dim)
        word2vec = np.load('/data1/private/zzy/nre/data/pretrain/glove/word2vec.npy')
        word2vec = torch.from_numpy(word2vec)
        unk = torch.randn(1, word_embedding_dim) / math.sqrt(word_embedding_dim)
        blk = torch.zeros(1, word_embedding_dim)
        self.word_embedding.weight.data.copy_(torch.cat([word2vec, unk, blk], 0))

        max_length = config.getint("data", "max_seq_length")
        pos_embedding_dim = 5
        kernel_size = 3
        padding = 1
        hidden_size = 230

        # Position Embedding
        self.pos1_embedding = nn.Embedding(2 * max_length, pos_embedding_dim, padding_idx=0)
        self.pos2_embedding = nn.Embedding(2 * max_length, pos_embedding_dim, padding_idx=0)

        self.embedding_dim = word_embedding_dim + pos_embedding_dim * 2
        # CNN
        self.conv = nn.Conv1d(self.embedding_dim, hidden_size, kernel_size, padding=padding)

        self.drop = nn.Dropout(0.5)

        self.input_dim = config.getint("model", "input_dim")
        self.output_dim = config.getint("model", "output_dim")

        self.att = KnowledgeAtt(hidden_size, self.input_dim)
        self.fc = nn.Linear(hidden_size, self.output_dim)

        self.criterion = nn.CrossEntropyLoss()
        self.accuracy_function = init_accuracy_function(config, *args, **params)

    def init_multi_gpu(self, device, config, *args, **params):
        #self.bert = nn.DataParallel(self.bert, device_ids=device)
        pass

    def forward(self, data, config, gpu_list, acc_result, mode):
        token = data['token']
        pos1 = data['pos1']
        pos2 = data['pos2']
        ent_emb = data['ent_emb']

        x = torch.cat([self.word_embedding(token), 
                    self.pos1_embedding(pos1), 
                    self.pos2_embedding(pos2)], 2) # (B, L, EMBED)
        x = x.transpose(1, 2) # (B, EMBED, L)
        x = self.conv(x) # (B, H, L)
        x = torch.relu(x) # (B, H, L)
        mask = token == 400001
        h, _ = self.att(x, ent_emb[:, :self.input_dim // 2], mask)
        t, _ = self.att(x, ent_emb[:, self.input_dim // 2:], mask)
        x = self.drop(torch.cat([h, t], -1))

        y = self.fc(x)
        y = y.view(y.size()[0], -1)

        if "label" in data.keys():
            label = data["label"]
            loss = self.criterion(y, label)
            acc_result = self.accuracy_function(y, label, config, acc_result)
            return {"loss": loss, "acc_result": acc_result}

        return {}


class LstmRE(nn.Module):
    def __init__(self, config, gpu_list, *args, **params):
        super(LstmRE, self).__init__()

        word_embedding_dim = 50
        self.word_embedding = nn.Embedding(400002, word_embedding_dim)
        word2vec = np.load('/data1/private/zzy/nre/data/pretrain/glove/word2vec.npy')
        word2vec = torch.from_numpy(word2vec)
        unk = torch.randn(1, word_embedding_dim) / math.sqrt(word_embedding_dim)
        blk = torch.zeros(1, word_embedding_dim)
        self.word_embedding.weight.data.copy_(torch.cat([word2vec, unk, blk], 0))

        max_length = config.getint("data", "max_seq_length")
        pos_embedding_dim = 5
        hidden_size = 230

        # Position Embedding
        self.pos1_embedding = nn.Embedding(2 * max_length, pos_embedding_dim, padding_idx=0)
        self.pos2_embedding = nn.Embedding(2 * max_length, pos_embedding_dim, padding_idx=0)

        self.embedding_dim = word_embedding_dim + pos_embedding_dim * 2
        # LSTM
        self.lstm = nn.LSTM(self.embedding_dim, hidden_size // 2, bidirectional=True)

        self.drop = nn.Dropout(0.5)

        self.output_dim = config.getint("model", "output_dim")
        self.fc = nn.Linear(hidden_size, self.output_dim)

        self.criterion = nn.CrossEntropyLoss()
        self.accuracy_function = init_accuracy_function(config, *args, **params)

    def init_multi_gpu(self, device, config, *args, **params):
        #self.bert = nn.DataParallel(self.bert, device_ids=device)
        pass

    def forward(self, data, config, gpu_list, acc_result, mode):
        token = data['token']
        pos1 = data['pos1']
        pos2 = data['pos2']

        x = torch.cat([self.word_embedding(token), 
                    self.pos1_embedding(pos1), 
                    self.pos2_embedding(pos2)], 2) # (B, L, EMBED)
        x = x.transpose(1, 0) # (L, B, EMBED)
        x, _ = self.lstm(x) # (L, B, H)
        x = torch.relu(x) # (L, B, H)
        x = x.transpose(1, 0)
        x, _ = x.max(1) # (B, H)
        x = self.drop(x)

        y = self.fc(x)
        y = y.view(y.size()[0], -1)

        if "label" in data.keys():
            label = data["label"]
            loss = self.criterion(y, label)
            acc_result = self.accuracy_function(y, label, config, acc_result)
            return {"loss": loss, "acc_result": acc_result}

        return {}


class LstmKFRE(nn.Module):
    def __init__(self, config, gpu_list, *args, **params):
        super(LstmKFRE, self).__init__()

        word_embedding_dim = 50
        self.word_embedding = nn.Embedding(400002, word_embedding_dim)
        word2vec = np.load('/data1/private/zzy/nre/data/pretrain/glove/word2vec.npy')
        word2vec = torch.from_numpy(word2vec)
        unk = torch.randn(1, word_embedding_dim) / math.sqrt(word_embedding_dim)
        blk = torch.zeros(1, word_embedding_dim)
        self.word_embedding.weight.data.copy_(torch.cat([word2vec, unk, blk], 0))

        max_length = config.getint("data", "max_seq_length")
        pos_embedding_dim = 5
        hidden_size = 230

        # Position Embedding
        self.pos1_embedding = nn.Embedding(2 * max_length, pos_embedding_dim, padding_idx=0)
        self.pos2_embedding = nn.Embedding(2 * max_length, pos_embedding_dim, padding_idx=0)

        self.embedding_dim = word_embedding_dim + pos_embedding_dim * 2
        # LSTM
        self.lstm = nn.LSTM(self.embedding_dim, hidden_size // 2, bidirectional=True)

        self.drop = nn.Dropout(0.5)

        self.input_dim = config.getint("model", "input_dim")
        self.output_dim = config.getint("model", "output_dim")
        self.fc = nn.Linear(hidden_size+self.input_dim, self.output_dim)

        self.criterion = nn.CrossEntropyLoss()
        self.accuracy_function = init_accuracy_function(config, *args, **params)

    def init_multi_gpu(self, device, config, *args, **params):
        #self.bert = nn.DataParallel(self.bert, device_ids=device)
        pass

    def forward(self, data, config, gpu_list, acc_result, mode):
        token = data['token']
        pos1 = data['pos1']
        pos2 = data['pos2']
        ent_emb = data['ent_emb']

        x = torch.cat([self.word_embedding(token), 
                    self.pos1_embedding(pos1), 
                    self.pos2_embedding(pos2)], 2) # (B, L, EMBED)
        x = x.transpose(1, 0) # (L, B, EMBED)
        x, _ = self.lstm(x) # (L, B, H)
        x = torch.relu(x) # (L, B, H)
        x = x.transpose(1, 0)
        x, _ = x.max(1) # (B, H)
        x = self.drop(x)

        x = torch.cat([x, ent_emb], 1)

        y = self.fc(x)
        y = y.view(y.size()[0], -1)

        if "label" in data.keys():
            label = data["label"]
            loss = self.criterion(y, label)
            acc_result = self.accuracy_function(y, label, config, acc_result)
            return {"loss": loss, "acc_result": acc_result}

        return {}


class LstmKARE(nn.Module):
    def __init__(self, config, gpu_list, *args, **params):
        super(LstmKARE, self).__init__()

        word_embedding_dim = 50
        self.word_embedding = nn.Embedding(400002, word_embedding_dim)
        word2vec = np.load('/data1/private/zzy/nre/data/pretrain/glove/word2vec.npy')
        word2vec = torch.from_numpy(word2vec)
        unk = torch.randn(1, word_embedding_dim) / math.sqrt(word_embedding_dim)
        blk = torch.zeros(1, word_embedding_dim)
        self.word_embedding.weight.data.copy_(torch.cat([word2vec, unk, blk], 0))

        max_length = config.getint("data", "max_seq_length")
        pos_embedding_dim = 5
        kernel_size = 3
        padding = 1
        hidden_size = 230

        # Position Embedding
        self.pos1_embedding = nn.Embedding(2 * max_length, pos_embedding_dim, padding_idx=0)
        self.pos2_embedding = nn.Embedding(2 * max_length, pos_embedding_dim, padding_idx=0)

        self.embedding_dim = word_embedding_dim + pos_embedding_dim * 2
        # LSTM
        self.lstm = nn.LSTM(self.embedding_dim, hidden_size // 2, bidirectional=True)

        self.drop = nn.Dropout(0.5)

        self.input_dim = config.getint("model", "input_dim")
        self.output_dim = config.getint("model", "output_dim")

        self.att = KnowledgeAtt(hidden_size, self.input_dim)
        self.fc = nn.Linear(hidden_size, self.output_dim)

        self.criterion = nn.CrossEntropyLoss()
        self.accuracy_function = init_accuracy_function(config, *args, **params)

    def init_multi_gpu(self, device, config, *args, **params):
        #self.bert = nn.DataParallel(self.bert, device_ids=device)
        pass

    def forward(self, data, config, gpu_list, acc_result, mode):
        token = data['token']
        pos1 = data['pos1']
        pos2 = data['pos2']
        ent_emb = data['ent_emb']

        x = torch.cat([self.word_embedding(token), 
                    self.pos1_embedding(pos1), 
                    self.pos2_embedding(pos2)], 2) # (B, L, EMBED)
        x = x.transpose(1, 0) # (L, B, EMBED)
        x, _ = self.lstm(x) # (L, B, H)
        x = torch.relu(x) # (L, B, H)
        x = x.permute(1, 2, 0) # (B, H, L)
        mask = token == 400001
        h, _ = self.att(x, ent_emb[:, :self.input_dim // 2], mask)
        t, _ = self.att(x, ent_emb[:, self.input_dim // 2:], mask)
        x = self.drop(torch.cat([h, t], -1))

        y = self.fc(x)
        y = y.view(y.size()[0], -1)

        if "label" in data.keys():
            label = data["label"]
            loss = self.criterion(y, label)
            acc_result = self.accuracy_function(y, label, config, acc_result)
            return {"loss": loss, "acc_result": acc_result}

        return {}


class BertRE(nn.Module):
    def __init__(self, config, gpu_list, *args, **params):
        super(BertRE, self).__init__()

        self.output_dim = config.getint("model", "output_dim")
        self.bert = BertModel.from_pretrained(config.get("model", "bert_path"))
        self.fc = nn.Linear(768, self.output_dim)

        self.criterion = nn.CrossEntropyLoss()
        self.accuracy_function = init_accuracy_function(config, *args, **params)

    def init_multi_gpu(self, device, config, *args, **params):
        #self.bert = nn.DataParallel(self.bert, device_ids=device)
        pass

    def forward(self, data, config, gpu_list, acc_result, mode):
        x = data['input']

        _, y = self.bert(x)
        y = y.view(y.size()[0], -1)
        y = self.fc(y)
        y = y.view(y.size()[0], -1)

        if "label" in data.keys():
            label = data["label"]
            loss = self.criterion(y, label)
            acc_result = self.accuracy_function(y, label, config, acc_result)
            return {"loss": loss, "acc_result": acc_result}

        return {}


class BertKFRE(nn.Module):
    def __init__(self, config, gpu_list, *args, **params):
        super(BertKFRE, self).__init__()

        self.input_dim = config.getint("model", "input_dim")
        self.output_dim = config.getint("model", "output_dim")
        self.bert = BertModel.from_pretrained(config.get("model", "bert_path"))
        self.fc = nn.Linear(768+self.input_dim, self.output_dim)

        self.dropout = nn.Dropout(0.5)
        self.fc_ent = nn.Linear(self.input_dim, self.input_dim)

        self.criterion = nn.CrossEntropyLoss()
        self.accuracy_function = init_accuracy_function(config, *args, **params)

    def init_multi_gpu(self, device, config, *args, **params):
        #self.bert = nn.DataParallel(self.bert, device_ids=device)
        pass

    def forward(self, data, config, gpu_list, acc_result, mode):
        x = data['input']
        ent_emb = data['ent_emb']

        _, y = self.bert(x)
        y = y.view(y.size()[0], -1)
        y = torch.cat([y, gelu(self.fc_ent(ent_emb))], 1)
        y = self.dropout(y)
        y = self.fc(y)
        y = y.view(y.size()[0], -1)

        if "label" in data.keys():
            label = data["label"]
            loss = self.criterion(y, label)
            acc_result = self.accuracy_function(y, label, config, acc_result)
            return {"loss": loss, "acc_result": acc_result}

        return {}

class KnowledgeAttBert(nn.Module):

    def __init__(self, hidden_dim, ent_dim):
        super().__init__()
        self.hidden_dim = hidden_dim
        self.A = nn.Linear(ent_dim, hidden_dim, bias=True)
        self.dropout = nn.Dropout(0.5)
    
    def forward(self, x, ke, mask):
        # x (B, H, L)
        # ke (B, E)
        x = x.permute(0, 2, 1).contiguous()
        alpha = x * self.A(self.dropout(ke)).unsqueeze(-2)
        # alpha = x * self.A(ke).unsqueeze(-2)
        alpha = alpha.sum(-1)
        mask_ = torch.zeros_like(alpha)
        mask_[mask] = -10000
        alpha = alpha + mask_
        scores = nn.Softmax(dim=-1)(alpha)
        # x = self.linear(x) * scores.unsqueeze(-1)
        x = x * scores.unsqueeze(-1)
        return x.sum(-2), scores

class BertKARE(nn.Module):
    def __init__(self, config, gpu_list, *args, **params):
        super(BertKARE, self).__init__()

        self.input_dim = config.getint("model", "input_dim")
        self.output_dim = config.getint("model", "output_dim")
        self.bert = BertModel.from_pretrained(config.get("model", "bert_path"))
        self.fc = nn.Linear(768*2, self.output_dim)

        self.att_h = KnowledgeAttBert(768, self.input_dim // 2)
        self.att_t = KnowledgeAttBert(768, self.input_dim // 2)

        self.dropout = nn.Dropout(0.5)

        self.criterion = nn.CrossEntropyLoss()
        self.accuracy_function = init_accuracy_function(config, *args, **params)

    def init_multi_gpu(self, device, config, *args, **params):
        #self.bert = nn.DataParallel(self.bert, device_ids=device)
        pass

    def forward(self, data, config, gpu_list, acc_result, mode):
        x = data['input']
        ent_emb = data['ent_emb']

        y, _ = self.bert(x)
        y = y.transpose(1, 2)
        mask = data['input'] == 0
        h, _ = self.att_h(y, ent_emb[:, :self.input_dim // 2], mask)
        t, _ = self.att_t(y, ent_emb[:, self.input_dim // 2:], mask)
        y = torch.cat([h, t], 1)
        y = self.dropout(y)

        y = self.fc(y)
        y = y.view(y.size()[0], -1)

        if "label" in data.keys():
            label = data["label"]
            loss = self.criterion(y, label)
            acc_result = self.accuracy_function(y, label, config, acc_result)
            return {"loss": loss, "acc_result": acc_result}

        return {}
