import asyncio
import io
import multiprocessing as mp
import queue
import re
import tarfile
import logging
import requests
from pathlib import Path
from typing import Dict, Optional

from config import config

logger = logging.getLogger("ckm.fulltext")


def _clean_id(arxiv_id: str) -> str:
    return re.sub(r"v\d+$", "", arxiv_id)


def _read_member_text(tf: tarfile.TarFile, member: tarfile.TarInfo) -> Optional[str]:
    try:
        extracted = tf.extractfile(member)
        if extracted is None:
            return None
        return extracted.read().decode("utf-8", errors="ignore")
    except Exception:
        return None


def _build_text_file_map(members: list[tarfile.TarInfo], tf: tarfile.TarFile) -> Dict[str, str]:
    text_files: Dict[str, str] = {}
    for member in members:
        if not member.isfile() or member.name.startswith("._"):
            continue
        if not member.name.endswith((".tex", ".bbl", ".bib", ".sty", ".cls")):
            continue
        content = _read_member_text(tf, member)
        if content is not None:
            text_files[member.name] = content
    return text_files


def _find_main_tex(text_files: Dict[str, str]) -> Optional[tuple[str, str]]:
    """Find the main .tex file (the one containing \\documentclass)."""
    for name, content in text_files.items():
        if name.endswith(".tex") and r"\documentclass" in content:
            return name, content
    return None


def _resolve_include_path(current_file: str, include_target: str, text_files: Dict[str, str]) -> Optional[str]:
    base_dir = Path(current_file).parent
    raw_target = include_target.strip()
    candidates = []

    target_path = Path(raw_target)
    if target_path.suffix:
        candidates.append((base_dir / target_path).as_posix())
        candidates.append(target_path.as_posix())
    else:
        for suffix in (".tex", ".bbl", ".bib", ".sty", ".cls"):
            candidates.append((base_dir / f"{raw_target}{suffix}").as_posix())
            candidates.append(f"{raw_target}{suffix}")

    for candidate in candidates:
        normalized = str(Path(candidate)).replace("\\", "/")
        if normalized in text_files:
            return normalized
    return None


def _expand_tex_includes(file_name: str, text: str, text_files: Dict[str, str], seen: Optional[set[str]] = None) -> str:
    seen = seen or set()
    seen.add(file_name)

    def replace_include(match: re.Match) -> str:
        include_target = match.group(1).strip()
        resolved = _resolve_include_path(file_name, include_target, text_files)
        if not resolved or resolved in seen:
            return "\n"
        return "\n" + _expand_tex_includes(resolved, text_files[resolved], text_files, seen | {resolved}) + "\n"

    expanded = re.sub(r"\\(?:input|include)\{([^}]+)\}", replace_include, text)
    expanded = re.sub(r"\\(?:input|include)\s+([^\s%]+)", replace_include, expanded)
    return expanded


def _strip_latex(tex: str) -> str:
    """Strip LaTeX commands to leave readable prose, preserving section structure."""
    # Remove preamble (everything before \begin{document})
    doc_match = re.search(r"\\begin\{document\}", tex)
    if doc_match:
        tex = tex[doc_match.end():]

    # Remove everything after \bibliography or \begin{thebibliography}
    tex = re.split(r"\\bibliography\b|\\begin\{thebibliography\}|\\end\{document\}", tex)[0]

    # Remove comments
    tex = re.sub(r"(?m)%.*$", "", tex)

    # Replace section commands with readable headers
    tex = re.sub(r"\\section\*?\{([^}]+)\}", r"\n\n## \1\n", tex)
    tex = re.sub(r"\\subsection\*?\{([^}]+)\}", r"\n\n### \1\n", tex)
    tex = re.sub(r"\\subsubsection\*?\{([^}]+)\}", r"\n\n#### \1\n", tex)

    # Unwrap common text commands
    tex = re.sub(r"\\(?:textbf|textit|emph|text|mbox|hbox)\{([^}]+)\}", r"\1", tex)

    # Remove figure/table environments
    tex = re.sub(r"\\begin\{(?:figure|table|algorithm|lstlisting)[^}]*\}[\s\S]*?\\end\{(?:figure|table|algorithm|lstlisting)\}", "", tex)

    # Remove display math but leave a placeholder
    tex = re.sub(r"\\\[[\s\S]*?\\\]", "[equation]", tex)
    tex = re.sub(r"\\begin\{equation\*?\}[\s\S]*?\\end\{equation\*?\}", "[equation]", tex)

    # Remove remaining LaTeX commands
    tex = re.sub(r"\\[a-zA-Z]+\*?(?:\[[^\]]*\])?\{([^}]*)\}", r"\1", tex)
    tex = re.sub(r"\\[a-zA-Z]+\*?", "", tex)
    tex = re.sub(r"[{}]", "", tex)

    # Clean up whitespace
    tex = re.sub(r"\n{3,}", "\n\n", tex)
    return tex.strip()


def _looks_like_stub(text: Optional[str]) -> bool:
    if not text:
        return True

    stripped = text.strip()
    if len(stripped) < 500:
        return True

    alpha_chars = sum(ch.isalpha() for ch in stripped)
    if alpha_chars < 300:
        return True

    lines = [line.strip() for line in stripped.splitlines() if line.strip()]
    if not lines:
        return True

    filenameish_lines = sum(
        1
        for line in lines[:12]
        if line.endswith((".tex", ".bib", ".bbl", ".sty", ".cls")) or "/" in line
    )
    return filenameish_lines >= max(3, len(lines[:12]) // 2)


def _download_tex(arxiv_id: str, timeout_s: int) -> Optional[str]:
    """Download and extract the main .tex source from arXiv."""
    clean_id = _clean_id(arxiv_id)
    url = f"https://arxiv.org/src/{clean_id}"
    try:
        resp = requests.get(url, timeout=(10, timeout_s), headers={"User-Agent": "CKM-Eval/1.0"})
        if resp.status_code != 200:
            return None
        with tarfile.open(fileobj=io.BytesIO(resp.content), mode="r:gz") as tf:
            text_files = _build_text_file_map(tf.getmembers(), tf)
            main_tex = _find_main_tex(text_files)
            if not main_tex:
                return None
            main_name, main_content = main_tex
            return _expand_tex_includes(main_name, main_content, text_files)
    except Exception:
        return None


def _download_pdf_as_md(arxiv_id: str, temp_dir: Path, timeout_s: int) -> Optional[str]:
    """Fallback: download PDF and convert to Markdown via pymupdf4llm."""
    try:
        import pymupdf4llm
    except ImportError:
        return None

    clean_id = _clean_id(arxiv_id)
    url = f"https://arxiv.org/pdf/{clean_id}.pdf"
    pdf_path = temp_dir / f"{clean_id.replace('/', '_')}.pdf"

    try:
        resp = requests.get(url, timeout=(10, timeout_s), headers={"User-Agent": "CKM-Eval/1.0"})
        if resp.status_code != 200:
            return None
        pdf_path.write_bytes(resp.content)
        md = pymupdf4llm.to_markdown(str(pdf_path))
        # Strip references section
        md = re.split(r"\n#+ (?:References|REFERENCES|Bibliography)", md)[0]
        md = re.sub(r"\n{3,}", "\n\n", md)
        return md
    except Exception:
        return None
    finally:
        if pdf_path.exists():
            pdf_path.unlink()


def get_paper_full_text(arxiv_id: str, cache_dir: Path) -> Optional[str]:
    result = get_paper_content_result(arxiv_id, cache_dir)
    return result.get("text")


def get_paper_content_result(arxiv_id: str, cache_dir: Path, timeout_s: Optional[int] = None) -> Dict[str, Optional[str]]:
    """
    Fetch full paper text with fallback chain:
      1. .tex source (preferred — clean prose, no rendering artifacts)
      2. PDF → Markdown via pymupdf4llm
      3. None (caller falls back to abstract)
    Results are cached to avoid re-downloading.
    """
    cache_dir.mkdir(parents=True, exist_ok=True)
    temp_dir = cache_dir / "temp"
    temp_dir.mkdir(exist_ok=True)
    timeout_s = max(30, int(timeout_s or config["experiment"].get("fulltext_timeout_s", 180)))

    cache_file = cache_dir / f"{_clean_id(arxiv_id).replace('/', '_')}.txt"
    if cache_file.exists():
        cached = cache_file.read_text(encoding="utf-8")
        if not _looks_like_stub(cached):
            return {"text": cached, "source": "cache-fulltext", "counted_fulltext": True}
        logger.warning("[Fulltext] %s: cached text looks incomplete, refreshing", arxiv_id)

    # 1. Try .tex source
    logger.info("[Fulltext] %s: trying .tex source...", arxiv_id)
    tex = _download_tex(arxiv_id, timeout_s)
    tex_text = None
    if tex:
        tex_text = _strip_latex(tex)
        if not _looks_like_stub(tex_text):
            cache_file.write_text(tex_text, encoding="utf-8")
            logger.info("[Fulltext] %s: .tex source OK (%d chars)", arxiv_id, len(tex_text))
            return {"text": tex_text, "source": "tex", "counted_fulltext": True}
        logger.warning(
            "[Fulltext] %s: .tex source looked incomplete after expansion (%d chars), trying PDF...",
            arxiv_id,
            len(tex_text),
        )

    # 2. Fallback: PDF → Markdown
    if tex_text is None:
        logger.info("[Fulltext] %s: .tex unavailable, trying PDF...", arxiv_id)
    md = _download_pdf_as_md(arxiv_id, temp_dir, timeout_s)
    if md:
        cache_file.write_text(md, encoding="utf-8")
        logger.info("[Fulltext] %s: PDF fallback OK (%d chars)", arxiv_id, len(md))
        return {"text": md, "source": "pdf", "counted_fulltext": True}

    if tex_text:
        cache_file.write_text(tex_text, encoding="utf-8")
        logger.warning("[Fulltext] %s: PDF fallback failed, keeping partial .tex text", arxiv_id)
        return {"text": tex_text, "source": "partial-tex", "counted_fulltext": False}

    logger.warning("[Fulltext] %s: all sources failed, will use abstract only", arxiv_id)
    return {"text": None, "source": "abstract-fallback", "counted_fulltext": False}


def _fulltext_worker(arxiv_id: str, cache_dir_str: str, timeout_s: int, result_queue) -> None:
    try:
        result = get_paper_content_result(arxiv_id, Path(cache_dir_str), timeout_s=timeout_s)
        result_queue.put({"ok": True, "result": result})
    except Exception as exc:
        result_queue.put({"ok": False, "error": f"{type(exc).__name__}: {exc}"})


def get_paper_content_result_bounded(arxiv_id: str, cache_dir: Path, timeout_s: int) -> Dict[str, Optional[str]]:
    ctx = mp.get_context("spawn")
    result_queue = ctx.Queue()
    process = ctx.Process(
        target=_fulltext_worker,
        args=(arxiv_id, str(cache_dir), timeout_s, result_queue),
    )
    process.start()
    process.join(timeout_s)

    if process.is_alive():
        logger.warning("[Fulltext] %s: worker exceeded %ds, terminating", arxiv_id, timeout_s)
        process.terminate()
        process.join(5)
        if process.is_alive():
            process.kill()
            process.join(1)
        return {"text": None, "source": "timeout-abstract-fallback", "counted_fulltext": False}

    try:
        payload = result_queue.get_nowait()
    except queue.Empty:
        logger.warning("[Fulltext] %s: worker exited without returning content", arxiv_id)
        return {"text": None, "source": "worker-empty-fallback", "counted_fulltext": False}

    if payload.get("ok"):
        return payload["result"]

    logger.warning("[Fulltext] %s: worker failed: %s", arxiv_id, payload.get("error", "unknown error"))
    return {"text": None, "source": "worker-error-fallback", "counted_fulltext": False}


async def get_paper_content_record(
    arxiv_id: str,
    abstract: str,
    cache_dir: Path,
    timeout_s: int = 30,
    retries: int = 2,
    retry_delay_s: float = 2.0,
) -> Dict[str, Optional[str]]:
    retries = max(1, retries)
    retry_delay_s = max(0.0, retry_delay_s)

    for attempt in range(1, retries + 1):
        result = await asyncio.to_thread(
            get_paper_content_result_bounded,
            arxiv_id,
            cache_dir,
            max(1, timeout_s),
        )
        text = result.get("text")
        counted_fulltext = bool(result.get("counted_fulltext"))
        source = result.get("source", "unknown")

        if counted_fulltext and text:
            return {
                "content": text,
                "fulltext_text": text,
                "source": source,
                "counted_fulltext": True,
            }

        if source == "timeout-abstract-fallback" and attempt < retries:
            logger.warning(
                "[Fulltext] %s: timed out after %ds (attempt %d/%d), retrying in %.1fs",
                arxiv_id,
                timeout_s,
                attempt,
                retries,
                retry_delay_s,
            )
            await asyncio.sleep(retry_delay_s)
            continue

        if source == "timeout-abstract-fallback":
            logger.warning(
                "[Fulltext] %s: timed out after %ds x %d attempt(s), using abstract fallback",
                arxiv_id,
                timeout_s,
                retries,
            )
        else:
            logger.warning(
                "[Fulltext] %s: fulltext unavailable for counting (source=%s), using abstract fallback",
                arxiv_id,
                source,
            )
        return {
            "content": abstract,
            "fulltext_text": None,
            "source": source,
            "counted_fulltext": False,
        }
