import os
from typing import List
import json

from src.hipporag.HippoRAG import HippoRAG
from src.hipporag.utils.misc_utils import string_to_bool
from src.hipporag.utils.config_utils import BaseConfig

import argparse

from src.hipporag.TAG import TAG
import logging, re, json
from typing import Dict, Any
import time

# os.environ["LOG_LEVEL"] = "DEBUG"
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["TOKENIZERS_PARALLELISM"] = "false"

import logging

logging.getLogger("openai").setLevel(logging.ERROR)
logging.getLogger("httpx").setLevel(logging.ERROR)

def get_gold_docs(samples: List, dataset_name: str = None) -> List:
    gold_docs = []
    for sample in samples:
        if 'supporting_facts' in sample:  # hotpotqa, 2wikimultihopqa
            gold_title = set([item[0] for item in sample['supporting_facts']])
            gold_title_and_content_list = [item for item in sample['context'] if item[0] in gold_title]
            if dataset_name.startswith('hotpotqa'):
                gold_doc = [item[0] + '\n' + ''.join(item[1]) for item in gold_title_and_content_list]
            else:
                gold_doc = [item[0] + '\n' + ' '.join(item[1]) for item in gold_title_and_content_list]
        elif 'contexts' in sample:
            gold_doc = [item['title'] + '\n' + item['text'] for item in sample['contexts'] if item['is_supporting']]
        else:
            assert 'paragraphs' in sample, "`paragraphs` should be in sample, or consider the setting not to evaluate retrieval"
            gold_paragraphs = []
            for item in sample['paragraphs']:
                if 'is_supporting' in item and item['is_supporting'] is False:
                    continue
                gold_paragraphs.append(item)
            gold_doc = [item['title'] + '\n' + (item['text'] if 'text' in item else item['paragraph_text']) for item in gold_paragraphs]

        gold_doc = list(set(gold_doc))
        gold_docs.append(gold_doc)
    return gold_docs


def get_gold_answers(samples):
    gold_answers = []
    for sample_idx in range(len(samples)):
        gold_ans = None
        sample = samples[sample_idx]

        if 'answer' in sample or 'gold_ans' in sample:
            gold_ans = sample['answer'] if 'answer' in sample else sample['gold_ans']
        elif 'reference' in sample:
            gold_ans = sample['reference']
        elif 'obj' in sample:
            gold_ans = set(
                [sample['obj']] + [sample['possible_answers']] + [sample['o_wiki_title']] + [sample['o_aliases']])
            gold_ans = list(gold_ans)
        assert gold_ans is not None
        # if sample_idx == 0: # lifang535 add
        #     print(f"[lifang535] gold_ans (type={type(gold_ans)}): {gold_ans}") # lifang535 add
        if isinstance(gold_ans, str):
            gold_ans = [gold_ans]
        assert isinstance(gold_ans, list)
        gold_ans = set(gold_ans)
        if 'answer_aliases' in sample:
            gold_ans.update(sample['answer_aliases'])

        gold_answers.append(gold_ans)
        
        # if sample_idx == 5: # lifang535 add
        #     print(f"[lifang535] gold_answers (type={type(gold_answers)}): {gold_answers}") # lifang535 add
        #     print(f"[lifang535] gold_answers[0] (type={type(gold_answers[0])}): {gold_answers[0]}") # lifang535 add

    return gold_answers

# lifang535 add
# API_KEY = "5VOQLBDdFVDUPLBrZxSiVfQNgQmGK09ESOv3FgvNbz7i9Uv2cIupJQQJ99BBACfhMk5XJ3w3AAABACOG0TWa"
# API_URL = "https://pcg-sweden-central.openai.azure.com/"
# API_VERSION = "2024-02-15-preview"
# MODEL = "gpt-4o"
# EMBEDDING_MODEL = "text-embedding-ada-002"
# """
# export OPENAI_API_KEY=5VOQLBDdFVDUPLBrZxSiVfQNgQmGK09ESOv3FgvNbz7i9Uv2cIupJQQJ99BBACfhMk5XJ3w3AAABACOG0TWa
# export CUDA_VISIBLE_DEVICES=2,3
# export HF_ENDPOINT=https://hf-mirror.com
# export HF_HOME=/data2-HDD-SATA-20T/nzq/huggingface_cache
# """
# API_KEY = "2WKBSMb1AE1bEOdmzlIC0N4SGbzqLAQPRe1hUH0cJGirKwtkl8FTJQQJ99BEACfhMk5XJ3w3AAABACOGcWL6"
# API_URL = "https://gpt-nzq-sweden-central.openai.azure.com/"
# EMBEDDING_API_URL = "https://gpt-nzq-sweden-central.openai.azure.com/"
# EMBEDDING_API_KEY = "2WKBSMb1AE1bEOdmzlIC0N4SGbzqLAQPRe1hUH0cJGirKwtkl8FTJQQJ99BEACfhMk5XJ3w3AAABACOGcWL6"

# API_KEY = "4vKnl0FwoZSYrsTJOvY9YJl74riuMpWbPDubVMKzaz3Ywr6fzuakJQQJ99BEACL93NaXJ3w3AAABACOGeATr"
# API_URL = "https://gpt-nzq-aus-east.openai.azure.com/"
# EMBEDDING_API_URL = "https://gpt-nzq-sweden-central.openai.azure.com/"
# EMBEDDING_API_KEY = "2WKBSMb1AE1bEOdmzlIC0N4SGbzqLAQPRe1hUH0cJGirKwtkl8FTJQQJ99BEACfhMk5XJ3w3AAABACOGcWL6"


# API_KEY = "4vKnl0FwoZSYrsTJOvY9YJl74riuMpWbPDubVMKzaz3Ywr6fzuakJQQJ99BEACL93NaXJ3w3AAABACOGeATr"
# API_URL = "https://gpt-nzq-aus-east.openai.azure.com/"

# API_KEY = "2WKBSMb1AE1bEOdmzlIC0N4SGbzqLAQPRe1hUH0cJGirKwtkl8FTJQQJ99BEACfhMk5XJ3w3AAABACOGcWL6"
# API_URL = "https://gpt-nzq-sweden-central.openai.azure.com/"

# east的api是ok的
API_KEY = "4zvobuaW6AGxlNzfLLtze9wdZDoko1Y7mY9JEkQoFLQB9tNeTe97JQQJ99BEACYeBjFXJ3w3AAABACOGTtI0"
API_URL = "https://gpt-nzq-east-us.openai.azure.com/"

#sweden的embedding是OK的
EMBEDDING_API_URL = "https://gpt-nzq-sweden-central.openai.azure.com/"
EMBEDDING_API_KEY = "2WKBSMb1AE1bEOdmzlIC0N4SGbzqLAQPRe1hUH0cJGirKwtkl8FTJQQJ99BEACfhMk5XJ3w3AAABACOGcWL6"


API_VERSION = "2024-02-15-preview"
# MODEL = "gpt-4o"
MODEL = "gpt-4o-mini"
EMBEDDING_MODEL = "text-embedding-ada-002"
import os
os.environ["OPENAI_API_KEY"] = API_KEY
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
os.environ["HF_ENDPOINT"] = "https://hf-mirror.com"
os.environ["HF_HOME"] = "/data2-HDD-SATA-20T/nzq/huggingface_model"
os.environ["TRANSFORMERS_CACHE"] = "/data2-HDD-SATA-20T/nzq/huggingface_model/transformers"
os.environ["HF_DATASETS_CACHE"] = "/data2-HDD-SATA-20T/nzq/huggingface_model/datasets"
os.environ["HF_METRICS_CACHE"] = "/data2-HDD-SATA-20T/nzq/huggingface_model/metrics"
"""
export OPENAI_API_KEY=2WKBSMb1AE1bEOdmzlIC0N4SGbzqLAQPRe1hUH0cJGirKwtkl8FTJQQJ99BEACfhMk5XJ3w3AAABACOGcWL6
export CUDA_VISIBLE_DEVICES=0
export HF_ENDPOINT=https://hf-mirror.com
export HF_HOME=/root/autodl-tmp/hf_home

cd /root/autodl-tmp/.autodl/hipporag_echo_2

conda activate hipporag

"""
# python main_azure.py --dataset 2wikimultihopqa --llm_name gpt-4o-mini --embedding_name text-embedding-ada-002

# python main_azure.py --dataset hotpotqa --llm_base_url https://gpt-nzq-east-us.openai.azure.com/ --llm_name gpt-4o-mini --embedding_name text-embedding-ada-002 --azure_embedding_endpoint https://gpt-nzq-sweden-central.openai.azure.com/
# python main_azure.py --dataset musique --llm_base_url https://gpt-nzq-east-us.openai.azure.com/ --llm_name gpt-4o-mini --embedding_name text-embedding-ada-002 --azure_embedding_endpoint https://gpt-nzq-sweden-central.openai.azure.com/
# python main_azure.py --dataset 2wikimultihopqa --llm_base_url https://gpt-nzq-sweden-central.openai.azure.com/ --llm_name gpt-4o-mini --embedding_name text-embedding-ada-002 --azure_embedding_endpoint https://gpt-nzq-sweden-central.openai.azure.com/
# python main_azure.py --dataset popqa --llm_base_url https://gpt-nzq-east-us.openai.azure.com/ --llm_name gpt-4o-mini --embedding_name text-embedding-ada-002 --azure_embedding_endpoint https://gpt-nzq-sweden-central.openai.azure.com/
# python main_azure.py --dataset nq_rear --llm_base_url https://gpt-nzq-east-us.openai.azure.com/ --llm_name gpt-4o-mini --embedding_name text-embedding-ada-002 --azure_embedding_endpoint https://gpt-nzq-sweden-central.openai.azure.com/


# python main_azure.py --dataset sample --llm_base_url https://api.openai.com/v1 --llm_name gpt-4o-mini --embedding_name nvidia/NV-Embed-v2

# python main_azure.py --dataset sample --llm_base_url https://gpt-nzq-east-us.openai.azure.com/ --llm_name gpt-4o-mini --embedding_name nvidia/NV-Embed-v2

# python main_azure.py --dataset hotpotqa --llm_name gpt-4o-mini --embedding_name nvidia/NV-Embed-v2
# python main_azure.py --dataset 2wikimultihopqa --llm_name gpt-4o-mini --embedding_name nvidia/NV-Embed-v2
# python main_azure.py --dataset musique --llm_name gpt-4o-mini --embedding_name nvidia/NV-Embed-v2
# python main_azure.py --dataset popqa --llm_name gpt-4o-mini --embedding_name nvidia/NV-Embed-v2
# python main_azure.py --dataset nq_rear --llm_name gpt-4o-mini --embedding_name nvidia/NV-Embed-v2


# python main_azure.py --dataset musique --llm_base_url https://gpt-nzq-east-us.openai.azure.com/ --llm_name gpt-4o-mini --embedding_name nvidia/NV-Embed-v2
# python main_azure.py --dataset 2wikimultihopqa --llm_base_url https://gpt-nzq-east-us.openai.azure.com/ --llm_name gpt-4o-mini --embedding_name nvidia/NV-Embed-v2

# python main_azure.py --dataset popqa --llm_base_url https://gpt-nzq-east-us.openai.azure.com/ --llm_name gpt-4o-mini --embedding_name nvidia/NV-Embed-v2
# python main_azure.py --dataset nq_rear --llm_base_url https://gpt-nzq-east-us.openai.azure.com/ --llm_name gpt-4o-mini --embedding_name nvidia/NV-Embed-v2

# python main_azure.py --dataset lveval --llm_base_url https://gpt-nzq-east-us.openai.azure.com/ --llm_name gpt-4o-mini --embedding_name nvidia/NV-Embed-v2
# python main_azure.py --dataset narrativeqa_dev_10_doc --llm_base_url https://gpt-nzq-east-us.openai.azure.com/ --llm_name gpt-4o-mini --embedding_name nvidia/NV-Embed-v2

# python main_azure.py --dataset hotpotqa --llm_base_url https://gpt-nzq-east-us.openai.azure.com/ --llm_name gpt-4o-mini

# 2wikimultihopqa.json

"""
[HippoRAG]

2wikimultihopqa | gpt-4o-mini | nvidia/NV-Embed-v2
    INFO:
2wikimultihopqa | gpt-4o-mini | text-embedding-ada-002


hotpotqa | gpt-4o-mini | nvidia/NV-Embed-v2
    INFO:src.hipporag.TAG:Evaluation results for QA: {'ExactMatch': 0.574, 'F1': 0.722}
hotpotqa | gpt-4o-mini | text-embedding-ada-002 # Lithuanian Jews of Kaunas during the  会违反 policy
    INFO:src.hipporag.TAG:Evaluation results for QA: {'ExactMatch': 0.491, 'F1': 0.6351}


musique | gpt-4o-mini | nvidia/NV-Embed-v2
    INFO:src.hipporag.HippoRAG:Evaluation results for QA: {'ExactMatch': 0.346, 'F1': 0.4903}
musique | gpt-4o-mini | text-embedding-ada-002
    INFO:src.hipporag.TAG:Evaluation results for QA: {'ExactMatch': 0.263, 'F1': 0.3903}


sample


[TAG] 也得用 nvidia/NV-Embed-v2

hotpotqa | gpt-4o-mini | nvidia/NV-Embed-v2
    INFO:src.hipporag.TAG:Evaluation results for QA: {'ExactMatch': 0.24, 'F1': 0.3383} [select_number = 100]
hotpotqa | gpt-4o-mini | text-embedding-ada-002


"""

def main():
    print("1234")

    parser = argparse.ArgumentParser(description="HippoRAG retrieval and QA")
    parser.add_argument('--dataset', type=str, default='musique', help='Dataset name')
    # parser.add_argument('--llm_base_url', type=str, default='https://api.openai.com/v1', help='LLM base URL') # lifang535 delete
    parser.add_argument('--llm_base_url', type=str, default=API_URL, help='LLM base URL') # lifang535 add
    # parser.add_argument('--llm_name', type=str, default='gpt-4o-mini', help='LLM name')
    parser.add_argument('--llm_name', type=str, default=MODEL, help='LLM name')
    # parser.add_argument('--embedding_name', type=str, default='text-embedding-3-small', help='embedding model name')
    parser.add_argument('--embedding_name', type=str, default=EMBEDDING_MODEL, help='embedding model name')
    # parser.add_argument('--azure_endpoint', type=str, default=None, help='Azure Endpoint URL')
    parser.add_argument('--azure_endpoint', type=str, default=API_URL, help='Azure Endpoint URL')
    # parser.add_argument('--azure_embedding_endpoint', type=str, default=None, help='Azure Embedding Endpoint')
    parser.add_argument('--azure_embedding_endpoint', type=str, default=EMBEDDING_API_URL, help='Azure Embedding Endpoint')
    parser.add_argument('--force_index_from_scratch', type=str, default='false',
                        help='If set to True, will ignore all existing storage files and graph data and will rebuild from scratch.')
    parser.add_argument('--force_openie_from_scratch', type=str, default='false', help='If set to False, will try to first reuse openie results for the corpus if they exist.')
    parser.add_argument('--openie_mode', choices=['online', 'offline'], default='online',
                        help="OpenIE mode, offline denotes using VLLM offline batch mode for indexing, while online denotes")
    parser.add_argument('--save_dir', type=str, default='outputs', help='Save directory')
    args = parser.parse_args()


    print("1234")
    save_dir = args.save_dir
    dataset_name = args.dataset
    if save_dir == 'outputs':
        save_dir = save_dir + '/' + dataset_name
    else:
        save_dir = save_dir + '_' + dataset_name

    llm_base_url = args.llm_base_url
    llm_name = args.llm_name
    azure_endpoint = args.azure_endpoint
    azure_embedding_endpoint = args.azure_embedding_endpoint

    # TODO: 看一下数据集读取格式
    # TODO: 为什么一些数据及上本地 embedding model 会爆显存
    corpus_path = f"reproduce/dataset/{dataset_name}_corpus.json"
    with open(corpus_path, "r") as f:
        corpus = json.load(f)

    docs = [f"{doc['title']}\n{doc['text']}" for doc in corpus]
    
    print(f"[lifang535] len(corpus): {len(corpus)}") # lifang535 add
    print(f"[lifang535] len(docs) = {len(docs)}") # lifang535 add
    # time.sleep(4110100) # lifang535 add
    
    select_number = 1000 # lifang535 add
    # corpus = corpus[:(select_number*30)]
    # docs = docs[:(select_number*30)]

    force_index_from_scratch = string_to_bool(args.force_index_from_scratch)
    force_openie_from_scratch = string_to_bool(args.force_openie_from_scratch)

    # Prepare datasets and evaluation
    samples = json.load(open(f"reproduce/dataset/{dataset_name}.json", "r"))
    all_queries = [s['question'] for s in samples]

    # lifang535 add: 读取 1000 个正确答案和正确文档
    gold_answers = get_gold_answers(samples)
    try:
        gold_docs = get_gold_docs(samples, dataset_name)
        assert len(all_queries) == len(gold_docs) == len(gold_answers), "Length of queries, gold_docs, and gold_answers should be the same."
    except:
        gold_docs = None
    
    """ # lifang535 delete
    """
    print(f"[lifang535] len(all_queries): {len(all_queries)}") # lifang535 add
    print(f"[lifang535] len(gold_docs) = {len(gold_docs)}") # lifang535 add
    # print(f"[lifang535] all_queries[:5]: \n{all_queries[:5]}")
    # # list(str)
    # print(f"[lifang535] gold_answers[:5]: \n{gold_answers[:5]}")
    # # list(set)
    # print(f"[lifang535] gold_docs[:5]: \n{gold_docs[:5]}")
    # # list(list(str))
    # time.sleep(4110100) # lifang535 add
    
    all_queries = all_queries[:select_number]
    gold_docs = gold_docs[:select_number]
    gold_answers = gold_answers[:select_number]

    config = BaseConfig(
        save_dir=save_dir,
        llm_base_url=llm_base_url,
        llm_name=llm_name,
        azure_endpoint=azure_endpoint,
        azure_embedding_endpoint=azure_embedding_endpoint,
        dataset=dataset_name,
        embedding_model_name=args.embedding_name,
        force_index_from_scratch=force_index_from_scratch,  # ignore previously stored index, set it to False if you want to use the previously stored index and embeddings
        force_openie_from_scratch=force_openie_from_scratch,
        rerank_dspy_file_path="src/hipporag/prompts/dspy_prompts/filter_llama3.3-70B-Instruct.json",
        retrieval_top_k=200,
        linking_top_k=5,
        max_qa_steps=3,
        qa_top_k=5,
        graph_type="facts_and_sim_passage_node_unidirectional",
        # embedding_batch_size=10, # lifang535 delete
        embedding_batch_size=5, # lifang535 add
        max_new_tokens=None,
        corpus_len=len(corpus),
        openie_mode=args.openie_mode,

        # === 新增实验超参 ===
        ppr_topk=12,                # run_ppr_new 里 Top-K
        dense_rerank_topk=200,     # 候选段落数
        dense_fuse_alpha=0.5       # 0=纯PPR, 1=纯Dense, 中间=融合
    )
    print("985========================================================")    
    logging.basicConfig(level=logging.INFO)

    # hipporag = HippoRAG(global_config=config) # lifang535 delete
    
    hipporag = TAG(global_config=config) # lifang535 add
    print("C9========================================================") 
    # hipporag.index(docs)
    # # Retrieval and QA
    # hipporag.rag_qa(queries=all_queries, gold_docs=gold_docs, gold_answers=gold_answers) # lifang535 delete

    # time_dir = None
    # time_dir = hipporag.TAG_index(docs) # lifang535 add
    # # time_dir = f"/data2-HDD-SATA-20T/nzq/jmf/new_rag_2/experiment/HippoRAG/TAG_data/2025-05-14 22:34:56"
    # hipporag.TAG_rag_qa(queries=all_queries, gold_docs=gold_docs, gold_answers=gold_answers, time_dir=time_dir) # lifang535 add
    
    
    hipporag.topic_index(docs)
    # hipporag.index(docs)
    # Retrieval and QA
    print("211========================================================")
    # ===== 新增：捕获日志并打印带配置的结果摘要 =====


    summary: Dict[str, Any] = {
        "dataset": args.dataset,
        "llm": args.llm_name,
        "embedding": args.embedding_name,
        "total_retrieval_time": None,
        "total_recognition_time": None,
        "total_ppr_time": None,
        "total_misc_time": None,
        "retrieval": {},
        "qa": {},
    }

    class CatchAllInfo(logging.Filter):
        def filter(self, record: logging.LogRecord) -> bool:
            msg = record.getMessage()
            if m := re.search(r"Total Retrieval Time\s+([\d.]+)s", msg):
                summary["total_retrieval_time"] = float(m.group(1))
            if m := re.search(r"Total Recognition Memory Time\s+([\d.]+)s", msg):
                summary["total_recognition_time"] = float(m.group(1))
            if m := re.search(r"Total PPR Time\s+([\d.]+)s", msg):
                summary["total_ppr_time"] = float(m.group(1))
            if m := re.search(r"Total Misc Time\s+([\d.]+)s", msg):
                summary["total_misc_time"] = float(m.group(1))
            if "Evaluation results for retrieval" in msg and (m := re.search(r"\{.*\}", msg)):
                try:
                    summary["retrieval"] = json.loads(m.group(0).replace("'", '"'))
                except Exception:
                    pass
            if "Evaluation results for QA" in msg and (m := re.search(r"\{.*\}", msg)):
                try:
                    summary["qa"] = json.loads(m.group(0).replace("'", '"'))
                except Exception:
                    pass
            return True

    logging.getLogger("src.hipporag.TAG").addFilter(CatchAllInfo())

    # ---------- 正式运行 ----------
    hipporag.rag_qa(queries=all_queries, gold_docs=gold_docs, gold_answers=gold_answers)

    # ---------- 打印结果 ----------
    print("\n" + "="*70)
    print(f"实验配置 -> 数据集: {summary['dataset']} | LLM: {summary['llm']} | Embedding: {summary['embedding']}")
    print("-"*70)
    print("各阶段耗时:")
    print(f"  Total Retrieval Time : {summary['total_retrieval_time'] or 'N/A'}s")
    print(f"  Recognition Memory   : {summary['total_recognition_time'] or 'N/A'}s")
    print(f"  PPR Time             : {summary['total_ppr_time'] or 'N/A'}s")
    print(f"  Misc Time            : {summary['total_misc_time'] or 'N/A'}s")
    print("\n检索评价:")
    for k, v in summary["retrieval"].items():
        print(f"  {k:<12}: {v:.4f}")
    print("\nQA 评价:")
    for k, v in summary["qa"].items():
        print(f"  {k:<12}: {v:.4f}")
    print("="*70)
# ===== 结束：捕获日志并打印带配置的结果摘要 =====

if __name__ == "__main__":
    main()


