import json, os, argparse
from datasets import load_dataset
from sentence_transformers import SentenceTransformer
import faiss
import random, re, os, time
from datasets import load_dataset
from rank_bm25 import BM25Okapi
from tqdm import tqdm
from bert_score import score as bert_score
from transformers import logging


def check_answer(model_response, true_answer): 
    """ true_answer is a list """
    if isinstance(model_response, list):
        model_response = [x.lower() for x in model_response if x is not None]
        true_answer = [x.lower() for x in true_answer if x is not None]
        if len(true_answer) != len(model_response): return False
        for ans_item in true_answer:
            if ans_item not in model_response:
                return False
        return True             
    return False

def expand_neighbors_via_faiss(seed_triples, full_triple_str_list, model, query_vec, triple_str2tuple, index, topk=10):
    # 1. 提取种子实体
    entities = set()
    for triple_str in seed_triples:
        try:
            s, _, o = triple_str2tuple[triple_str]
            entities.update([s, o])
        except KeyError:
            continue

    # 2. 找所有邻居 triple（满足 s/o 相邻）
    candidate_neighbors = set()
    for triple_str in full_triple_str_list:
        if triple_str in seed_triples:
            continue
        triple = triple_str2tuple.get(triple_str)
        if triple is None:
            continue
        s, _, o = triple
        if s in entities or o in entities:
            candidate_neighbors.add(triple_str)

    if not candidate_neighbors:
        return seed_triples

    # 3. 全图 query 向量检索，按相似度排序
    distances, indices = index.search(query_vec, len(full_triple_str_list))
    ranked_neighbors = []
    seen = set()
    for i in indices[0]:
        triple_str = full_triple_str_list[i]
        if triple_str in candidate_neighbors and triple_str not in seen:
            ranked_neighbors.append(triple_str)
            seen.add(triple_str)
        # if len(ranked_neighbors) >= 150:
        #     break

    # 4. 合并
    extended_set = set(seed_triples)
    extended_set.update(ranked_neighbors)
    return list(extended_set)
#[:150]

def test_ours():
    log_dir = "log/ours_retriever_test_step_log"
    datasets = os.listdir(log_dir)
    for dataset in datasets:
        dataset_dir = os.path.join(log_dir, dataset)
        one_dataset = os.listdir(dataset_dir)
        count = 0
        data_num = 0
        for one in one_dataset:
            one_path = os.path.join(dataset_dir, one)
            with open(one_path, "r") as f:
                data = json.load(f)
            try: last_step = data[-1]
            except IndexError: continue
            if last_step["extract_res"]["Action"] == "Finish":
                model_answer_list = last_step["extract_res"]["Objects"]
                ground_truth_list = last_step["true_answer"]
                check_flag = check_answer(model_answer_list, ground_truth_list)
                if check_flag:
                    count += len(last_step["now_state"])
                    data_num += 1
        print(f"Dataset: {dataset}, Count: {count}, Data Num: {data_num}, average: {count/data_num}")

def test_g_retriever(dataset_name):

    if dataset_name == "cwq" or dataset_name == "webqsp":
        with open(f"log/qa_format/tuned_g_retriever_{dataset_name}.jsonl", "r") as f:
            data = [json.loads(line) for line in f.readlines()]
        true_id = set()
        for item in data:
            if item["accuracy"] == 1.0:
                true_id.add(item["id"])
        dataset = load_dataset(f"rmanluo/RoG-{dataset_name}")
        test_data = dataset['test']
        count = 0
        data_num = 0
        for idx, data_item in enumerate(test_data):
            if data_item['id'] not in true_id:
                continue
            subg_path = f"G-retriever/dataset/0511_gretriever_cache_subg/{dataset_name}/cached_desc/{idx}.txt"
            with open(subg_path, 'r', encoding='utf-8') as f:
                lines = f.readlines()
            node_id_to_attr = {}
            triples = []
            is_node_section = True
            for line in lines:
                line = line.strip()
                if not line:
                    continue
                if line.startswith('src,edge_attr,dst'):
                    is_node_section = False
                    continue
                if is_node_section:
                    node_id, node_attr = line.split(',', 1)
                    node_id_to_attr[node_id.strip()] = node_attr.strip()
                else:
                    src, edge_attr, dst = line.split(',', 2)
                    src_name = node_id_to_attr.get(src.strip(), f"UNKNOWN({src.strip()})")
                    dst_name = node_id_to_attr.get(dst.strip(), f"UNKNOWN({dst.strip()})")
                    triples.append((src_name, edge_attr.strip(), dst_name))
            count += len(triples)
            data_num += 1
        print(f"Count: {count}, Data Num: {data_num}, average: {count/data_num}")

    if "metaqa" in dataset_name:
        with open(f"log/qa_format/tuned_g_retriever_{dataset_name}.jsonl", "r") as f:
            data = [json.loads(line) for line in f.readlines()]
        true_id = set()
        for item in data:
            if item["accuracy"] == 1.0:
                true_id.add(item["id"])

        dataset_path = f"data/{dataset_name}/qa_test_hf_format.json"
        with open(dataset_path, "r", encoding="utf-8") as f:
            test_data = json.load(f)
        count = 0
        data_num = 0
        for idx, data_item in enumerate(test_data):
            #print(f"finish {idx}")
            if data_item['id'] not in true_id:
                continue
            subg_path = f"G-retriever/dataset/0511_gretriever_cache_subg/{dataset_name}/cached_desc/{idx}.txt"
            with open(subg_path, 'r', encoding='utf-8') as f:
                lines = f.readlines()
            node_id_to_attr = {}
            triples = []
            is_node_section = True
            for line in lines:
                line = line.strip()
                if not line:
                    continue
                if line.startswith('src,edge_attr,dst'):
                    is_node_section = False
                    continue
                if is_node_section:
                    node_id, node_attr = line.split(',', 1)
                    node_id_to_attr[node_id.strip()] = node_attr.strip()
                else:
                    src, edge_attr, dst = line.split(',', 2)
                    src_name = node_id_to_attr.get(src.strip(), f"UNKNOWN({src.strip()})")
                    dst_name = node_id_to_attr.get(dst.strip(), f"UNKNOWN({dst.strip()})")
                    triples.append((src_name, edge_attr.strip(), dst_name))
            count += len(triples)
            data_num += 1
        print(f"Count: {count}, Data Num: {data_num}, average: {count/data_num}")

def test_rag(dataset_name, hops):
    print("Loading Embedding model...")
    model = SentenceTransformer("sentence-transformers/gtr-t5-large")

    with open(f"log/qa_format/tuned_rag_{hops}hop_{dataset_name}.jsonl", "r") as f:
        data = [json.loads(line) for line in f.readlines()]
    true_id = set()
    for item in data:
        if item["accuracy"] == 1.0:
            true_id.add(item["id"])

    if dataset_name == "cwq" or dataset_name == "webqsp":
        print("Loading dataset...")
        dataset = load_dataset(f"rmanluo/RoG-{dataset_name}")
        test_data = dataset['test']
        data_num = 0
        count = 0
        for idx, data_item in enumerate(test_data):
            if data_item['id'] not in true_id: continue
            question=data_item['question']
            now_graph = data_item['graph']                
            allgraph_str_list = [f"({t[0]}, {t[1]}, {t[2]})" for t in now_graph]
            allgraph_triple_list = [(t[0], t[1], t[2]) for t in now_graph]
            triple_str2tuple = dict(zip(allgraph_str_list, allgraph_triple_list))  # 可选
            print("Encoding graph texts...")
            embeddings = model.encode(allgraph_str_list, convert_to_numpy=True, show_progress_bar=True) 
            results = []
            try:
                print("Building FAISS index...")
                dimension = embeddings.shape[1]
                index = faiss.IndexFlatL2(dimension)
                index.add(embeddings)  # 将向量加入索引中
                print("索引中的向量总数:", index.ntotal) 
                print("Searching...")
                query_vec = model.encode([question], convert_to_numpy=True)
                print("finish query_vec")
                distances, indices = index.search(query_vec, 10)
                print("finish distance indices")
                initial_results = [allgraph_str_list[i] for i in indices[0]]

                results = initial_results
                for _ in range(hops):
                    results = expand_neighbors_via_faiss(results, allgraph_str_list, model, query_vec, triple_str2tuple, index, topk=10)
            except Exception as e:
                print(f"Error in FAISS index building or searching: {e}")
                continue
            count += len(results)
            data_num += 1
            print(f"Dataset {dataset_name}, hops {hops}, Now {idx}, Count: {count}, Data Num: {data_num}, average: {count/data_num}")
            print(f"Dataset {dataset_name}, hops {hops}, Total Count: {count}, Data Num: {data_num}, average: {count/data_num}")
    else:

        with open("data/metaqa_kb.json", "r", encoding="utf-8") as f:
            now_graph = json.load(f)
            metaqa_allgraph_str_list = []
            for graph_tuple in now_graph:
                metaqa_allgraph_str_list.append(f"({graph_tuple[0]}, {graph_tuple[1]}, {graph_tuple[2]})")
            allgraph_triple_list = [(t[0], t[1], t[2]) for t in now_graph]
            triple_str2tuple = dict(zip(metaqa_allgraph_str_list, allgraph_triple_list))  # 可选
                
            print("Encoding graph texts...")
            metaqa_embeddings = model.encode(metaqa_allgraph_str_list, convert_to_numpy=True, show_progress_bar=True) 
            print("Building metaqa FAISS index...")
            metaqa_dimension = metaqa_embeddings.shape[1]
            metaqa_index = faiss.IndexFlatL2(metaqa_dimension)
            metaqa_index.add(metaqa_embeddings)  # 将向量加入索引中
        dataset_path = f"data/{dataset_name}/qa_test_hf_format.json"
        with open(dataset_path, "r", encoding="utf-8") as f:
            test_data = json.load(f)
        
        count = 0
        data_num = 0

        for idx, data_item in enumerate(test_data):
            if data_item['id'] not in true_id: continue
            question = data_item['question']
            try:
                print("Searching...")
                query_vec = model.encode([question], convert_to_numpy=True)
                distances, indices = metaqa_index.search(query_vec, 10)
                initial_results = [metaqa_allgraph_str_list[i] for i in indices[0]]
                results = initial_results
                for _ in range(hops):
                    results = expand_neighbors_via_faiss(results, metaqa_allgraph_str_list, model, query_vec, triple_str2tuple, metaqa_index, topk=10)
            except Exception as e:
                print(f"Error in FAISS index building or searching: {e}")
                continue
            count += len(results)
            data_num += 1
            print(f"Dataset {dataset_name}, hops {hops}, Now {idx}, Count: {count}, Data Num: {data_num}, average: {count/data_num}")
            print(f"Dataset {dataset_name}, hops {hops}, Total Count: {count}, Data Num: {data_num}, average: {count/data_num}")

def test_tog(dataset_name):
    logging.set_verbosity_error()
    """ Load the dataset """
    if "metaqa" in dataset_name:
        data_path = f"data/{dataset_name}/qa_test_hf_format.json"
        with open(data_path, 'r') as f:
            test_data = json.load(f)
        with open("data/metaqa_kb.json", 'r') as f:
            metaqa_kb = json.load(f)
    else:
        print("Loading dataset...")
        dataset = load_dataset(f"rmanluo/RoG-{args.dataset}")
        test_data = dataset['test']

    with open(f"log/qa_format/tuned_ToG_{dataset_name}.jsonl", "r") as f:
        data = [json.loads(line) for line in f.readlines()]
    true_id = set()
    for item in data:
        if item["accuracy"] == 1.0:
            true_id.add(item["id"])

    count = 0
    data_num = 0
    
    """ 数据格式 """
    """ data_item['id]: string, data_item['question']: string, data_item['q_entity']: list of strings, data_item['a_entity']: list of strings, data_item['graph]: list of tuples (subject, relation, object) """
    for idx, data_item in enumerate(test_data):
        if data_item['id'] not in true_id: continue
        question = data_item["question"]
        topic_entity_list = data_item["q_entity"]
        if "metaqa" in dataset_name:
            graph_triples = metaqa_kb
        else:
            graph_triples = data_item["graph"]

        """ 构建图索引（以便快速查找邻居实体) """
        entity_to_triples = {}  # dict[str, list[tuple[str, str, str]]]
        for h, r, t in graph_triples:
            if h not in entity_to_triples:
                entity_to_triples[h] = []
            if t not in entity_to_triples:
                entity_to_triples[t] = []
            entity_to_triples[h].append((h, r, t))
        """ 定义 dict 存储路径与分数的对应关系, path_to_score: dict[list[str], float] """ 
        path_to_score = {}
        for ent in topic_entity_list:
            path_to_score[(ent,)] = 0
             
        new_path_to_score = {}  # 用于收集当前 depth 扩展的路径
        # 收集所有三元组
        all_pairs = []
        all_keys = []
        for ent_list in path_to_score:
            pre_ent = ent_list[-1]  # 获取当前路径的最后一个实体
            now_score = path_to_score[ent_list]
            for h, r, t in entity_to_triples[pre_ent]:
                if t in ent_list: continue
                triple_str = f"{h} {r} {t}."
                all_pairs.append((ent_list, r, t, triple_str, now_score))
                all_keys.append(triple_str)

        # 一次性计算所有 BERTScore
        P, R, F1 = bert_score([question] * len(all_keys), all_keys, lang="en", verbose=False)
        F1 = [f.item() for f in F1]
        # 回填 new_path_to_score
        for i, (ent_list, r, t, triple_str, now_score) in enumerate(all_pairs):
            new_ent_list = ent_list + (r, t)
            new_path_to_score[new_ent_list] = now_score + F1[i]
        print("******************* Finish update path")
        
        """ 仅保留分数前 5 的路径 """
        sorted_paths = sorted(path_to_score.items(), key=lambda x: x[1], reverse=True)
        print(sorted_paths)
        path_to_score = dict(sorted_paths[:10])
        print(path_to_score)
        """ 构造 prompt """
        for ent_list in list(path_to_score):
            triples = []
            """ 每一步是一个三元组，从索引 0 开始，每两个索引形成一个 (h, r, t) """
            for i in range(0, len(ent_list) - 2, 2):
                h = ent_list[i]
                r = ent_list[i + 1]
                t = ent_list[i + 2]
                triples.append(f"({h}, {r}, {t})")
            count += len(triples)
        data_num += 1
        print(f"ToG, Dataset {dataset_name}, Now {idx}, Count: {count}, Data Num: {data_num}, average: {count/data_num}")


  
                
                
                
            
if __name__ == "__main__":
    # parser = argparse.ArgumentParser(description="count retrieve number experiment")

    # parser.add_argument('--dataset', type=str, required=True, help='Dataset to use')
    # #parser.add_argument('--hops', type=str, required=True, help='hops')

    # args = parser.parse_args()
    # dataset_name = args.dataset
    #hops = int(args.hops)
    #test_rag(dataset_name, hops)
    #test_tog(dataset_name)
    # with open("data/metaqa_kb.json", 'r') as f:
    #     metaqa_kb = json.load(f)
    # print(len(metaqa_kb))
    print("Loading dataset...")
    dataset = load_dataset(f"rmanluo/RoG-cwq")
    test_data = dataset['test']
    count = 0
    num = len(test_data)
    for item in test_data:
        graph = item['graph']
        count += len(graph)
    print(f"Total triples in webqsp test set: {count}, average: {count/num}")




   
    
