import os
from typing import List
import json

from src.hipporag.HippoRAG import HippoRAG
from src.hipporag.utils.misc_utils import string_to_bool
from src.hipporag.utils.config_utils import BaseConfig

import argparse

from src.hipporag.TAG import TAG

# os.environ["LOG_LEVEL"] = "DEBUG"
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["TOKENIZERS_PARALLELISM"] = "false"

import logging

logging.getLogger("openai").setLevel(logging.ERROR)
logging.getLogger("httpx").setLevel(logging.ERROR)

def get_gold_docs(samples: List, dataset_name: str = None) -> List:
    gold_docs = []
    for sample in samples:
        if 'supporting_facts' in sample:  # hotpotqa, 2wikimultihopqa
            gold_title = set([item[0] for item in sample['supporting_facts']])
            gold_title_and_content_list = [item for item in sample['context'] if item[0] in gold_title]
            if dataset_name.startswith('hotpotqa'):
                gold_doc = [item[0] + '\n' + ''.join(item[1]) for item in gold_title_and_content_list]
            else:
                gold_doc = [item[0] + '\n' + ' '.join(item[1]) for item in gold_title_and_content_list]
        elif 'contexts' in sample:
            gold_doc = [item['title'] + '\n' + item['text'] for item in sample['contexts'] if item['is_supporting']]
        else:
            assert 'paragraphs' in sample, "`paragraphs` should be in sample, or consider the setting not to evaluate retrieval"
            gold_paragraphs = []
            for item in sample['paragraphs']:
                if 'is_supporting' in item and item['is_supporting'] is False:
                    continue
                gold_paragraphs.append(item)
            gold_doc = [item['title'] + '\n' + (item['text'] if 'text' in item else item['paragraph_text']) for item in gold_paragraphs]

        gold_doc = list(set(gold_doc))
        gold_docs.append(gold_doc)
    return gold_docs


def get_gold_answers(samples):
    gold_answers = []
    for sample_idx in range(len(samples)):
        gold_ans = None
        sample = samples[sample_idx]

        if 'answer' in sample or 'gold_ans' in sample:
            gold_ans = sample['answer'] if 'answer' in sample else sample['gold_ans']
        elif 'reference' in sample:
            gold_ans = sample['reference']
        elif 'obj' in sample:
            gold_ans = set(
                [sample['obj']] + [sample['possible_answers']] + [sample['o_wiki_title']] + [sample['o_aliases']])
            gold_ans = list(gold_ans)
        assert gold_ans is not None
        if isinstance(gold_ans, str):
            gold_ans = [gold_ans]
        assert isinstance(gold_ans, list)
        gold_ans = set(gold_ans)
        if 'answer_aliases' in sample:
            gold_ans.update(sample['answer_aliases'])

        gold_answers.append(gold_ans)

    return gold_answers

# lifang535 add
# API_KEY = "5VOQLBDdFVDUPLBrZxSiVfQNgQmGK09ESOv3FgvNbz7i9Uv2cIupJQQJ99BBACfhMk5XJ3w3AAABACOG0TWa"
# API_URL = "https://pcg-sweden-central.openai.azure.com/"
# API_VERSION = "2024-02-15-preview"
# MODEL = "gpt-4o"
# EMBEDDING_MODEL = "text-embedding-ada-002"
# """
# export OPENAI_API_KEY=5VOQLBDdFVDUPLBrZxSiVfQNgQmGK09ESOv3FgvNbz7i9Uv2cIupJQQJ99BBACfhMk5XJ3w3AAABACOG0TWa
# export CUDA_VISIBLE_DEVICES=2,3
# export HF_ENDPOINT=https://hf-mirror.com
# export HF_HOME=/data2-HDD-SATA-20T/nzq/huggingface_cache
# """
API_KEY = "4zvobuaW6AGxlNzfLLtze9wdZDoko1Y7mY9JEkQoFLQB9tNeTe97JQQJ99BEACYeBjFXJ3w3AAABACOGTtI0"
API_URL = "https://gpt-nzq-east-us.openai.azure.com/"
EMBEDDING_API_URL = "https://gpt-nzq-sweden-central.openai.azure.com/"
EMBEDDING_API_KEY = "2WKBSMb1AE1bEOdmzlIC0N4SGbzqLAQPRe1hUH0cJGirKwtkl8FTJQQJ99BEACfhMk5XJ3w3AAABACOGcWL6"
API_VERSION = "2024-02-15-preview"
# MODEL = "gpt-4o"
MODEL = "gpt-4o-mini"
EMBEDDING_MODEL = "text-embedding-ada-002"
import os
os.environ["OPENAI_API_KEY"] = API_KEY
os.environ["CUDA_VISIBLE_DEVICES"] = "2,3"
os.environ["HF_ENDPOINT"] = "https://hf-mirror.com"
os.environ["HF_HOME"] = "/data2-HDD-SATA-20T/nzq/huggingface_cache"
os.environ["TRANSFORMERS_CACHE"] = "/data2-HDD-SATA-20T/nzq/huggingface_cache/transformers"
os.environ["HF_DATASETS_CACHE"] = "/data2-HDD-SATA-20T/nzq/huggingface_cache/datasets"
os.environ["HF_METRICS_CACHE"] = "/data2-HDD-SATA-20T/nzq/huggingface_cache/metrics"
"""
export OPENAI_API_KEY=4zvobuaW6AGxlNzfLLtze9wdZDoko1Y7mY9JEkQoFLQB9tNeTe97JQQJ99BEACYeBjFXJ3w3AAABACOGTtI0
export CUDA_VISIBLE_DEVICES=1,2,3
export HF_ENDPOINT=https://hf-mirror.com
export HF_HOME=/data2-HDD-SATA-20T/nzq/huggingface_cache
"""
# python main_azure_TAG.py --dataset sample --llm_base_url https://api.openai.com/v1 --llm_name gpt-4o-mini --embedding_name nvidia/NV-Embed-v2

# python main_azure_TAG.py --dataset sample --llm_base_url https://gpt-nzq-east-us.openai.azure.com/ --llm_name gpt-4o-mini --embedding_name nvidia/NV-Embed-v2

# python main_azure_TAG.py --dataset hotpotqa --llm_base_url https://gpt-nzq-east-us.openai.azure.com/ --llm_name gpt-4o-mini --embedding_name nvidia/NV-Embed-v2
# python main_azure_TAG.py --dataset hotpotqa --llm_base_url https://gpt-nzq-east-us.openai.azure.com/ --llm_name gpt-4o-mini

# 2wikimultihopqa.json

"""
[HippoRAG]

2wikimultihopqa | gpt-4o-mini | nvidia/NV-Embed-v2
    INFO:
2wikimultihopqa | gpt-4o-mini | text-embedding-ada-002


hotpotqa | gpt-4o-mini | nvidia/NV-Embed-v2
    INFO:src.hipporag.TAG:Evaluation results for QA: {'ExactMatch': 0.574, 'F1': 0.722}
hotpotqa | gpt-4o-mini | text-embedding-ada-002 # Lithuanian Jews of Kaunas during the  会违反 policy
    INFO:src.hipporag.TAG:Evaluation results for QA: {'ExactMatch': 0.491, 'F1': 0.6351}


musique | gpt-4o-mini | nvidia/NV-Embed-v2
    INFO:src.hipporag.HippoRAG:Evaluation results for QA: {'ExactMatch': 0.346, 'F1': 0.4903}
musique | gpt-4o-mini | text-embedding-ada-002
    INFO:src.hipporag.TAG:Evaluation results for QA: {'ExactMatch': 0.263, 'F1': 0.3903}


sample


[TAG] 也得用 nvidia/NV-Embed-v2

hotpotqa | gpt-4o-mini | nvidia/NV-Embed-v2
    INFO:src.hipporag.TAG:Evaluation results for QA: {'ExactMatch': 0.24, 'F1': 0.3383} [select_number = 100]
hotpotqa | gpt-4o-mini | text-embedding-ada-002
    

"""

def main():
    parser = argparse.ArgumentParser(description="HippoRAG retrieval and QA")
    parser.add_argument('--dataset', type=str, default='musique', help='Dataset name')
    # parser.add_argument('--llm_base_url', type=str, default='https://api.openai.com/v1', help='LLM base URL') # lifang535 delete
    parser.add_argument('--llm_base_url', type=str, default=API_URL, help='LLM base URL') # lifang535 add
    # parser.add_argument('--llm_name', type=str, default='gpt-4o-mini', help='LLM name')
    parser.add_argument('--llm_name', type=str, default=MODEL, help='LLM name')
    # parser.add_argument('--embedding_name', type=str, default='text-embedding-3-small', help='embedding model name')
    parser.add_argument('--embedding_name', type=str, default=EMBEDDING_MODEL, help='embedding model name')
    # parser.add_argument('--azure_endpoint', type=str, default=None, help='Azure Endpoint URL')
    parser.add_argument('--azure_endpoint', type=str, default=API_URL, help='Azure Endpoint URL')
    # parser.add_argument('--azure_embedding_endpoint', type=str, default=None, help='Azure Embedding Endpoint')
    parser.add_argument('--azure_embedding_endpoint', type=str, default=EMBEDDING_API_URL, help='Azure Embedding Endpoint')
    parser.add_argument('--force_index_from_scratch', type=str, default='false',
                        help='If set to True, will ignore all existing storage files and graph data and will rebuild from scratch.')
    parser.add_argument('--force_openie_from_scratch', type=str, default='false', help='If set to False, will try to first reuse openie results for the corpus if they exist.')
    parser.add_argument('--openie_mode', choices=['online', 'offline'], default='online',
                        help="OpenIE mode, offline denotes using VLLM offline batch mode for indexing, while online denotes")
    parser.add_argument('--save_dir', type=str, default='outputs', help='Save directory')
    args = parser.parse_args()

    save_dir = args.save_dir
    dataset_name = args.dataset
    if save_dir == 'outputs':
        save_dir = save_dir + '/' + dataset_name
    else:
        save_dir = save_dir + '_' + dataset_name

    llm_base_url = args.llm_base_url
    llm_name = args.llm_name
    azure_endpoint = args.azure_endpoint
    azure_embedding_endpoint = args.azure_embedding_endpoint

    # TODO: 看一下数据集读取格式
    # TODO: 为什么一些数据及上本地 embedding model 会爆显存
    corpus_path = f"reproduce/dataset/{dataset_name}_corpus.json"
    with open(corpus_path, "r") as f:
        corpus = json.load(f)

    docs = [f"{doc['title']}\n{doc['text']}" for doc in corpus]
    
    select_number = 100 # lifang535 add
    docs = docs[:select_number]

    force_index_from_scratch = string_to_bool(args.force_index_from_scratch)
    force_openie_from_scratch = string_to_bool(args.force_openie_from_scratch)

    # Prepare datasets and evaluation
    samples = json.load(open(f"reproduce/dataset/{dataset_name}.json", "r"))
    all_queries = [s['question'] for s in samples]

    gold_answers = get_gold_answers(samples)
    try:
        gold_docs = get_gold_docs(samples, dataset_name)
        assert len(all_queries) == len(gold_docs) == len(gold_answers), "Length of queries, gold_docs, and gold_answers should be the same."
    except:
        gold_docs = None
    
    all_queries = all_queries[:select_number]
    gold_docs = gold_docs[:select_number]
    gold_answers = gold_answers[:select_number]

    config = BaseConfig(
        save_dir=save_dir,
        llm_base_url=llm_base_url,
        llm_name=llm_name,
        azure_endpoint=azure_endpoint,
        azure_embedding_endpoint=azure_embedding_endpoint,
        dataset=dataset_name,
        embedding_model_name=args.embedding_name,
        force_index_from_scratch=force_index_from_scratch,  # ignore previously stored index, set it to False if you want to use the previously stored index and embeddings
        force_openie_from_scratch=force_openie_from_scratch,
        rerank_dspy_file_path="src/hipporag/prompts/dspy_prompts/filter_llama3.3-70B-Instruct.json",
        retrieval_top_k=200,
        linking_top_k=5,
        max_qa_steps=3,
        qa_top_k=5,
        graph_type="facts_and_sim_passage_node_unidirectional",
        # embedding_batch_size=10, # lifang535 delete
        embedding_batch_size=5, # lifang535 add
        max_new_tokens=None,
        corpus_len=len(corpus),
        openie_mode=args.openie_mode
    )

    logging.basicConfig(level=logging.INFO)

    # hipporag = HippoRAG(global_config=config) # lifang535 delete
    hipporag = TAG(global_config=config) # lifang535 add

    # hipporag.index(docs)

    # Retrieval and QA
    # hipporag.rag_qa(queries=all_queries, gold_docs=gold_docs, gold_answers=gold_answers) # lifang535 delete

    time_dir = None
    time_dir = hipporag.TAG_index(docs) # lifang535 add
    # # time_dir = f"/data2-HDD-SATA-20T/nzq/jmf/new_rag_2/experiment/HippoRAG/TAG_data/2025-05-14 22:34:56"
    hipporag.TAG_rag_qa(queries=all_queries, gold_docs=gold_docs, gold_answers=gold_answers, time_dir=time_dir) # lifang535 add

if __name__ == "__main__":
    main()
