#!/usr/bin/env python3

import os
#!/usr/bin/env python3
# pipeline_a_fair.py — pin embedder to index + rebuild option

import argparse, json, os, re, time, math, shutil
from typing import List, Dict, Optional
import numpy as np
from tqdm import tqdm
import pandas as pd

# --- LlamaIndex core ---
from llama_index.core import (
    Document, VectorStoreIndex, Settings,
    StorageContext, load_index_from_storage, PromptTemplate
)
from llama_index.core.schema import TextNode, NodeWithScore
from llama_index.core.node_parser import SentenceSplitter

# --- Embeddings: new plugin path then legacy fallback ---
_HF_IMPORT_ERR = None
try:
    from llama_index.embeddings.huggingface import HuggingFaceEmbedding
except Exception as _e:
    try:
        # legacy monolithic path (older LI versions)
        from llama_index.embeddings import HuggingFaceEmbedding  # type: ignore
    except Exception as _e2:
        _HF_IMPORT_ERR = (_e, _e2)

# --- OpenAI LLM/Emb (Pipeline A uses OpenAI LLM by design) ---
_HAVE_OPENAI = True
try:
    from llama_index.llms.openai import OpenAI as LI_OpenAI
    try:
        from llama_index.embeddings.openai import OpenAIEmbedding as LI_OpenAIEmb
    except Exception:
        from llama_index.embeddings import OpenAIEmbedding as LI_OpenAIEmb  # type: ignore
except Exception:
    _HAVE_OPENAI = False

GENERAL_POLICY_TEXT = """
# General Content Moderation Policy Framework
## CRITERIA FOR FLAGGING (Immediate Human Intervention Required)
Content should be flagged if it contains indications of:
- Potential harm to self or others
- Severe mental health crises
- Highly sensitive personal information
- Illegal activities or dangerous behavior
- Discriminatory or hate speech
- Severe harassment or threats
- Crisis situations requiring professional intervention
## CRITERIA FOR NOT FLAGGING (Automated Processing Allowed)
Content should NOT be flagged if it contains:
- Normal professional communication
- Routine personal or professional concerns
- Appropriate personal experiences
- Constructive feedback or advice
- General inquiries
"""

SAFETY_FIRST = (
    "If any FLAGGED criteria are plausibly present, prefer FLAGGED. "
    "Only output NOT FLAGGED if clearly none apply. "
    "Return EXACTLY one token: FLAGGED or NOT FLAGGED."
)

def load_jsonl(path: str) -> List[Dict]:
    rows = []
    with open(path, "r", encoding="utf-8") as f:
        for line in f:
            rows.append(json.loads(line))
    return rows

def parse_pred(s: str) -> str:
    if s is None: return "UNKNOWN"
    t = re.sub(r"[^A-Za-z ]", "", str(s)).strip().upper()
    if "NOT FLAGGED" in t: return "NOT FLAGGED"
    if "FLAGGED" in t and "NOT" not in t: return "FLAGGED"
    if t in {"FLAGGED","NOT FLAGGED"}: return t
    return "UNKNOWN"

def ndcg_at_k(grades: List[int], k: int) -> float:
    k = min(k, len(grades))
    if k == 0: return 0.0
    def dcg(gs): return sum((2**g - 1) / math.log2(i + 2) for i, g in enumerate(gs[:k]))
    ideal = sorted(grades, reverse=True)
    idcg = dcg(ideal)
    return (dcg(grades) / idcg) if idcg > 0 else 0.0

def diversity_at_k(texts: List[str], embedder, k: int) -> float:
    top_texts = texts[:k]
    if len(top_texts) <= 1: return 1.0
    embs = np.array([embedder.get_text_embedding(t) for t in top_texts], dtype=np.float32)
    norms = np.linalg.norm(embs, axis=1, keepdims=True) + 1e-8
    cos = (embs @ embs.T) / (norms @ norms.T)
    mask = np.triu(np.ones_like(cos, dtype=bool), k=1)
    sims = cos[mask]
    return 1.0 - float(np.mean(sims)) if sims.size else 1.0

def is_policy(nws: NodeWithScore) -> bool:
    return (nws.node.metadata or {}).get("type") == "policy"

def node_label(nws: NodeWithScore) -> Optional[str]:
    return (nws.node.metadata or {}).get("label")

def relevance_grade(nws: NodeWithScore, true_label: str) -> int:
    if is_policy(nws): return 1
    if node_label(nws) == true_label: return 1
    return 0

def build_or_load_index(
    train_rows: List[Dict],
    storage_dir: str,
    embed_kind: str,
    label_exposure: str,
    chunk_size: int,
    rebuild: bool = False,
):
    if rebuild and os.path.isdir(storage_dir):
        shutil.rmtree(storage_dir, ignore_errors=True)
    os.makedirs(storage_dir, exist_ok=True)

    # If storage exists, load and pin its embedder
    if os.listdir(storage_dir):
        storage_context = StorageContext.from_defaults(persist_dir=storage_dir)
        index = load_index_from_storage(storage_context)
        # Try to pull the embedder from the loaded index/service context
        try:
            embedder = index.service_context.embed_model  # newer LI
        except Exception:
            embedder = Settings.embed_model               # fallback
        # Pin global Settings to whatever the index uses
        if embedder is None and embed_kind == "hf":
            if _HF_IMPORT_ERR is not None:
                raise ImportError("Install: pip install -U llama-index-embeddings-huggingface")
            embedder = HuggingFaceEmbedding(model_name="sentence-transformers/all-MiniLM-L6-v2")
        Settings.embed_model = embedder
        return index, embedder

    # Fresh build with selected embedder
    if embed_kind == "hf":
        if _HF_IMPORT_ERR is not None:
            raise ImportError(
                "HuggingFaceEmbedding not found. Install:\n"
                "  pip install -U llama-index-embeddings-huggingface"
            )
        embedder = HuggingFaceEmbedding(model_name="sentence-transformers/all-MiniLM-L6-v2")
        Settings.embed_model = embedder
      #  Settings.llm = None
    elif embed_kind == "openai":
        if not _HAVE_OPENAI:
            raise RuntimeError("OpenAI packages missing. Install openai + llama-index-llms-openai + llama-index-embeddings-openai")
        embedder = LI_OpenAIEmb(model="text-embedding-3-small")  # 1536-dim
        Settings.embed_model = embedder
      #  Settings.llm = None
    else:
        raise ValueError("--embedding must be 'hf' or 'openai'")

    # Build nodes: policy + example chunks
    nodes: List[TextNode] = [TextNode(text=GENERAL_POLICY_TEXT.strip(), metadata={"type":"policy"})]
    splitter = SentenceSplitter(chunk_size=chunk_size, chunk_overlap=80)

    for r in train_rows:
        text = r["text"]
        label = r["label"]
        ds = r.get("dataset","unknown")
        chunks = splitter.split_text(text)
        for ch in chunks:
            t = ch if label_exposure == "hidden" else f"Text: {ch}\nLabel: {label}"
            nodes.append(TextNode(text=t, metadata={"type":"example","label":label,"dataset":ds}))

    index = VectorStoreIndex(nodes)
    index.storage_context.persist(persist_dir=storage_dir)
    return index, embedder

def sanity_check_embed_dim(index_embedder, retriever):
    """Try one dry retrieval to surface any shape mismatch early."""
    try:
        _ = retriever.retrieve("sanity check")
    except ValueError as e:
        if "not aligned" in str(e):
            raise SystemExit(
                "\n[Embed mismatch] Your stored index and current embedder differ.\n"
                "Fix by either:\n"
                "  a) Rebuilding: run with --rebuild (or delete the storage dir), or\n"
                "  b) Matching the embedder: use the same --embedding you used to build.\n"
            ) from e
        else:
            raise

def main():
    ap = argparse.ArgumentParser()
    ap.add_argument("--train", default="train_augmented.jsonl")
    ap.add_argument("--test", default="test_augmented.jsonl")
    ap.add_argument("--storage", default="storage_a")
    ap.add_argument("--model", default="gpt-4o-mini")
    ap.add_argument("--temperature", type=float, default=0.0)
    ap.add_argument("--rebuild", action="store_true", help="Delete storage and rebuild index")


    # AFTER  (set defaults to the winning config)
    ap.add_argument("--embedding", choices=["hf","openai"], default="openai")
    ap.add_argument("--label-exposure", choices=["hidden","exposed"], default="exposed")
    ap.add_argument("--top-k", type=int, default=10)
    ap.add_argument("--chunk-size", type=int, default=768)
    args = ap.parse_args()

    train_rows = load_jsonl(args.train)
    print("---------------------------------------------")
    print(len(train_rows), f"rows loaded from {args.train}")
    test_rows  = load_jsonl(args.test)
    print("---------------------------------------------")
    print(len(test_rows), f"rows loaded from {args.test}")

    index, embedder = build_or_load_index(
        train_rows, args.storage, args.embedding, args.label_exposure, args.chunk_size, rebuild=args.rebuild
    )

    # Pin the global embedder to the one used by the index (prevents 1536 vs 384 issues)
    Settings.embed_model = embedder

    if not _HAVE_OPENAI:
        raise RuntimeError(
            "OpenAI LLM is required for Pipeline A.\n"
            "Install: pip install -U openai llama-index-llms-openai\n"
            "Set: export OPENAI_API_KEY=YOUR_KEY"
        )
    Settings.llm = LI_OpenAI(model=args.model, temperature=args.temperature)

    qa_tmpl = PromptTemplate(
        f"You are a content moderation classifier. Follow this policy strictly:\n"
        f"{GENERAL_POLICY_TEXT}\n"
        f"{SAFETY_FIRST}\n"
        "CONTENT:\n{{query_str}}\nRetrieved Context:\n{{context_str}}\nFinal answer (FLAGGED or NOT FLAGGED) only:"
    )
    query_engine = index.as_query_engine(similarity_top_k=args.top_k, text_qa_template=qa_tmpl)
    retriever = index.as_retriever(similarity_top_k=args.top_k)

    # Early sanity check to catch any remaining embed mismatch
    sanity_check_embed_dim(embedder, retriever)

    preds_path = os.path.join("outputs", "preds_a.jsonl")
    ret_path   = os.path.join("outputs", "retrieval_a.jsonl")
    os.makedirs("outputs", exist_ok=True)

    t0 = time.perf_counter()
    with open(preds_path, "w", encoding="utf-8") as fp, open(ret_path,"w",encoding="utf-8") as fr:
        for row in tqdm(test_rows, desc="Pipeline A classify"):
            content = row["text"]
            gt = row["label"]
            ds = row.get("dataset","")
            tid = row["text_id"]

            resp = query_engine.query(content)
            pred = parse_pred(str(resp))
            fp.write(json.dumps({"text_id": tid, "dataset": ds, "true": gt, "pred": pred}, ensure_ascii=False) + "\n")

            nodes = retriever.retrieve(content)
            texts = [n.node.get_content() for n in nodes]
            grades = [1 if (is_policy(n) or node_label(n)==gt) else 0 for n in nodes]
            for k in [3,5,10]:
                kk = min(k, len(nodes))
                top_nodes = nodes[:kk]
                top_texts = texts[:kk]
                top_grades = grades[:kk]
                hit = 1.0 if any(top_grades) else 0.0
                ctx_prec = float(sum(top_grades))/kk if kk>0 else 0.0
                nDCG = ndcg_at_k(grades, kk)
                lbl_prec = float(sum(1 for n in top_nodes if node_label(n)==gt))/kk if kk>0 else 0.0
                div = diversity_at_k(top_texts, Settings.embed_model, kk)
                fr.write(json.dumps({
                    "text_id": tid, "dataset": ds, "true": gt, "k": kk,
                    "hit": hit, "precision": ctx_prec, "ndcg": nDCG,
                    "label_precision": lbl_prec, "diversity": div
                }) + "\n")

    dur = time.perf_counter() - t0
    print(f"Saved {preds_path} and {ret_path}. Elapsed {dur:.2f}s.")

if __name__ == "__main__":
    main()
