#!/usr/bin/env python3
# -*- coding: utf-8 -*-

import argparse
import subprocess
import sys
import re

# python monitor_and_extract.py --tee python -u main_azure.py --dataset hotpotqa
# python monitor_and_extract.py --tee python -u main_azure.py --dataset nq_rear --embedding_name nvidia/NV-Embed-v2
# python monitor_and_extract.py --tee python -u main_azure.py --dataset popqa --embedding_name nvidia/NV-Embed-v2
# python monitor_and_extract.py --tee python -u main_azure.py --dataset musique --embedding_name nvidia/NV-Embed-v2
# python monitor_and_extract.py --tee python -u main_azure.py --dataset 2wikimultihopqa --embedding_name nvidia/NV-Embed-v2
# python monitor_and_extract.py --tee python -u main_azure.py --dataset hotpotqa --embedding_name nvidia/NV-Embed-v2

'''
export OPENAI_API_KEY=2WKBSMb1AE1bEOdmzlIC0N4SGbzqLAQPRe1hUH0cJGirKwtkl8FTJQQJ99BEACfhMk5XJ3w3AAABACOGcWL6
export CUDA_VISIBLE_DEVICES=1
export HF_ENDPOINT=https://hf-mirror.com
export HF_HOME=/root/autodl-tmp/hf_home

cd /root/autodl-tmp/.autodl/hipporag_echo_2

conda activate hipporag
python monitor_and_extract.py --tee python -u main_azure.py --dataset XXX

hotpotqa
2wikimultihopqa
musique
'''


# 正则匹配
TIME_RE = re.compile(r"Total Retrieval Time\s+([0-9]+(?:\.[0-9]+)?)s")
QA_LINE_RE = re.compile(r"Evaluation results for QA:\s*(\{.*\})", re.IGNORECASE)
EM_RE = re.compile(r"(?:['\"])?ExactMatch(?:['\"])?:\s*([0-9.]+)")
F1_RE = re.compile(r"(?:['\"])?F1(?:['\"])?:\s*([0-9.]+)")

# 识别包含 Recall@K 的行（或开始位置）
RECALL_INLINE_RE = re.compile(r"\{.*Recall@\d+.*\}")
RECALL_START_RE  = re.compile(r"\{.*Recall@\d+.*")   # 可能是多行起始
RECALL_END_RE    = re.compile(r".*\}")               # 结束到 '}'


def parse_args():
    p = argparse.ArgumentParser(
        description="Run a command, capture output, extract Retrieval Time / EM / F1 / Recall@K, then print summary."
    )
    p.add_argument("--tee", action="store_true",
                   help="边跑边把原始日志打印到终端（默认只输出最后摘要）")
    p.add_argument("cmd", nargs=argparse.REMAINDER,
                   help="要执行的命令，例如：python -u main_azure.py --dataset hotpotqa")
    return p.parse_args()


def try_parse_recall_dict(text: str):
    """用 ast.literal_eval 安全解析 {'Recall@1': 0.27, ...}，失败返回 None。"""
    try:
        data = ast.literal_eval(text.strip())
        if isinstance(data, dict) and any(k.startswith("Recall@") for k in map(str, data.keys())):
            return data
    except Exception:
        pass
    return None


def main():
    args = parse_args()
    # 去掉可能的 '--' 分隔符
    cmd = [c for c in args.cmd if c != "--"]

    if not cmd:
        print("用法示例：\n  python monitor_and_extract.py --tee python -u main_azure.py --dataset hotpotqa")
        sys.exit(1)

    proc = subprocess.Popen(
        cmd,
        stdout=subprocess.PIPE,
        stderr=subprocess.STDOUT,
        text=True,
        bufsize=1,
        encoding="utf-8",
        errors="replace",
    )

    retrieval_time = None
    em = None
    f1 = None
    recall_dict = None  # 保存最后一次解析到的 Recall@K 结果

    # 多行缓冲（当 Recall@K 跨行打印时）
    buffering_recall = False
    recall_buf_lines = []

    for line in proc.stdout:
        if args.tee:
            sys.stdout.write(line)

        # 1) Retrieval Time
        if retrieval_time is None:
            mt = TIME_RE.search(line)
            if mt:
                retrieval_time = float(mt.group(1))

        # 2) EM / F1
        #   a) 如果整行就带了 QA 结果的大括号
        mq = QA_LINE_RE.search(line)
        if mq:
            blob = mq.group(1)
            me = EM_RE.search(blob)
            mf = F1_RE.search(blob)
            if me:
                em = float(me.group(1))
            if mf:
                f1 = float(mf.group(1))
        else:
            #   b) 有的日志会分行出现 EM/F1，这里容错（可选）
            if em is None:
                me = EM_RE.search(line)
                if me:
                    em = float(me.group(1))
            if f1 is None:
                mf = F1_RE.search(line)
                if mf:
                    f1 = float(mf.group(1))

        # 3) Recall@K
        #   a) 一行内就打印了完整 dict
        if RECALL_INLINE_RE.search(line):
            maybe = try_parse_recall_dict(line)
            if maybe:
                recall_dict = maybe
                buffering_recall = False
                recall_buf_lines.clear()
                continue

        #   b) 可能是多行，从含 Recall@ 的 '{' 开始缓冲，直到遇到 '}'
        if not buffering_recall and RECALL_START_RE.search(line) and "{" in line:
            buffering_recall = True
            recall_buf_lines = [line]
            # 如果这一行已经闭合 '}'，上面的 INLINE 会先捕到，这里就不会触发
            continue

        if buffering_recall:
            recall_buf_lines.append(line)
            if RECALL_END_RE.search(line):
                # 结束，尝试解析
                chunk = "".join(recall_buf_lines)
                maybe = try_parse_recall_dict(chunk)
                if maybe:
                    recall_dict = maybe
                buffering_recall = False
                recall_buf_lines.clear()

    retcode = proc.wait()

    # 汇总输出
    parts = []
    if retrieval_time is not None:
        parts.append(f"retrieval_time={retrieval_time:.2f}s")
    if em is not None:
        parts.append(f"EM={em:.4f}")
    if f1 is not None:
        parts.append(f"F1={f1:.4f}")
    if recall_dict:
        # 按 K 排序输出
        def key_k(item):
            k, v = item
            try:
                return int(str(k).split("@", 1)[1])
            except Exception:
                return 10**9
        kvs = " ".join([f"R@{str(k).split('@')[-1]}={float(v):.4f}"
                        for k, v in sorted(recall_dict.items(), key=key_k)])
        parts.append(kvs)

    if parts:
        print("Summary:", " | ".join(parts))
    else:
        print("Summary: 数据不完整（未匹配到 Retrieval Time / EM / F1 / Recall@K）")

    sys.exit(retcode)


if __name__ == "__main__":
    main()
