#!/usr/bin/env python3
import argparse, json, os, re, time, math
from typing import List, Dict, Tuple, Optional
import numpy as np
from tqdm import tqdm
import torch

# Vector store & embeddings
from langchain_community.vectorstores import FAISS
from langchain_community.embeddings import HuggingFaceEmbeddings  # deprecation warning ok for now
from langchain.docstore.document import Document

# Sentence-Transformers CrossEncoder (optional reranker)
try:
    from sentence_transformers import CrossEncoder
    _HAVE_XENC = True
except Exception:
    _HAVE_XENC = False

# Local LLM
from llama_cpp import Llama

# Optional graph retrieval helper 
try:
    import graph_retrieval
    _HAVE_GRAPH = True
except Exception:
    _HAVE_GRAPH = False

GENERAL_POLICY_TEXT = """
# General Content Moderation Policy Framework
## CRITERIA FOR FLAGGING (Immediate Human Intervention Required)
Content should be flagged if it contains indications of:
- Potential harm to self or others
- Severe mental health crises
- Highly sensitive personal information
- Illegal activities or dangerous behavior
- Discriminatory or hate speech
- Severe harassment or threats
- Crisis situations requiring professional intervention
## CRITERIA FOR NOT FLAGGING (Automated Processing Allowed)
Content should NOT be flagged if it contains:
- Normal professional communication
- Routine personal or professional concerns
- Appropriate personal experiences
- Constructive feedback or advice
- General inquiries
"""

SAFETY_FIRST = (
    "If any FLAGGED criteria are plausibly present, prefer FLAGGED. "
    "Only output NOT FLAGGED if clearly none apply. "
    "Return EXACTLY one token: FLAGGED or NOT FLAGGED."
)

def load_jsonl(path: str) -> List[Dict]:
    rows = []
    with open(path, "r", encoding="utf-8") as f:
        for ln in f:
            rows.append(json.loads(ln))
    return rows

def normalize_q(text: str) -> str:
    t = re.sub(r"^(Agent|Customer|Client):\s*", "", text, flags=re.IGNORECASE)
    t = re.sub(r"[ \t]+", " ", t).strip()
    return t

def build_prompt(utterance: str, examples: List[Document], label_exposure: str) -> str:
    examples_str = "RETRIEVED EXAMPLES:\n"
    if examples:
        for i, d in enumerate(examples):
            lbl = d.metadata.get("stage","UNKNOWN")
            src = d.metadata.get("dataset","")
            sample = d.page_content
            if len(sample) > 350: sample = sample[:347] + "..."
            if label_exposure == "exposed":
                examples_str += f"### Example {i+1} (Label: {lbl}, Source: {src})\n{sample}\n"
            else:
                examples_str += f"### Example {i+1} (Source: {src})\n{sample}\n"
    else:
        examples_str += "(none)\n"

    rules = (
        "You are a content moderation classifier.\n"
        f"{SAFETY_FIRST}\n"
        "Answer with exactly one token.\n"
    )
    prompt = (
        rules + "\n" +
        "POLICY:\n" + GENERAL_POLICY_TEXT + "\n" +
        examples_str + "\n" +
        "UTTERANCE:\n" + utterance + "\n" +
        "FINAL ANSWER (one token):\n"
    )
    return prompt

def parse_label(out: str) -> str:
    s = (out or "").strip().upper()
    if "NOT FLAGGED" in s: return "NOT FLAGGED"
    if "FLAGGED" in s and "NOT" not in s: return "FLAGGED"
    return "NOT FLAGGED"

def ndcg_at_k(grades: List[int], k: int) -> float:
    k = min(k, len(grades))
    if k == 0: return 0.0
    def dcg(gs): return sum((2**g - 1) / math.log2(i + 2) for i, g in enumerate(gs[:k]))
    ideal = sorted(grades, reverse=True)
    idcg = dcg(ideal)
    return (dcg(grades) / idcg) if idcg > 0 else 0.0

def mmr_select(query_emb: np.ndarray, cand_embs: np.ndarray, lambda_: float, top_k: int) -> List[int]:
    """
    Basic MMR over cosine sim: maximize lambda*sim(q, d) - (1-lambda)*max_sim(d, S)
    """
    q = query_emb / (np.linalg.norm(query_emb) + 1e-8)
    E = cand_embs / (np.linalg.norm(cand_embs, axis=1, keepdims=True) + 1e-8)
    sims = (E @ q)
    selected = []
    remaining = set(range(len(E)))
    while len(selected) < min(top_k, len(E)) and remaining:
        best_i, best_score = None, -1e9
        for i in list(remaining):
            if not selected:
                score = sims[i]
            else:
                # diversity term
                S = E[selected]
                div = np.max(S @ E[i])
                score = lambda_ * sims[i] - (1.0 - lambda_) * div
            if score > best_score:
                best_score, best_i = score, i
        selected.append(best_i)
        remaining.remove(best_i)
    return selected

def build_or_load_vectorstore(train_rows: List[Dict], vs_dir: str, emb_model: str, device: str) -> Tuple[FAISS, HuggingFaceEmbeddings]:
    os.makedirs(vs_dir, exist_ok=True)
    emb = HuggingFaceEmbeddings(model_name=emb_model, model_kwargs={"device": device})
    if os.path.exists(os.path.join(vs_dir, "index.faiss")):
        vs = FAISS.load_local(vs_dir, emb, allow_dangerous_deserialization=True)
        return vs, emb

    docs = [Document(page_content=r["text"], metadata={"stage": r["label"], "dataset": r.get("dataset","")})
            for r in train_rows]
    vs = FAISS.from_documents(docs, emb)
    vs.save_local(vs_dir)
    return vs, emb

def embed_texts(emb: HuggingFaceEmbeddings, texts: List[str]) -> np.ndarray:
    return np.array(emb.embed_documents(texts), dtype=np.float32)

def crossencode_scores(xenc: CrossEncoder, query: str, docs: List[str]) -> np.ndarray:
    pairs = [[query, d] for d in docs]
    sc = np.array(xenc.predict(pairs), dtype=np.float32)
    return sc

def retrieve_candidates(
    q: str,
    vs: FAISS,
    emb: HuggingFaceEmbeddings,
    strategy: str,
    top_k: int,
    top_k0: int,
    mmr_lambda: float,
    graph_path: Optional[str],
    graph_api_available: bool,
) -> List[Document]:
    if strategy == "graph" and graph_api_available and graph_path:
        # get graph-aware seeds then map to docs via FAISS search on each seed
        seeds = graph_retrieval.query_knn_graph(q, emb, graph_path, top_k=top_k0)  # returns list[str] of “seed texts”
        # Consolidate by embedding seeds and MMR them with original q
        q_emb = np.array(emb.embed_query(q), dtype=np.float32)
        seed_embs = embed_texts(emb, seeds)
        keep_idx = mmr_select(q_emb, seed_embs, lambda_=mmr_lambda, top_k=min(top_k, len(seeds)))
        selected_texts = [seeds[i] for i in keep_idx]
        # Convert back to Documents (metadata unknown here; do a similarity search to fetch actual docs)
        docs = []
        for st in selected_texts:
            d = vs.similarity_search(st, k=1)
            if d: docs.append(d[0])
        return docs

    # default: vector-only path
    docs0 = vs.similarity_search(q, k=max(top_k, top_k0))
    if strategy == "mmr":
        if len(docs0) <= top_k:
            return docs0
        q_emb = np.array(emb.embed_query(q), dtype=np.float32)
        cand_texts = [d.page_content for d in docs0]
        cand_embs = embed_texts(emb, cand_texts)
        keep_idx = mmr_select(q_emb, cand_embs, lambda_=mmr_lambda, top_k=top_k)
        return [docs0[i] for i in keep_idx]
    else:
        # simple top-k
        return docs0[:top_k]

def maybe_rerank(
    query: str,
    docs: List[Document],
    reranker_name: Optional[str]
) -> List[Document]:
    if reranker_name is None:
        return docs
    if not _HAVE_XENC:
        return docs
    xenc = CrossEncoder(reranker_name)
    texts = [d.page_content for d in docs]
    scores = crossencode_scores(xenc, query, texts)
    order = np.argsort(-scores)
    return [docs[i] for i in order]

def committee_predict(
    llm: Llama,
    prompt: str,
    committee: int,
    temperature: float = 0.6,
    max_tokens: int = 8,
    chat_format: str = "llama-3"
) -> Tuple[str, float]:
    votes = []
    for _ in range(committee):
        try:
            resp = llm.create_chat_completion(
                messages=[{"role":"user","content":prompt}],
                max_tokens=max_tokens,
                temperature=temperature,
                stop=["\n"],
                chat_format=chat_format
            )
            out = resp["choices"][0]["message"]["content"]
        except Exception:
            out = "NOT FLAGGED"
        votes.append(parse_label(out))
    # probability = fraction of FLAGGED votes
    p_flagged = float(sum(1 for v in votes if v=="FLAGGED")) / max(1, committee)
    pred = "FLAGGED" if p_flagged >= 0.5 else "NOT FLAGGED"
    return pred, p_flagged

def main():
    ap = argparse.ArgumentParser()
    ap.add_argument("--train", default="train_augmented_2000.jsonl")
    ap.add_argument("--test",  default="test_augmented_2000.jsonl")
    ap.add_argument("--vectorstore", default="vectorstore_plus")
    ap.add_argument("--llm-path", required=True)
    ap.add_argument("--n-gpu-layers", type=int, default=40)
    ap.add_argument("--chat-format", default="llama-3")
    ap.add_argument("--label-exposure", choices=["hidden","exposed"], default="exposed")

    # Retrieval
    ap.add_argument("--retrieval-strategy", choices=["mmr","graph","topk"], default="mmr")
    ap.add_argument("--mmr-lambda", type=float, default=0.6)
    ap.add_argument("--graph-path", default=None)
    ap.add_argument("--top-k", type=int, default=15)
    ap.add_argument("--top-k0", type=int, default=100, help="first-stage recall size before rerank/MMR")

    # Embedding / reranker
    ap.add_argument("--hf-embed-model", default="sentence-transformers/all-MiniLM-L6-v2",
                    help="e.g., BAAI/bge-large-en-v1.5 or intfloat/e5-large-v2")
    ap.add_argument("--cross-encoder", default=None,
                    help="e.g., cross-encoder/ms-marco-MiniLM-L-6-v2 (optional)")

    # Committee / thresholding
    ap.add_argument("--committee", type=int, default=5)
    ap.add_argument("--flagged-threshold", type=float, default=0.4, help="prob threshold to map to FLAGGED")
    ap.add_argument("--temperature", type=float, default=0.6)

    args = ap.parse_args()

    device = "cuda" if torch.cuda.is_available() else "cpu"

    train_rows = load_jsonl(args.train)
    test_rows  = load_jsonl(args.test)

    # Build / load FAISS with chosen embedder
    vs, emb = build_or_load_vectorstore(train_rows, args.vectorstore, args.hf_embed_model, device)

    # Local LLM
    llm = Llama(model_path=args.llm_path, n_ctx=4096,
                n_gpu_layers=args.n_gpu_layers if torch.cuda.is_available() else 0,
                chat_format=args.chat_format, verbose=False)

    os.makedirs("outputs/pipeline_b_plus", exist_ok=True)
    preds_path = "outputs/pipeline_b_plus/preds.jsonl"
    ret_path   = "outputs/pipeline_b_plus/retrieval.jsonl"

    def classify(txt: str) -> Tuple[str, float]:
        q = normalize_q(txt)
        # Stage-1 retrieval
        cands0 = vs.similarity_search(q, k=max(args.top_k, args.top_k0))
        # Strategy selection (mmr/graph/topk)
        cands = retrieve_candidates(
            q, vs, emb,
            strategy=args.retrieval_strategy,
            top_k=args.top_k,
            top_k0=args.top_k0,
            mmr_lambda=args.mmr_lambda,
            graph_path=args.graph_path,
            graph_api_available=_HAVE_GRAPH
        )
        # Rerank (optional)
        cands = maybe_rerank(q, cands, args.cross_encoder)

        prompt = build_prompt(q, cands, args.label_exposure)
        pred_raw, p_vote = committee_predict(
            llm, prompt, committee=args.committee,
            temperature=args.temperature,
            max_tokens=8, chat_format=args.chat_format
        )
        # Convert committee vote to final label with recall-friendly threshold
        label = "FLAGGED" if p_vote >= args.flagged_threshold else "NOT FLAGGED"
        return label, p_vote, cands

    t0 = time.perf_counter()
    with open(preds_path,"w",encoding="utf-8") as fp, open(ret_path,"w",encoding="utf-8") as fr:
        for r in tqdm(test_rows, desc="Pipeline B+ classify"):
            gt = r["label"]; ds = r.get("dataset","")
            txt = r["text"]; tid = r["text_id"]

            label, p_vote, cands = classify(txt)
            fp.write(json.dumps({"text_id":tid,"dataset":ds,"true":gt,"pred":label,"p_flagged":p_vote})+"\n")

            # retrieval diagnostics
            # (we no longer write 'diversity' to match your NEW_evaluate_fair.py behavior)
            q = normalize_q(txt)
            # For diagnostics, take a larger candidate set for metrics
            cands_diag = vs.similarity_search(q, k=max(args.top_k, 10))
            texts = [c.page_content for c in cands_diag]
            grades = [1 if c.metadata.get("stage")==gt else 0 for c in cands_diag]
            for k in [3,5,10]:
                kk = min(k, len(cands_diag))
                hit = 1.0 if any(grades[:kk]) else 0.0
                prec = float(sum(grades[:kk]))/kk if kk>0 else 0.0
                ndcg = ndcg_at_k(grades, kk)
                lblp = float(sum(1 for c in cands_diag[:kk] if c.metadata.get("stage")==gt))/kk if kk>0 else 0.0
                fr.write(json.dumps({
                    "text_id": tid, "dataset": ds, "true": gt, "k": kk,
                    "hit": hit, "precision": prec, "ndcg": ndcg, "label_precision": lblp
                })+"\n")

    dur = time.perf_counter()-t0
    print(f"Saved {preds_path} and {ret_path}. Elapsed {dur:.2f}s.")

if __name__ == "__main__":
    main()
