import logging
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
import random
import pickle
import math
import collections
import itertools
import time
from tqdm import tqdm
import os
import sys
import json
MAX = 10000000000.0
sys.path.append('rp')
from kbc.src.models import ComplEx , TransE
def add(a,b):
    return a+b
def sub(a,b):
    return a-b
def mul(a,b):
    return a*b
def div(a,b):
    if b == 0:
        return MAX
    return a/b
operator_dict = {
    "0":add,
    "1":sub,
    "2":mul,
    "3":div,
}
def load_kbc(model_path, device, nentity, nrelation,nattribute,nnum,nnumpred):
    model = ComplEx(sizes=[nentity, nrelation, nattribute,nentity,nnum,nnumpred,nnum], rank=500, init_size=1e-3)
    state_dict = torch.load(model_path)
    model.load_state_dict(state_dict)
    model.to(device)
    return model

@torch.no_grad()
def kge_forward(model, h, r,t,device, nentity,task):
    bsz = h.size(0)
    r = r.unsqueeze(-1).repeat(bsz, 1) # r is an n*1 all-relationship-numbered tensor
    h = h.unsqueeze(-1)
    t = t.unsqueeze(-1)
    positive_sample = torch.cat((h, r, t), dim=1)
    if(task == "ere"):
        score = model(positive_sample,None,None,score_rhs=True, score_rel=False, score_lhs=False)
    elif(task == "eav"): 
          score = model(None,positive_sample,None,score_rhs=True, score_rel=False, score_lhs=False)
    else:
        score = model(None,None,positive_sample, score_rhs=True, score_rel=False, score_lhs=False)# Access to reasoning within the model
    return score[0],score[1],score[2]# is a k*nentity tensor representing the probability of being able to connect to all entities through the relation r

@torch.no_grad()
# Entity Link Forecast
def neural_adj_matrix(model, rel, nentity, device, thrshd, adj_list):
    bsz = 100
    softmax = nn.Softmax(dim=1)
    relation_embedding = torch.zeros(nentity, nentity).to(torch.float)
    r = torch.LongTensor([rel]).to(device)
    num = torch.zeros(nentity, 1).to(torch.float).to(device)
    for (h, t) in adj_list:
        num[h, 0] += 1
    num = torch.maximum(num, torch.ones(nentity, 1).to(torch.float).to(device))
    for s in range(0, nentity, bsz):# Reasoning in batches
        t = min(nentity, s+bsz)
        h = torch.arange(s, t).to(device)
        score,score_num,score_np = kge_forward(model, h, r,h, device, nentity,"ere")
        normalized_score = softmax(score) * num[s:t, :]
        mask = (normalized_score >= thrshd).to(torch.float) # Too small probability can be ignored
        normalized_score = mask * normalized_score# Normalization
        relation_embedding[s:t, :] = normalized_score.to('cpu')
    return relation_embedding
#Numerical and Entity Connection Prediction
def num_neural_adj_matrix(model, rel, nentity,n_num, device, thrshd, adj_list):
    bsz = 100
    softmax = nn.Softmax(dim=1)
    attribute_embedding = torch.zeros(nentity, n_num).to(torch.float)
    r = torch.LongTensor([rel]).to(device)
    num = torch.zeros(nentity, 1).to(torch.float).to(device)
    for (h, t) in adj_list:
        num[h, 0] += 1
    num = torch.maximum(num, torch.ones(nentity, 1).to(torch.float).to(device))
    for s in range(0, nentity, bsz):
        t = min(nentity, s+bsz)
        h = torch.arange(s, t).to(device)
        score,score_att,score_np = kge_forward(model, h, r,h, device, nentity,"eav")
        normalized_score = softmax(score_att) * num[s:t, :]
        mask = (normalized_score >= thrshd).to(torch.float) 
        normalized_score = mask * normalized_score
        attribute_embedding[s:t, :] = normalized_score.to('cpu')
    return attribute_embedding
#Numerical Connection Prediction
def np_neural_adj_matrix(model, rel, nentity,n_num, device, thrshd, adj_list):
    bsz = 100
    softmax = nn.Softmax(dim=1)
    numericalpred_embedding = torch.zeros(n_num, n_num).to(torch.float)
    r = torch.LongTensor([rel]).to(device)
    num = torch.zeros(n_num, 1).to(torch.float).to(device)
    for (h, t) in adj_list:
        num[h, 0] += 1
    num = torch.maximum(num, torch.ones(n_num, 1).to(torch.float).to(device))
    for s in range(0, n_num, bsz):
        t = min(n_num, s+bsz)
        h = torch.arange(s, t).to(device)
        score,score_att,score_np = kge_forward(model, h, r,h, device, n_num,"vnpv")
        normalized_score = softmax(score_np) * num[s:t, :]
        mask = (normalized_score >= thrshd).to(torch.float) 
        normalized_score = mask * normalized_score
        numericalpred_embedding[s:t, :] = normalized_score.to('cpu')
    return numericalpred_embedding
class KGReasoning(nn.Module):
    def __init__(self, args, device, adj_list,num_adj_list,np_adj_list,num_reverse_list, query_name_dict, name_answer_dict,id2number):
        super(KGReasoning, self).__init__()
        self.nentity = args.nentity
        self.nrelation = args.nrelation
        self.device = device
        self.relation_embeddings = list()
        self.attribute_embeddings = list()
        self.attribute_reverse_embeddings = list()
        self.np_embeddings = list()
        self.fraction = args.fraction
        self.query_name_dict = query_name_dict
        self.name_answer_dict = name_answer_dict
        self.neg_scale = args.neg_scale
        self.n_num = args.nnum
        self.num_rank = args.num_rank
        self.id2number = id2number
        dataset_name = args.data_path.split('/')[1].split('-')[0]
        if args.data_path.split('/')[1].split('-')[1] == "237":
            dataset_name += "-237"
        self.numthrshd = args.eavthrshd
        self.fuzzythrshd = args.fuzzythrshd
        # self.fuzzycenter = torch.tensor(np.loadtxt('data/' + dataset_name + "-number/pre_fuzzy_Call.txt")).cuda()
        # self.prefuzzy = torch.tensor(np.loadtxt('data/' + dataset_name + "-number/pre_fuzzy_U.txt")).cuda()
        filename = 'neural_adj/'+dataset_name+'_'+str(args.fraction)+'_'+str(args.thrshd)+'.pt'
        with open('data/' + dataset_name + "-number/id2number.json",'r') as f:
            id2num = json.load(f)
        self.num2id = {}
        for id in id2num:
            self.num2id[float(id2num[id])] = id
        if os.path.exists(filename):
            self.relation_embeddings = torch.load(filename, map_location=device)
        else:
            kbc_model = load_kbc(args.kbc_path, device, args.nentity, args.nrelation,args.nattribute,args.nnum,args.nnumpre) #Loading pre-trained models
            # eve link predictor
            for i in tqdm(range(args.nrelation)):
                relation_embedding = neural_adj_matrix(kbc_model, i, args.nentity, device, args.thrshd, adj_list[i])
                relation_embedding = (relation_embedding>=1).to(torch.float) * 0.9999 + (relation_embedding<1).to(torch.float) * relation_embedding
                for (h, t) in adj_list[i]:
                    relation_embedding[h, t] = 1.#Setting the probability of an existing ternary to 1
                # add fractional
                fractional_relation_embedding = []#Slicing the original matrix to facilitate subsequent operations
                dim = args.nentity // args.fraction
                rest = args.nentity - args.fraction * dim
                for i in range(args.fraction):
                    s = i * dim
                    t = (i+1) * dim
                    if i == args.fraction - 1:
                        t += rest
                    fractional_relation_embedding.append(relation_embedding[s:t, :].to_sparse().to(self.device))
                self.relation_embeddings.append(fractional_relation_embedding)
            torch.save(self.relation_embeddings, filename) 
        filename = 'neural_adj/'+dataset_name+'_'+str(args.fraction)+'_'+str(args.eavthrshd)+'.pt'
        if os.path.exists(filename.replace(".pt","num.pt")):
            self.attribute_embeddings = torch.load(filename.replace(".pt","num.pt"), map_location=device)
            self.attribute_reverse_embeddings = torch.load(filename.replace(".pt","num_reverse.pt"), map_location=device)
        else:
            #eav Link Pridictor
            for i in tqdm(range(args.nattribute)):
                kbc_model = load_kbc(args.kbc_path, device, args.nentity, args.nrelation,args.nattribute,args.nnum,args.nnumpre)
                attribute_embedding = num_neural_adj_matrix(kbc_model, i, args.nentity,args.nnum, device, args.eavthrshd, num_adj_list[i])
                attribute_embedding = (attribute_embedding>=1).to(torch.float) * 0.9999 + (attribute_embedding<1).to(torch.float) * attribute_embedding
                for (h, t) in num_adj_list[i]:
                    attribute_embedding[h, t] = 1.
                # add fractional
                fractional_attribute_embedding = []
                fractional_attribute_reverse_embedding = []
                dim = args.nentity // args.fraction
                rest = args.nentity - args.fraction * dim
                for i in range(args.fraction):
                    s = i * dim
                    t = (i+1) * dim
                    if i == args.fraction - 1:
                        t += rest
                    fractional_attribute_embedding.append(attribute_embedding[s:t, :].to_sparse().to(self.device))
                self.attribute_embeddings.append(fractional_attribute_embedding)
                attribute_embedding = attribute_embedding.transpose(0,1)
                dim = args.nnum // args.fraction
                rest = args.nnum - args.fraction * dim
                for j in range(args.fraction):
                    s = j * dim
                    t = (j+1) * dim
                    if j == args.fraction - 1:
                        t += rest
                    fractional_attribute_reverse_embedding.append(attribute_embedding[s:t, :].to_sparse().to(self.device))
                self.attribute_reverse_embeddings.append(fractional_attribute_reverse_embedding)
            torch.save(self.attribute_embeddings, filename.replace(".pt","num.pt"))
            torch.save(self.attribute_reverse_embeddings, filename.replace(".pt","num_reverse.pt"))
        filename = 'neural_adj/'+dataset_name+'_'+str(args.fraction)+'_'+str(args.vnpvthrshd)+'.pt'
        if os.path.exists(filename.replace(".pt","np.pt")):
            self.np_embeddings = torch.load(filename.replace(".pt","np.pt"), map_location=device)
        else:
            #vnpv Link Predictor
            for i in tqdm(range(args.nnumpre)):
                kbc_model = load_kbc(args.kbc_path, device, args.nentity, args.nrelation,args.nattribute,args.nnum,args.nnumpre)
                numericalpred_embedding = np_neural_adj_matrix(kbc_model, i, args.nentity,args.nnum, device, args.vnpvthrshd, np_adj_list[i])
                numericalpred_embedding = (numericalpred_embedding>=1).to(torch.float) * 0.9999 + (numericalpred_embedding<1).to(torch.float) * numericalpred_embedding
                for (h, t) in np_adj_list[i]:
                    if i == 0:
                        numericalpred_embedding[t, h] = 1
                    numericalpred_embedding[h, t] = 1.
                # add fractional
                fractional_numericalpred_embedding = []
                dim = args.nnum // args.fraction
                rest = args.nnum - args.fraction * dim
                for i in range(args.fraction):
                    s = i * dim
                    t = (i+1) * dim
                    if i == args.fraction - 1:
                        t += rest
                    fractional_numericalpred_embedding.append(numericalpred_embedding[s:t, :].to_sparse().to(self.device))
                self.np_embeddings.append(fractional_numericalpred_embedding)
            torch.save(self.np_embeddings, filename.replace(".pt","np.pt"))
            

    def relation_projection(self, embedding, r_embedding, is_neg=False):
        dim = self.nentity // self.fraction
        rest = self.nentity - self.fraction * dim
        new_embedding = torch.zeros_like(embedding).to(self.device)
        r_argmax = torch.zeros(self.nentity).to(self.device)
        for i in range(self.fraction):
            s = i * dim
            t = (i+1) * dim
            if i == self.fraction - 1:
                t += rest
            fraction_embedding = embedding[:, s:t]# Take the tensor for this fraction segment
            if fraction_embedding.sum().item() == 0: # How this paragraph does not appear elements, just skip it
                continue
            nonzero = torch.nonzero(fraction_embedding, as_tuple=True)[1]
            fraction_embedding = fraction_embedding[:, nonzero]
            fraction_r_embedding = r_embedding[i].to_dense()[nonzero, :].unsqueeze(0)
            if is_neg:
                fraction_r_embedding = torch.minimum(torch.ones_like(fraction_r_embedding).to(torch.float), self.neg_scale*fraction_r_embedding)
                fraction_r_embedding = 1. - fraction_r_embedding
            fraction_embedding_premax = fraction_r_embedding * fraction_embedding.unsqueeze(-1) # relationship and the embedding of the current node are multiplied
            fraction_embedding, tmp_argmax = torch.max(fraction_embedding_premax, dim=1)
            tmp_argmax = nonzero[tmp_argmax.squeeze()] + s
            new_argmax = (fraction_embedding > new_embedding).to(torch.long).squeeze()
            r_argmax = new_argmax * tmp_argmax + (1-new_argmax) * r_argmax
            new_embedding = torch.maximum(new_embedding, fraction_embedding)
        return new_embedding, r_argmax.cpu().numpy()
    def attribute_projection(self, embedding, r_embedding, is_neg=False):
        dim = self.nentity // self.fraction
        rest = self.nentity - self.fraction * dim
        new_embedding = torch.zeros(1,self.n_num).to(self.device)
        r_argmax = torch.zeros(self.n_num).to(self.device)
        for i in range(self.fraction):
            s = i * dim
            t = (i+1) * dim
            if i == self.fraction - 1:
                t += rest
            fraction_embedding = embedding[:, s:t]
            if fraction_embedding.sum().item() == 0: 
                continue
            nonzero = torch.nonzero(fraction_embedding, as_tuple=True)[1]#Take the subscript of a non-zero element
            fraction_embedding = fraction_embedding[:, nonzero]
            fraction_r_embedding = r_embedding[i].to_dense()[nonzero, :].unsqueeze(0)
            if is_neg:
                fraction_r_embedding = torch.minimum(torch.ones_like(fraction_r_embedding).to(torch.float), self.neg_scale*fraction_r_embedding)
                fraction_r_embedding = 1. - fraction_r_embedding
            fraction_embedding_premax = fraction_r_embedding * fraction_embedding.unsqueeze(-1) 
            fraction_embedding, tmp_argmax = torch.max(fraction_embedding_premax, dim=1)
            tmp_argmax = nonzero[tmp_argmax.squeeze()] + s
            new_argmax = (fraction_embedding > new_embedding).to(torch.long).squeeze()
            r_argmax = new_argmax * tmp_argmax + (1-new_argmax) * r_argmax
            new_embedding = torch.maximum(new_embedding, fraction_embedding)
        max_idx = new_embedding.argmax()
        return new_embedding,r_argmax.cpu().numpy(), max_idx                                                                                                                                                                                                                                                                                                                    
    def reverse_attribute_projection(self, embedding, r_embedding,is_neg=False):
        dim = self.n_num // self.fraction
        rest = self.n_num - self.fraction * dim
        new_embedding = torch.zeros(1,self.nentity).to(self.device)
        r_argmax = torch.zeros(self.nentity).to(self.device)
        for i in range(self.fraction):
            s = i * dim
            t = (i+1) * dim
            if i == self.fraction - 1:
                t += rest
            fraction_embedding = embedding[:, s:t]
            if fraction_embedding.sum().item() == 0: 
                continue
            nonzero = torch.nonzero(fraction_embedding, as_tuple=True)[1]
            fraction_embedding = fraction_embedding[:, nonzero]
            fraction_r_embedding = r_embedding[i].to_dense()[nonzero, :].unsqueeze(0)
            if is_neg:
                fraction_r_embedding = torch.minimum(torch.ones_like(fraction_r_embedding).to(torch.float), self.neg_scale*fraction_r_embedding)
                fraction_r_embedding = 1. - fraction_r_embedding
            fraction_embedding_premax = fraction_r_embedding * fraction_embedding.unsqueeze(-1) #关系和当前节点的嵌入相乘
            fraction_embedding, tmp_argmax = torch.max(fraction_embedding_premax, dim=1)
            tmp_argmax = nonzero[tmp_argmax.squeeze()] + s
            new_argmax = (fraction_embedding > new_embedding).to(torch.long).squeeze()
            r_argmax = new_argmax * tmp_argmax + (1-new_argmax) * r_argmax
            new_embedding = torch.maximum(new_embedding, fraction_embedding)
        return new_embedding, r_argmax.cpu().numpy()
    def number_projection(self, embedding, r_embedding,is_neg=False):
    # "0": approximately_equal,
    # "1": greater_than,
    # "2": smaller_than,
    # "3": approximately_two_times_equal_to,
    # "4": approximately_three_times_equal_to,
    # "5": two_times_larger_than,
    # "6": three_times_larger_than,
        dim = self.n_num // self.fraction
        rest = self.n_num - self.fraction * dim
        new_embedding = torch.zeros_like(embedding).to(self.device)
        r_argmax = torch.zeros(self.n_num).to(self.device)
        for i in range(self.fraction):
            s = i * dim
            t = (i+1) * dim
            if i == self.fraction - 1:
                t += rest
            fraction_embedding = embedding[:, s:t]
            if fraction_embedding.sum().item() == 0: 
                continue
            nonzero = torch.nonzero(fraction_embedding, as_tuple=True)[1]
            fraction_embedding = fraction_embedding[:, nonzero]
            fraction_r_embedding = r_embedding[i].to_dense()[nonzero, :].unsqueeze(0)
            if is_neg:
                fraction_r_embedding = torch.minimum(torch.ones_like(fraction_r_embedding).to(torch.float), self.neg_scale*fraction_r_embedding)
                fraction_r_embedding = 1. - fraction_r_embedding
            fraction_embedding_premax = fraction_r_embedding * fraction_embedding.unsqueeze(-1)
            fraction_embedding, tmp_argmax = torch.max(fraction_embedding_premax, dim=1)
            tmp_argmax = nonzero[tmp_argmax.squeeze()] + s
            new_argmax = (fraction_embedding > new_embedding).to(torch.long).squeeze()
            r_argmax = new_argmax * tmp_argmax + (1-new_argmax) * r_argmax
            new_embedding = torch.maximum(new_embedding, fraction_embedding)
        return new_embedding, r_argmax.cpu().numpy()
    
    def intersection(self, embeddings):
        return torch.prod(embeddings, dim=0)# torch.prod is to return the product of all line elements
    def union(self, embeddings):
        return (1. - torch.prod(1.-embeddings, dim=0))
    def binary_operator(self,u1,u2,operator):
        out = {}
        idx1 = torch.nonzero(u1[0]>self.fuzzythrshd)
        idx2 = torch.nonzero(u2[0]>self.fuzzythrshd)
        if idx1.shape[0] <= 0 or idx2.shape[0] <= 0:
            return None,None
        for x1 in idx1:
            for x2 in idx2:
                result = []
                try:
                    result.append(operator_dict[str(operator)](float(self.id2number[str(x1.item())]),float(self.id2number[str(x2.item())])))
                except:
                    result.append(operator_dict[str(operator)](float(self.id2number[str(x1[0].item())]),float(self.id2number[str(x2[0].item())])))
                out[tuple(result)] = u1[0][x1.item()].item()*u2[0][x2.item()].item()
        out = sorted(out.items(),key = lambda x:x[1],reverse = True)
        out_embed = []
        out_cen = []
        i = 0
        for key in out:
            out_embed.append(key[1])
            out_cen.append(key[0][0])
            i += 1
            if key[1] < self.numthrshd:
                break
            if len(out_embed) >= self.n_num/100:
                break
        out_tensor = nn.functional.normalize(torch.tensor(np.array(out_embed)).cuda(),p=2,dim=0).unsqueeze(0)
        mask = out_tensor > self.numthrshd
                       
        return  torch.masked_select(out_tensor, mask),torch.masked_select(torch.tensor(out_cen).cuda(), mask)


    def embed_query(self, queries, query_structure, idx,Affiliation_Center_new):#Reasoning about queries
        '''
        Iterative embed a batch of queries with same structure
        queries: a flattened batch of queries
        '''
        entity_fuzzyCenter = torch.tensor(np.arange(0,self.nentity)).cuda()
        all_relation_flag = True
        Affiliation_Center_new = None
        exec_query = []
        for ele in query_structure[-1]: # whether the current query tree has merged to one branch and only need to do relation traversal, e.g., path queries or conjunctive queries after the intersection
            if ele not in ['rp','np','ap','rap']:
                all_relation_flag = False
                break
        if all_relation_flag:
            if query_structure[0] == 'e':
                bsz = queries.size(0)
                embedding = torch.zeros(bsz, self.nentity).to(torch.float).to(self.device) 
                embedding.scatter_(-1, queries[:, idx].unsqueeze(-1), 1)
                exec_query.append(queries[:, idx].item())
                idx += 1
            elif query_structure[0] == 'nv':
                bsz = queries.size(0)
                embedding = torch.zeros(bsz, self.n_num).to(torch.float).to(self.device) 
                embedding.scatter_(-1, queries[:, idx].unsqueeze(-1), 1)
                exec_query.append(queries[:, idx].item())
                idx += 1
            else:
                embedding, idx, pre_exec_query,Affiliation_Center_new = self.embed_query(queries, query_structure[0], idx,Affiliation_Center_new)
                if embedding == None:
                    return None, idx, exec_query,None
                if(len(embedding.shape) == 1):
                    embedding = embedding.unsqueeze(0)
                exec_query.append(pre_exec_query)
            r_exec_query = []
            for i in range(len(query_structure[-1])):  
                if query_structure[-1][i] == 'ap':
                    r_embedding = self.attribute_embeddings[(queries[0, idx]/2).int()]
                    embedding, r_argmax,max_idx = self.attribute_projection(embedding, r_embedding,False) 
                    r_exec_query.append((queries[0, idx].item(), r_argmax))
                    r_exec_query.append('nv')
                elif query_structure[-1][i] == 'rap':
                    # Approximate matching of fuzzy sets to specific values is performed here
                    if embedding.shape[1] != self.n_num:
                        bsz = queries.size(0)
                        embedding_new = torch.zeros(bsz, self.n_num).to(torch.float).to(self.device) 
                        for i in range(min(embedding.shape[1],10)):
                            now_idx = 0
                            min_dis = 100
                            now_center = Affiliation_Center_new[i]
                            for id in self.id2number:
                                temp_cen = float(self.id2number[id])
                                try:
                                    now_dis = abs(abs(temp_cen-now_center.item())/max(abs(now_center.item()),abs(temp_cen)))
                                except:
                                    now_dis = 100
                                if  now_dis < min_dis:
                                    min_dis = now_dis
                                    now_idx = int(id)
                                if min_dis <= 1e-4:
                                    break
                            embedding_new[0][now_idx] = 1-min_dis
                        embedding = embedding_new
                    r_embedding = self.attribute_reverse_embeddings[((queries[0, idx]-1)/2).int()]
                    embedding, r_argmax = self.reverse_attribute_projection(embedding, r_embedding, False) 
                    r_exec_query.append((queries[0, idx].item(), r_argmax))
                    r_exec_query.append('e')
                    Affiliation_Center_new = entity_fuzzyCenter
                elif query_structure[-1][i] == 'np':
                    if embedding.shape[1] != self.n_num:
                        bsz = queries.size(0)
                        embedding_new = torch.zeros(bsz, self.n_num).to(torch.float).to(self.device) 
                        for i in range(min(embedding.shape[1],10)):
                            now_idx = 0
                            min_dis = 100
                            now_center = Affiliation_Center_new[i]
                            for id in self.id2number:
                                temp_cen = float(self.id2number[id])
                                try:
                                    now_dis = abs(abs(temp_cen-now_center.item())/max(abs(now_center.item()),abs(temp_cen)))
                                except:
                                    now_dis = 10
                                if  now_dis < min_dis:
                                    min_dis = now_dis
                                    now_idx = int(id)
                                if min_dis <= 1e-4:
                                    break
                            embedding_new[0][now_idx] = 1-min_dis
                        embedding = embedding_new
                    Unary_operators = int(queries[0, idx].item())
                    r_embedding = self.np_embeddings[Unary_operators]
                    embedding,r_argmax = self.number_projection(embedding,r_embedding,False)
                    r_exec_query.append((queries[0, idx].item(), r_argmax))
                    r_exec_query.append('nv')
                else:
                    r_embedding = self.relation_embeddings[(queries[0, idx]/2).int()]
                    if (i < len(query_structure[-1]) - 1) and query_structure[-1][i+1] == 'n':
                        embedding, r_argmax = self.relation_projection(embedding, r_embedding, True)
                    else:
                        embedding, r_argmax = self.relation_projection(embedding, r_embedding, False)
                    r_exec_query.append((queries[0, idx].item(), r_argmax))
                    r_exec_query.append('e')
                idx += 1
            r_exec_query.pop()
            exec_query.append(r_exec_query)
            exec_query.append('e')
        # Integration of multiple branches
        else:
            embedding_list = []
            union_flag = False
            binary_flag = False
            for ele in query_structure[-1]:
                if ele == 'b':
                    binary_flag = True
                    query_structure = query_structure[:-1]
                if ele == 'u':
                    union_flag = True
                    query_structure = query_structure[:-1]
                    break
                if ele == 'i':
                    query_structure = query_structure[:-1]
                    break
            for i in range(len(query_structure)):
                embedding, idx, pre_exec_query,Affiliation_Center = self.embed_query(queries, query_structure[i], idx,Affiliation_Center_new)
                embedding_list.append(embedding)
                exec_query.append(pre_exec_query)
            if union_flag:
                embedding = self.union(torch.stack(embedding_list))
                idx += 1
                exec_query.append(['u'])
            elif binary_flag:
                embedding,Affiliation_Center_new= self.binary_operator(embedding_list[0],embedding_list[1],queries[-1][idx].item())
                idx += 1
                exec_query.append(['b'])
            else:
                embedding = self.intersection(torch.stack(embedding_list))
                idx += 1
                
            exec_query.append('e')
        
        return embedding, idx, exec_query,Affiliation_Center_new

    def find_ans(self, exec_query, query_structure, anchor):
        ans_structure = self.name_answer_dict[self.query_name_dict[query_structure]]
        return self.backward_ans(ans_structure, exec_query, anchor)

    def backward_ans(self, ans_structure, exec_query, anchor):
        if ans_structure == 'e': # 'e'
            return exec_query, exec_query

        elif ans_structure[0] == 'u': # 'u'
            return ['u'], 'u'
        
        elif ans_structure[0] == 'r': # ['r', 'e', 'r']
            cur_ent = anchor
            ans = []
            for ele, query_ele in zip(ans_structure[::-1], exec_query[::-1]):
                if ele == 'r':
                    r_id, r_argmax = query_ele
                    ans.append(r_id)
                    cur_ent = int(r_argmax[cur_ent])
                elif ele == 'n':
                    ans.append('n')
                else:
                    ans.append(cur_ent)
            return ans[::-1], cur_ent

        elif ans_structure[1][0] == 'r': 
            r_ans, r_ent = self.backward_ans(ans_structure[1], exec_query[1], anchor)
            e_ans, e_ent = self.backward_ans(ans_structure[0], exec_query[0], r_ent)
            ans = [e_ans, r_ans, anchor]
            return ans, e_ent
            
        else:
            ans = []
            for ele, query_ele in zip(ans_structure[:-1], exec_query[:-1]):
                ele_ans, ele_ent = self.backward_ans(ele, query_ele, anchor)
                ans.append(ele_ans)
            ans.append(anchor)
            return ans, ele_ent