#!/usr/bin/env python3
# run.py
import os, re, glob, csv, html
from pathlib import Path

from pylatexenc.latex2text import LatexNodes2Text
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.vectorstores import Chroma
from langchain.schema import Document
from langchain_community.embeddings import HuggingFaceEmbeddings

# --------------------------------------------------
# 1. Batch convert .tex to plain text
# --------------------------------------------------

def tex_to_text(tex_path: str) -> str:
    """
    Convert .tex content to plain text.
    First use pylatexenc, if exception thrown then fallback to fast regex cleanup.
    """
    with open(tex_path, encoding="utf-8") as f:
        latex = f.read()

    try:
        return LatexNodes2Text().latex_to_text(latex)
    except Exception as e:
        # Fallback solution: remove all commands, braces, $...$, keep readable text
        print(f"[WARN] pylatexenc failed on {tex_path}: {e}")
        text = re.sub(r"(?s)\\(?:[a-zA-Z]+|.)[^}]*(?:\{[^}]*\})*", "", latex)
        text = re.sub(r"\$[^$]*\$", " ", text)          # 去掉行内/行间公式
        text = re.sub(r"\s+", " ", text)                # 合并空白
        text = html.unescape(text)
        return text.strip()

# run.py 里的新 prepare_txt_files
def prepare_txt_files(papers_dir="papers", txt_dir="txts"):
    """
    papers_dir/
    ├─ AAA/
    │  └─ main.tex
    ├─ BBB/
    │  └─ paper.tex
    └─ CCC/DDD/
       └─ submission.tex

    Output to txts/
    ├─ AAA_main.txt
    ├─ BBB_paper.txt
    └─ CCC_DDD_submission.txt
    """
    Path(txt_dir).mkdir(exist_ok=True)

    # Recursively find all .tex files
    for tex_path in Path(papers_dir).rglob("*.tex"):
        # Relative path prefix to papers_dir + filename (without .tex)
        rel_parts = tex_path.relative_to(papers_dir).with_suffix("").parts
        new_stem = "_".join(rel_parts)          # For example AAA_main
        txt_path = Path(txt_dir) / f"{new_stem}.txt"

        if txt_path.exists():
            continue
        text = tex_to_text(str(tex_path))
        txt_path.write_text(text, encoding="utf-8")

    return txt_dir

# --------------------------------------------------
# 2. Build Chroma vector database
# --------------------------------------------------
def build_vectorstore(txt_dir="txts", db_dir="chroma_db"):
    docs = []
    for txt in glob.glob(os.path.join(txt_dir, "*.txt")):
        with open(txt, encoding="utf-8") as f:
            content = f.read()
        # Each txt file as one Document, with filename attached
        docs.append(Document(page_content=content,
                             metadata={"source": Path(txt).stem}))
    text_splitter = RecursiveCharacterTextSplitter(
        chunk_size=1200, chunk_overlap=200, add_start_index=True)
    splits = text_splitter.split_documents(docs)

    embeddings = HuggingFaceEmbeddings(model_name="")
    vectordb = Chroma.from_documents(
        splits, embeddings, persist_directory=db_dir)
    return vectordb

# --------------------------------------------------
# 3. RAG search HotpotQA - Relaxed search conditions
# --------------------------------------------------
def search_hotpotqa(vectordb):
    # Significantly expand query terms, including more possible expressions
    queries = [
        "HotpotQA",
        "hotpot QA", 
        "hotpot",
        "multi-hop question answering",
        "multi-hop QA",
        "question answering dataset",
        "QA benchmark",
        "QA dataset",
        "evaluation dataset",
        "baseline comparison",
        "experimental results",
        "dataset performance",
        "compared with other methods",
        "benchmark results",
        "table results",
        "accuracy",
        "F1 score",
        "performance comparison",
        "evaluation on",
        "tested on",
        "experiments on",
        "results on"
    ]
    retriever = vectordb.as_retriever(search_kwargs={"k": 10})  # Increase retrieval count
    hits = []
    for q in queries:
        for doc in retriever.invoke(q):
            content = doc.page_content
            
            # More relaxed matching conditions
            hotpot_matches = len(re.findall(r"hotpot", content, flags=re.I))
            qa_matches = len(re.findall(r"\bQA\b|question.{0,30}answer|answering", content, flags=re.I))
            
            # Experiment-related words
            experiment_words = re.findall(r"evaluat|compar|benchmark|result|accura|perform|table|baseline|experiment|test", content, flags=re.I)
            
            # Dataset-related words
            dataset_words = re.findall(r"dataset|data|corpus|collection", content, flags=re.I)
            
            # Extremely relaxed conditions: any of the following conditions can be met
            conditions = [
                hotpot_matches > 0,  # Directly mention hotpot
                (qa_matches > 0 and len(experiment_words) > 1),  # QA + experiment words
                (qa_matches > 0 and len(dataset_words) > 1),     # QA + dataset words
                (len(experiment_words) > 2 and len(dataset_words) > 1)  # Experiment + dataset words
            ]
            
            if any(conditions):
                # Recalculate score, lower threshold
                score = hotpot_matches * 5 + qa_matches * 2 + len(experiment_words) + len(dataset_words)
                
                # Find the most relevant snippet
                if hotpot_matches > 0:
                    idx = content.lower().find("hotpot")
                elif qa_matches > 0:
                    qa_idx = re.search(r"\bQA\b|question.{0,30}answer", content, flags=re.I)
                    idx = qa_idx.start() if qa_idx else 0
                else:
                    # Find experiment-related content
                    exp_idx = re.search(r"evaluat|compar|benchmark|result|experiment", content, flags=re.I)
                    idx = exp_idx.start() if exp_idx else 0
                
                snippet = content[max(0, idx-200):idx+300].replace("\n", " ")
                hits.append({
                    "file": doc.metadata["source"],
                    "query": q,
                    "snippet": snippet,
                    "score": score,
                    "hotpot_direct": hotpot_matches > 0
                })
    
    # Deduplicate and sort by score, prioritize those directly mentioning hotpot
    seen = set()
    dedup = []
    for h in sorted(hits, key=lambda x: (x["hotpot_direct"], x["score"]), reverse=True):
        if h["file"] not in seen:
            seen.add(h["file"])
            dedup.append(h)
    
    return dedup

# --------------------------------------------------
# 4. Summarize CSV
# --------------------------------------------------
def save_csv(hits, csv_path="output/hotpotqa_hits_loose.csv"):
    Path("output").mkdir(exist_ok=True)
    with open(csv_path, "w", newline='', encoding="utf-8") as f:
        writer = csv.DictWriter(f, fieldnames=["file", "query", "snippet", "score", "hotpot_direct"])
        writer.writeheader()
        writer.writerows(hits)
    print(f"✅ Results saved to {csv_path}")

# --------------------------------------------------
# Main process
# --------------------------------------------------
if __name__ == "__main__":
    papers_dir = "single_tex_hotpotqa"
    txt_dir = prepare_txt_files(papers_dir=papers_dir)
    vectordb = build_vectorstore(txt_dir)
    hits = search_hotpotqa(vectordb)
    save_csv(hits)
    print(f"Total {len(hits)} papers suspected of using HotpotQA as comparison dataset.")