import json, os, re, requests, time, random, csv, sys, argparse
from pathlib import Path
from openai import OpenAI
from datasets import load_dataset
from sentence_transformers import SentenceTransformer
import faiss


PROMPT_TEMPLATE = """
You are an expert in the field of knowledge graphs. Below is a list of triples retrieved from a knowledge graph. Please answer the following question based strictly on this information.

Please follow these instructions:
1. Carefully analyze the given knowledge graph triples and answer only based on the provided content. Do not fabricate or infer information that is not explicitly stated.
2. There may be multiple correct answers. Please provide a complete and accurate response.

Retrieved triples:
{retrieved_triples}
…

Please answer the following question:
{question}
"""



def expand_neighbors_via_faiss(seed_triples, full_triple_str_list, model, query_vec, triple_str2tuple, index, topk=10):
    # 1. 提取种子实体
    entities = set()
    for triple_str in seed_triples:
        try:
            s, _, o = triple_str2tuple[triple_str]
            entities.update([s, o])
        except KeyError:
            continue

    # 2. 找所有邻居 triple（满足 s/o 相邻）
    candidate_neighbors = set()
    for triple_str in full_triple_str_list:
        if triple_str in seed_triples:
            continue
        triple = triple_str2tuple.get(triple_str)
        if triple is None:
            continue
        s, _, o = triple
        if s in entities or o in entities:
            candidate_neighbors.add(triple_str)

    if not candidate_neighbors:
        return seed_triples

    # 3. 全图 query 向量检索，按相似度排序
    distances, indices = index.search(query_vec, len(full_triple_str_list))
    ranked_neighbors = []
    seen = set()
    for i in indices[0]:
        triple_str = full_triple_str_list[i]
        if triple_str in candidate_neighbors and triple_str not in seen:
            ranked_neighbors.append(triple_str)
            seen.add(triple_str)
        if len(ranked_neighbors) >= 150:
            break

    # 4. 合并
    extended_set = set(seed_triples)
    extended_set.update(ranked_neighbors)
    return list(extended_set)[:150]
    # # 提取种子实体
    # entities = set()
    # for triple_str in seed_triples:
    #     try:
    #         s, _, o = triple_str2tuple[triple_str]
    #         entities.update([s, o])
    #     except KeyError:
    #         print(f"[WARN] Triple string not found in dict: {triple_str}")
    #         continue
    # # 找邻居 triples（仍使用字符串表示，方便后面 encode）
    # neighbor_strs = []
    # for triple_str in full_triple_str_list:
    #     triple = triple_str2tuple.get(triple_str)
    #     if triple is None:
    #         continue
    #     s, _, o = triple
    #     if s in entities or o in entities:
    #         neighbor_strs.append(triple_str)

    # # 去重 + 去除已存在的种子
    # neighbor_strs = list(set(neighbor_strs) - set(seed_triples))
    # if not neighbor_strs:
    #     return seed_triples
    
    # # 构建 full_triple_str -> index 中的向量位置映射
    # triple_str2index = {triple: i for i, triple in enumerate(full_triple_str_list)}

    # # 在全图中进行 query 检索
    # distances, indices = index.search(query_vec, len(full_triple_str_list))  # 全量搜索

    # # 从 query 相似度排序中，筛出出现在 neighbor_strs 中的前 top_n 个
    # top_n = 200
    # ranked_neighbors = []
    # seen = set()
    # for i in indices[0]:
    #     triple_str = full_triple_str_list[i]
    #     if triple_str in neighbor_strs and triple_str not in seen:
    #         ranked_neighbors.append(triple_str)
    #         seen.add(triple_str)
    #     if len(ranked_neighbors) >= top_n:
    #         break

    # neighbor_strs = ranked_neighbors
    # if not neighbor_strs:
    #     return seed_triples

    # print(f"Expanding to {len(neighbor_strs)} neighbors...")
    # neighbor_embeddings = model.encode(neighbor_strs, convert_to_numpy=True, show_progress_bar=True)
    # tmp_index = faiss.IndexFlatL2(neighbor_embeddings.shape[1])
    # tmp_index.add(neighbor_embeddings)

    # seed_embeddings = model.encode(seed_triples, convert_to_numpy=True)
    # distances, indices = tmp_index.search(seed_embeddings, topk)

    # extended_set = set(seed_triples)
    # for row in indices:
    #     for i in row:
    #         extended_set.add(neighbor_strs[i])

    # return list(extended_set)


def query_llm(prompt, qa_model_name, api_base, served_model_name):
    print(f"Querying {qa_model_name} model...")
    openai_api_key = "EMPTY"
    openai_api_base = api_base
    
    if len(prompt) > 4096:
        prompt = prompt[:4096]
        
    #print(f"Prompt length: {len(prompt)} characters")

    client = OpenAI(    
        api_key=openai_api_key,   
        base_url=openai_api_base,   
    ) 
    while(True):
        try:
            response = client.completions.create(   
                model=served_model_name,     
                prompt=prompt,    
                stream=False,
                max_tokens=1024
            )
            model_response = response.choices[0].text
            break
        except Exception as e:
            print(f"{e},请求失败，进行重试...")
            time.sleep(random.uniform(1, 3))  
        
    return model_response


def run_test(qa_model_name, experiment_method, dataset_name, api_base, served_model_name):
         
    triple_str2tuple = {}
    if experiment_method == "naive_rag" or experiment_method == "rag_2hop" or experiment_method == "rag_3hop":
        print("Loading Embedding model...")
        model = SentenceTransformer("sentence-transformers/gtr-t5-large")
        if dataset_name == "metaqa_1hop" or dataset_name == "metaqa_2hop" or dataset_name == "metaqa_3hop":
            with open("data/metaqa_kb.json", "r", encoding="utf-8") as f:
                now_graph = json.load(f)
            metaqa_allgraph_str_list = []
            for graph_tuple in now_graph:
                metaqa_allgraph_str_list.append(f"({graph_tuple[0]}, {graph_tuple[1]}, {graph_tuple[2]})")
            allgraph_triple_list = [(t[0], t[1], t[2]) for t in now_graph]
            triple_str2tuple = dict(zip(metaqa_allgraph_str_list, allgraph_triple_list))  # 可选
                
            print("Encoding graph texts...")
            metaqa_embeddings = model.encode(metaqa_allgraph_str_list, convert_to_numpy=True, show_progress_bar=True) 
            print("Building metaqa FAISS index...")
            metaqa_dimension = metaqa_embeddings.shape[1]
            metaqa_index = faiss.IndexFlatL2(metaqa_dimension)
            metaqa_index.add(metaqa_embeddings)  # 将向量加入索引中
                
    save_path = f"log/qa_test_log_new/{qa_model_name}_{experiment_method}_{dataset_name}.jsonl"
    save_dir = Path(save_path).parent
    save_dir.mkdir(parents=True, exist_ok=True)
        
    if dataset_name == "webqsp" or dataset_name == "cwq":
        print("Loading dataset...")
        dataset = load_dataset(f"rmanluo/RoG-{dataset_name}")
        test_data = dataset['test']
    else:
        dataset_path = f"data/{dataset_name}/qa_test_hf_format.json"
        with open(dataset_path, "r", encoding="utf-8") as f:
            test_data = json.load(f)
        with open("data/metaqa_kb.json", "r", encoding="utf-8") as f:
            metaqa_kb = json.load(f)
        metaqa_triples = "\n".join([f"({t[0]}, {t[1]}, {t[2]})" for t in metaqa_kb])
            
    data_num = len(test_data)
    
    if not os.path.exists(save_path):
        with open(save_path, "w", encoding="utf-8") as f:
            header = {"qa_model_name": qa_model_name, "experiment_method": experiment_method, "dataset_name": dataset_name, "test_data_num": data_num}
            f.write(json.dumps(header, ensure_ascii=False) + "\n")
    with open(save_path, "r", encoding="utf-8") as f:
        now_line_count = sum(1 for _ in f)
        
    for idx, data_item in enumerate(test_data):
        if idx < now_line_count - 1:
            continue
        print(f"PROCESSING {idx} / {data_num}")
        
        question=data_item['question']
        true_answer=data_item['a_entity']
        
        if experiment_method == "no_graph":
            retrieved_triples = ""
        
        elif experiment_method == "no_retriever":
            if dataset_name == "webqsp" or dataset_name == "cwq":
                retrieved_triples = data_item['graph']
                retrieved_triples = "\n".join([f"({t[0]}, {t[1]}, {t[2]})" for t in retrieved_triples])
                retrieved_triples = retrieved_triples[:4096]
            else:
                retrieved_triples = "\n".join([f"({t[0]}, {t[1]}, {t[2]})" for t in metaqa_kb])
                retrieved_triples = metaqa_triples[:4096]
            
        elif experiment_method in ["naive_rag", "rag_2hop", "rag_3hop"]:  
            if dataset_name == "metaqa_1hop" or dataset_name == "metaqa_2hop" or dataset_name == "metaqa_3hop":
                print("Searching...")
                query_vec = model.encode([question], convert_to_numpy=True)
                distances, indices = metaqa_index.search(query_vec, 10)
                initial_results = [metaqa_allgraph_str_list[i] for i in indices[0]]
                results = initial_results
                hops = {"naive_rag": 0, "rag_2hop": 1, "rag_3hop": 2}[experiment_method]
                for _ in range(hops):
                    results = expand_neighbors_via_faiss(results, metaqa_allgraph_str_list, model, query_vec, triple_str2tuple, metaqa_index, topk=10)

            else:
                now_graph = data_item['graph']                
                allgraph_str_list = [f"({t[0]}, {t[1]}, {t[2]})" for t in now_graph]
                allgraph_triple_list = [(t[0], t[1], t[2]) for t in now_graph]
                triple_str2tuple = dict(zip(allgraph_str_list, allgraph_triple_list))  # 可选
                print("Encoding graph texts...")
                embeddings = model.encode(allgraph_str_list, convert_to_numpy=True, show_progress_bar=True) 
                results = []
                try:
                    print("Building FAISS index...")
                    dimension = embeddings.shape[1]
                    index = faiss.IndexFlatL2(dimension)
                    index.add(embeddings)  # 将向量加入索引中
                    print("Searching...")
                    query_vec = model.encode([question], convert_to_numpy=True)
                    distances, indices = index.search(query_vec, 10)
                    initial_results = [allgraph_str_list[i] for i in indices[0]]
                    results = initial_results
                    hops = {"naive_rag": 0, "rag_2hop": 1, "rag_3hop": 2}[experiment_method]
                    for _ in range(hops):
                        results = expand_neighbors_via_faiss(results, allgraph_str_list, model, query_vec, triple_str2tuple, index, topk=10)
                except Exception as e:
                    print(f"Error in FAISS index building or searching: {e}")
                    
            retrieved_triples = "\n".join(results)
        
        elif experiment_method == "g_retriever":
            subg_path = f"G-retriever/dataset/0511_gretriever_cache_subg/{dataset_name}/cached_desc/{idx}.txt"
            with open(subg_path, 'r', encoding='utf-8') as f:
                lines = f.readlines()
            node_id_to_attr = {}
            triples = []
            is_node_section = True
            for line in lines:
                line = line.strip()
                if not line:
                    continue
                if line.startswith('src,edge_attr,dst'):
                    is_node_section = False
                    continue
                if is_node_section:
                    node_id, node_attr = line.split(',', 1)
                    node_id_to_attr[node_id.strip()] = node_attr.strip()
                else:
                    src, edge_attr, dst = line.split(',', 2)
                    src_name = node_id_to_attr.get(src.strip(), f"UNKNOWN({src.strip()})")
                    dst_name = node_id_to_attr.get(dst.strip(), f"UNKNOWN({dst.strip()})")
                    triples.append((src_name, edge_attr.strip(), dst_name))
            retrieved_triples = "\n".join([f"({h}, {r}, {t})" for h, r, t in triples])
            
        elif experiment_method == "ours":
            id = data_item['id']
            subg_path = f"log/test_log/qwen3/ours/{dataset_name}/{id}.json"
            with open(subg_path, 'r', encoding='utf-8') as f:
                subg_data = json.load(f)
            try:
                triples = subg_data[-1]["now_state"]
            except:
                triples = []
            retrieved_triples = "\n".join(triples)
        
        prompt = PROMPT_TEMPLATE.format(retrieved_triples=retrieved_triples, question=question)
        res = query_llm(prompt, qa_model_name, api_base, served_model_name)
        print(res)
        #time.sleep(random.uniform(0.5, 3.5)) 
        
        with open(save_path, "a", encoding="utf-8") as f:
            writer = {
                "id": data_item['id'],
                "question": question,
                "true_answer": true_answer,
                "model_response": res,
            }
            f.write(json.dumps(writer, ensure_ascii=False) + "\n")
        # print(question)
        # print(true_answer)
        # print(retrieved_triples)
        # print(res)
        
        
if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="QA experiment")
    parser.add_argument('--model', type=str, required=True, help='Model to use')
    parser.add_argument('--method', type=str, required=True, help='Method to apply')
    parser.add_argument('--dataset', type=str, required=True, help='Dataset to use')
    parser.add_argument('--api_base', type=str, required=True, help='API base URL for the model')
    parser.add_argument('--served_model_name', type=str, required=True, help='Served model name')
   
    args = parser.parse_args()
    if args.model not in ["deepseek", "qwen", "llama", "tuned"]:
        print(f"Error: Invalid model '{args.model}'.")
        sys.exit(1)
    if args.method not in ["no_graph", "naive_rag", "g_retriever", "ours", "rag_2hop", "rag_3hop", "no_retriever"]:
        print(f"Error: Invalid method '{args.method}'.")
        sys.exit(1)
    if args.dataset not in ["webqsp", "cwq", "metaqa_1hop", "metaqa_2hop", "metaqa_3hop"]:
        print(f"Error: Invalid dataset '{args.dataset}'.")
        sys.exit(1)
        
    qa_model_name = args.model
    experiment_method = args.method
    dataset_name = args.dataset
    api_base = args.api_base
    served_model_name = args.served_model_name
    print(f"Running test with model: {qa_model_name}, method: {experiment_method}, dataset: {dataset_name}")
    run_test(qa_model_name, experiment_method, dataset_name, api_base, served_model_name)
    

    