#!/usr/bin/env python3
# -*- coding: utf-8 -*-

import os
import re
import time
import json
import pickle
import threading
from collections import defaultdict
from urllib.parse import urljoin, urldefrag, urlparse
from concurrent.futures import ThreadPoolExecutor, as_completed

import requests
from bs4 import BeautifulSoup
from tqdm import tqdm

from selenium import webdriver
from selenium.webdriver.chrome.service import Service
from selenium.webdriver.chrome.options import Options
from selenium.webdriver.common.by import By

# ==============================
# Config & Paths
# ==============================

BASE_ROOT = "https://leanprover-community.github.io/mathlib4_docs/"
NAVBAR_URL = urljoin(BASE_ROOT, "navbar.html")

# Chrome paths via env (override if needed)
chrome_path = os.getenv("CHROME_PATH", os.path.expanduser("~/browsers/chrome/chrome"))
chromedriver_path = os.getenv("CHROMEDRIVER_PATH", os.path.expanduser("~/browsers/chromedriver/chromedriver"))

DATA_DIR = "data/retrieval_data"
os.makedirs(DATA_DIR, exist_ok=True)

DISCOVERED_LINKS_FILE = os.path.join(DATA_DIR, "discovered_links.txt")
OUTPUT_JSON = os.path.join(DATA_DIR, "lean_definitions.json")
INPUT_FILE = os.path.join(DATA_DIR, "lean_definitions.pkl")
FULL_OUTPUT_FILE = os.path.join(DATA_DIR, "partitioned_theorems.pkl")
PREFERRED_OUTPUT_FILE = os.path.join(DATA_DIR, "preferred_partitioned.pkl")
PKL_FILE = os.path.join(DATA_DIR, "lean_definitions.pkl")

# Crawl/Scrape settings
REQUEST_DELAY = float(os.getenv("REQUEST_DELAY", "0.3"))
MAX_RETRIES = int(os.getenv("MAX_RETRIES", "3"))
CRAWL_MAX_PAGES = int(os.getenv("CRAWL_MAX_PAGES", "150000"))  # headroom
MAX_WORKERS = int(os.getenv("MAX_WORKERS", "8"))  # <<< parallelism

# Optional filter to limit which trees you keep (comma-separated prefixes).
# Example: INCLUDE_PREFIXES="Mathlib,Aesop,Std,Init,Lean"
INCLUDE_PREFIXES = os.getenv("INCLUDE_PREFIXES", "").strip()
INCLUDE_LIST = [p.strip().strip("/") for p in INCLUDE_PREFIXES.split(",") if p.strip()]

PREFERRED_NAMESPACES = [
    "Nat", "Int", "Rat", "Real", "Complex", "ENat", "NNReal", "EReal", "Monoid",
    "CommMonoid", "Group", "CommGroup", "Ring", "CommRing", "Field", "Algebra",
    "Module", "Set", "Finset", "Fintype", "Multiset", "List", "Fin",
    "BigOperators", "Filter", "Polynomial", "SimpleGraph.Walk", "Equiv",
    "Embedding", "Injective", "Surjective", "Bijective", "Order", "Topology"
]

# ==============================
# Utilities
# ==============================

def _normalize_url(u: str) -> str:
    u, _ = urldefrag(u)
    return u

def _in_docs_root(u: str) -> bool:
    return _normalize_url(u).startswith(BASE_ROOT)

def save_lines(path: str, lines):
    with open(path, "w", encoding="utf-8") as f:
        for ln in sorted(set(lines)):
            f.write(ln + "\n")

def load_lines(path: str):
    if not os.path.exists(path): return []
    with open(path, "r", encoding="utf-8") as f:
        return [ln.strip() for ln in f if ln.strip()]

def _passes_prefix_filter(u: str) -> bool:
    if not INCLUDE_LIST:
        return True
    rel = _normalize_url(u)[len(BASE_ROOT):]
    top = rel.split("/", 1)[0]
    return top in INCLUDE_LIST

# ==============================
# Phase 1: Seed from navbar.html
# ==============================

def seed_from_navbar():
    print(f"Fetching navbar: {NAVBAR_URL}")
    r = requests.get(NAVBAR_URL, timeout=30, headers={"User-Agent": "LeanDocsCrawler/1.0"})
    r.raise_for_status()
    soup = BeautifulSoup(r.text, "html.parser")

    seeds = set()

    # A) any element with data-path="...html"
    for el in soup.select("[data-path]"):
        path = (el.get("data-path") or "").strip()
        if path.endswith(".html"):
            abs_url = urljoin(BASE_ROOT, path.lstrip("./"))
            if _in_docs_root(abs_url) and _passes_prefix_filter(abs_url):
                seeds.add(_normalize_url(abs_url))

    # B) any anchor with href="...html"
    for a in soup.select("a[href]"):
        href = a["href"].strip()
        if href.endswith(".html"):
            abs_url = urljoin(BASE_ROOT, href)
            if _in_docs_root(abs_url) and _passes_prefix_filter(abs_url):
                seeds.add(_normalize_url(abs_url))

    print(f"Navbar seeds: {len(seeds)}")
    return sorted(seeds)

# ==============================
# Phase 2: Discover pages (navbar-only)
# ==============================

HTML_RE = re.compile(r"\.html?$", re.IGNORECASE)

def discover_mathlib_pages():
    cached = load_lines(DISCOVERED_LINKS_FILE)
    if cached:
        print(f"Loaded {len(cached)} links from cache: {DISCOVERED_LINKS_FILE}")
        return cached

    print(f"Fetching navbar: {NAVBAR_URL}")
    r = requests.get(NAVBAR_URL, timeout=30, headers={"User-Agent": "LeanDocsCrawler/1.0"})
    r.raise_for_status()
    soup = BeautifulSoup(r.text, "html.parser")

    urls = set()

    # Collect from data-path
    for el in soup.select("[data-path]"):
        path = (el.get("data-path") or "").strip()
        if path.endswith(".html"):
            abs_url = urljoin(BASE_ROOT, path.lstrip("./"))
            if _passes_prefix_filter(abs_url):
                urls.add(_normalize_url(abs_url))

    # Collect from anchors
    for a in soup.select("a[href]"):
        href = a["href"].strip()
        if href.endswith(".html"):
            abs_url = urljoin(BASE_ROOT, href)
            if _passes_prefix_filter(abs_url):
                urls.add(_normalize_url(abs_url))

    urls = sorted(urls)[:CRAWL_MAX_PAGES]
    print(f"Discovered {len(urls)} pages from navbar only.")
    save_lines(DISCOVERED_LINKS_FILE, urls)
    return urls

# ==============================
# Selenium driver management (thread-local)
# ==============================

_thread_local = threading.local()
_all_drivers = set()
_all_drivers_lock = threading.Lock()

def make_driver():
    chrome_options = Options()
    chrome_options.add_argument("--headless=new")
    chrome_options.add_argument("--no-sandbox")
    chrome_options.add_argument("--disable-dev-shm-usage")
    chrome_options.add_argument("--disable-gpu")
    chrome_options.add_argument("--disable-software-rasterizer")
    chrome_options.binary_location = chrome_path

    service = Service(chromedriver_path)
    driver = webdriver.Chrome(service=service, options=chrome_options)
    with _all_drivers_lock:
        _all_drivers.add(driver)
    return driver

def get_thread_driver():
    drv = getattr(_thread_local, "driver", None)
    if drv is None:
        drv = make_driver()
        _thread_local.driver = drv
    return drv

def close_all_drivers():
    with _all_drivers_lock:
        drivers = list(_all_drivers)
        _all_drivers.clear()
    for d in drivers:
        try:
            d.quit()
        except Exception:
            pass

# ==============================
# Scraping helpers
# ==============================

def _extract_with_driver(driver, page_url: str):
    driver.get(page_url)
    time.sleep(REQUEST_DELAY)

    blocks = driver.find_elements(By.CLASS_NAME, "decl")
    out = []
    for block in blocks:
        try:
            name = (block.get_attribute("id") or "").strip()
            if not name:
                continue
            # decl_header preferred; fallback to decl_type
            try:
                type_signature = block.find_element(By.CLASS_NAME, "decl_header").text.strip()
            except Exception:
                try:
                    type_signature = block.find_element(By.CLASS_NAME, "decl_type").text.strip()
                except Exception:
                    type_signature = "No type signature found"

            description = "\n".join(
                p.text.strip() for p in block.find_elements(By.TAG_NAME, "p") if p.text.strip()
            )

            out.append({
                "definition_name": name,
                "source_url": page_url,
                "type_signature": type_signature,
                "description": description
            })
        except Exception:
            continue
    return out

def extract_definitions_from_page(page_url: str):
    """Extracts all declaration blocks from a single page using a thread-local driver."""
    for attempt in range(1, MAX_RETRIES + 1):
        try:
            driver = get_thread_driver()
            return _extract_with_driver(driver, page_url)
        except Exception as e:
            print(f"⚠ Error fetching {page_url} (attempt {attempt}): {e}")
            time.sleep(REQUEST_DELAY * (attempt + 1))
    return []

# ==============================
# Parallel scraping
# ==============================

def scrape_all_definitions_parallel():
    page_links = discover_mathlib_pages()
    if not page_links:
        print("⚠ No pages discovered. Nothing to scrape.")
        return []

    # Resume if exists
    all_defs = []
    processed = set()
    if os.path.exists(OUTPUT_JSON):
        try:
            with open(OUTPUT_JSON, "r", encoding="utf-8") as f:
                existing = json.load(f)
            all_defs.extend(existing)
            processed = {d.get("source_url") for d in existing if d.get("source_url")}
            page_links = [u for u in page_links if u not in processed]
            print(f"Resuming: loaded {len(existing)} existing definitions; remaining pages: {len(page_links)}")
        except Exception:
            pass

    # Limit to max pages if desired
    page_links = page_links[:CRAWL_MAX_PAGES]

    # Submit tasks
    futures = {}
    with ThreadPoolExecutor(max_workers=MAX_WORKERS, thread_name_prefix="lean-scraper") as ex:
        for url in page_links:
            futures[ex.submit(extract_definitions_from_page, url)] = url

        pbar = tqdm(total=len(futures), desc=f"Scraping Lean Docs (workers={MAX_WORKERS})")
        try:
            for fut in as_completed(futures):
                url = futures[fut]
                defs = []
                try:
                    defs = fut.result()
                except Exception as e:
                    print(f"⚠ Unhandled error for {url}: {e}")
                if defs:
                    all_defs.extend(defs)
                # Progressive save from main thread
                try:
                    with open(OUTPUT_JSON, "w", encoding="utf-8") as f:
                        json.dump(all_defs, f, indent=2, ensure_ascii=False)
                except Exception as e:
                    print(f"⚠ Failed to write {OUTPUT_JSON}: {e}")
                pbar.update(1)
        finally:
            pbar.close()

    print(f"Scraping complete! {len(all_defs)} definitions saved in {OUTPUT_JSON}")
    return all_defs

# ==============================
# Processing (optional)
# ==============================

def load_json_definitions(input_file=OUTPUT_JSON):
    if not os.path.exists(input_file):
        print(f"⚠ JSON file not found: {input_file}. Returning empty dict.")
        return {}
    with open(input_file, "r", encoding="utf-8") as f:
        data = json.load(f)
    defs = {d["definition_name"]: d for d in data if "definition_name" in d}
    with open(PKL_FILE, "wb") as f:
        pickle.dump(defs, f)
    return defs

def load_theorems(path):
    if not os.path.exists(path):
        print(f"⚠ Pickle file not found: {path}. Returning empty dict.")
        return {}
    with open(path, "rb") as f:
        return pickle.load(f)

def split_namespace(full_name):
    parts = full_name.split(".")
    return ".".join(parts[:-1]), parts[-1]

def partition_by_namespace(theorem_dict, filter_to_preferred=False):
    ns_dict = defaultdict(dict)
    for full_name, definition in theorem_dict.items():
        ns, _ = split_namespace(full_name)
        if filter_to_preferred and not any(
            ns == pref or ns.startswith(pref + ".") for pref in PREFERRED_NAMESPACES
        ):
            continue
        ns_dict[ns][full_name] = definition
    return dict(ns_dict)

def save_pickle(obj, path):
    with open(path, "wb") as f:
        pickle.dump(obj, f)

# ==============================
# Main
# ==============================

if __name__ == "__main__":
    try:
        # 1) Parallel scrape
        scrape_all_definitions_parallel()

        # 2) Optional: persist a name->entry pickle
        definitions = load_json_definitions()

        # 3) Optional: partition an existing {full_name: def} mapping
        original = load_theorems(INPUT_FILE)
        full_partition = partition_by_namespace(original, filter_to_preferred=False)
        preferred_partition = partition_by_namespace(original, filter_to_preferred=True)
        save_pickle(full_partition, FULL_OUTPUT_FILE)
        save_pickle(preferred_partition, PREFERRED_OUTPUT_FILE)

        print(f"Saved full partitioned dict to {FULL_OUTPUT_FILE} ({len(full_partition)} namespaces)")
        print(f"Saved preferred-only partitioned dict to {PREFERRED_OUTPUT_FILE} ({len(preferred_partition)} namespaces)")
    finally:
        # Ensure all per-thread drivers are closed
        close_all_drivers()
