# -*- coding: utf-8 -*-
"""
main_azure.py — 批量消融实验入口
支持：
  - 一次建索引，随后多组参数复用索引跑检索与QA
  - 消融维度：dense_fuse_alpha 列表 × ppr_topk 列表
  - 记录耗时、检索Recall@k、EM/F1，并打印/保存为CSV（可选）
"""

import os
import re
import json
import time
import argparse
import logging
from typing import List, Dict, Any, Tuple, Optional

import numpy as np

# === 你项目内模块 ===
from src.hipporag.TAG import TAG
from src.hipporag.utils.config_utils import BaseConfig
from src.hipporag.utils.misc_utils import string_to_bool


# east的api是ok的
API_KEY = "4zvobuaW6AGxlNzfLLtze9wdZDoko1Y7mY9JEkQoFLQB9tNeTe97JQQJ99BEACYeBjFXJ3w3AAABACOGTtI0"
API_URL = "https://gpt-nzq-east-us.openai.azure.com/"

#sweden的embedding是OK的
EMBEDDING_API_URL = "https://gpt-nzq-sweden-central.openai.azure.com/"
EMBEDDING_API_KEY = "2WKBSMb1AE1bEOdmzlIC0N4SGbzqLAQPRe1hUH0cJGirKwtkl8FTJQQJ99BEACfhMk5XJ3w3AAABACOGcWL6"

API_VERSION = "2024-02-15-preview"
MODEL = "gpt-4o-mini"
EMBEDDING_MODEL = "text-embedding-ada-002"

import os
os.environ["OPENAI_API_KEY"] = API_KEY
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
os.environ["HF_ENDPOINT"] = "https://hf-mirror.com"
os.environ["HF_HOME"] = "/data2-HDD-SATA-20T/nzq/huggingface_model"
os.environ["TRANSFORMERS_CACHE"] = "/data2-HDD-SATA-20T/nzq/huggingface_model/transformers"
os.environ["HF_DATASETS_CACHE"] = "/data2-HDD-SATA-20T/nzq/huggingface_model/datasets"
os.environ["HF_METRICS_CACHE"] = "/data2-HDD-SATA-20T/nzq/huggingface_model/metrics"




# ----------------------------
# 实用函数：gold 文档与答案
# ----------------------------
def get_gold_docs(samples: List, dataset_name: str = None) -> List:
    gold_docs = []
    for sample in samples:
        if 'supporting_facts' in sample:  # hotpotqa, 2wikimultihopqa
            gold_title = set([item[0] for item in sample['supporting_facts']])
            gold_title_and_content_list = [item for item in sample['context'] if item[0] in gold_title]
            if dataset_name and dataset_name.startswith('hotpotqa'):
                gold_doc = [item[0] + '\n' + ''.join(item[1]) for item in gold_title_and_content_list]
            else:
                gold_doc = [item[0] + '\n' + ' '.join(item[1]) for item in gold_title_and_content_list]
        elif 'contexts' in sample:
            gold_doc = [item['title'] + '\n' + item['text'] for item in sample['contexts'] if item.get('is_supporting')]
        else:
            assert 'paragraphs' in sample, "`paragraphs` should be in sample, or set not to evaluate retrieval."
            gold_paragraphs = []
            for item in sample['paragraphs']:
                if item.get('is_supporting') is False:
                    continue
                gold_paragraphs.append(item)
            gold_doc = [item['title'] + '\n' + (item.get('text') or item.get('paragraph_text',"")) for item in gold_paragraphs]
        gold_doc = list(set(gold_doc))
        gold_docs.append(gold_doc)
    return gold_docs

def get_gold_answers(samples: List) -> List[List[str]]:
    gold_answers = []
    for sample in samples:
        gold_ans = None
        if 'answer' in sample or 'gold_ans' in sample:
            gold_ans = sample['answer'] if 'answer' in sample else sample['gold_ans']
        elif 'reference' in sample:
            gold_ans = sample['reference']
        elif 'obj' in sample:
            gold_ans = list(set([sample['obj']] + [sample.get('possible_answers')] + [sample.get('o_wiki_title')] + [sample.get('o_aliases')]))
        assert gold_ans is not None
        if isinstance(gold_ans, str):
            gold_ans = [gold_ans]
        assert isinstance(gold_ans, list)
        s = set(gold_ans)
        if 'answer_aliases' in sample:
            s.update(sample['answer_aliases'])
        gold_answers.append(list(s))
    return gold_answers

# -----------------------------------
# 日志过滤器：抓取 TAG 打印的指标
# -----------------------------------
class CatchAllInfo(logging.Filter):
    def __init__(self, summary_ref: Dict[str, Any]):
        super().__init__()
        self.summary_ref = summary_ref

    def filter(self, record: logging.LogRecord) -> bool:
        msg = record.getMessage()
        if m := re.search(r"Total Retrieval Time\s+([\d.]+)s", msg):
            self.summary_ref["total_retrieval_time"] = float(m.group(1))
        if m := re.search(r"Total Recognition Memory Time\s+([\d.]+)s", msg):
            self.summary_ref["total_recognition_time"] = float(m.group(1))
        if m := re.search(r"Total PPR Time\s+([\d.]+)s", msg):
            self.summary_ref["total_ppr_time"] = float(m.group(1))
        if m := re.search(r"Total Misc Time\s+([\d.]+)s", msg):
            self.summary_ref["total_misc_time"] = float(m.group(1))
        if "Evaluation results for retrieval" in msg and (m := re.search(r"\{.*\}", msg)):
            try:
                self.summary_ref["retrieval"] = json.loads(m.group(0).replace("'", '"'))
            except Exception:
                pass
        if "Evaluation results for QA" in msg and (m := re.search(r"\{.*\}", msg)):
            try:
                self.summary_ref["qa"] = json.loads(m.group(0).replace("'", '"'))
            except Exception:
                pass
        return True

# ----------------------------
# 消融实验：alpha × ppr_topk
# ----------------------------
def run_ablation(
    hipporag: TAG,
    queries: List[str],
    gold_docs: Optional[List[List[str]]] = None,
    gold_answers: Optional[List[List[str]]] = None,
    alphas: List[float] = (0.0, 0.5, 1.0),
    ppr_topk_list: List[int] = (5, 8, 12),
    dense_rerank_topk: int = 200,
    save_csv: Optional[str] = None,
) -> List[Dict[str, Any]]:
    # --- 内置：日志捕手与解析 ---
    class _QACatcher(logging.Filter):
        def __init__(self):
            super().__init__()
            self.payload = None  # dict
        def filter(self, record: logging.LogRecord) -> bool:
            msg = record.getMessage()
            if "Evaluation results for QA" in msg:
                m = re.search(r"\{.*\}", msg)
                if m:
                    try:
                        self.payload = json.loads(m.group(0).replace("'", '"'))
                    except Exception:
                        pass
            return True

    def _extract_em_f1(obj) -> Tuple[Optional[float], Optional[float]]:
        if not isinstance(obj, dict):
            return None, None
        em = obj.get("ExactMatch", obj.get("EM", obj.get("exact_match")))
        f1 = obj.get("F1", obj.get("f1"))
        try: em = float(em) if em is not None else None
        except: em = None
        try: f1 = float(f1) if f1 is not None else None
        except: f1 = None
        return em, f1

    results = []
    headers = ["alpha","ppr_topk","Recall@1","Recall@5","Recall@20","Recall@100",
               "EM","F1","RetrievalTime(s)","PPRTime(s)","RMTime(s)","WallClock(s)"]

    print("\n===== Running Ablation =====")
    print("Grid:", "alpha=", alphas, "; ppr_topk=", ppr_topk_list, "; dense_rerank_topk=", dense_rerank_topk)

    for a in alphas:
        for k in ppr_topk_list:
            # 设置全局超参（立即生效）
            hipporag.global_config.dense_fuse_alpha  = float(a)
            hipporag.global_config.ppr_topk          = int(k)
            hipporag.global_config.dense_rerank_topk = int(dense_rerank_topk)

            # 清零计时器
            hipporag.all_retrieval_time = 0.0
            hipporag.rerank_time = 0.0
            hipporag.ppr_time = 0.0

            t0 = time.time()
            # 检索
            ret_out = hipporag.retrieve(queries=queries, gold_docs=gold_docs)
            retrieval_metrics = ret_out[1] if (isinstance(ret_out, tuple) and len(ret_out) > 1) else None

            # QA（一次调用 + 日志兜底）
            em, f1 = None, None
            if gold_answers is not None:
                catcher = _QACatcher()
                tag_logger = logging.getLogger("src.hipporag.TAG")
                tag_logger.addFilter(catcher)
                try:
                    qa_out = hipporag.rag_qa(queries=queries, gold_docs=gold_docs, gold_answers=gold_answers)
                    # 先尝试从返回值拿
                    if isinstance(qa_out, dict):
                        em, f1 = _extract_em_f1(qa_out)
                    elif isinstance(qa_out, tuple) and len(qa_out) >= 1 and isinstance(qa_out[0], dict):
                        em, f1 = _extract_em_f1(qa_out[0])
                    # 返回值没有则用日志兜底
                    if (em is None and f1 is None) and catcher.payload:
                        em, f1 = _extract_em_f1(catcher.payload)
                finally:
                    tag_logger.removeFilter(catcher)

            t1 = time.time()

            # 检索指标
            rec1 = rec5 = rec20 = rec100 = None
            if retrieval_metrics is not None:
                rec1   = retrieval_metrics.get("Recall@1")
                rec5   = retrieval_metrics.get("Recall@5")
                rec20  = retrieval_metrics.get("Recall@20")
                rec100 = retrieval_metrics.get("Recall@100")

            row = {
                "alpha": a,
                "ppr_topk": k,
                "Recall@1": rec1, "Recall@5": rec5, "Recall@20": rec20, "Recall@100": rec100,
                "EM": em, "F1": f1,
                "RetrievalTime(s)": round(getattr(hipporag, "all_retrieval_time", t1 - t0), 2),
                "PPRTime(s)": round(getattr(hipporag, "ppr_time", 0.0), 2),
                "RMTime(s)": round(getattr(hipporag, "rerank_time", 0.0), 2),
                "WallClock(s)": round(t1 - t0, 2),
            }
            results.append(row)

            # 在线打印单条
            printable = []
            for h in headers:
                v = row.get(h, "")
                printable.append(f"{v:.4f}" if isinstance(v, float) else str(v))
            print("\t".join(printable))

    # 汇总打印
    print("\n===== Ablation Results =====")
    print("\t".join(headers))
    for r in results:
        printable = []
        for h in headers:
            v = r.get(h, "")
            printable.append(f"{v:.4f}" if isinstance(v, float) else str(v))
        print("\t".join(printable))

    # 保存CSV（可选）
    if save_csv:
        try:
            import csv
            with open(save_csv, "w", newline="", encoding="utf-8") as f:
                writer = csv.DictWriter(f, fieldnames=headers)
                writer.writeheader()
                for r in results:
                    writer.writerow({h: r.get(h, "") for h in headers})
            print(f"[Ablation] Saved CSV: {save_csv}")
        except Exception as e:
            print(f"[Ablation] Failed to save CSV: {e}")

    return results


# -------------
# 主入口
# -------------
def main():
    parser = argparse.ArgumentParser(description="HippoRAG/TAG — Batch Ablation Runner")
    parser.add_argument('--dataset', type=str, default='musique', help='Dataset name')
    parser.add_argument('--save_dir', type=str, default='outputs', help='Save directory root')
    parser.add_argument('--llm_base_url', type=str, default=API_URL, help='LLM base URL')
    parser.add_argument('--llm_name', type=str, default=MODEL, help='LLM name')
    parser.add_argument('--embedding_name', type=str, default=EMBEDDING_MODEL, help='embedding model name')
    parser.add_argument('--azure_endpoint', type=str, default=API_URL, help='Azure Endpoint URL')
    parser.add_argument('--azure_embedding_endpoint', type=str, default=EMBEDDING_API_URL, help='Azure Embedding Endpoint')


    parser.add_argument('--force_index_from_scratch', type=str, default='false',
                        help='True = ignore existing index/graph; rebuild')
    parser.add_argument('--force_openie_from_scratch', type=str, default='false',
                        help='False = try reuse openie results if exist')
    parser.add_argument('--openie_mode', choices=['online', 'offline'], default='online')

    parser.add_argument('--select_number', type=int, default=50, help='Use first N samples for quick ablation')

    # 实验超参（默认先跑纯PPR）
    parser.add_argument('--ppr_topk', type=int, default=8, help='Top-K entities for paragraph aggregation (run_ppr_new)')
    parser.add_argument('--dense_rerank_topk', type=int, default=200, help='Top-K passages (post-PPR) for dense rerank')
    parser.add_argument('--dense_fuse_alpha', type=float, default=0.0, help='0=PPR only, 1=Dense only')

    # 消融网格
    parser.add_argument('--ablate_alphas', type=str, default="0.0,0.5,1.0", help='Comma list, e.g., "0.0,0.5,1.0"')
    parser.add_argument('--ablate_ppr_topk', type=str, default="5,8,12", help='Comma list, e.g., "5,8,12"')
    parser.add_argument('--save_csv', type=str, default=None, help='If set, save ablation results to CSV path')
    parser.add_argument('--do_single_run_first', action='store_true', help='Run one single config before ablation')

    args = parser.parse_args()

    # 读取数据
    corpus_path = f"reproduce/dataset/{args.dataset}_corpus.json"
    with open(corpus_path, "r", encoding="utf-8") as f:
        corpus = json.load(f)
    docs = [f"{doc['title']}\n{doc['text']}" for doc in corpus]

    samples = json.load(open(f"reproduce/dataset/{args.dataset}.json", "r", encoding="utf-8"))
    all_queries = [s['question'] for s in samples]
    gold_answers = get_gold_answers(samples)

    # gold_docs 可能不可得，做保护
    try:
        gold_docs = get_gold_docs(samples, args.dataset)
        assert len(all_queries) == len(gold_docs) == len(gold_answers)
    except Exception as e:
        logging.warning(f"[WARN] gold_docs unavailable or mismatched: {e}")
        gold_docs = None

    # 采样子集
    sel = int(args.select_number)
    all_queries = all_queries[:sel]
    if gold_docs is not None:
        gold_docs = gold_docs[:sel]
    gold_answers = gold_answers[:sel]

    # 保存目录
    if args.save_dir == 'outputs':
        save_dir = f"{args.save_dir}/{args.dataset}"
    else:
        save_dir = f"{args.save_dir}_{args.dataset}"

    # 全局配置
    config = BaseConfig(
        save_dir=save_dir,
        llm_base_url=args.llm_base_url,
        llm_name=args.llm_name,
        azure_endpoint=args.azure_endpoint,
        azure_embedding_endpoint=args.azure_embedding_endpoint,
        dataset=args.dataset,
        embedding_model_name=args.embedding_name,
        force_index_from_scratch=string_to_bool(args.force_index_from_scratch),
        force_openie_from_scratch=string_to_bool(args.force_openie_from_scratch),
        rerank_dspy_file_path="src/hipporag/prompts/dspy_prompts/filter_llama3.3-70B-Instruct.json",
        retrieval_top_k=200,
        linking_top_k=5,
        max_qa_steps=3,
        qa_top_k=5,
        graph_type="facts_and_sim_passage_node_unidirectional",
        embedding_batch_size=5,
        max_new_tokens=None,
        corpus_len=len(corpus),
        openie_mode=args.openie_mode,

        # 新增实验超参（初始单跑配置）
        ppr_topk=int(args.ppr_topk),
        dense_rerank_topk=int(args.dense_rerank_topk),
        dense_fuse_alpha=float(args.dense_fuse_alpha),
    )

    # 基础日志
    logging.basicConfig(level=logging.INFO)
    logger = logging.getLogger("main_azure")

    # 初始化与建索引
    hipporag = TAG(global_config=config)
    hipporag.topic_index(docs)  # 一次建索引，多次复用

    # 捕获 TAG 打印的总览指标
    summary: Dict[str, Any] = {
        "dataset": args.dataset,
        "llm": args.llm_name,
        "embedding": args.embedding_name,
        "total_retrieval_time": None,
        "total_recognition_time": None,
        "total_ppr_time": None,
        "total_misc_time": None,
        "retrieval": {},
        "qa": {},
    }
    logging.getLogger("src.hipporag.TAG").addFilter(CatchAllInfo(summary))

    # —— 单跑（可选）
    if args.do_single_run_first:
        logger.info("[SingleRun] Start with current config (no ablation).")
        hipporag.rag_qa(queries=all_queries, gold_docs=gold_docs, gold_answers=gold_answers)

        print("\n" + "="*70)
        print(f"单跑配置 -> 数据集: {summary['dataset']} | LLM: {summary['llm']} | Embedding: {summary['embedding']}")
        print("-"*70)
        print("各阶段耗时:")
        print(f"  Total Retrieval Time : {summary['total_retrieval_time'] if summary['total_retrieval_time'] is not None else 'N/A'}s")
        print(f"  Recognition Memory   : {summary['total_recognition_time'] if summary['total_recognition_time'] is not None else 'N/A'}s")
        print(f"  PPR Time             : {summary['total_ppr_time'] if summary['total_ppr_time'] is not None else 'N/A'}s")
        print(f"  Misc Time            : {summary['total_misc_time'] if summary['total_misc_time'] is not None else 'N/A'}s")

        print("\n检索评价:")
        if summary["retrieval"]:
            for k, v in summary["retrieval"].items():
                try: print(f"  {k:<12}: {float(v):.4f}")
                except: print(f"  {k:<12}: {v}")
        else:
            print("  N/A")

        print("\nQA 评价:")
        if summary["qa"]:
            for k, v in summary["qa"].items():
                try: print(f"  {k:<12}: {float(v):.4f}")
                except: print(f"  {k:<12}: {v}")
        else:
            print("  N/A")
        print("="*70)

    # —— 消融网格
    alphas = [float(x) for x in args.ablate_alphas.split(",") if x.strip() != ""]
    ppr_topk_list = [int(x) for x in args.ablate_ppr_topk.split(",") if x.strip() != ""]




    _ = run_ablation(
        hipporag=hipporag,
        queries=all_queries,
        gold_docs=gold_docs,
        gold_answers=gold_answers,
        alphas=alphas,
        ppr_topk_list=ppr_topk_list,
        dense_rerank_topk=int(args.dense_rerank_topk),
        save_csv=args.save_csv
    )

if __name__ == "__main__":
    # 避免多余的 Tokenizers 并行告警
    os.environ.setdefault("TOKENIZERS_PARALLELISM", "false")
    main()
