from tqdm import tqdm
import argparse
from utils import *
from freebase_func import *
import random
from client import *

if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument("--dataset", type=str,
                        default="webqsp", help="choose the dataset.")
    parser.add_argument("--max_length", type=int,
                        default=256, help="the max length of LLMs output.")
    parser.add_argument("--temperature_exploration", type=float,
                        default=0.6, help="the temperature in exploration stage.")
    parser.add_argument("--temperature_reasoning", type=float,
                        default=0, help="the temperature in reasoning stage.")
    parser.add_argument("--width", type=int,
                        default=3, help="choose the search width of ToG.")
    parser.add_argument("--depth", type=int,
                        default=3, help="choose the search depth of ToG.")
    parser.add_argument("--remove_unnecessary_rel", type=bool,
                        default=True, help="whether removing unnecessary relations.")
    parser.add_argument("--LLM_type", type=str,
                        default="", help="base LLM model.")
    parser.add_argument("--openai_url_base", type=str,
                        default="", help="peani_url_base.")
    parser.add_argument("--openai_api_keys", type=str,
                        default="", help="if the LLM_type is gpt-3.5-turbo or gpt-4, you need add your own openai api keys.")
    parser.add_argument("--num_retain_entity", type=int,
                        default=5, help="Number of entities retained during entities search.")
    parser.add_argument("--prune_tools", type=str,
                        default="llm", help="prune tools for ToG, can be llm (same as LLM_type), bm25 or sentencebert.")
    args = parser.parse_args()

    datas, question_string = prepare_dataset(args.dataset)
    print("Start Running ToG on %s dataset." % args.dataset)
    error_r = 0
    error_n = 0
    for data in tqdm(datas[:1000]):
        question = data[question_string]
        topic_entity = data['topic_entity']
        cluster_chain_of_entities = []
        if len(topic_entity) == 0:
            results = generate_without_explored_paths(question, args)
            save_2_jsonl(question, results, [], file_name=args.dataset)
            continue
        pre_relations = []
        reasons_list = []
        pre_heads = [-1] * len(topic_entity)
        flag_printed = False
        instruction = reason_search(question, args)
        instruction_all = ["".join(instruction[:index + 1]) for index in range(len(instruction))]
        # reasons_scores = reasons_score(question, reasons_all, args)
        # reasons = [reason for score,reason in zip(reasons_scores, reasons_all) if score > 0.3]
        # allin = True
        # for r in reasons:
        #     if r not in reasons_all:
        #         allin = False
        #reasons = [reason.replace('Reason1:', 'Reason:').replace('Reason2:', 'Reason:').replace('Reason3:', 'Reason:') for reason in reasons]
        # if allin:
        #     save_2_jsonl_reason(len(reasons_all),len(reasons),file_name=args.dataset)
        # else:
        #     save_2_jsonl_reason(reasons_all, reasons, file_name=args.dataset)
        #best entity triplet, entitiy_id
        flag_break_all = False
        reason_finish = False
        for reason in instruction_all:
            if reason == instruction_all[-1]:
                reason_finish = True
            if not flag_printed:
                for depth in range(1, args.depth + 1):
                    current_entity_relations_list = []
                    i = 0
                    for entity in topic_entity:
                        if entity != "[FINISH_ID]":  # 关系打分
                            retrieve_relations_with_scores = relation_search_prune(entity, topic_entity[entity], pre_relations, pre_heads[i], question, args, reason)  # best entity triplet, entitiy_id
                            current_entity_relations_list.extend(retrieve_relations_with_scores)  # 获取关系和得分
                        i += 1
                    total_candidates = []
                    total_scores = []
                    total_relations = []
                    total_entities_id = []
                    total_topic_entities = []
                    total_head = []

                    for entity in current_entity_relations_list:
                        if entity['head']:
                            entity_candidates_id = entity_search(entity['entity'], entity['relation'], True)
                        else:
                            entity_candidates_id = entity_search(entity['entity'], entity['relation'], False)

                        if args.prune_tools == "llm":
                            if len(entity_candidates_id) >= 20:

                                entity_candidates_id = random.sample(entity_candidates_id, args.num_retain_entity)

                        if len(entity_candidates_id) == 0:
                            continue  # 不同关系对应的实体打分
                        scores, entity_candidates, entity_candidates_id = entity_score(question, entity_candidates_id, entity['score'], entity['relation'], args ,reason)

                        total_candidates, total_scores, total_relations, total_entities_id, total_topic_entities, total_head = update_history(entity_candidates, entity, scores, entity_candidates_id, total_candidates, total_scores,
                            total_relations, total_entities_id, total_topic_entities, total_head)

                    if len(total_candidates) == 0:
                        if reason_finish:
                            half_stop(question, cluster_chain_of_entities, depth, args)
                            flag_printed = True
                            flag_break_all = True
                            break
                        else:
                            break

                    # 返回前三条链路
                    flag, chain_of_entities, entities_id, pre_relations, pre_heads = entity_prune(total_entities_id,total_relations,total_candidates,total_topic_entities,total_head, total_scores,args)
                    cluster_chain_of_entities.append(chain_of_entities)
                    if flag:
                        stop, results = reasoning(question, cluster_chain_of_entities, args)  # 三元组和您的知识是否足以回答问题（是或否）。
                        if stop:
                            print("ToG stoped at depth %d." % depth)
                            results = generate_answer(question, cluster_chain_of_entities, args)
                            save_2_jsonl(question, results, cluster_chain_of_entities, file_name=args.dataset)
                            flag_printed = True
                            flag_break_all = True
                            break
                        else:
                            print("depth %d still not find the answer." % depth)
                            flag_finish, entities_id = if_finish_list(entities_id)
                            if flag_finish:
                                half_stop(question, cluster_chain_of_entities, depth, args)
                                flag_printed = True
                                flag_break_all = True
                            else:
                                topic_entity = {entity: id2entity_name_or_type(entity) for entity in entities_id}
                                continue
                    elif reason_finish == True:
                        half_stop(question, cluster_chain_of_entities, depth, args)
                        flag_printed = True
                        flag_break_all = True
                        break
                    else:
                        break
            if flag_break_all:  # 检查是否需要跳出最外层循环
                break
        if not flag_printed:
            results = generate_without_explored_paths(question, args)
            save_2_jsonl(question, results, [], file_name=args.dataset)


