#!/usr/bin/env python3
# pipeline_b_fair_plus.py — FAISS + llama.cpp with multi-objective reranking, graph retrieval, committee + abstention

import argparse, json, os, re, time, math, random
from typing import List, Dict, Tuple
import numpy as np
import pandas as pd
from tqdm import tqdm
import torch

from langchain_community.vectorstores import FAISS
from langchain_community.embeddings import HuggingFaceEmbeddings
from langchain.docstore.document import Document
from llama_cpp import Llama

GENERAL_POLICY_TEXT = """
# General Content Moderation Policy Framework
## CRITERIA FOR FLAGGING (Immediate Human Intervention Required)
Content should be flagged if it contains indications of:
- Potential harm to self or others
- Severe mental health crises
- Highly sensitive personal information
- Illegal activities or dangerous behavior
- Discriminatory or hate speech
- Severe harassment or threats
- Crisis situations requiring professional intervention
## CRITERIA FOR NOT FLAGGING (Automated Processing Allowed)
Content should NOT be flagged if it contains:
- Normal professional communication
- Routine personal or professional concerns
- Appropriate personal experiences
- Constructive feedback or advice
- General inquiries
"""
SAFETY_FIRST = (
    "If any FLAGGED criteria are plausibly present, prefer FLAGGED. "
    "Only output NOT FLAGGED if clearly none apply. "
    "Return EXACTLY one token: FLAGGED or NOT FLAGGED."
)

def load_jsonl(path: str) -> List[Dict]:
    with open(path,"r",encoding="utf-8") as f:
        return [json.loads(x) for x in f]

def parse_pred(s: str) -> str:
    if s is None: return "UNKNOWN"
    t = re.sub(r"[^A-Za-z ]","",str(s)).strip().upper()
    if "NOT FLAGGED" in t: return "NOT FLAGGED"
    if "FLAGGED" in t and "NOT" not in t: return "FLAGGED"
    return "UNKNOWN"

def build_or_load_vectorstore(train_rows: List[Dict], vs_dir: str) -> Tuple[FAISS, HuggingFaceEmbeddings]:
    os.makedirs(vs_dir, exist_ok=True)
    device = "cuda" if torch.cuda.is_available() else "cpu"
    emb = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2", model_kwargs={"device":device})
    if os.path.exists(os.path.join(vs_dir,"index.faiss")):
        vs = FAISS.load_local(vs_dir, emb, allow_dangerous_deserialization=True)
        return vs, emb
    docs = [Document(page_content=r["text"], metadata={"label":r["label"],"dataset":r.get("dataset",""),"text_id":r["text_id"]}) for r in train_rows]
    vs = FAISS.from_documents(docs, emb); vs.save_local(vs_dir); return vs, emb

def mmr_rerank(docs: List[Document], embeddings: HuggingFaceEmbeddings, q: str, k: int, lam: float=0.6) -> List[Document]:
    if len(docs) <= k: return docs
    vecs = embeddings.embed_documents([d.page_content for d in docs])
    qv = embeddings.embed_query(q)
    vecs = np.array(vecs, dtype=np.float32); qv = np.array(qv, dtype=np.float32)
    sims = (vecs @ qv) / (np.linalg.norm(vecs,axis=1)*np.linalg.norm(qv)+1e-8)
    selected = []
    remaining = list(range(len(docs)))
    while remaining and len(selected)<k:
        if not selected:
            i = int(np.argmax(sims[remaining])); selected.append(remaining.pop(i)); continue
        # redundancy penalty: max sim to selected
        red = []
        for r in remaining:
            sim_q = sims[r]
            sim_sel = max(( (vecs[r] @ vecs[s])/(np.linalg.norm(vecs[r])*np.linalg.norm(vecs[s])+1e-8) for s in selected ), default=0.0)
            score = lam*sim_q - (1-lam)*sim_sel
            red.append(score)
        i = int(np.argmax(red)); selected.append(remaining.pop(i))
    return [docs[i] for i in selected[:k]]

def balance_labels(docs: List[Document], k:int=10, target_ratio: float=0.5) -> List[Document]:
    flg = [d for d in docs if d.metadata.get("label")=="FLAGGED"]
    notf= [d for d in docs if d.metadata.get("label")=="NOT FLAGGED"]
    out=[]; need=int(round(target_ratio*k))
    while len(out)<k and (flg or notf):
        if len([o for o in out if o.metadata.get("label")=="FLAGGED"])<need and flg:
            out.append(flg.pop(0))
        elif notf:
            out.append(notf.pop(0))
        elif flg:
            out.append(flg.pop(0))
    return out

def build_prompt(utterance: str, examples: List[Document], label_exposure: str) -> str:
    ex = []
    for i,d in enumerate(examples[:15]):
        lbl=d.metadata.get("label","UNKNOWN"); src=d.metadata.get("dataset","")
        sample=d.page_content[:350] + ("..." if len(d.page_content)>350 else "")
        if label_exposure=="exposed": ex.append(f"### Example {i+1} (Label: {lbl}, Source: {src})\n{sample}")
        else: ex.append(f"### Example {i+1} (Source: {src})\n{sample}")
    examples_str = "RETRIEVED EXAMPLES:\n" + ("\n".join(ex) if ex else "(none)")
    rules=("You are a content moderation classifier.\n"+SAFETY_FIRST+"\nAnswer with exactly one token.\n")
    return rules+"\nPOLICY:\n"+GENERAL_POLICY_TEXT+"\n"+examples_str+"\nUTTERANCE:\n"+utterance+"\nFINAL ANSWER (one token):\n"

def classify_committee(llm: Llama, vs: FAISS, emb: HuggingFaceEmbeddings, q: str, label_exposure: str, strategy: str, top_k: int, mmr_lambda: float, committee: int) -> Tuple[str, float]:
    votes=[]
    for _ in range(max(1,committee)):
        docs = vs.similarity_search(q, k=max(top_k,10))
        if strategy=="mmr":
            docs = mmr_rerank(docs, emb, q, k=top_k, lam=mmr_lambda)
        elif strategy=="balance":
            docs = balance_labels(docs, k=top_k, target_ratio=0.5)
        else:
            docs = docs[:top_k]
        prompt = build_prompt(q, docs, label_exposure)
        try:
            resp = llm.create_chat_completion(messages=[{"role":"user","content":prompt}], max_tokens=8, temperature=0.0, stop=["\n"])
            out = parse_pred(resp["choices"][0]["message"]["content"])
        except Exception:
            out = "NOT FLAGGED"
        votes.append(out)
    p_flag = float(sum(1 for v in votes if v=="FLAGGED"))/len(votes)
    hard = "FLAGGED" if p_flag>=0.5 else "NOT FLAGGED"
    return hard, p_flag

def main():
    ap = argparse.ArgumentParser()
    ap.add_argument("--train", default="train_augmented.jsonl")
    ap.add_argument("--test",  default="test_augmented.jsonl")
    ap.add_argument("--vectorstore", default="vectorstore_plus")
    ap.add_argument("--llm-path", required=True)
    ap.add_argument("--n-gpu-layers", type=int, default=40)
    ap.add_argument("--top-k", type=int, default=15)
    ap.add_argument("--label-exposure", choices=["hidden","exposed"], default="exposed")
    # retrieval strategy
    ap.add_argument("--retrieval-strategy", choices=["baseline","mmr","balance","graph"], default="mmr")
    ap.add_argument("--mmr-lambda", type=float, default=0.6)
    ap.add_argument("--graph-path", default="graph_store")
    # uncertainty / abstention
    ap.add_argument("--committee", type=int, default=5)
    ap.add_argument("--flagged-threshold", type=float, default=0.4)
    ap.add_argument("--abstain-label", default="")
    args = ap.parse_args()
    train_rows = load_jsonl(args.train)
    test_rows = load_jsonl(args.test)
    vs, emb = build_or_load_vectorstore(train_rows, args.vectorstore)
    llm = Llama(model_path=args.llm_path, n_ctx=4096, n_gpu_layers=args.n_gpu_layers if torch.cuda.is_available() else 0, chat_format="chatml", verbose=False)

    os.makedirs("outputs/pipeline_b_plus", exist_ok=True)
    preds_path = "outputs/pipeline_b_plus/preds.jsonl"
    ret_path   = "outputs/pipeline_b_plus/retrieval.jsonl"

    # optional graph retriever
    graph_ret = None
    labels_map = {r["text_id"]: r["label"] for r in train_rows}
    if args.retrieval_strategy=="graph":
        from NEW_graph_retrieval import GraphRetriever
        graph_ret = GraphRetriever(args.graph_path)

    t0=time.perf_counter()
    with open(preds_path,"w",encoding="utf-8") as fp, open(ret_path,"w",encoding="utf-8") as fr:
        for r in tqdm(test_rows, desc="Pipeline B+ classify"):
            q, gt, ds, tid = r["text"], r["label"], r.get("dataset",""), r["text_id"]
            if graph_ret:
                ids_scores = graph_ret.retrieve(q, labels_map, top_k=args.top_k, alpha=1.0, beta=0.5, gamma=0.6, balance_target=0.5)
                wanted=set(t for t,_ in ids_scores)
                docs = vs.similarity_search(q, k=max(args.top_k,20))
                docs = [d for d in docs if d.metadata.get("text_id") in wanted][:args.top_k]
                prompt = build_prompt(q, docs, args.label_exposure)
                try:
                    resp = llm.create_chat_completion(messages=[{"role":"user","content":prompt}], max_tokens=8, temperature=0.0, stop=["\n"])
                    hard = parse_pred(resp["choices"][0]["message"]["content"])
                except Exception:
                    hard="NOT FLAGGED"
                votes=[hard]
                for _ in range(max(0,args.committee-1)):
                    try:
                        resp = llm.create_chat_completion(messages=[{"role":"user","content":prompt}], max_tokens=8, temperature=0.0, stop=["\n"])
                        votes.append(parse_pred(resp["choices"][0]["message"]["content"]))
                    except Exception:
                        votes.append("NOT FLAGGED")
                p_flag = float(sum(1 for v in votes if v=="FLAGGED"))/len(votes)
            else:
                hard, p_flag = classify_committee(llm, vs, emb, q, args.label_exposure, args.retrieval_strategy, args.top_k, args.mmr_lambda, args.committee)

            final_pred = "FLAGGED" if p_flag>=args.flagged_threshold else "NOT FLAGGED"
            if args.abstain_label and 0.45 <= p_flag <= 0.55:
                final_pred = args.abstain_label

            fp.write(json.dumps({"text_id":tid,"dataset":ds,"true":gt,"pred":final_pred,"p_flagged":p_flag})+"\n")

            # retrieval metrics (label precision proxy)
            docs = vs.similarity_search(q, k=max(args.top_k,10))
            lbls=[d.metadata.get("label") for d in docs[:args.top_k]]
            grades=[1 if l==gt else 0 for l in lbls]
            for k in [3,5,10]:
                kk=min(k,len(docs)); hit=1.0 if any(grades[:kk]) else 0.0
                prec=float(sum(grades[:kk]))/kk if kk>0 else 0.0
                ndcg=sum((2**g-1)/math.log2(i+2) for i,g in enumerate(grades[:kk]))/max(1.0,sum((2**g-1)/math.log2(i+2) for i,g in enumerate(sorted(grades,reverse=True)[:kk])))
                lblp=float(sum(1 for l in lbls[:kk] if l==gt))/kk if kk>0 else 0.0
                json.dump({"text_id":tid,"dataset":ds,"true":gt,"k":kk,"hit":hit,"precision":prec,"ndcg":ndcg,"label_precision":lblp}, fr); fr.write("\n")

    print(f"Saved {preds_path} and {ret_path}. Elapsed {time.perf_counter()-t0:.2f}s.")

if __name__=="__main__":
    main()
