import math
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from pytorch_transformers import BertModel

from tools.accuracy_init import init_accuracy_function

def gelu(x):
    """Implementation of the gelu activation function.
        For information: OpenAI GPT's gelu is slightly different (and gives slightly different results):
        0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))
    """
    return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0)))


class BasicET(nn.Module):
    def __init__(self, config, gpu_list, *args, **params):
        super(BasicET, self).__init__()

        self.input_dim = config.getint("model", "input_dim")
        self.output_dim = config.getint("model", "output_dim")
        self.fc = nn.Linear(self.input_dim, self.output_dim, False)

        self.criterion = nn.BCEWithLogitsLoss()
        self.accuracy_function = init_accuracy_function(config, *args, **params)

    def init_multi_gpu(self, device, config, *args, **params):
        #self.bert = nn.DataParallel(self.bert, device_ids=device)
        pass

    def forward(self, data, config, gpu_list, acc_result, mode):
        x = data['input']

        y = self.fc(x)
        y = y.view(y.size()[0], -1)

        if "label" in data.keys():
            label = data["label"]
            loss = self.criterion(y, label.float())
            acc_result = self.accuracy_function(y, label, config, acc_result)
            return {"loss": loss, "acc_result": acc_result}

        return {}

class CnnET(nn.Module):
    def __init__(self, config, gpu_list, *args, **params):
        super(CnnET, self).__init__()

        word_embedding_dim = 50
        self.word_embedding = nn.Embedding(400002, word_embedding_dim)
        word2vec = np.load('/data1/private/zzy/nre/data/pretrain/glove/word2vec.npy')
        word2vec = torch.from_numpy(word2vec)
        unk = torch.randn(1, word_embedding_dim) / math.sqrt(word_embedding_dim)
        blk = torch.zeros(1, word_embedding_dim)
        self.word_embedding.weight.data.copy_(torch.cat([word2vec, unk, blk], 0))

        self.word_embedding.weight.requires_grad = False

        max_length = config.getint("data", "max_seq_length")
        self.max_len = max_length
        pos_embedding_dim = 10
        kernel_size = 3
        padding = 1
        hidden_size = 100

        # Position Embedding
        self.pos1_embedding = nn.Embedding(2 * max_length, pos_embedding_dim, padding_idx=0)

        self.embedding_dim = word_embedding_dim + pos_embedding_dim 
        # CNN
        self.conv = nn.Conv1d(self.embedding_dim, hidden_size, kernel_size, padding=padding)

        self.drop = nn.Dropout(0.5)
        self.mention_dropout = nn.Dropout(0.5)

        self.output_dim = config.getint("model", "output_dim")
        self.fc = nn.Linear(hidden_size+word_embedding_dim, self.output_dim, False)

        self.criterion = nn.BCEWithLogitsLoss()
        self.accuracy_function = init_accuracy_function(config, *args, **params)

    def init_multi_gpu(self, device, config, *args, **params):
        #self.bert = nn.DataParallel(self.bert, device_ids=device)
        pass

    def forward(self, data, config, gpu_list, acc_result, mode):
        token = data['token']
        pos1 = data['pos1']

        x = torch.cat([self.word_embedding(token), 
                    self.pos1_embedding(pos1)], 2) # (B, L, EMBED)
        x = x.transpose(1, 2) # (B, EMBED, L)
        x = self.conv(x) # (B, H, L)
        x = torch.relu(x) # (B, H, L)
        x, _ = x.max(-1) # (B, H)
        x = self.drop(x)

        words = self.word_embedding(token)
        words[pos1!=self.max_len, :] = 0

        mention_embed = self.mention_dropout(words.sum(1))
        output = torch.cat((x, mention_embed), 1)

        y = self.fc(output)
        y = y.view(y.size()[0], -1)

        if "label" in data.keys():
            label = data["label"]
            loss = self.criterion(y, label.float())
            acc_result = self.accuracy_function(y, label, config, acc_result)
            return {"loss": loss, "acc_result": acc_result}

        return {}


class CnnKFET(nn.Module):
    def __init__(self, config, gpu_list, *args, **params):
        super(CnnKFET, self).__init__()

        word_embedding_dim = 50
        self.word_embedding = nn.Embedding(400002, word_embedding_dim)
        word2vec = np.load('/data1/private/zzy/nre/data/pretrain/glove/word2vec.npy')
        word2vec = torch.from_numpy(word2vec)
        unk = torch.randn(1, word_embedding_dim) / math.sqrt(word_embedding_dim)
        blk = torch.zeros(1, word_embedding_dim)
        self.word_embedding.weight.data.copy_(torch.cat([word2vec, unk, blk], 0))

        self.word_embedding.weight.requires_grad = False

        max_length = config.getint("data", "max_seq_length")
        self.max_len = max_length
        pos_embedding_dim = 10
        kernel_size = 3
        padding = 1
        hidden_size = 100

        # Position Embedding
        self.pos1_embedding = nn.Embedding(2 * max_length, pos_embedding_dim, padding_idx=0)

        self.embedding_dim = word_embedding_dim + pos_embedding_dim
        # CNN
        self.conv = nn.Conv1d(self.embedding_dim, hidden_size, kernel_size, padding=padding)

        self.mention_dropout = nn.Dropout(0.5)
        self.drop = nn.Dropout(0.5)

        self.input_dim = config.getint("model", "input_dim")
        self.output_dim = config.getint("model", "output_dim")
        self.fc = nn.Linear(hidden_size+self.input_dim+word_embedding_dim, self.output_dim)

        self.criterion = nn.BCEWithLogitsLoss()
        self.accuracy_function = init_accuracy_function(config, *args, **params)

    def init_multi_gpu(self, device, config, *args, **params):
        #self.bert = nn.DataParallel(self.bert, device_ids=device)
        pass

    def forward(self, data, config, gpu_list, acc_result, mode):
        token = data['token']
        pos1 = data['pos1']
        ent_emb = data['ent_emb']

        x = torch.cat([self.word_embedding(token), 
                    self.pos1_embedding(pos1)], 2) # (B, L, EMBED)
        x = x.transpose(1, 2) # (B, EMBED, L)
        x = self.conv(x) # (B, H, L)
        x = torch.relu(x) # (B, H, L)
        x, _ = x.max(-1) # (B, H)
        x = self.drop(x)

        words = self.word_embedding(token)
        words[pos1!=self.max_len, :] = 0

        mention_embed = self.mention_dropout(words.sum(1))
        output = torch.cat((x, mention_embed), 1)

        x = torch.cat([output, ent_emb], 1)

        y = self.fc(x)
        y = y.view(y.size()[0], -1)

        if "label" in data.keys():
            label = data["label"]
            loss = self.criterion(y, label.float())
            acc_result = self.accuracy_function(y, label, config, acc_result)
            return {"loss": loss, "acc_result": acc_result}

        return {}

class KnowledgeAtt(nn.Module):

    def __init__(self, hidden_dim, ent_dim):
        super().__init__()
        self.hidden_dim = hidden_dim
        self.A = nn.Linear(ent_dim, hidden_dim, bias=True)
        self.linear = nn.Linear(self.hidden_dim, self.hidden_dim, bias=True)
        self.dropout = nn.Dropout(0.5)
    
    def forward(self, x, ke, mask):
        # x (B, H, L)
        # ke (B, E)
        x = x.permute(0, 2, 1).contiguous()
        alpha = x * self.A(self.dropout(ke)).unsqueeze(-2)
        # alpha = x * self.A(ke).unsqueeze(-2)
        alpha = alpha.sum(-1)
        mask_ = torch.zeros_like(alpha)
        mask_[mask] = -10000
        alpha = alpha + mask_
        scores = nn.Softmax(dim=-1)(alpha)
        # x = self.linear(x) * scores.unsqueeze(-1)
        x = x * scores.unsqueeze(-1)
        return torch.relu(self.linear(x.sum(-2))), scores



class CnnKAET(nn.Module):
    def __init__(self, config, gpu_list, *args, **params):
        super(CnnKAET, self).__init__()

        word_embedding_dim = 50
        self.word_embedding = nn.Embedding(400002, word_embedding_dim)
        word2vec = np.load('/data1/private/zzy/nre/data/pretrain/glove/word2vec.npy')
        word2vec = torch.from_numpy(word2vec)
        unk = torch.randn(1, word_embedding_dim) / math.sqrt(word_embedding_dim)
        blk = torch.zeros(1, word_embedding_dim)
        self.word_embedding.weight.data.copy_(torch.cat([word2vec, unk, blk], 0))

        self.word_embedding.weight.requires_grad = False

        max_length = config.getint("data", "max_seq_length")
        self.max_len = max_length
        pos_embedding_dim = 10
        kernel_size = 3
        padding = 1
        hidden_size = 100

        # Position Embedding
        self.pos1_embedding = nn.Embedding(2 * max_length, pos_embedding_dim, padding_idx=0)

        self.embedding_dim = word_embedding_dim + pos_embedding_dim
        # CNN
        self.conv = nn.Conv1d(self.embedding_dim, hidden_size, kernel_size, padding=padding)

        self.drop = nn.Dropout(0.5)
        self.mention_dropout = nn.Dropout(0.5)

        self.input_dim = config.getint("model", "input_dim")
        self.output_dim = config.getint("model", "output_dim")

        self.att = KnowledgeAtt(hidden_size, self.input_dim)
        self.fc = nn.Linear(hidden_size+word_embedding_dim, self.output_dim)

        self.criterion = nn.BCEWithLogitsLoss()
        self.accuracy_function = init_accuracy_function(config, *args, **params)

    def init_multi_gpu(self, device, config, *args, **params):
        #self.bert = nn.DataParallel(self.bert, device_ids=device)
        pass

    def forward(self, data, config, gpu_list, acc_result, mode):
        token = data['token']
        pos1 = data['pos1']
        ent_emb = data['ent_emb']

        x = torch.cat([self.word_embedding(token), 
                    self.pos1_embedding(pos1)], 2) # (B, L, EMBED)
        x = x.transpose(1, 2) # (B, EMBED, L)
        x = self.conv(x) # (B, H, L)
        x = torch.relu(x) # (B, H, L)
        mask = token == 400001
        x, _ = self.att(x, ent_emb, mask)
        x = self.drop(x)

        words = self.word_embedding(token)
        words[pos1!=self.max_len, :] = 0

        mention_embed = self.mention_dropout(words.sum(1))
        output = torch.cat((x, mention_embed), 1)

        y = self.fc(output)
        y = y.view(y.size()[0], -1)

        if "label" in data.keys():
            label = data["label"]
            loss = self.criterion(y, label.float())
            acc_result = self.accuracy_function(y, label, config, acc_result)
            return {"loss": loss, "acc_result": acc_result}

        return {}

class LstmET(nn.Module):
    def __init__(self, config, gpu_list, *args, **params):
        super(LstmET, self).__init__()

        word_embedding_dim = 50
        self.word_embedding = nn.Embedding(400002, word_embedding_dim)
        word2vec = np.load('/data1/private/zzy/nre/data/pretrain/glove/word2vec.npy')
        word2vec = torch.from_numpy(word2vec)
        unk = torch.randn(1, word_embedding_dim) / math.sqrt(word_embedding_dim)
        blk = torch.zeros(1, word_embedding_dim)
        self.word_embedding.weight.data.copy_(torch.cat([word2vec, unk, blk], 0))

        self.word_embedding.weight.requires_grad = False

        max_length = config.getint("data", "max_seq_length")
        self.max_len = max_length
        pos_embedding_dim = 10
        hidden_size = 100

        # Position Embedding
        self.pos1_embedding = nn.Embedding(2 * max_length, pos_embedding_dim, padding_idx=0)

        self.embedding_dim = word_embedding_dim + pos_embedding_dim
        # LSTM
        self.lstm = nn.LSTM(self.embedding_dim, hidden_size // 2, bidirectional=True)

        self.mention_dropout = nn.Dropout(0.5)
        self.drop = nn.Dropout(0.5)

        self.output_dim = config.getint("model", "output_dim")
        self.fc = nn.Linear(hidden_size+word_embedding_dim, self.output_dim, False)

        self.criterion = nn.BCEWithLogitsLoss()
        self.accuracy_function = init_accuracy_function(config, *args, **params)

    def init_multi_gpu(self, device, config, *args, **params):
        #self.bert = nn.DataParallel(self.bert, device_ids=device)
        pass

    def forward(self, data, config, gpu_list, acc_result, mode):
        token = data['token']
        pos1 = data['pos1']

        x = torch.cat([self.word_embedding(token), 
                    self.pos1_embedding(pos1)], 2) # (B, L, EMBED)
        x = x.transpose(1, 0) # (L, B, EMBED)
        x, _ = self.lstm(x) # (L, B, H)
        x = torch.relu(x)
        x = x.transpose(1, 0)
        x, _ = x.max(1)
        x = self.drop(x)

        words = self.word_embedding(token)
        words[pos1!=self.max_len, :] = 0

        mention_embed = self.mention_dropout(words.sum(1))
        output = torch.cat((x, mention_embed), 1)

        y = self.fc(output)
        y = y.view(y.size()[0], -1)

        if "label" in data.keys():
            label = data["label"]
            loss = self.criterion(y, label.float())
            acc_result = self.accuracy_function(y, label, config, acc_result)
            return {"loss": loss, "acc_result": acc_result}

        return {}


class LstmKFET(nn.Module):
    def __init__(self, config, gpu_list, *args, **params):
        super(LstmKFET, self).__init__()

        word_embedding_dim = 50
        self.word_embedding = nn.Embedding(400002, word_embedding_dim)
        word2vec = np.load('/data1/private/zzy/nre/data/pretrain/glove/word2vec.npy')
        word2vec = torch.from_numpy(word2vec)
        unk = torch.randn(1, word_embedding_dim) / math.sqrt(word_embedding_dim)
        blk = torch.zeros(1, word_embedding_dim)
        self.word_embedding.weight.data.copy_(torch.cat([word2vec, unk, blk], 0))

        self.word_embedding.weight.requires_grad = False

        max_length = config.getint("data", "max_seq_length")
        self.max_len = max_length
        pos_embedding_dim = 10
        hidden_size = 100

        # Position Embedding
        self.pos1_embedding = nn.Embedding(2 * max_length, pos_embedding_dim, padding_idx=0)

        self.embedding_dim = word_embedding_dim + pos_embedding_dim
        # LSTM
        self.lstm = nn.LSTM(self.embedding_dim, hidden_size // 2, bidirectional=True)

        self.mention_dropout = nn.Dropout(0.5)
        self.drop = nn.Dropout(0.5)

        self.input_dim = config.getint("model", "input_dim")
        self.output_dim = config.getint("model", "output_dim")
        self.fc = nn.Linear(hidden_size+word_embedding_dim+self.input_dim, self.output_dim, False)

        self.criterion = nn.BCEWithLogitsLoss()
        self.accuracy_function = init_accuracy_function(config, *args, **params)

    def init_multi_gpu(self, device, config, *args, **params):
        #self.bert = nn.DataParallel(self.bert, device_ids=device)
        pass

    def forward(self, data, config, gpu_list, acc_result, mode):
        token = data['token']
        pos1 = data['pos1']
        ent_emb = data['ent_emb']

        x = torch.cat([self.word_embedding(token), 
                    self.pos1_embedding(pos1)], 2) # (B, L, EMBED)
        x = x.transpose(1, 0) # (L, B, EMBED)
        x, _ = self.lstm(x) # (L, B, H)
        x = torch.relu(x) # (L, B, H)
        x = x.transpose(1, 0)
        x, _ = x.max(1) # (B, H)
        x = self.drop(x)

        words = self.word_embedding(token)
        words[pos1!=self.max_len, :] = 0

        mention_embed = self.mention_dropout(words.sum(1))
        output = torch.cat((x, mention_embed), 1)

        x = torch.cat([output, ent_emb], 1)

        y = self.fc(x)
        y = y.view(y.size()[0], -1)

        if "label" in data.keys():
            label = data["label"]
            loss = self.criterion(y, label.float())
            acc_result = self.accuracy_function(y, label, config, acc_result)
            return {"loss": loss, "acc_result": acc_result}

        return {}


class LstmKAET(nn.Module):
    def __init__(self, config, gpu_list, *args, **params):
        super(LstmKAET, self).__init__()

        word_embedding_dim = 50
        self.word_embedding = nn.Embedding(400002, word_embedding_dim)
        word2vec = np.load('/data1/private/zzy/nre/data/pretrain/glove/word2vec.npy')
        word2vec = torch.from_numpy(word2vec)
        unk = torch.randn(1, word_embedding_dim) / math.sqrt(word_embedding_dim)
        blk = torch.zeros(1, word_embedding_dim)
        self.word_embedding.weight.data.copy_(torch.cat([word2vec, unk, blk], 0))

        self.word_embedding.weight.requires_grad = False

        max_length = config.getint("data", "max_seq_length")
        pos_embedding_dim = 10
        hidden_size = 100
        self.max_len = max_length

        # Position Embedding
        self.pos1_embedding = nn.Embedding(2 * max_length, pos_embedding_dim, padding_idx=0)

        self.embedding_dim = word_embedding_dim + pos_embedding_dim
        # LSTM
        self.lstm = nn.LSTM(self.embedding_dim, hidden_size // 2, bidirectional=True)

        self.drop = nn.Dropout(0.5)
        self.mention_dropout = nn.Dropout(0.5)

        self.input_dim = config.getint("model", "input_dim")
        self.output_dim = config.getint("model", "output_dim")

        self.att = KnowledgeAtt(hidden_size, self.input_dim)
        self.fc = nn.Linear(hidden_size+word_embedding_dim, self.output_dim)

        self.criterion = nn.BCEWithLogitsLoss()
        self.accuracy_function = init_accuracy_function(config, *args, **params)

    def init_multi_gpu(self, device, config, *args, **params):
        #self.bert = nn.DataParallel(self.bert, device_ids=device)
        pass

    def forward(self, data, config, gpu_list, acc_result, mode):
        token = data['token']
        pos1 = data['pos1']
        ent_emb = data['ent_emb']

        x = torch.cat([self.word_embedding(token), 
                    self.pos1_embedding(pos1)], 2) # (B, L, EMBED)
        x = x.transpose(1, 0) # (L, B, EMBED)
        x, _ = self.lstm(x) # (L, B, H)
        x = torch.relu(x) # (L, B, H)
        x = x.permute(1, 2, 0) # (B, H, L)
        mask = token == 400001
        x, _ = self.att(x, ent_emb, mask)
        x = self.drop(x)

        words = self.word_embedding(token)
        words[pos1!=self.max_len, :] = 0

        mention_embed = self.mention_dropout(words.sum(1))
        output = torch.cat((x, mention_embed), 1)

        y = self.fc(output)
        y = y.view(y.size()[0], -1)

        if "label" in data.keys():
            label = data["label"]
            loss = self.criterion(y, label.float())
            acc_result = self.accuracy_function(y, label, config, acc_result)
            return {"loss": loss, "acc_result": acc_result}

        return {}


class BertET(nn.Module):
    def __init__(self, config, gpu_list, *args, **params):
        super(BertET, self).__init__()

        self.output_dim = config.getint("model", "output_dim")
        self.bert = BertModel.from_pretrained(config.get("model", "bert_path"))
        self.fc = nn.Linear(768, self.output_dim)

        self.criterion = nn.BCEWithLogitsLoss()
        self.accuracy_function = init_accuracy_function(config, *args, **params)

    def init_multi_gpu(self, device, config, *args, **params):
        #self.bert = nn.DataParallel(self.bert, device_ids=device)
        pass

    def forward(self, data, config, gpu_list, acc_result, mode):
        x = data['input']

        _, y = self.bert(x)
        y = y.view(y.size()[0], -1)
        y = self.fc(y)
        y = y.view(y.size()[0], -1)

        if "label" in data.keys():
            label = data["label"]
            loss = self.criterion(y, label.float())
            acc_result = self.accuracy_function(y, label, config, acc_result)
            return {"loss": loss, "acc_result": acc_result}

        return {}


class BertKFET(nn.Module):
    def __init__(self, config, gpu_list, *args, **params):
        super(BertKFET, self).__init__()

        self.input_dim = config.getint("model", "input_dim")
        self.output_dim = config.getint("model", "output_dim")
        self.bert = BertModel.from_pretrained(config.get("model", "bert_path"))
        self.fc = nn.Linear(768+self.input_dim, self.output_dim)

        self.dropout = nn.Dropout(0.5)
        self.fc_ent = nn.Linear(self.input_dim, self.input_dim)

        self.criterion = nn.BCEWithLogitsLoss()
        self.accuracy_function = init_accuracy_function(config, *args, **params)

    def init_multi_gpu(self, device, config, *args, **params):
        #self.bert = nn.DataParallel(self.bert, device_ids=device)
        pass

    def forward(self, data, config, gpu_list, acc_result, mode):
        x = data['input']
        ent_emb = data['ent_emb']

        _, y = self.bert(x)
        y = y.view(y.size()[0], -1)
        y = torch.cat([y, gelu(self.fc_ent(ent_emb))], 1)
        y = self.dropout(y)
        y = self.fc(y)
        y = y.view(y.size()[0], -1)

        if "label" in data.keys():
            label = data["label"]
            loss = self.criterion(y, label.float())
            acc_result = self.accuracy_function(y, label, config, acc_result)
            return {"loss": loss, "acc_result": acc_result}

        return {}

class KnowledgeAttBert(nn.Module):

    def __init__(self, hidden_dim, ent_dim):
        super().__init__()
        self.hidden_dim = hidden_dim
        self.A = nn.Linear(ent_dim, hidden_dim, bias=True)
        self.dropout = nn.Dropout(0.5)
    
    def forward(self, x, ke, mask):
        # x (B, H, L)
        # ke (B, E)
        x = x.permute(0, 2, 1).contiguous()
        alpha = x * self.A(self.dropout(ke)).unsqueeze(-2)
        # alpha = x * self.A(ke).unsqueeze(-2)
        alpha = alpha.sum(-1)
        mask_ = torch.zeros_like(alpha)
        mask_[mask] = -10000
        alpha = alpha + mask_
        scores = nn.Softmax(dim=-1)(alpha)
        # x = self.linear(x) * scores.unsqueeze(-1)
        x = x * scores.unsqueeze(-1)
        return x.sum(-2), scores

class BertKAET(nn.Module):
    def __init__(self, config, gpu_list, *args, **params):
        super(BertKAET, self).__init__()

        self.input_dim = config.getint("model", "input_dim")
        self.output_dim = config.getint("model", "output_dim")
        self.bert = BertModel.from_pretrained(config.get("model", "bert_path"))
        self.fc = nn.Linear(768, self.output_dim)

        self.att = KnowledgeAttBert(768, self.input_dim)

        self.dropout = nn.Dropout(0.5)

        self.criterion = nn.BCEWithLogitsLoss()
        self.accuracy_function = init_accuracy_function(config, *args, **params)

    def init_multi_gpu(self, device, config, *args, **params):
        #self.bert = nn.DataParallel(self.bert, device_ids=device)
        pass

    def forward(self, data, config, gpu_list, acc_result, mode):
        x = data['input']
        ent_emb = data['ent_emb']

        y, _ = self.bert(x)
        y = y.transpose(1, 2)
        mask = data['input'] == 0
        y, _ = self.att(y, ent_emb, mask)
        y = self.dropout(y)

        y = self.fc(y)
        y = y.view(y.size()[0], -1)

        if "label" in data.keys():
            label = data["label"]
            loss = self.criterion(y, label.float())
            acc_result = self.accuracy_function(y, label, config, acc_result)
            return {"loss": loss, "acc_result": acc_result}

        return {}
