import math
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from typing import Dict
from pytorch_transformers import BertModel

from tools.accuracy_init import init_accuracy_function

def gelu(x):
    """Implementation of the gelu activation function.
        For information: OpenAI GPT's gelu is slightly different (and gives slightly different results):
        0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))
    """
    return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0)))

class KernelMatcher(nn.Module):
    def __init__(
        self,
        embed_dim: int,
        kernel_num: int = 21
    ) -> None:
        super(KernelMatcher, self).__init__()
        self._embed_dim = embed_dim
        self._kernel_num = kernel_num
        self._kernel_params = self.kernel_init(self._kernel_num)
        if torch.cuda.is_available():
            for kernel_param in self._kernel_params:
                self._kernel_params[kernel_param] = self._kernel_params[kernel_param].clone().detach().to('cuda')

    def kernel_init(self, kernel_num: int) -> Dict[str, torch.Tensor]:
        params = {'mus': [1], 'sigmas': [0.001]}

        bin_size = 2.0/(kernel_num-1)
        params['mus'].append(1-bin_size/2)
        for i in range(1, kernel_num-1):
            params['mus'].append(params['mus'][i]-bin_size)
        params['mus'] = torch.Tensor(params['mus']).view(1, 1, 1, kernel_num)

        params['sigmas'] += [0.1]*(kernel_num-1)
        params['sigmas'] = torch.Tensor(params['sigmas']).view(1, 1, 1, kernel_num)
        return params

    def forward(self, k_embed: torch.Tensor, k_mask: torch.Tensor, v_embed: torch.Tensor, v_mask: torch.Tensor) -> torch.Tensor:
        k_embed = k_embed * k_mask.unsqueeze(-1)
        v_embed = v_embed * v_mask.unsqueeze(-1)
        k_by_v_mask = torch.bmm(k_mask.float().unsqueeze(-1), v_mask.float().unsqueeze(-1).transpose(1, 2))
        k_norm = k_embed / (k_embed.norm(p=2, dim=-1, keepdim=True) + 1e-13)
        v_norm = v_embed / (v_embed.norm(p=2, dim=-1, keepdim=True) + 1e-13)
        inter = (torch.bmm(k_norm, v_norm.transpose(1, 2)) * k_by_v_mask).unsqueeze(-1)

        kernel_outputs = torch.exp((-((inter-self._kernel_params['mus'])**2)/(self._kernel_params['sigmas']**2)/2))
        kernel_outputs = kernel_outputs.sum(dim=2).clamp(min=1e-10).log()
        logits = kernel_outputs.sum(dim=1)
        return logits

class BasicIR(nn.Module):
    def __init__(self, config, gpu_list, *args, **params):
        super(BasicIR, self).__init__()

        self.input_dim = config.getint("model", "input_dim")
        self.matcher = KernelMatcher(self.input_dim)
        self.fc = nn.Linear(21, 1)

        self.criterion = nn.MarginRankingLoss(margin=1, reduction='mean')
        self.accuracy_function = init_accuracy_function(config, *args, **params)

    def init_multi_gpu(self, device, config, *args, **params):
        #self.bert = nn.DataParallel(self.bert, device_ids=device)
        pass

    def get_rep(self, x):
        x = torch.relu(self.fc(x))
        return x.sum(1)

    def forward(self, data, config, gpu_list, acc_result, mode):
        if mode == "train":
            query = data['query']
            doc_pos = data['doc_pos']
            doc_neg = data['doc_neg']
            '''
            x_query = self.get_rep(query)
            x_doc_pos = self.get_rep(doc_pos)
            x_doc_neg = self.get_rep(doc_neg)
            '''

            pos_scores = self.matcher(query, (abs(query.sum(-1))>1e-6).float(), doc_pos, (abs(doc_pos.sum(-1))>1e-6).float())
            neg_scores = self.matcher(query, (abs(query.sum(-1))>1e-6).float(), doc_neg, (abs(doc_neg.sum(-1))>1e-6).float())

            pos_scores = self.fc(pos_scores)
            neg_scores = self.fc(neg_scores)
            

            '''
            pos_scores = (x_query * x_doc_pos).sum(-1) #/ torch.norm(x_query, dim=-1) / torch.norm(x_doc_pos, dim=-1)
            neg_scores = (x_query * x_doc_neg).sum(-1) #/ torch.norm(x_query, dim=-1) / torch.norm(x_doc_neg, dim=-1)
            '''

            loss = self.criterion(pos_scores, neg_scores, torch.ones_like(pos_scores))
            acc_result = self.accuracy_function(pos_scores, neg_scores, config, acc_result)

            return {"loss": loss, "acc_result": acc_result}
        else:
            query = data['query']
            doc_pos = data['doc_pos']
            #x_query = self.get_rep(query)
            #x_doc_pos = self.get_rep(doc_pos)
            pos_scores = self.matcher(query, (abs(query.sum(-1))>1e-6).float(), doc_pos, (abs(doc_pos.sum(-1))>1e-6).float())
            pos_scores = self.fc(pos_scores)
            #pos_scores = (x_query * x_doc_pos).sum(-1) / torch.norm(x_query, dim=-1) / torch.norm(x_doc_pos, dim=-1)

            return {"score": pos_scores, "qid": data["qid"], "did": data["did"], 'label': data['label']}

class CnnIR(nn.Module):
    def __init__(self, config, gpu_list, *args, **params):
        super(CnnIR, self).__init__()

        word_embedding_dim = 50
        self.word_embedding = nn.Embedding(400002, word_embedding_dim)
        word2vec = np.load('/data1/private/zzy/nre/data/pretrain/glove/word2vec.npy')
        word2vec = torch.from_numpy(word2vec)
        unk = torch.randn(1, word_embedding_dim) / math.sqrt(word_embedding_dim)
        blk = torch.zeros(1, word_embedding_dim)
        self.word_embedding.weight.data.copy_(torch.cat([word2vec, unk, blk], 0))

        self.word_embedding.weight.requires_grad = False

        max_length = config.getint("data", "max_seq_length")
        self.max_len = max_length
        kernel_size = 3
        padding = 1
        hidden_size = 100

        self.embedding_dim = word_embedding_dim
        # CNN
        self.conv = nn.Conv1d(self.embedding_dim, hidden_size, kernel_size, padding=padding)

        self.matcher = KernelMatcher(hidden_size)
        self.fc = nn.Linear(21, 1)

        self.criterion = nn.MarginRankingLoss(margin=1, reduction='mean')
        self.accuracy_function = init_accuracy_function(config, *args, **params)

    def init_multi_gpu(self, device, config, *args, **params):
        #self.bert = nn.DataParallel(self.bert, device_ids=device)
        pass

    def get_rep(self, ids):
        x = self.word_embedding(ids)
        x = x.transpose(1, 2)
        x = self.conv(x)
        x = torch.sigmoid(x)
        x = x.transpose(1, 2)
        #x = x * (ids != 400001).float().unsqueeze(-1)
        #x = x.sum(1) / (ids != 400001).sum(-1).float().unsqueeze(-1)
        return x

    def forward(self, data, config, gpu_list, acc_result, mode):
        if mode == "train":
            query = data['query']
            doc_pos = data['doc_pos']
            doc_neg = data['doc_neg']

            x_query = self.get_rep(query)
            x_doc_pos = self.get_rep(doc_pos)
            x_doc_neg = self.get_rep(doc_neg)

            pos_scores = self.matcher(x_query, (query != 400001).float(), x_doc_pos, (doc_pos != 400001).float())
            neg_scores = self.matcher(x_query, (query != 400001).float(), x_doc_neg, (doc_neg != 400001).float())

            pos_scores = self.fc(pos_scores)
            neg_scores = self.fc(neg_scores)

            #pos_scores = (x_query * x_doc_pos).sum(-1) / torch.norm(x_query, dim=-1) / torch.norm(x_doc_pos, dim=-1)
            #neg_scores = (x_query * x_doc_neg).sum(-1) / torch.norm(x_query, dim=-1) / torch.norm(x_doc_neg, dim=-1)

            loss = self.criterion(pos_scores, neg_scores, torch.ones_like(pos_scores))
            acc_result = self.accuracy_function(pos_scores, neg_scores, config, acc_result)

            return {"loss": loss, "acc_result": acc_result}
        else:
            query = data['query']
            doc_pos = data['doc_pos']
            x_query = self.get_rep(query)
            x_doc_pos = self.get_rep(doc_pos)

            pos_scores = self.matcher(x_query, (query != 400001).float(), x_doc_pos, (doc_pos != 400001).float())
            pos_scores = self.fc(pos_scores)

            return {"score": pos_scores, "qid": data["qid"], "did": data["did"], "label": data["label"]}


class CnnKFIR(nn.Module):
    def __init__(self, config, gpu_list, *args, **params):
        super(CnnKFIR, self).__init__()

        word_embedding_dim = 50
        self.word_embedding = nn.Embedding(400002, word_embedding_dim)
        word2vec = np.load('/data1/private/zzy/nre/data/pretrain/glove/word2vec.npy')
        word2vec = torch.from_numpy(word2vec)
        unk = torch.randn(1, word_embedding_dim) / math.sqrt(word_embedding_dim)
        blk = torch.zeros(1, word_embedding_dim)
        self.word_embedding.weight.data.copy_(torch.cat([word2vec, unk, blk], 0))

        self.word_embedding.weight.requires_grad = False

        max_length = config.getint("data", "max_seq_length")
        self.max_len = max_length
        pos_embedding_dim = 10
        kernel_size = 3
        padding = 1
        hidden_size = 100

        self.embedding_dim = word_embedding_dim
        # CNN
        self.conv = nn.Conv1d(self.embedding_dim, hidden_size, kernel_size, padding=padding)

        self.input_dim = config.getint("model", "input_dim")
        self.trans = nn.Linear(self.input_dim, hidden_size)

        self.matcher = KernelMatcher(hidden_size)
        self.fc = nn.Linear(21*4, 1)

        self.criterion = nn.MarginRankingLoss(margin=1, reduction='mean')
        self.accuracy_function = init_accuracy_function(config, *args, **params)

    def init_multi_gpu(self, device, config, *args, **params):
        #self.bert = nn.DataParallel(self.bert, device_ids=device)
        pass

    def get_rep(self, ids):
        x = self.word_embedding(ids)
        x = x.transpose(1, 2)
        x = self.conv(x)
        x = torch.sigmoid(x)
        x = x.transpose(1, 2)
        #x = x * (ids != 400001).float().unsqueeze(-1)
        #x = x.sum(1) / (ids != 400001).sum(-1).float().unsqueeze(-1)
        return x

    def forward(self, data, config, gpu_list, acc_result, mode):
        if mode == "train":
            query = data['query']
            doc_pos = data['doc_pos']
            doc_neg = data['doc_neg']

            x_query = self.get_rep(query)
            x_doc_pos = self.get_rep(doc_pos)
            x_doc_neg = self.get_rep(doc_neg)

            query_ent = data['query_ent']
            query_mask_ent = (abs(query_ent.sum(-1))>1e-6).float()
            query_ent = torch.sigmoid(self.trans(query_ent))
            doc_pos_ent = data['doc_pos_ent']
            doc_pos_mask_ent = (abs(doc_pos_ent.sum(-1))>1e-6).float()
            doc_pos_ent = torch.sigmoid(self.trans(doc_pos_ent))
            doc_neg_ent = data['doc_neg_ent']
            doc_neg_mask_ent = (abs(doc_neg_ent.sum(-1))>1e-6).float()
            doc_neg_ent = torch.sigmoid(self.trans(doc_neg_ent))

            #x_query = torch.cat([x_query, query_ent], 1)
            query_mask = (query != 400001).float()
            #x_doc_pos = torch.cat([x_doc_pos, doc_pos_ent], 1)
            doc_pos_mask = (doc_pos != 400001).float()
            #x_doc_neg = torch.cat([x_doc_neg, doc_neg_ent], 1)
            doc_neg_mask = (doc_neg != 400001).float()

            pos_scores = torch.cat([self.matcher(x_query, query_mask, x_doc_pos, doc_pos_mask),
                self.matcher(x_query, query_mask, doc_pos_ent, doc_pos_mask_ent),
                self.matcher(query_ent, query_mask_ent, x_doc_pos, doc_pos_mask),
                self.matcher(query_ent, query_mask_ent, doc_pos_ent, doc_pos_mask_ent)], -1)
            neg_scores = torch.cat([self.matcher(x_query, query_mask, x_doc_neg, doc_neg_mask),
                self.matcher(x_query, query_mask, doc_neg_ent, doc_neg_mask_ent),
                self.matcher(query_ent, query_mask_ent, x_doc_neg, doc_neg_mask),
                self.matcher(query_ent, query_mask_ent, doc_neg_ent, doc_neg_mask_ent)], -1)
           

            pos_scores = self.fc(pos_scores)
            neg_scores = self.fc(neg_scores)

            #pos_scores = (x_query * x_doc_pos).sum(-1) / torch.norm(x_query, dim=-1) / torch.norm(x_doc_pos, dim=-1)
            #neg_scores = (x_query * x_doc_neg).sum(-1) / torch.norm(x_query, dim=-1) / torch.norm(x_doc_neg, dim=-1)

            loss = self.criterion(pos_scores, neg_scores, torch.ones_like(pos_scores))
            acc_result = self.accuracy_function(pos_scores, neg_scores, config, acc_result)

            return {"loss": loss, "acc_result": acc_result}
        else:
            query = data['query']
            doc_pos = data['doc_pos']
            x_query = self.get_rep(query)
            x_doc_pos = self.get_rep(doc_pos)

            query_ent = data['query_ent']
            query_mask_ent = (abs(query_ent.sum(-1))>1e-6).float()
            query_ent = torch.sigmoid(self.trans(query_ent))
            doc_pos_ent = data['doc_pos_ent']
            doc_pos_mask_ent = (abs(doc_pos_ent.sum(-1))>1e-6).float()
            doc_pos_ent = torch.sigmoid(self.trans(doc_pos_ent))

            #x_query = torch.cat([x_query, query_ent], 1)
            query_mask = (query != 400001).float()
            #x_doc_pos = torch.cat([x_doc_pos, doc_pos_ent], 1)
            doc_pos_mask = (doc_pos != 400001).float()

            pos_scores = torch.cat([self.matcher(x_query, query_mask, x_doc_pos, doc_pos_mask),
                self.matcher(x_query, query_mask, doc_pos_ent, doc_pos_mask_ent),
                self.matcher(query_ent, query_mask_ent, x_doc_pos, doc_pos_mask),
                self.matcher(query_ent, query_mask_ent, doc_pos_ent, doc_pos_mask_ent)], -1)
            pos_scores = self.fc(pos_scores)

            return {"score": pos_scores, "qid": data["qid"], "did": data["did"], 'label': data['label']}

class KnowledgeAtt(nn.Module):

    def __init__(self, hidden_dim, ent_dim):
        super().__init__()
        self.hidden_dim = hidden_dim
        self.A = nn.Linear(ent_dim, hidden_dim, bias=True)
        self.linear = nn.Linear(self.hidden_dim, self.hidden_dim, bias=True)
    
    def forward(self, x, ke, mask):
        # x (B, L, H)
        # ke (B, L', E)
        # x (B, L)
        alpha = torch.bmm(x, self.A(ke).transpose(2, 1)) # (B, L, L')
        mask_ = torch.zeros_like(alpha)
        mask_[mask, :] = -10000
        alpha = alpha + mask_
        alpha = alpha.transpose(2, 1)
        scores = nn.Softmax(dim=-1)(alpha) # (B, L', L)
        return torch.bmm(scores, x)
        #return torch.relu(self.linear(x.sum(-2))), scores


class CnnKAIR(nn.Module):
    def __init__(self, config, gpu_list, *args, **params):
        super(CnnKAIR, self).__init__()

        word_embedding_dim = 50
        self.word_embedding = nn.Embedding(400002, word_embedding_dim)
        word2vec = np.load('/data1/private/zzy/nre/data/pretrain/glove/word2vec.npy')
        word2vec = torch.from_numpy(word2vec)
        unk = torch.randn(1, word_embedding_dim) / math.sqrt(word_embedding_dim)
        blk = torch.zeros(1, word_embedding_dim)
        self.word_embedding.weight.data.copy_(torch.cat([word2vec, unk, blk], 0))

        self.word_embedding.weight.requires_grad = False

        max_length = config.getint("data", "max_seq_length")
        self.max_len = max_length
        pos_embedding_dim = 10
        kernel_size = 3
        padding = 1
        hidden_size = 100

        self.embedding_dim = word_embedding_dim
        # CNN
        self.conv = nn.Conv1d(self.embedding_dim, hidden_size, kernel_size, padding=padding)

        self.input_dim = config.getint("model", "input_dim")
        self.att = KnowledgeAtt(hidden_size, self.input_dim)
        #self.trans = nn.Linear(self.input_dim, hidden_size)

        self.matcher = KernelMatcher(hidden_size)
        self.fc = nn.Linear(21*4, 1)

        self.criterion = nn.MarginRankingLoss(margin=1, reduction='mean')
        self.accuracy_function = init_accuracy_function(config, *args, **params)

    def init_multi_gpu(self, device, config, *args, **params):
        #self.bert = nn.DataParallel(self.bert, device_ids=device)
        pass

    def get_rep(self, ids):
        x = self.word_embedding(ids)
        x = x.transpose(1, 2)
        x = self.conv(x)
        x = torch.sigmoid(x)
        x = x.transpose(1, 2)
        #x = x * (ids != 400001).float().unsqueeze(-1)
        #x = x.sum(1) / (ids != 400001).sum(-1).float().unsqueeze(-1)
        return x

    def forward(self, data, config, gpu_list, acc_result, mode):
        if mode == "train":
            query = data['query']
            doc_pos = data['doc_pos']
            doc_neg = data['doc_neg']

            x_query = self.get_rep(query)
            x_doc_pos = self.get_rep(doc_pos)
            x_doc_neg = self.get_rep(doc_neg)

            #x_query = torch.cat([x_query, query_ent], 1)
            query_mask = (query != 400001).float()
            #x_doc_pos = torch.cat([x_doc_pos, doc_pos_ent], 1)
            doc_pos_mask = (doc_pos != 400001).float()
            #x_doc_neg = torch.cat([x_doc_neg, doc_neg_ent], 1)
            doc_neg_mask = (doc_neg != 400001).float()

            query_ent = data['query_ent']
            query_mask_ent = (abs(query_ent.sum(-1))>1e-6).float()
            query_ent = self.att(x_query, query_ent, query == 400001)
            doc_pos_ent = data['doc_pos_ent']
            doc_pos_mask_ent = (abs(doc_pos_ent.sum(-1))>1e-6).float()
            doc_pos_ent = self.att(x_doc_pos, doc_pos_ent, doc_pos == 400001)
            doc_neg_ent = data['doc_neg_ent']
            doc_neg_mask_ent = (abs(doc_neg_ent.sum(-1))>1e-6).float()
            doc_neg_ent = self.att(x_doc_neg, doc_neg_ent, doc_pos == 400001)

            pos_scores = torch.cat([self.matcher(x_query, query_mask, x_doc_pos, doc_pos_mask),
                self.matcher(x_query, query_mask, doc_pos_ent, doc_pos_mask_ent),
                self.matcher(query_ent, query_mask_ent, x_doc_pos, doc_pos_mask),
                self.matcher(query_ent, query_mask_ent, doc_pos_ent, doc_pos_mask_ent)], -1)
            neg_scores = torch.cat([self.matcher(x_query, query_mask, x_doc_neg, doc_neg_mask),
                self.matcher(x_query, query_mask, doc_neg_ent, doc_neg_mask_ent),
                self.matcher(query_ent, query_mask_ent, x_doc_neg, doc_neg_mask),
                self.matcher(query_ent, query_mask_ent, doc_neg_ent, doc_neg_mask_ent)], -1)
           

            pos_scores = self.fc(pos_scores)
            neg_scores = self.fc(neg_scores)

            #pos_scores = (x_query * x_doc_pos).sum(-1) / torch.norm(x_query, dim=-1) / torch.norm(x_doc_pos, dim=-1)
            #neg_scores = (x_query * x_doc_neg).sum(-1) / torch.norm(x_query, dim=-1) / torch.norm(x_doc_neg, dim=-1)

            loss = self.criterion(pos_scores, neg_scores, torch.ones_like(pos_scores))
            acc_result = self.accuracy_function(pos_scores, neg_scores, config, acc_result)

            return {"loss": loss, "acc_result": acc_result}
        else:
            query = data['query']
            doc_pos = data['doc_pos']
            x_query = self.get_rep(query)
            x_doc_pos = self.get_rep(doc_pos)

            #x_query = torch.cat([x_query, query_ent], 1)
            query_mask = (query != 400001).float()
            #x_doc_pos = torch.cat([x_doc_pos, doc_pos_ent], 1)
            doc_pos_mask = (doc_pos != 400001).float()

            query_ent = data['query_ent']
            query_mask_ent = (abs(query_ent.sum(-1))>1e-6).float()
            query_ent = self.att(x_query, query_ent, query == 400001)
            doc_pos_ent = data['doc_pos_ent']
            doc_pos_mask_ent = (abs(doc_pos_ent.sum(-1))>1e-6).float()
            doc_pos_ent = self.att(x_doc_pos, doc_pos_ent, doc_pos == 400001)


            pos_scores = torch.cat([self.matcher(x_query, query_mask, x_doc_pos, doc_pos_mask),
                self.matcher(x_query, query_mask, doc_pos_ent, doc_pos_mask_ent),
                self.matcher(query_ent, query_mask_ent, x_doc_pos, doc_pos_mask),
                self.matcher(query_ent, query_mask_ent, doc_pos_ent, doc_pos_mask_ent)], -1)
            pos_scores = self.fc(pos_scores)

            return {"score": pos_scores, "qid": data["qid"], "did": data["did"], 'label': data['label']}


class LstmIR(nn.Module):
    def __init__(self, config, gpu_list, *args, **params):
        super(LstmIR, self).__init__()

        word_embedding_dim = 50
        self.word_embedding = nn.Embedding(400002, word_embedding_dim)
        word2vec = np.load('/data1/private/zzy/nre/data/pretrain/glove/word2vec.npy')
        word2vec = torch.from_numpy(word2vec)
        unk = torch.randn(1, word_embedding_dim) / math.sqrt(word_embedding_dim)
        blk = torch.zeros(1, word_embedding_dim)
        self.word_embedding.weight.data.copy_(torch.cat([word2vec, unk, blk], 0))

        self.word_embedding.weight.requires_grad = False

        max_length = config.getint("data", "max_seq_length")
        self.max_len = max_length
        hidden_size = 100

        self.embedding_dim = word_embedding_dim
        # LSTM
        self.lstm = nn.LSTM(self.embedding_dim, hidden_size // 2, bidirectional=True)

        self.matcher = KernelMatcher(hidden_size)
        self.fc = nn.Linear(21, 1)

        self.criterion = nn.MarginRankingLoss(margin=1, reduction='mean')
        self.accuracy_function = init_accuracy_function(config, *args, **params)

    def init_multi_gpu(self, device, config, *args, **params):
        #self.bert = nn.DataParallel(self.bert, device_ids=device)
        pass
    
    def get_rep(self, ids):
        x = self.word_embedding(ids)
        x = x.transpose(1, 0)
        x, _ = self.lstm(x)
        #x = torch.relu(x)
        x = x.transpose(1, 0)
        #x = x * (ids != 400001).float().unsqueeze(-1)
        #x = x.sum(1) / (ids != 400001).sum(-1).float().unsqueeze(-1)
        return x

    def forward(self, data, config, gpu_list, acc_result, mode):
        if mode == "train":
            query = data['query']
            doc_pos = data['doc_pos']
            doc_neg = data['doc_neg']

            x_query = self.get_rep(query)
            x_doc_pos = self.get_rep(doc_pos)
            x_doc_neg = self.get_rep(doc_neg)

            #x_query = torch.cat([x_query, query_ent], 1)
            query_mask = (query != 400001).float()
            #x_doc_pos = torch.cat([x_doc_pos, doc_pos_ent], 1)
            doc_pos_mask = (doc_pos != 400001).float()
            #x_doc_neg = torch.cat([x_doc_neg, doc_neg_ent], 1)
            doc_neg_mask = (doc_neg != 400001).float()

            pos_scores = self.matcher(x_query, query_mask, x_doc_pos, doc_pos_mask)
            neg_scores = self.matcher(x_query, query_mask, x_doc_neg, doc_neg_mask)
           

            pos_scores = self.fc(pos_scores)
            neg_scores = self.fc(neg_scores)

            loss = self.criterion(pos_scores, neg_scores, torch.ones_like(pos_scores))
            acc_result = self.accuracy_function(pos_scores, neg_scores, config, acc_result)

            return {"loss": loss, "acc_result": acc_result}
        else:
            query = data['query']
            doc_pos = data['doc_pos']
            x_query = self.get_rep(query)
            x_doc_pos = self.get_rep(doc_pos)

            #x_query = torch.cat([x_query, query_ent], 1)
            query_mask = (query != 400001).float()
            #x_doc_pos = torch.cat([x_doc_pos, doc_pos_ent], 1)
            doc_pos_mask = (doc_pos != 400001).float()

            pos_scores = self.matcher(x_query, query_mask, x_doc_pos, doc_pos_mask)
            pos_scores = self.fc(pos_scores)

            return {"score": pos_scores, "qid": data["qid"], "did": data["did"], 'label': data['label']}


class LstmKFIR(nn.Module):
    def __init__(self, config, gpu_list, *args, **params):
        super(LstmKFIR, self).__init__()

        word_embedding_dim = 50
        self.word_embedding = nn.Embedding(400002, word_embedding_dim)
        word2vec = np.load('/data1/private/zzy/nre/data/pretrain/glove/word2vec.npy')
        word2vec = torch.from_numpy(word2vec)
        unk = torch.randn(1, word_embedding_dim) / math.sqrt(word_embedding_dim)
        blk = torch.zeros(1, word_embedding_dim)
        self.word_embedding.weight.data.copy_(torch.cat([word2vec, unk, blk], 0))

        self.word_embedding.weight.requires_grad = False

        max_length = config.getint("data", "max_seq_length")
        self.max_len = max_length
        pos_embedding_dim = 10
        hidden_size = 100

        self.embedding_dim = word_embedding_dim
        # LSTM
        self.lstm = nn.LSTM(self.embedding_dim, hidden_size // 2, bidirectional=True)

        self.input_dim = config.getint("model", "input_dim")
        self.trans = nn.Linear(self.input_dim, hidden_size)
        self.matcher = KernelMatcher(hidden_size)
        self.fc = nn.Linear(21*4, 1)


        self.criterion = nn.MarginRankingLoss(margin=1, reduction='mean')
        self.accuracy_function = init_accuracy_function(config, *args, **params)

    def init_multi_gpu(self, device, config, *args, **params):
        #self.bert = nn.DataParallel(self.bert, device_ids=device)
        pass
    def get_rep(self, ids):
        x = self.word_embedding(ids)
        x = x.transpose(1, 0)
        x, _ = self.lstm(x)
        #x = torch.relu(x)
        x = x.transpose(1, 0)
        #x = x * (ids != 400001).float().unsqueeze(-1)
        #x = x.sum(1) / (ids != 400001).sum(-1).float().unsqueeze(-1)
        return x

    def forward(self, data, config, gpu_list, acc_result, mode):
        if mode == "train":
            query = data['query']
            doc_pos = data['doc_pos']
            doc_neg = data['doc_neg']

            x_query = self.get_rep(query)
            x_doc_pos = self.get_rep(doc_pos)
            x_doc_neg = self.get_rep(doc_neg)

            query_ent = data['query_ent']
            query_mask_ent = (abs(query_ent.sum(-1))>1e-6).float()
            query_ent = torch.tanh(self.trans(query_ent))
            doc_pos_ent = data['doc_pos_ent']
            doc_pos_mask_ent = (abs(doc_pos_ent.sum(-1))>1e-6).float()
            doc_pos_ent = torch.tanh(self.trans(doc_pos_ent))
            doc_neg_ent = data['doc_neg_ent']
            doc_neg_mask_ent = (abs(doc_neg_ent.sum(-1))>1e-6).float()
            doc_neg_ent = torch.tanh(self.trans(doc_neg_ent))

            #x_query = torch.cat([x_query, query_ent], 1)
            query_mask = (query != 400001).float()
            #x_doc_pos = torch.cat([x_doc_pos, doc_pos_ent], 1)
            doc_pos_mask = (doc_pos != 400001).float()
            #x_doc_neg = torch.cat([x_doc_neg, doc_neg_ent], 1)
            doc_neg_mask = (doc_neg != 400001).float()

            pos_scores = torch.cat([self.matcher(x_query, query_mask, x_doc_pos, doc_pos_mask),
                self.matcher(x_query, query_mask, doc_pos_ent, doc_pos_mask_ent),
                self.matcher(query_ent, query_mask_ent, x_doc_pos, doc_pos_mask),
                self.matcher(query_ent, query_mask_ent, doc_pos_ent, doc_pos_mask_ent)], -1)
            neg_scores = torch.cat([self.matcher(x_query, query_mask, x_doc_neg, doc_neg_mask),
                self.matcher(x_query, query_mask, doc_neg_ent, doc_neg_mask_ent),
                self.matcher(query_ent, query_mask_ent, x_doc_neg, doc_neg_mask),
                self.matcher(query_ent, query_mask_ent, doc_neg_ent, doc_neg_mask_ent)], -1)
           

            pos_scores = self.fc(pos_scores)
            neg_scores = self.fc(neg_scores)

            #pos_scores = (x_query * x_doc_pos).sum(-1) / torch.norm(x_query, dim=-1) / torch.norm(x_doc_pos, dim=-1)
            #neg_scores = (x_query * x_doc_neg).sum(-1) / torch.norm(x_query, dim=-1) / torch.norm(x_doc_neg, dim=-1)

            loss = self.criterion(pos_scores, neg_scores, torch.ones_like(pos_scores))
            acc_result = self.accuracy_function(pos_scores, neg_scores, config, acc_result)

            return {"loss": loss, "acc_result": acc_result}
        else:
            query = data['query']
            doc_pos = data['doc_pos']
            x_query = self.get_rep(query)
            x_doc_pos = self.get_rep(doc_pos)

            query_ent = data['query_ent']
            query_mask_ent = (abs(query_ent.sum(-1))>1e-6).float()
            query_ent = torch.tanh(self.trans(query_ent))
            doc_pos_ent = data['doc_pos_ent']
            doc_pos_mask_ent = (abs(doc_pos_ent.sum(-1))>1e-6).float()
            doc_pos_ent = torch.tanh(self.trans(doc_pos_ent))

            #x_query = torch.cat([x_query, query_ent], 1)
            query_mask = (query != 400001).float()
            #x_doc_pos = torch.cat([x_doc_pos, doc_pos_ent], 1)
            doc_pos_mask = (doc_pos != 400001).float()

            pos_scores = torch.cat([self.matcher(x_query, query_mask, x_doc_pos, doc_pos_mask),
                self.matcher(x_query, query_mask, doc_pos_ent, doc_pos_mask_ent),
                self.matcher(query_ent, query_mask_ent, x_doc_pos, doc_pos_mask),
                self.matcher(query_ent, query_mask_ent, doc_pos_ent, doc_pos_mask_ent)], -1)
            pos_scores = self.fc(pos_scores)

            return {"score": pos_scores, "qid": data["qid"], "did": data["did"], 'label': data["label"]}

class LstmKAIR(nn.Module):
    def __init__(self, config, gpu_list, *args, **params):
        super(LstmKAIR, self).__init__()

        word_embedding_dim = 50
        self.word_embedding = nn.Embedding(400002, word_embedding_dim)
        word2vec = np.load('/data1/private/zzy/nre/data/pretrain/glove/word2vec.npy')
        word2vec = torch.from_numpy(word2vec)
        unk = torch.randn(1, word_embedding_dim) / math.sqrt(word_embedding_dim)
        blk = torch.zeros(1, word_embedding_dim)
        self.word_embedding.weight.data.copy_(torch.cat([word2vec, unk, blk], 0))

        self.word_embedding.weight.requires_grad = False

        max_length = config.getint("data", "max_seq_length")
        self.max_len = max_length
        pos_embedding_dim = 10
        hidden_size = 100

        self.embedding_dim = word_embedding_dim
        # LSTM
        self.lstm = nn.LSTM(self.embedding_dim, hidden_size // 2, bidirectional=True)

        self.input_dim = config.getint("model", "input_dim")
        self.att = KnowledgeAtt(hidden_size, self.input_dim)
        self.matcher = KernelMatcher(hidden_size)
        self.fc = nn.Linear(21*4, 1)


        self.criterion = nn.MarginRankingLoss(margin=1, reduction='mean')
        self.accuracy_function = init_accuracy_function(config, *args, **params)

    def init_multi_gpu(self, device, config, *args, **params):
        #self.bert = nn.DataParallel(self.bert, device_ids=device)
        pass
    def get_rep(self, ids):
        x = self.word_embedding(ids)
        x = x.transpose(1, 0)
        x, _ = self.lstm(x)
        #x = torch.relu(x)
        x = x.transpose(1, 0)
        #x = x * (ids != 400001).float().unsqueeze(-1)
        #x = x.sum(1) / (ids != 400001).sum(-1).float().unsqueeze(-1)
        return x

    def forward(self, data, config, gpu_list, acc_result, mode):
        if mode == "train":
            query = data['query']
            doc_pos = data['doc_pos']
            doc_neg = data['doc_neg']

            x_query = self.get_rep(query)
            x_doc_pos = self.get_rep(doc_pos)
            x_doc_neg = self.get_rep(doc_neg)

            query_ent = data['query_ent']
            query_mask_ent = (abs(query_ent.sum(-1))>1e-6).float()
            query_ent = self.att(x_query, query_ent, query == 400001)
            doc_pos_ent = data['doc_pos_ent']
            doc_pos_mask_ent = (abs(doc_pos_ent.sum(-1))>1e-6).float()
            doc_pos_ent = self.att(x_doc_pos, doc_pos_ent, doc_pos == 400001)
            doc_neg_ent = data['doc_neg_ent']
            doc_neg_mask_ent = (abs(doc_neg_ent.sum(-1))>1e-6).float()
            doc_neg_ent = self.att(x_doc_neg, doc_neg_ent, doc_pos == 400001)

            #x_query = torch.cat([x_query, query_ent], 1)
            query_mask = (query != 400001).float()
            #x_doc_pos = torch.cat([x_doc_pos, doc_pos_ent], 1)
            doc_pos_mask = (doc_pos != 400001).float()
            #x_doc_neg = torch.cat([x_doc_neg, doc_neg_ent], 1)
            doc_neg_mask = (doc_neg != 400001).float()

            pos_scores = torch.cat([self.matcher(x_query, query_mask, x_doc_pos, doc_pos_mask),
                self.matcher(x_query, query_mask, doc_pos_ent, doc_pos_mask_ent),
                self.matcher(query_ent, query_mask_ent, x_doc_pos, doc_pos_mask),
                self.matcher(query_ent, query_mask_ent, doc_pos_ent, doc_pos_mask_ent)], -1)
            neg_scores = torch.cat([self.matcher(x_query, query_mask, x_doc_neg, doc_neg_mask),
                self.matcher(x_query, query_mask, doc_neg_ent, doc_neg_mask_ent),
                self.matcher(query_ent, query_mask_ent, x_doc_neg, doc_neg_mask),
                self.matcher(query_ent, query_mask_ent, doc_neg_ent, doc_neg_mask_ent)], -1)
           

            pos_scores = self.fc(pos_scores)
            neg_scores = self.fc(neg_scores)

            #pos_scores = (x_query * x_doc_pos).sum(-1) / torch.norm(x_query, dim=-1) / torch.norm(x_doc_pos, dim=-1)
            #neg_scores = (x_query * x_doc_neg).sum(-1) / torch.norm(x_query, dim=-1) / torch.norm(x_doc_neg, dim=-1)

            loss = self.criterion(pos_scores, neg_scores, torch.ones_like(pos_scores))
            acc_result = self.accuracy_function(pos_scores, neg_scores, config, acc_result)

            return {"loss": loss, "acc_result": acc_result}
        else:
            query = data['query']
            doc_pos = data['doc_pos']
            x_query = self.get_rep(query)
            x_doc_pos = self.get_rep(doc_pos)

            query_ent = data['query_ent']
            query_mask_ent = (abs(query_ent.sum(-1))>1e-6).float()
            query_ent = self.att(x_query, query_ent, query == 400001)
            doc_pos_ent = data['doc_pos_ent']
            doc_pos_mask_ent = (abs(doc_pos_ent.sum(-1))>1e-6).float()
            doc_pos_ent = self.att(x_doc_pos, doc_pos_ent, doc_pos == 400001)

            #x_query = torch.cat([x_query, query_ent], 1)
            query_mask = (query != 400001).float()
            #x_doc_pos = torch.cat([x_doc_pos, doc_pos_ent], 1)
            doc_pos_mask = (doc_pos != 400001).float()

            pos_scores = torch.cat([self.matcher(x_query, query_mask, x_doc_pos, doc_pos_mask),
                self.matcher(x_query, query_mask, doc_pos_ent, doc_pos_mask_ent),
                self.matcher(query_ent, query_mask_ent, x_doc_pos, doc_pos_mask),
                self.matcher(query_ent, query_mask_ent, doc_pos_ent, doc_pos_mask_ent)], -1)
            pos_scores = self.fc(pos_scores)

            return {"score": pos_scores, "qid": data["qid"], "did": data["did"], 'label': data["label"]}


class BertIR(nn.Module):
    def __init__(self, config, gpu_list, *args, **params):
        super(BertIR, self).__init__()

        self.bert = BertModel.from_pretrained(config.get("model", "bert_path"))
        self.fc = nn.Linear(768, 1)

        self.criterion = nn.MarginRankingLoss(margin=1, reduction='mean')
        self.accuracy_function = init_accuracy_function(config, *args, **params)

    def init_multi_gpu(self, device, config, *args, **params):
        #self.bert = nn.DataParallel(self.bert, device_ids=device)
        pass

    def forward(self, data, config, gpu_list, acc_result, mode):
        if mode == "train":
            pos_input = data['pos_input']
            neg_input = data['neg_input']
            _, y = self.bert(pos_input)
            pos_scores = self.fc(y)
            _, y = self.bert(neg_input)
            neg_scores = self.fc(y)
            loss = self.criterion(pos_scores, neg_scores, torch.ones_like(pos_scores))
            acc_result = self.accuracy_function(pos_scores, neg_scores, config, acc_result)

            return {"loss": loss, "acc_result": acc_result}
        else:
            pos_input = data['pos_input']
            _, y = self.bert(pos_input)
            pos_scores = self.fc(y)
            return {"score": pos_scores, "qid": data["qid"], "did": data["did"], "label": data["label"]}


class BertKFIR(nn.Module):
    def __init__(self, config, gpu_list, *args, **params):
        super(BertKFIR, self).__init__()

        self.bert = BertModel.from_pretrained(config.get("model", "bert_path"))
        
        self.input_dim = config.getint("model", "input_dim")
        self.matcher = KernelMatcher(self.input_dim)
        self.fc = nn.Linear(768+21, 1)

        self.criterion = nn.MarginRankingLoss(margin=1, reduction='mean')
        self.accuracy_function = init_accuracy_function(config, *args, **params)

    def init_multi_gpu(self, device, config, *args, **params):
        #self.bert = nn.DataParallel(self.bert, device_ids=device)
        pass

    def forward(self, data, config, gpu_list, acc_result, mode):
        if mode == "train":
            query_ent = data['query_ent']
            query_mask_ent = (abs(query_ent.sum(-1))>1e-6).float()
            doc_pos_ent = data['doc_pos_ent']
            doc_pos_mask_ent = (abs(doc_pos_ent.sum(-1))>1e-6).float()
            doc_neg_ent = data['doc_neg_ent']
            doc_neg_mask_ent = (abs(doc_neg_ent.sum(-1))>1e-6).float()

            pos_input = data['pos_input']
            neg_input = data['neg_input']
            _, y = self.bert(pos_input)
            pos_scores = self.fc(torch.cat([y, self.matcher(query_ent, query_mask_ent, doc_pos_ent, doc_pos_mask_ent)], -1))
            _, y = self.bert(neg_input)
            neg_scores = self.fc(torch.cat([y, self.matcher(query_ent, query_mask_ent, doc_neg_ent, doc_neg_mask_ent)], -1))
            loss = self.criterion(pos_scores, neg_scores, torch.ones_like(pos_scores))
            acc_result = self.accuracy_function(pos_scores, neg_scores, config, acc_result)

            return {"loss": loss, "acc_result": acc_result}
        else:
            query_ent = data['query_ent']
            query_mask_ent = (abs(query_ent.sum(-1))>1e-6).float()
            doc_pos_ent = data['doc_pos_ent']
            doc_pos_mask_ent = (abs(doc_pos_ent.sum(-1))>1e-6).float()
            pos_input = data['pos_input']
            _, y = self.bert(pos_input)
            pos_scores = self.fc(torch.cat([y, self.matcher(query_ent, query_mask_ent, doc_pos_ent, doc_pos_mask_ent)], -1))

            return {"score": pos_scores, "qid": data["qid"], "did": data["did"], "label": data["label"]}

class KnowledgeAttBert(nn.Module):

    def __init__(self, hidden_dim, ent_dim):
        super().__init__()
        self.hidden_dim = hidden_dim
        self.A = nn.Linear(hidden_dim, ent_dim, bias=True)
    
    def forward(self, x, ke, mask):
        # x (B, L, H)
        # ke (B, L', E)
        # x (B, L)
        x = self.A(x)
        alpha = torch.bmm(x, ke.transpose(2, 1)) # (B, L, L')
        mask_ = torch.zeros_like(alpha)
        mask_[mask, :] = -10000
        alpha = alpha + mask_
        alpha = alpha.transpose(2, 1)
        scores = nn.Softmax(dim=-1)(alpha) # (B, L', L)
        return torch.bmm(scores, x)


class BertKAIR(nn.Module):
    def __init__(self, config, gpu_list, *args, **params):
        super(BertKAIR, self).__init__()

        self.bert = BertModel.from_pretrained(config.get("model", "bert_path"))
        
        self.input_dim = config.getint("model", "input_dim")
        self.att = KnowledgeAttBert(768, self.input_dim)
        self.matcher = KernelMatcher(768)
        self.fc = nn.Linear(768+21, 1)

        self.criterion = nn.MarginRankingLoss(margin=1, reduction='mean')
        self.accuracy_function = init_accuracy_function(config, *args, **params)

    def init_multi_gpu(self, device, config, *args, **params):
        #self.bert = nn.DataParallel(self.bert, device_ids=device)
        pass

    def forward(self, data, config, gpu_list, acc_result, mode):
        if mode == "train":
            query_ent = data['query_ent']
            query_mask_ent = (abs(query_ent.sum(-1))>1e-6).float()
            doc_pos_ent = data['doc_pos_ent']
            doc_pos_mask_ent = (abs(doc_pos_ent.sum(-1))>1e-6).float()
            doc_neg_ent = data['doc_neg_ent']
            doc_neg_mask_ent = (abs(doc_neg_ent.sum(-1))>1e-6).float()

            pos_input = data['pos_input']
            neg_input = data['neg_input']
            yy, y = self.bert(pos_input)
            doc_pos_ent = self.att(yy, doc_pos_ent, data['pos_input']==0)
            pos_scores = self.fc(torch.cat([y, self.matcher(query_ent, query_mask_ent, doc_pos_ent, doc_pos_mask_ent)], -1))
            yy, y = self.bert(neg_input)
            doc_neg_ent = self.att(yy, doc_neg_ent, data['neg_input']==0)
            neg_scores = self.fc(torch.cat([y, self.matcher(query_ent, query_mask_ent, doc_neg_ent, doc_neg_mask_ent)], -1))
            loss = self.criterion(pos_scores, neg_scores, torch.ones_like(pos_scores))
            acc_result = self.accuracy_function(pos_scores, neg_scores, config, acc_result)

            return {"loss": loss, "acc_result": acc_result}
        else:
            query_ent = data['query_ent']
            query_mask_ent = (abs(query_ent.sum(-1))>1e-6).float()
            doc_pos_ent = data['doc_pos_ent']
            doc_pos_mask_ent = (abs(doc_pos_ent.sum(-1))>1e-6).float()
            pos_input = data['pos_input']
            yy, y = self.bert(pos_input)
            doc_pos_ent = self.att(yy, doc_pos_ent, data['pos_input']==0)
            pos_scores = self.fc(torch.cat([y, self.matcher(query_ent, query_mask_ent, doc_pos_ent, doc_pos_mask_ent)], -1))

            return {"score": pos_scores, "qid": data["qid"], "did": data["did"], "label": data["label"]}



