import argparse
import json
import os
from collections import defaultdict
from typing import Dict, List, Any

# Optional: HDBSCAN (fallback to KMeans if unavailable)
try:
    import hdbscan  # type: ignore
    HDBSCAN_AVAILABLE = True
except Exception:
    HDBSCAN_AVAILABLE = False

import numpy as np
from sklearn.cluster import KMeans
from sklearn.metrics import silhouette_score
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.preprocessing import normalize as sk_normalize

# OpenAI client for embeddings (and optional LLM refinement)
try:
    from openai import OpenAI
except Exception:
    OpenAI = None  # type: ignore


# -----------------------------
# IO helpers
# -----------------------------

def load_results(path: str) -> Dict[str, List[Dict[str, str]]]:
    with open(path, "r", encoding="utf-8") as f:
        return json.load(f)


def iter_facts(results: Dict[str, List[Dict[str, str]]], group_strategy: str = "parentdir"):
    """
    Yields (group_id, page_path, fact_idx, fact_text, fact_source).
    group_strategy: "parentdir" (default) or "none"
    """
    for page_path, items in results.items():
        if not items:
            continue
        if group_strategy == "parentdir":
            # group by first directory (book folder). Example: "book1/page_001.png" -> "book1"
            group_id = page_path.strip("/").split("/")[0]
        else:
            group_id = "all"
        for i, it in enumerate(items):
            fact = (it.get("fact") or "").strip()
            if not fact:
                continue
            source = (it.get("source") or "text").lower().strip()
            yield group_id, page_path, i, fact, source


def make_fact_id(group_id: str, page_path: str, idx: int) -> str:
    # Example: "book1::page_001.png::f3"
    base = os.path.basename(page_path)
    return f"{group_id}::{base}::f{idx}"


# -----------------------------
# Embeddings (OpenAI)
# -----------------------------

def embed_texts_openai(texts: List[str], model: str, client) -> np.ndarray:
    """
    Batch-embed with OpenAI embeddings. Returns (N, D) float32 array (L2-normalized).
    """
    out = []
    B = 128
    for i in range(0, len(texts), B):
        batch = texts[i:i+B]
        resp = client.embeddings.create(model=model, input=batch)
        for d in resp.data:
            out.append(d.embedding)
    arr = np.array(out, dtype=np.float32)
    # L2-normalize for cosine-friendly comparisons via euclidean
    arr = sk_normalize(arr, norm="l2", axis=1)
    return arr


# -----------------------------
# Clustering
# -----------------------------

def cluster_with_hdbscan(X: np.ndarray, min_cluster_size: int = 3, min_samples: int | None = None):
    """
    Returns labels (shape N,) where -1 is noise.
    """
    if not HDBSCAN_AVAILABLE:
        raise RuntimeError("HDBSCAN not available")
    clusterer = hdbscan.HDBSCAN(min_cluster_size=min_cluster_size,
                                min_samples=min_samples,
                                metric="euclidean",
                                cluster_selection_method="eom")
    labels = clusterer.fit_predict(X)
    return labels


def cluster_with_kmeans(X: np.ndarray, k_min: int = 2, k_max: int | None = 20, random_state: int = 0):
    n = X.shape[0]
    if k_max is None:
        k_max = min(20, n)
    k_max = max(k_min, min(k_max, n))
    if n <= 2:
        return np.zeros(n, dtype=int)
    best_k, best_score, best_labels = None, -1.0, None
    for k in range(k_min, k_max + 1):
        try:
            km = KMeans(n_clusters=k, n_init="auto", random_state=random_state)
        except TypeError:
            km = KMeans(n_clusters=k, n_init=10, random_state=random_state)
        labels = km.fit_predict(X)
        if len(set(labels)) == 1:
            continue
        try:
            score = silhouette_score(X, labels)
        except Exception:
            score = -1.0
        if score > best_score:
            best_k, best_score, best_labels = k, score, labels
    if best_labels is None:
        # Fallback: all in one cluster
        best_labels = np.zeros(n, dtype=int)
    return best_labels


def infer_modality(facts: List[Dict[str, Any]], lo: float = 0.2, hi: float = 0.8) -> str:
    n = len(facts)
    if n == 0:
        return "mixed"
    n_text = sum(1 for f in facts if (f.get("source") or "text").lower() == "text")
    frac = n_text / n
    if frac >= hi:
        return "text"
    if frac <= lo:
        return "figure"
    return "mixed"


# -----------------------------
# Topics via TF-IDF
# -----------------------------

def tfidf_topics_and_terms(clusters: Dict[int, List[Dict[str, Any]]],
                           ngram_range=(1,2),
                           top_k: int = 5) -> Dict[int, Dict[str, Any]]:
    """
    Create compact topics from cluster facts using TF-IDF keywords.
    Returns: {cid: {"hint": "apple, red jacket, ...", "terms": ["apple","red jacket",...]}}
    """
    all_docs = []
    for _, facts in clusters.items():
        for f in facts:
            all_docs.append(f["fact"])
    if not all_docs:
        return {cid: {"hint":"Misc", "terms": []} for cid in clusters.keys()}

    vec = TfidfVectorizer(ngram_range=ngram_range,
                          max_features=8000,
                          lowercase=True,
                          strip_accents="unicode")
    V = vec.fit_transform(all_docs)

    # Map each fact's row back to cluster id
    row = 0
    per_cluster_rows: Dict[int, List[int]] = defaultdict(list)
    for cid, facts in clusters.items():
        for _ in facts:
            per_cluster_rows[cid].append(row)
            row += 1

    terms = np.array(vec.get_feature_names_out())
    out: Dict[int, Dict[str, Any]] = {}
    for cid, rows in per_cluster_rows.items():
        if not rows:
            out[cid] = {"hint":"Misc", "terms":[]}
            continue
        M = V[rows].mean(axis=0)
        M = M.A1 if hasattr(M, "A1") else np.asarray(M).ravel()
        idx = np.argsort(-M)[:top_k]
        top_terms = [t for t in terms[idx] if t]
        hint = ", ".join(top_terms).strip(", ")
        if not hint:
            hint = "Misc"
        out[cid] = {"hint": hint, "terms": top_terms}
    return out


# -----------------------------
# LLM prompts (optional)
# -----------------------------

LLM_TOPIC_SYSTEM = """You create short, neutral topics from clusters of atomic, literal facts taken from the books.
Rules:
- 3–6 words, Title Case.
- Use only the provided facts; no outside knowledge.
- Prefer concrete nouns/verbs over themes or morals.
Return ONLY JSON: {"topic": "..."}
""".strip()

LLM_DESC_SYSTEM = """You write a concise, objective description for a topic derived from literal facts from the books.
Rules:
- maximum 100 words total.
- Neutral and literal; present tense; third person.
- Use ONLY the provided facts; no outside knowledge or inference beyond what is stated.
- Avoid morals, lessons, feelings, and plot speculation.
Return ONLY JSON: {"description": "..."}
""".strip()


def llm_refine_topic(topic_hint: str, sample_facts: List[str], client, model: str) -> str:
    excerpt = "\n".join(f"- {s}" for s in sample_facts[:10])
    msgs = [
        {"role":"system","content": LLM_TOPIC_SYSTEM},
        {"role":"user","content": f"Topic hint: {topic_hint}\nFacts:\n{excerpt}\n\nReturn ONLY JSON."}
    ]
    try:
        resp = client.chat.completions.create(model=model, temperature=0.0,
                                              response_format={"type":"json_object"},
                                              messages=msgs)
        content = resp.choices[0].message.content
        data = json.loads(content)
        out = (data.get("topic") or "").strip()
        if not out:
            return topic_hint.title()[:100]
        return out[:100]
    except Exception:
        return topic_hint.title()[:100]


def llm_generate_description(topic: str, sample_facts: List[str], modality: str, client, model: str) -> str:
    """Generate a short description with an LLM, grounded in provided facts."""
    excerpt = "\n".join(f"- {s}" for s in sample_facts[:8])
    msgs = [
        {"role":"system","content": LLM_DESC_SYSTEM},
        {"role":"user","content": f"Topic: {topic}\nModality: {modality}\nFacts:\n{excerpt}\n\nReturn ONLY JSON."}
    ]
    try:
        resp = client.chat.completions.create(model=model, temperature=0.0,
                                              response_format={"type":"json_object"},
                                              messages=msgs)
        content = resp.choices[0].message.content
        data = json.loads(content)
        out = (data.get("description") or "").strip()
        if not out:
            return heuristic_description(topic, sample_facts, modality)
        # Keep to ~40 words hard cap
        return " ".join(out.split())
    except Exception:
        return heuristic_description(topic, sample_facts, modality)


# -----------------------------
# Heuristic description (no LLM)
# -----------------------------

def heuristic_description(topic_hint: str, facts: List[str], modality: str) -> str:
    """
    Produce a literal 1–2 sentence description using the topic hint and a few short facts.
    Keeps to ~35–40 words and avoids inference.
    """
    facts = [s.strip() for s in facts if s.strip()]
    facts_sorted = sorted(facts, key=lambda s: len(s.split()))
    examples = facts_sorted[:2]
    base = topic_hint.strip().rstrip(".")
    if examples:
        ex = " ".join([e.rstrip(".") + "." for e in examples])
        desc = f"This cluster groups facts about {base.lower()}. Examples include: {ex}"
    else:
        desc = f"This cluster groups facts about {base.lower()}."
    words = desc.split()
    if len(words) > 40:
        desc = " ".join(words[:40]).rstrip(",;:") + "."
    return desc


# -----------------------------
# Main pipeline
# -----------------------------

def process(results_path: str,
            out_dir: str,
            group_strategy: str = "parentdir",
            embed_model: str = "text-embedding-3-large",
            topics_llm: bool = True,
            llm_model: str = "o4-mini",
            hdbscan_min_cluster_size: int = 3,
            kmeans_k_max: int = 20,
            keep_noise: bool = False):
    os.makedirs(out_dir, exist_ok=True)
    results = load_results(results_path)

    # Build groups -> facts
    groups: Dict[str, List[Dict[str, Any]]] = defaultdict(list)
    for group_id, page_path, i, fact, source in iter_facts(results, group_strategy=group_strategy):
        fid = make_fact_id(group_id, page_path, i)
        groups[group_id].append({"id": fid, "fact": fact, "source": source, "page": page_path})

    # OpenAI client (needed for embeddings; also used for optional LLM refinement)
    if OpenAI is None:
        raise RuntimeError("openai package not available. Install `openai` and set OPENAI_API_KEY.")
    try:
        client = OpenAI()  # uses OPENAI_API_KEY env var
    except Exception as e:
        raise RuntimeError("Failed to create OpenAI client. Ensure OPENAI_API_KEY is set.") from e

    # For collecting output
    out_jsonl = []
    topic_map = {}
    desc_map = {}

    for group_id, facts in groups.items():
        if not facts:
            continue
        texts = [f["fact"] for f in facts]

        # --- Embeddings ---
        X = embed_texts_openai(texts, model=embed_model, client=client)

        # --- Clustering ---
        if HDBSCAN_AVAILABLE and len(facts) >= max(5, hdbscan_min_cluster_size):
            labels = cluster_with_hdbscan(X, min_cluster_size=hdbscan_min_cluster_size)
        else:
            labels = cluster_with_kmeans(X, k_min=2, k_max=min(kmeans_k_max, max(2, len(facts)//3)))

        # Map labels -> facts
        clusters: Dict[int, List[Dict[str, Any]]] = defaultdict(list)
        for f, lbl in zip(facts, labels):
            if lbl == -1 and not keep_noise:
                continue  # skip noise
            clusters[int(lbl)].append(f)

        if not clusters:
            clusters = {0: facts}

        # --- Topic via TF-IDF (get hints + terms) ---
        hints = tfidf_topics_and_terms(clusters, ngram_range=(1,2), top_k=5)

        # --- Assemble clusters ---
        for local_cid, cfacts in clusters.items():
            cluster_uid = f"{group_id}_c{local_cid if local_cid>=0 else 'noise'}"
            modality = infer_modality(cfacts)
            hint = hints.get(local_cid, {}).get("hint", "Misc")

            # Topic (LLM refine or TF-IDF title case)
            if topics_llm:
                sample = [f["fact"] for f in cfacts[:6]]
                topic = llm_refine_topic(hint, sample, client=client, model=llm_model)
                topic_method = "llm"
            else:
                topic = hint.title()[:80]
                topic_method = "tfidf"

            # Description (LLM or heuristic)
            sample_facts = [f["fact"] for f in cfacts]
            if topics_llm:
                description = llm_generate_description(topic, sample_facts, modality, client, llm_model)
                description_method = "llm"
            else:
                description = heuristic_description(topic, sample_facts, modality)
                description_method = "heuristic"

            topic_map[cluster_uid] = topic
            desc_map[cluster_uid] = description
            out_jsonl.append({
                "group_id": group_id,
                "cluster_id": cluster_uid,
                "topic": topic,
                "topic_method": topic_method,
                "description": description,
                "description_method": description_method,
                "modality": modality,
                "fact_ids": [f["id"] for f in cfacts],
                "facts": cfacts
            })

    # Write outputs
    clusters_path = os.path.join(out_dir, "clusters.jsonl")
    with open(clusters_path, "w", encoding="utf-8") as f:
        for rec in out_jsonl:
            f.write(json.dumps(rec, ensure_ascii=False) + "\n")

    topics_path = os.path.join(out_dir, "cluster_topics.json")
    with open(topics_path, "w", encoding="utf-8") as f:
        json.dump(topic_map, f, ensure_ascii=False, indent=2)

    descs_path = os.path.join(out_dir, "cluster_descriptions.json")
    with open(descs_path, "w", encoding="utf-8") as f:
        json.dump(desc_map, f, ensure_ascii=False, indent=2)

    print(f"Wrote: {clusters_path}")
    print(f"Wrote: {topics_path}")
    print(f"Wrote: {descs_path}")


def main():
    ap = argparse.ArgumentParser(description="Cluster facts and generate topics + descriptions.")
    ap.add_argument("--results", required=True, help="Path to JSON from extractor.")
    ap.add_argument("--out-dir", default="out_clusters", help="Directory to write outputs.")
    ap.add_argument("--group-strategy", choices=["parentdir","none"], default="parentdir",
                    help="Group facts by first-level folder (per book) or 'all'.")
    ap.add_argument("--embed-model", default="text-embedding-3-large")
    ap.add_argument("--topics-llm", action="store_true",
                    help="Use LLM to refine topics and descriptions (requires OPENAI_API_KEY).")
    ap.add_argument("--llm-model", default="o4-mini")
    ap.add_argument("--hdbscan-min-cluster-size", type=int, default=3)
    ap.add_argument("--kmeans-k-max", type=int, default=20)
    ap.add_argument("--keep-noise", action="store_true", help="Keep HDBSCAN noise as clusters (label -1).")
    args = ap.parse_args()

    process(results_path=args.results,
            out_dir=args.out_dir,
            group_strategy=args.group_strategy,
            embed_model=args.embed_model,
            topics_llm=args.topics_llm,
            llm_model=args.llm_model,
            hdbscan_min_cluster_size=args.hdbscan_min_cluster_size,
            kmeans_k_max=args.kmeans_k_max,
            keep_noise=args.keep_noise)


if __name__ == "__main__":
    main()
