#!/usr/bin/env python3
"""
On-the-fly BibTeX generation for missing citation keys.
All paths are anonymised via ANON_PROJECT_ROOT.
"""

import re
import json
import os
import arxiv
from create_util import *

client = arxiv.Client()

# ---------- anonymous base ----------
BASE_DIR = os.environ.get("ANON_PROJECT_ROOT", "./anonymous_root")

# ---------- helpers ----------
def cite_already(text: str) -> list[str]:
    """Return all citation keys appearing in \cite{}, \citep{}, \citet{}."""
    raw = re.findall(r"\\cite[tp]?{([^}]+)}", text)
    keys = []
    for r in raw:
        keys.extend(k.strip() for k in r.split(","))
    return list({k for k in keys if k})

def split_citation_key(key: str):
    """Try to extract (author, year, tail) from key like 'cao2020xxx'."""
    m = re.match(r"^([a-zA-Z]+)(\d{4})([a-zA-Z]*)$", key)
    return m.groups() if m else None

def search_arxiv_by_cite(url: str):
    """Fetch arXiv metadata from HTML URL."""
    arxiv_id = url.split("/")[-1].split("v")[0]
    search = arxiv.Search(id_list=[arxiv_id])
    try:
        for res in client.results(search):
            return res
    except Exception as e:
        logger.warning(f"arXiv fetch failed: {e}")
    return None

def build_bibtex_entry(cite_key: str, paper: arxiv.Result) -> str:
    authors = " and ".join(a.name for a in paper.authors)
    title = paper.title.replace("{", "\\{").replace("}", "\\}")
    year = paper.published.year
    return f"""
@misc{{{cite_key},
title={{{title}}},
author={{{authors}}},
year={{{year}}},
archivePrefix={{arXiv}},
primaryClass={{{paper.primary_category}}},
url={{{paper.entry_id}}}
}}"""

def get_cite_real(cite_key: str, bib_path: str):
    """Generate BibTeX for a single key if missing."""
    with open(bib_path, "r", encoding="utf-8") as f:
        if cite_key in f.read():
            return

    parts = split_citation_key(cite_key)
    if not parts:
        return
    author, year, tail = parts

    query = f"au:{author} AND ti:{tail} AND submittedDate:[{int(year)-1}00000000 TO {int(year)+1}00000000]"
    search = arxiv.Search(query=query, max_results=5)
    try:
        results = list(client.results(search))
        if results:
            entry = build_bibtex_entry(cite_key, results[-1])
            with open(bib_path, "a", encoding="utf-8") as f:
                f.write(entry)
    except Exception as e:
        logger.warning(f"BibTeX generation failed for {cite_key}: {e}")

def get_cite_from_llm_and_arxiv(content: str, bib_path: str):
    """Generate BibTeX for all missing keys found in content."""
    missing = [k for k in cite_already(content) if k not in open(bib_path, encoding="utf-8").read()]
    for key in missing:
        get_cite_real(key, bib_path)

def delete_cite(paper_text: str, tex_path: str):
    """Remove citations whose keys are absent from the global .bib file."""
    bib_path = os.path.join(os.path.dirname(tex_path), "iclr2025_conference.bib")
    with open(bib_path, encoding="utf-8") as f:
        bib_keys = set(re.findall(r"@misc{([^,]+),", f.read()))

    cleaned = paper_text
    for key in cite_already(paper_text):
        if key not in bib_keys:
            for cmd in [rf"\cite{{{key}}}", rf"\citet{{{key}}}", rf"\citep{{{key}}}"]:
                cleaned = cleaned.replace(cmd, "")
            # remove from multi-cite lists
            cleaned = re.sub(rf"{key},?\s?", "", cleaned)
    with open(tex_path, "w", encoding="utf-8") as f:
        f.write(cleaned)