import argparse, random, json, re, os, time
from datasets import load_dataset
from rank_bm25 import BM25Okapi
from tqdm import tqdm
from bert_score import score as bert_score
from openai import OpenAI
from transformers import logging

# 消除 HuggingFace 模型加载时的警告输出
logging.set_verbosity_error()

prompt_evaluate="""Given a question and the associated retrieved knowledge graph triplets (entity, relation, entity), you are asked to answer whether it's sufficient for you to answer the question with these triplets and your knowledge (Yes or No). If the answer is Yes, you should also provide the answer in your response.
### Examples:

Question: Find the person who said \"Taste cannot be controlled by law\", what did this person die from?
Knowledge Triplets: 
Taste cannot be controlled by law., media_common.quotation.author, Thomas Jefferson
Answer: {No}. Based on the given knowledge triplets, it's not sufficient to answer the entire question. The triplets only provide information about the person who said "Taste cannot be controlled by law," which is Thomas Jefferson. To answer the second part of the question, it's necessary to have additional knowledge about where Thomas Jefferson's dead.

Question: The artist nominated for The Long Winter lived where?
Knowledge Triplets: 
The Long Winter, book.written_work.author, Laura Ingalls Wilder
Laura Ingalls Wilder, people.person.places_lived, Unknown-Entity
Unknown-Entity, people.place_lived.location, De Smet
Answer: {Yes}. Based on the given knowledge triplets, the author of The Long Winter, Laura Ingalls Wilder, lived in De Smet. Therefore, the answer to the question is {De Smet}.

### Your task:
"""

def query_llm(prompt, args):
    qa_model_name = args.LLM_type
    max_length = args.max_length
    print(f"Querying {qa_model_name} model...")
    openai_api_key = "EMPTY"
    
    if len(prompt) > 4096:
        prompt = prompt[:4096]
    if qa_model_name == "qwen":
        api_base = "http://localhost:9102/v1"
        served_model_name = ""
    elif qa_model_name == "llama":
        api_base = "http://localhost:9104/v1"
        served_model_name = ""
    elif qa_model_name == "tuned":
        api_base = "http://localhost:9105/v1"
        served_model_name = ""
        
    client = OpenAI(    
        api_key=openai_api_key,   
        base_url=api_base,   
    ) 
    while(True):
        try:
            response = client.completions.create(   
                model=served_model_name,     
                prompt=prompt,    
                stream=False,
                max_tokens=max_length
            )
            model_response = response.choices[0].text
            break
        except Exception as e:
            print(f"{e},请求失败，进行重试...")
            time.sleep(random.uniform(1, 3))  
        
    return model_response
    
def if_true(prompt):
    if prompt.lower().strip().replace(" ","")=="yes":
        return True
    return False


def main(args):
    """ Load the dataset """
    if "metaqa" in args.dataset:
        data_path = f"data/{args.dataset}/qa_test_hf_format.json"
        with open(data_path, 'r') as f:
            test_data = json.load(f)
        with open("data/metaqa_kb.json", 'r') as f:
            metaqa_kb = json.load(f)
    else:
        print("Loading dataset...")
        dataset = load_dataset(f"rmanluo/RoG-{args.dataset}")
        test_data = dataset['test']
        
    save_path = f"log/qa_test_{args.LLM_type}/{args.LLM_type}_ToG_{args.dataset}.jsonl"  
    if os.path.exists(save_path):  
        with open(save_path, "r", encoding="utf-8") as f:
            now_line_count = sum(1 for _ in f)
    else:
        now_line_count = 0

    """ 数据格式 """
    """ data_item['id]: string, data_item['question']: string, data_item['q_entity']: list of strings, data_item['a_entity']: list of strings, data_item['graph]: list of tuples (subject, relation, object) """
    for idx, data_item in enumerate(test_data):
        if idx < now_line_count - 1: continue
        print(f"***************************** Processing {idx+1}/{len(test_data)}")
        question = data_item["question"]
        topic_entity_list = data_item["q_entity"]
        if "metaqa" in args.dataset:
            graph_triples = metaqa_kb
        else:
            graph_triples = data_item["graph"]
        writer = {
            "id": data_item["id"],
            "question": question,
            "true_answer": data_item["a_entity"],
            "model_response": ""
        }
        """ 构建图索引（以便快速查找邻居实体) """
        entity_to_triples = {}  # dict[str, list[tuple[str, str, str]]]
        for h, r, t in graph_triples:
            if h not in entity_to_triples:
                entity_to_triples[h] = []
            if t not in entity_to_triples:
                entity_to_triples[t] = []
            entity_to_triples[h].append((h, r, t))
        """ 定义 dict 存储路径与分数的对应关系, path_to_score: dict[list[str], float] """ 
        path_to_score = {}
        for ent in topic_entity_list:
            path_to_score[(ent,)] = 0
            
        """ 搜索深度 """
        for depth in range(1, args.depth+1): 
            print(f"******************* Exploring depth {depth}")
            new_path_to_score = {}  # 用于收集当前 depth 扩展的路径
            # 收集所有三元组
            all_pairs = []
            all_keys = []
            for ent_list in path_to_score:
                pre_ent = ent_list[-1]  # 获取当前路径的最后一个实体
                now_score = path_to_score[ent_list]
                for h, r, t in entity_to_triples[pre_ent]:
                    if t in ent_list: continue
                    triple_str = f"{h} {r} {t}."
                    all_pairs.append((ent_list, r, t, triple_str, now_score))
                    all_keys.append(triple_str)
            if not all_keys:
                print(f"No new paths found at depth {depth}, skipping...")
                continue  # 跳过本层 depth 的扩展
            # 一次性计算所有 BERTScore
            P, R, F1 = bert_score([question] * len(all_keys), all_keys, lang="en", verbose=False)
            F1 = [f.item() for f in F1]
            # 回填 new_path_to_score
            for i, (ent_list, r, t, triple_str, now_score) in enumerate(all_pairs):
                new_ent_list = ent_list + (r, t)
                new_path_to_score[new_ent_list] = now_score + F1[i]
            print("******************* Finish update path")
            
            """ 仅保留分数前 5 的路径 """
            sorted_paths = sorted(path_to_score.items(), key=lambda x: x[1], reverse=True)
            path_to_score = dict(sorted_paths[:args.num_retain_entity])
            """ 构造 prompt """
            triples_str_list = []
            for ent_list in list(path_to_score):
                triples = []
                """ 每一步是一个三元组，从索引 0 开始，每两个索引形成一个 (h, r, t) """
                for i in range(0, len(ent_list) - 2, 2):
                    h = ent_list[i]
                    r = ent_list[i + 1]
                    t = ent_list[i + 2]
                    triples.append(f"({h}, {r}, {t})")
                triples_str = " ".join(triples)
                triples_str_list.append(triples_str)

            # 如果你需要一个总字符串（例如拼接所有路径的所有三元组）
            all_triples_str = "\n".join(triples_str_list)
            prompt = prompt_evaluate + "\n" + f"Question: {question}\n" + "Knowledge Triplets\n" + all_triples_str + "\n\nAnswer: "
            #print(prompt)
            """ 调用 LLM 进行回答 """
            model_res = query_llm(prompt, args)
            #print("***************************** Model response:", model_res)
            writer["model_response"] = model_res
            flag = if_true(model_res)
            if flag: break
        with open(f"log/qa_test_{args.LLM_type}/{args.LLM_type}_ToG_{args.dataset}.jsonl", 'a') as f:
            f.write(json.dumps(writer, ensure_ascii=False) + "\n")
            
            
if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument("--dataset", type=str,
                        default="webqsp", help="choose the dataset.")
    parser.add_argument("--max_length", type=int,
                        default=256, help="the max length of LLMs output.")
    parser.add_argument("--temperature_exploration", type=float,
                        default=0.4, help="the temperature in exploration stage.")
    parser.add_argument("--temperature_reasoning", type=float,
                        default=0, help="the temperature in reasoning stage.")
    parser.add_argument("--width", type=int,
                        default=3, help="choose the search width of ToG.")
    parser.add_argument("--depth", type=int,
                        default=3, help="choose the search depth of ToG.")
    parser.add_argument("--remove_unnecessary_rel", type=bool,
                        default=True, help="whether removing unnecessary relations.")
    parser.add_argument("--LLM_type", type=str,
                        default="qwen", help="base LLM model.")
    parser.add_argument("--opeani_api_keys", type=str,
                        default="", help="if the LLM_type is gpt-3.5-turbo or gpt-4, you need add your own openai api keys.")
    parser.add_argument("--num_retain_entity", type=int,
                        default=5, help="Number of entities retained during entities search.")
    parser.add_argument("--prune_tools", type=str,
                        default="bm25", help="prune tools for ToG, can be llm (same as LLM_type), bm25 or sentencebert.")
    args = parser.parse_args()
    main(args)
    
