"""
Benchmark: 10 对经典 arXiv 论文语义距离评测
从 benchmark-landmark-pairs.ts 移植
验证 embedding 模型的语义感知能力
"""
import sys
import urllib.request
import re
import math
from openai import OpenAI
from config import config

client = OpenAI(
    api_key=config["api"]["embedding"]["api_key"],
    base_url=config["api"]["embedding"]["base_url"]
)

PAIRS = [
    {"name": "LoRA v DoRA",                  "a": "2106.09685", "b": "2402.09353"},
    {"name": "AlexNet v ResNet",             "a": "1404.5997",  "b": "1512.03385"},
    {"name": "Transformer v FlashAttention", "a": "1706.03762", "b": "2205.14135"},
    {"name": "GAN v WGAN",                   "a": "1406.2661",  "b": "1701.07875"},
    {"name": "YOLOv1 v YOLOv3",             "a": "1506.02640", "b": "1804.02767"},
    {"name": "BERT v RoBERTa",              "a": "1810.04805", "b": "1907.11692"},
    {"name": "PPO v DPO",                   "a": "1707.06347", "b": "2305.18290"},
    {"name": "DDPM v DDIM",                 "a": "2006.11239", "b": "2010.02502"},
    {"name": "SimCLR v MoCO",               "a": "2002.05709", "b": "1911.05722"},
    {"name": "GCN v GraphSAGE",             "a": "1609.02907", "b": "1706.02216"},
]


def get_abstract(arxiv_id: str) -> str | None:
    url = f"http://export.arxiv.org/api/query?id_list={arxiv_id}"
    try:
        with urllib.request.urlopen(url, timeout=15) as resp:
            text = resp.read().decode("utf-8")
        match = re.search(r"<summary>([\s\S]*?)</summary>", text)
        if match:
            return match.group(1).strip().replace("\n", " ")
    except Exception:
        pass
    return None


def embed(text: str) -> list[float]:
    r = client.embeddings.create(
        model=config["api"]["embedding"]["model"],
        input=text
    )
    return r.data[0].embedding


def cosine(a: list[float], b: list[float]) -> float:
    dot = sum(x * y for x, y in zip(a, b))
    na = math.sqrt(sum(x * x for x in a))
    nb = math.sqrt(sum(x * x for x in b))
    return dot / (na * nb) if na and nb else 0.0


def main():
    print(f"\n=============================================================")
    print(f"📐 10 对经典论文地标相似度评测 | 模型: {config['api']['embedding']['model']}")
    print(f"=============================================================\n")

    for p in PAIRS:
        print(f"⏳ 正在处理 [{p['name']}]...", end="", flush=True)
        try:
            abs_a = get_abstract(p["a"])
            abs_b = get_abstract(p["b"])
            if not abs_a or not abs_b:
                print(" ❌ 抓取失败 (ID 可能无效或网络超时)")
                continue
            va = embed(abs_a)
            vb = embed(abs_b)
            sim = cosine(va, vb)
            dist = 1 - sim
            print(f" ✅ Distance: {dist:.4f} (Sim: {sim:.4f})")
        except Exception as e:
            print(f" ❌ 运行出错: {e}")

    print()


if __name__ == "__main__":
    main()
