from SPARQLWrapper import SPARQLWrapper, JSON
from utils import *

# depend on your own internal address and port, shown in Freebase folder's readme.md
SPARQLPATH = ""

# pre-defined sparqls
sparql_head_relations = """\nPREFIX ns: <http://rdf.freebase.com/ns/>\nSELECT ?relation\nWHERE {\n  ns:%s ?relation ?x .\n}"""
sparql_tail_relations = """\nPREFIX ns: <http://rdf.freebase.com/ns/>\nSELECT ?relation\nWHERE {\n  ?x ?relation ns:%s .\n}"""
sparql_tail_entities_extract = """PREFIX ns: <http://rdf.freebase.com/ns/>\nSELECT ?tailEntity\nWHERE {\nns:%s ns:%s ?tailEntity .\n}""" 
sparql_head_entities_extract = """PREFIX ns: <http://rdf.freebase.com/ns/>\nSELECT ?tailEntity\nWHERE {\n?tailEntity ns:%s ns:%s  .\n}"""
sparql_id = """PREFIX ns: <http://rdf.freebase.com/ns/>\nSELECT DISTINCT ?tailEntity\nWHERE {\n  {\n    ?entity ns:type.object.name ?tailEntity .\n    FILTER(?entity = ns:%s)\n  }\n  UNION\n  {\n    ?entity <http://www.w3.org/2002/07/owl#sameAs> ?tailEntity .\n    FILTER(?entity = ns:%s)\n  }\n}"""
    
def check_end_word(s):
    words = [" ID", " code", " number", "instance of", "website", "URL", "inception", "image", " rate", " count"]
    return any(s.endswith(word) for word in words)

def abandon_rels(relation):
    if relation == "type.object.type" or relation == "type.object.name" or relation.startswith("common.") or relation.startswith("freebase.") or "sameAs" in relation:
        return True


def execurte_sparql_entity(entity, relation, sparql_entities_extract):

    #results = sparql.query().convert()
    max_retries = 10  # 设置最大重试次数
    retry_delay = 2  # 设置重试间隔时间（秒）

    for attempt in range(max_retries):
        try:
            sparql_query = sparql_entities_extract % (entity, relation)
            sparql = SPARQLWrapper(SPARQLPATH)
            sparql.setQuery(sparql_query)
            sparql.setReturnFormat(JSON)
            results = sparql.query().convert()
            break  # 如果查询成功，退出循环
        except Exception as e:
            print(sparql_query)
            print(f"查询失败，尝试次数 {attempt + 1}/{max_retries}: {e}")
            if attempt < max_retries - 1:
                time.sleep(retry_delay)  # 等待一段时间后重试
            else:
                print("达到最大重试次数，终止查询。")
                raise  # 如果达到最大重试次数，抛出异常

    return results["results"]["bindings"]
def execurte_sparql_relation(entity_id, sparql_relations):
    #results = sparql.query().convert()
# 'PREFIX ns: <http://rdf.freebase.com/ns/>
# SELECT ?tailEntity
# WHERE {
# ?tailEntity ns:m.06hggft ns:No additional provided relation  .
# }'

    max_retries = 10  # 设置最大重试次数
    retry_delay = 2  # 设置重试间隔时间（秒）

    for attempt in range(max_retries):
        try:
            sparql_query = sparql_relations % (entity_id)
            sparql = SPARQLWrapper(SPARQLPATH)
            sparql.setQuery(sparql_query)
            sparql.setReturnFormat(JSON)
            results = sparql.query().convert()
            break  # 如果查询成功，退出循环
        except Exception as e:
            print("sparql_query:"+sparql_query)
            print("entity_id:" + entity_id)
            print(f"查询失败，尝试次数 {attempt + 1}/{max_retries}: {e}")
            if attempt < max_retries - 1:
                time.sleep(retry_delay)  # 等待一段时间后重试
            else:
                print("达到最大重试次数，终止查询。")
                raise  # 如果达到最大重试次数，抛出异常
    return results["results"]["bindings"]


def replace_relation_prefix(relations):
    return [relation['relation']['value'].replace("http://rdf.freebase.com/ns/","") for relation in relations]

def replace_entities_prefix(entities):
    return [entity['tailEntity']['value'].replace("http://rdf.freebase.com/ns/","") for entity in entities]


def id2entity_name_or_type(entity_id):
    sparql_query = sparql_id % (entity_id, entity_id)
    sparql = SPARQLWrapper(SPARQLPATH)
    sparql.setQuery(sparql_query)
    sparql.setReturnFormat(JSON)
    #results = sparql.query().convert()
    max_retries = 10  # 设置最大重试次数
    retry_delay = 2  # 设置重试间隔时间（秒）

    for attempt in range(max_retries):
        try:
            results = sparql.query().convert()
            break  # 如果查询成功，退出循环
        except Exception as e:
            print(f"查询失败，尝试次数 {attempt + 1}/{max_retries}: {e}")
            if attempt < max_retries - 1:
                time.sleep(retry_delay)  # 等待一段时间后重试
            else:
                print("达到最大重试次数，终止查询。")
                raise  # 如果达到最大重试次数，抛出异常
    if len(results["results"]["bindings"])==0:
        return "UnName_Entity"
    else:
        return results["results"]["bindings"][0]['tailEntity']['value']
    
from freebase_func import *
from prompt_list import *
import json
import time
import openai
import re
from prompt_list import *
from rank_bm25 import BM25Okapi
from sentence_transformers import util
from sentence_transformers import SentenceTransformer


def clean_relations(string, entity_id, head_relations):
    pattern = r"{\s*(?P<relation>[^()]+)\s+\(Score:\s+(?P<score>[0-9.]+)\)}"
    relations=[]
    for match in re.finditer(pattern, string):
        relation = match.group("relation").strip()
        if ';' in relation:
            continue
        score = match.group("score")
        if not relation or not score:
            return False, "output uncompleted.."
        try:
            score = float(score)
        except ValueError:
            return False, "Invalid score"
        if relation in head_relations:
            relations.append({"entity": entity_id, "relation": relation, "score": score, "head": True})
        else:
            relations.append({"entity": entity_id, "relation": relation, "score": score, "head": False})
    if not relations:
        return False, "No relations found,未匹配到关系"
    pattern = re.compile(r'^[a-zA-Z._]+$')
    for i in relations:
        if not pattern.match(i["relation"]):
            return False, "关系格式有误"
    return True, relations


def if_all_zero(topn_scores):
    return all(score == 0 for score in topn_scores)


def clean_relations_bm25_sent(topn_relations, topn_scores, entity_id, head_relations):
    relations = []
    if if_all_zero(topn_scores):
        topn_scores = [float(1/len(topn_scores))] * len(topn_scores)
    i=0
    for relation in topn_relations:
        if relation in head_relations:
            relations.append({"entity": entity_id, "relation": relation, "score": topn_scores[i], "head": True})
        else:
            relations.ask({"entity": entity_id, "relation": relation, "score": topn_scores[i], "head": False})
        i+=1
    return True, relations


def construct_relation_prune_prompt_reason(question, entity_name, total_relations, args, reasons):
    return extract_relation_prompt_reason % (args.width, args.width) + question + '\n' + reasons +'\nTopic Entity: ' + entity_name + '\nRelations: '+ '; '.join(total_relations) + "\n###output"+ "\nA: "
        

def construct_entity_score_prompt_reason(question, relation, entity_candidates, reason):
    return score_entity_candidates_prompt_reason.format(question, reason, relation, "; ".join(entity_candidates))


def construct_relation_prune_prompt(question, entity_name, total_relations, args):
    return extract_relation_prompt % (
    args.width, args.width) + question + '\nTopic Entity: ' + entity_name + '\nRelations: ' + '; '.join(
        total_relations) + "\n###output: " + "\nA: "


def construct_entity_score_prompt(question, relation, entity_candidates):
    return score_entity_candidates_prompt.format(question, relation, "; ".join(entity_candidates))


def reason_search_prompt(question):
    return query_reason_prompt.format(question)

def clean_reason(result):
    # 使用正则表达式匹配冒号后面的内容
    # pattern = r":\s*(.*)"
    # matches = re.findall(pattern, result)
    try:
        pattern = r"instruction\d+:[^\n]+"
        extracted_reasons = re.findall(pattern, result)
    except Exception:
        return False, "instruction 提取出错"

    if not extracted_reasons:
        return False, []
    return True, extracted_reasons
def clean_reason_scores(result):
    # 使用正则表达式匹配冒号后面的内容
    # pattern = r":\s*(.*)"
    # matches = re.findall(pattern, result)
    try:
        scores = re.findall(r'\((.*?)\)', result)
        reason_scores = [float(score) for score in scores]
    except Exception:
        return False, "reason_scores 提取报错"

    if not reason_scores:
        return False, []
    return True, reason_scores

# def entity_filter(question, reasons, top_entity, relation,  entitys, args):
#     prompt = entity_filter_prompt.format(args.num_retain_entity, question, "; ".join(reasons), top_entity, relation,  "; ".join(entitys))
#     result = run_llm(prompt, args.temperature_exploration, args.max_length, args.openai_api_keys, args.openai_url_base, args.LLM_type)
#     flag, reasons = clean_reason(result)
#     if flag:
#         return reasons
#     else:
#         save_2_jsonl_error(question, "推理出错", result, file_name=args.dataset)
#         return []
def reason_search(question, args):
    prompt = reason_search_prompt(question)
    result = run_llm(prompt, args.temperature_exploration, args.max_length, args.openai_api_keys, args.openai_url_base, args.LLM_type)
    flag, reasons = clean_reason(result)
    if flag:
        return reasons
    else:
        return []
def reasons_score(question, reasons, args):
    prompt = reason_score_prompt.format(question, "\n".join(reasons))
    result = run_llm(prompt, args.temperature_exploration, args.max_length, args.openai_api_keys, args.openai_url_base, args.LLM_type)
    flag, reasons_score = clean_reason_scores(result)
    if flag:
        return reasons_score
    else:
        max_attempts = 3
        attempts = 0
        while flag == False and attempts < max_attempts:
            result = run_llm(prompt, args.temperature_exploration, args.max_length, args.openai_api_keys, args.openai_url_base, args.LLM_type)
            flag, reasons_score = clean_reason_scores(result)
            attempts += 1
            if attempts == 3:
                save_2_jsonl_error(question, "reasons_score提取出错", result, file_name=args.dataset)
                return []
        return reasons_score
# def reason_summary(question, args):
#     prompt = reason_summary_prompt.format(question, reasons)
#     result = run_llm(prompt, args.temperature_exploration, args.max_length, args.openai_api_keys, args.openai_url_base, args.LLM_type)
#     flag, reasons = clean_reason(result)
#     if flag:
#         return reasons
#     else:
#         return []
def relation_search_prune(entity_id, entity_name, pre_relations, pre_head, question, args, reasons):
    #sparql_relations_extract_head = sparql_head_relations % (entity_id)
    #print(sparql_relations_extract_head)
    head_relations = execurte_sparql_relation(entity_id,sparql_head_relations)
    head_relations = replace_relation_prefix(head_relations)
    
    #sparql_relations_extract_tail= sparql_tail_relations % (entity_id)
    tail_relations = execurte_sparql_relation(entity_id,sparql_tail_relations)
    tail_relations = replace_relation_prefix(tail_relations)

    if args.remove_unnecessary_rel:
        head_relations = [relation for relation in head_relations if not abandon_rels(relation)]
        tail_relations = [relation for relation in tail_relations if not abandon_rels(relation)]
    
    if pre_head:
        tail_relations = list(set(tail_relations) - set(pre_relations))
    else:
        head_relations = list(set(head_relations) - set(pre_relations))

    head_relations = list(set(head_relations))
    tail_relations = list(set(tail_relations))
    total_relations = head_relations+tail_relations
    total_relations.sort()  # make sure the order in prompt is always equal
    
    if args.prune_tools == "llm":
        prompt = construct_relation_prune_prompt(question, entity_name, total_relations, args)
        result = run_llm(prompt, args.temperature_exploration, args.max_length, args.openai_api_keys, args.openai_url_base, args.LLM_type)
        #print(result)
        flag, retrieve_relations_with_scores = clean_relations(result, entity_id, head_relations)

    if flag:
        return retrieve_relations_with_scores
    else:
        return [] # format error or too small max_length
    
    
def entity_search(entity, relation, head=True):
    if head:
        #tail_entities_extract = sparql_tail_entities_extract% (entity, relation)
        entities = execurte_sparql_entity(entity, relation, sparql_tail_entities_extract)
    else:
        #head_entities_extract = sparql_head_entities_extract% (entity, relation)
        entities = execurte_sparql_entity(entity, relation, sparql_head_entities_extract)


    entity_ids = replace_entities_prefix(entities)
    new_entity = [entity for entity in entity_ids if entity.startswith("m.")]
    return new_entity


def entity_score(question, entity_candidates_id, score, relation, args, reasons):
    entity_candidates = [id2entity_name_or_type(entity_id) for entity_id in entity_candidates_id]
    if all_unknown_entity(entity_candidates):
        return [1/len(entity_candidates) * score] * len(entity_candidates), entity_candidates, entity_candidates_id
    entity_candidates = del_unknown_entity(entity_candidates)
    if len(entity_candidates) == 1:
        return [score], entity_candidates, entity_candidates_id
    if len(entity_candidates) == 0:
        return [0.0], entity_candidates, entity_candidates_id
    
    # make sure the id and entity are in the same order
    zipped_lists = sorted(zip(entity_candidates, entity_candidates_id))
    entity_candidates, entity_candidates_id = zip(*zipped_lists)
    entity_candidates = list(entity_candidates)
    entity_candidates_id = list(entity_candidates_id)
    if args.prune_tools == "llm":
        if reasons[0] == " ":
            prompt = construct_entity_score_prompt(question, relation, entity_candidates)
        else:
            prompt = construct_entity_score_prompt_reason(question, relation, entity_candidates, reasons)
        result = run_llm(prompt, args.temperature_exploration, args.max_length, args.openai_api_keys, args.openai_url_base, args.LLM_type)
        #print(result)
        flag = clean_scores(result, entity_candidates)[0]
        #print(flag)
        if flag == True:
            return [float(x) * score for x in clean_scores(result, entity_candidates)[1]], entity_candidates, entity_candidates_id
        else:
            return []


    
def update_history(entity_candidates, entity, scores, entity_candidates_id, total_candidates, total_scores, total_relations, total_entities_id, total_topic_entities, total_head):
    if len(entity_candidates) == 0:
        entity_candidates.append("[FINISH]")
        entity_candidates_id = ["[FINISH_ID]"]
    candidates_relation = [entity['relation']] * len(entity_candidates)
    topic_entities = [entity['entity']] * len(entity_candidates)
    head_num = [entity['head']] * len(entity_candidates)
    total_candidates.extend(entity_candidates)
    total_scores.extend(scores)
    total_relations.extend(candidates_relation)
    total_entities_id.extend(entity_candidates_id)
    total_topic_entities.extend(topic_entities)
    total_head.extend(head_num)
    return total_candidates, total_scores, total_relations, total_entities_id, total_topic_entities, total_head


def half_stop(question, cluster_chain_of_entities, depth, args):
    print("No new knowledge added during search depth %d, stop searching." % depth)
    answer = generate_answer(question, cluster_chain_of_entities, args)
    save_2_jsonl(question, answer, cluster_chain_of_entities, file_name=args.dataset)


def generate_answer(question, cluster_chain_of_entities, args): 
    prompt = answer_prompt + question + '\n'
    chain_prompt = '\n'.join([', '.join([str(x) for x in chain]) for sublist in cluster_chain_of_entities for chain in sublist])
    prompt += "\nKnowledge Triplets: " + chain_prompt + 'A: '
    result = run_llm(prompt, args.temperature_reasoning, args.max_length, args.openai_api_keys, args.openai_url_base, args.LLM_type)
    return result


def entity_prune(total_entities_id, total_relations, total_candidates, total_topic_entities, total_head, total_scores, args):
    zipped = list(zip(total_entities_id, total_relations, total_candidates, total_topic_entities, total_head, total_scores))
    sorted_zipped = sorted(zipped, key=lambda x: x[5], reverse=True)
    sorted_entities_id, sorted_relations, sorted_candidates, sorted_topic_entities, sorted_head, sorted_scores = [x[0] for x in sorted_zipped], [x[1] for x in sorted_zipped], [x[2] for x in sorted_zipped], [x[3] for x in sorted_zipped], [x[4] for x in sorted_zipped], [x[5] for x in sorted_zipped]

    entities_id, relations, candidates, topics, heads, scores = sorted_entities_id[:args.width], sorted_relations[:args.width], sorted_candidates[:args.width], sorted_topic_entities[:args.width], sorted_head[:args.width], sorted_scores[:args.width]
    merged_list = list(zip(entities_id, relations, candidates, topics, heads, scores))
    filtered_list_test = [(id, rel, ent, top, hea, score) for id, rel, ent, top, hea, score in merged_list if score != 0]#去掉得分为0的值
    if len(filtered_list_test) == 0:
        print('所有路径均为0分')

    filtered_list = [(id, rel, ent, top, hea, score) for id, rel, ent, top, hea, score in merged_list if score >= 0]
    if len(filtered_list) == 0:
        return False, [], [], [], []
    entities_id, relations, candidates, tops, heads, scores = map(list, zip(*filtered_list))

    tops = [id2entity_name_or_type(entity_id) for entity_id in tops]
    cluster_chain_of_entities = [[(tops[i], relations[i], candidates[i]) for i in range(len(candidates))]]
    return True, cluster_chain_of_entities, entities_id, relations, heads


def reasoning(question, cluster_chain_of_entities, args):
    prompt = prompt_evaluate + question
    chain_prompt = '\n'.join([', '.join([str(x) for x in chain]) for sublist in cluster_chain_of_entities for chain in sublist])
    prompt += "\nKnowledge Triplets: " + chain_prompt + 'A: '

    response = run_llm(prompt, args.temperature_reasoning, args.max_length, args.openai_api_keys, args.openai_url_base, args.LLM_type)
    
    result = extract_answer(response)
    if if_true(result):
        return True, response
    else:
        return False, response
    



