"""Re-judge Lite hits with gpt-4o-mini to check if they pass the pre-filter threshold."""
import sys
from pathlib import Path
sys.path.insert(0, str(Path(__file__).parent.parent))

import asyncio, json, re, os, glob, logging
from openai import AsyncOpenAI
from config import config
from core.judge import HIT_VERIFY_SYSTEM, HIT_VERIFY_TEMPLATE

logging.basicConfig(level=logging.WARNING)

mini_client = AsyncOpenAI(
    api_key=config["api"]["judge_mini"]["api_key"],
    base_url=config["api"]["judge_mini"]["base_url"],
    timeout=config["api"]["judge_mini"]["timeout_s"],
)
mini_model = config["api"]["judge_mini"]["model"]
threshold = config["api"]["hit_prefilter_threshold"]

sem = asyncio.Semaphore(10)

# Unicode checkmark for matching in reports
HIT_MARKER = "Predictive Hit: \u2713"


async def mini_judge(hyp_content, paper_content, arxiv_id):
    prompt = HIT_VERIFY_TEMPLATE.format(
        hypothesis_content=hyp_content,
        paper_title="(unknown)",
        paper_published="2025-01-01",
        paper_arxiv_id=arxiv_id,
        paper_content=paper_content,
    )
    async with sem:
        for attempt in range(3):
            response = await mini_client.chat.completions.create(
                model=mini_model,
                messages=[
                    {"role": "system", "content": HIT_VERIFY_SYSTEM},
                    {"role": "user", "content": prompt},
                ],
                temperature=0.1,
            )
            content = response.choices[0].message.content or ""
            if content.strip():
                break
            await asyncio.sleep(1)

    score_match = re.search(r"Avg Score:\s*([\d.]+)", content)
    return float(score_match.group(1)) if score_match else 0.0


async def process_experiment(exp_dir):
    summary_path = os.path.join(exp_dir, "batch_summary.json")
    with open(summary_path) as f:
        summary = json.load(f)

    topics_with_hits = [t for t in summary if t.get("hit_rate", 0) > 0]
    print(f"Topics with hits: {len(topics_with_hits)}")

    all_checks = []
    for t in topics_with_hits:
        slug = t["slug"]
        report_files = glob.glob(f"{exp_dir}/all_reports/*{slug}*report.md")
        if not report_files:
            continue
        with open(report_files[0]) as f:
            content = f.read()

        sections = re.split(r"### (hyp-.*?\.md)", content)
        for i in range(1, len(sections), 2):
            hyp_name = sections[i]
            hyp_section = sections[i + 1] if i + 1 < len(sections) else ""

            if HIT_MARKER not in hyp_section:
                continue

            after_hit = hyp_section.split(HIT_MARKER)[1][:500]
            matched_papers = re.findall(r"\[(\d{4}\.\d{5}(?:v\d+)?)\]", after_hit)

            hyp_path = os.path.join(exp_dir, "topics", slug, "metabolism", "hypotheses", hyp_name)
            if not os.path.exists(hyp_path):
                continue
            hyp_content = open(hyp_path).read()

            for arxiv_id in matched_papers[:3]:
                clean_id = arxiv_id.split("v")[0]
                cache_path = os.path.join(exp_dir, "topics", slug, "metabolism", "fulltext_cache", f"{clean_id}.txt")
                if not os.path.exists(cache_path):
                    found = glob.glob(f"{exp_dir}/topics/*/metabolism/fulltext_cache/{clean_id}.txt")
                    if found:
                        cache_path = found[0]
                    else:
                        continue

                paper_content = open(cache_path).read()[:5000]
                all_checks.append({
                    "slug": slug,
                    "hyp_name": hyp_name,
                    "arxiv_id": arxiv_id,
                    "hyp_content": hyp_content,
                    "paper_content": paper_content,
                })

    print(f"Total hit-paper pairs to mini-judge: {len(all_checks)}\n")

    results = []
    for check in all_checks:
        score = await mini_judge(check["hyp_content"], check["paper_content"], check["arxiv_id"])
        status = "PASS" if score >= threshold else "FAIL"
        results.append({
            "slug": check["slug"],
            "hyp": check["hyp_name"],
            "arxiv_id": check["arxiv_id"],
            "mini_score": score,
            "would_pass": score >= threshold,
        })
        print(f"  {check['slug'][:35]:35s} {check['hyp_name']:25s} {check['arxiv_id']:15s} mini={score:.1f} {status}")

    passed = sum(1 for r in results if r["would_pass"])
    failed = sum(1 for r in results if not r["would_pass"])
    print(f"\n{'=' * 60}")
    print(f"Total: {len(results)} hit-paper pairs")
    print(f"Mini PASS (>= {threshold}): {passed} ({passed / len(results) * 100:.0f}%)")
    print(f"Mini FAIL (< {threshold}): {failed} ({failed / len(results) * 100:.0f}%)")

    with open("/tmp/abl3_mini_recheck.json", "w") as f:
        json.dump(results, f, indent=2)
    return results


if __name__ == "__main__":
    exp_dir = sys.argv[1] if len(sys.argv) > 1 else "results/lite_20260405_000653"
    asyncio.run(process_experiment(exp_dir))
