#!/usr/bin/env python3
# pipeline_a_fair_plus.py — LlamaIndex RAG with (a) multi-objective retrieval, (b) graph retrieval, (c) uncertainty/committee + abstention

import argparse, json, os, re, time, math, shutil, random
from typing import List, Dict, Optional, Tuple
import numpy as np
from tqdm import tqdm
import pandas as pd

from llama_index.core import Document, VectorStoreIndex, Settings, StorageContext, load_index_from_storage, PromptTemplate
from llama_index.core.schema import TextNode, NodeWithScore
from llama_index.core.node_parser import SentenceSplitter
from llama_index.llms.openai import OpenAI as LI_OpenAI
try:
    from llama_index.embeddings.openai import OpenAIEmbedding as LI_OpenAIEmb
except Exception:
    from llama_index.embeddings import OpenAIEmbedding as LI_OpenAIEmb  # type: ignore

GENERAL_POLICY_TEXT = """
# General Content Moderation Policy Framework
## CRITERIA FOR FLAGGING (Immediate Human Intervention Required)
Content should be flagged if it contains indications of:
- Potential harm to self or others
- Severe mental health crises
- Highly sensitive personal information
- Illegal activities or dangerous behavior
- Discriminatory or hate speech
- Severe harassment or threats
- Crisis situations requiring professional intervention
## CRITERIA FOR NOT FLAGGING (Automated Processing Allowed)
Content should NOT be flagged if it contains:
- Normal professional communication
- Routine personal or professional concerns
- Appropriate personal experiences
- Constructive feedback or advice
- General inquiries
"""
SAFETY_FIRST = (
    "If any FLAGGED criteria are plausibly present, prefer FLAGGED. "
    "Only output NOT FLAGGED if clearly none apply. "
    "Return EXACTLY one token: FLAGGED or NOT FLAGGED."
)

def load_jsonl(path: str) -> List[Dict]:
    with open(path, "r", encoding="utf-8") as f:
        return [json.loads(x) for x in f]

def parse_pred(s: str) -> str:
    if s is None: return "UNKNOWN"
    t = re.sub(r"[^A-Za-z ]", "", str(s)).strip().upper()
    if "NOT FLAGGED" in t: return "NOT FLAGGED"
    if "FLAGGED" in t and "NOT" not in t: return "FLAGGED"
    if t in {"FLAGGED","NOT FLAGGED"}: return t
    return "UNKNOWN"

def ndcg_at_k(grades: List[int], k: int) -> float:
    k = min(k, len(grades))
    if k==0: return 0.0
    def dcg(gs): return sum((2**g -1)/math.log2(i+2) for i,g in enumerate(gs[:k]))
    ideal = sorted(grades, reverse=True); idcg = dcg(ideal)
    return (dcg(grades)/idcg) if idcg>0 else 0.0

def build_nodes(train_rows: List[Dict], label_exposure: str, chunk_size: int) -> List[TextNode]:
    splitter = SentenceSplitter(chunk_size=chunk_size, chunk_overlap=80)
    nodes: List[TextNode] = [TextNode(text=GENERAL_POLICY_TEXT.strip(), metadata={"type":"policy"})]
    for r in train_rows:
        text, label, ds = r["text"], r["label"], r.get("dataset","unknown")
        for ch in splitter.split_text(text):
            t = ch if label_exposure=="hidden" else f"Text: {ch}\nLabel: {label}"
            nodes.append(TextNode(text=t, metadata={"type":"example","label":label,"dataset":ds,"text_id":r.get("text_id")}))
    return nodes

def build_index(nodes: List[TextNode], storage_dir: str, rebuild: bool) -> VectorStoreIndex:
    if rebuild and os.path.isdir(storage_dir):
        shutil.rmtree(storage_dir, ignore_errors=True)
    os.makedirs(storage_dir, exist_ok=True)
    if os.listdir(storage_dir):
        storage_context = StorageContext.from_defaults(persist_dir=storage_dir)
        return load_index_from_storage(storage_context)
    idx = VectorStoreIndex(nodes)
    idx.storage_context.persist(persist_dir=storage_dir)
    return idx

def mmr_rerank(cands: List[NodeWithScore], k: int, lam: float=0.5) -> List[NodeWithScore]:
    """MMR on node text embeddings accessed via index service (approx via scores only).
    We approximate redundancy using node score as similarity proxy when embeddings not exposed."""
    sel: List[NodeWithScore] = []
    pool = list(cands)
    while pool and len(sel)<k:
        if not sel:
            sel.append(pool.pop(0)); continue
        # penalize nodes similar to selected via metadata label equality (cheap proxy) + score proximity
        def utility(n):
            sim = n.score
            red = max((1.0 if (getattr(n.node,"metadata",{}).get("label")==getattr(s.node,"metadata",{}).get("label")) else 0.0) for s in sel) if sel else 0.0
            return lam*sim - (1-lam)*red
        pool.sort(key=utility, reverse=True)
        sel.append(pool.pop(0))
    return sel[:k]

def balance_labels(cands: List[NodeWithScore], target_ratio: float=0.5, k:int=10) -> List[NodeWithScore]:
    """Greedy pick to approach target FLAGGED ratio among top-k."""
    flag = [c for c in cands if (c.node.metadata or {}).get("label")=="FLAGGED"]
    notf = [c for c in cands if (c.node.metadata or {}).get("label")=="NOT FLAGGED"]
    out: List[NodeWithScore] = []
    need_flag = int(round(target_ratio*k))
    while len(out)<k and (flag or notf):
        if len([o for o in out if (o.node.metadata or {}).get("label")=="FLAGGED"]) < need_flag and flag:
            out.append(flag.pop(0))
        elif notf:
            out.append(notf.pop(0))
        elif flag:
            out.append(flag.pop(0))
    return out

def committee_vote(llm, prompt_tmpl: PromptTemplate, index: VectorStoreIndex, content: str, top_k: int, strategy: str, mmr_lambda: float, balance: bool, committee: int) -> Tuple[str, float]:
    """Return (hard_pred, p_flagged). We create slight perturbations via retriever randomness."""
    preds = []
    for _ in range(max(1, committee)):
        retriever = index.as_retriever(similarity_top_k=max(top_k,10))
        nodes = retriever.retrieve(content)
        # apply reranking strategy
        if strategy=="mmr":
            nodes = mmr_rerank(nodes, top_k, lam=mmr_lambda)
        elif strategy=="balance":
            nodes = balance_labels(nodes, target_ratio=0.5, k=top_k)
        else:
            nodes = nodes[:top_k]
        context = "\n\n".join([n.node.get_content() for n in nodes])
        qa_tmpl = prompt_tmpl
        resp = llm.predict(prompt=qa_tmpl, query_str=content, context_str=context)
        preds.append(parse_pred(resp))
    p_flagged = float(sum(1 for p in preds if p=="FLAGGED"))/len(preds)
    hard = "FLAGGED" if p_flagged >= 0.5 else "NOT FLAGGED"
    return hard, p_flagged

def main():
    ap = argparse.ArgumentParser()
    ap.add_argument("--train", default="train_augmented_2000.jsonl")
    ap.add_argument("--test",  default="test_augmented_2000.jsonl")
    ap.add_argument("--storage", default="storage_a_plus")
    ap.add_argument("--embedding", default="openai")
    ap.add_argument("--model", default="gpt-4o-mini")
    ap.add_argument("--temperature", type=float, default=0.0)
    ap.add_argument("--label-exposure", choices=["hidden","exposed"], default="exposed")
    ap.add_argument("--top-k", type=int, default=10)
    ap.add_argument("--chunk-size", type=int, default=512)
    ap.add_argument("--rebuild", action="store_true")
    # retrieval strategies
    ap.add_argument("--retrieval-strategy", choices=["baseline","mmr","balance","graph"], default="mmr")
    ap.add_argument("--mmr-lambda", type=float, default=0.6)
    ap.add_argument("--graph-path", default="graph_store")
    # uncertainty / abstention
    ap.add_argument("--committee", type=int, default=5)
    ap.add_argument("--flagged-threshold", type=float, default=0.4, help="lower threshold favors recall")
    ap.add_argument("--abstain-label", default="", help="e.g., REVIEW to abstain when uncertain")
    args = ap.parse_args()

    train_rows = load_jsonl(args.train); test_rows = load_jsonl(args.test)
    Settings.embed_model = LI_OpenAIEmb(model="text-embedding-3-small")
    Settings.llm = LI_OpenAI(model=args.model, temperature=args.temperature)

    nodes = build_nodes(train_rows, args.label_exposure, args.chunk_size)
    index = build_index(nodes, args.storage, args.rebuild)

    qa_tmpl = PromptTemplate(
        f"You are a content moderation classifier. Follow this policy strictly:\n"
        f"{GENERAL_POLICY_TEXT}\n{SAFETY_FIRST}\n"
        "CONTENT:\n{query_str}\nRetrieved Context:\n{context_str}\nFinal answer (FLAGGED or NOT FLAGGED) only:"
    )

    os.makedirs("outputs/pipeline_a_plus", exist_ok=True)
    preds_path = "outputs/pipeline_a_plus/preds.jsonl"
    ret_path   = "outputs/pipeline_a_plus/retrieval.jsonl"

    # optional graph retriever (external multi-objective)
    graph_ret = None
    labels_map = {r["text_id"]: r["label"] for r in train_rows}
    if args.retrieval_strategy == "graph":  # type: ignore[attr-defined]
        from NEW_graph_retrieval import GraphRetriever
        graph_ret = GraphRetriever(args.graph_path)

    t0 = time.perf_counter()
    with open(preds_path,"w",encoding="utf-8") as fp, open(ret_path,"w",encoding="utf-8") as fr:
        for r in tqdm(test_rows, desc="Pipeline A+ classify"):
            content, gt, ds, tid = r["text"], r["label"], r.get("dataset",""), r["text_id"]

            if graph_ret:
                # graph retrieval + class balance/diversity
                ids_scores = graph_ret.retrieve(content, labels_map, top_k=args.top_k, alpha=1.0, beta=0.5, gamma=0.6, balance_target=0.5)
                # fetch node texts from index by filtering
                retriever = index.as_retriever(similarity_top_k=max(args.top_k,10))
                nodes = retriever.retrieve(content)
                # map by text_id metadata if present
                wanted = set(t for t,_ in ids_scores)
                nodes = [n for n in nodes if (n.node.metadata or {}).get("text_id") in wanted][:args.top_k]
                context = "\n\n".join([n.node.get_content() for n in nodes])
                resp = Settings.llm.predict(prompt=qa_tmpl, query_str=content, context_str=context)
                pred_one = parse_pred(resp)
                # committee around graph selection
                votes = [pred_one]
                for _ in range(max(0,args.committee-1)):
                    resp2 = Settings.llm.predict(prompt=qa_tmpl, query_str=content, context_str=context)
                    votes.append(parse_pred(resp2))
                p_flag = float(sum(1 for v in votes if v=="FLAGGED"))/len(votes)
            else:
                # committee vote using baseline/mmr/balance
                hard, p_flag = committee_vote(Settings.llm, qa_tmpl, index, content, args.top_k,
                                              "mmr" if args.retrieval_strategy=="mmr" else ("balance" if args.retrieval_strategy=="balance" else "baseline"),
                                              args.mmr_lambda, args.retrieval_strategy=="balance", args.committee)
                pred_one = hard

            # threshold + optional abstention
            if p_flag >= args.flagged_threshold:
                final_pred = "FLAGGED"
            else:
                final_pred = "NOT FLAGGED"
            if args.abstain_label and 0.45 <= p_flag <= 0.55:
                final_pred = args.abstain_label

            fp.write(json.dumps({"text_id":tid,"dataset":ds,"true":gt,"pred":final_pred,"p_flagged":p_flag}, ensure_ascii=False)+"\n")

            # simple retrieval diagnostics: percent FLAGGED among top-k
            retriever = index.as_retriever(similarity_top_k=args.top_k)
            nodes = retriever.retrieve(content)
            labels_top = [ (n.node.metadata or {}).get("label") for n in nodes[:args.top_k] ]
            grades = [1 if (l==gt or (n.node.metadata or {}).get("type")=="policy") else 0 for n,l in zip(nodes[:args.top_k], labels_top)]
            for k in [3,5,10]:
                kk = min(k, len(nodes))
                hit = 1.0 if any(grades[:kk]) else 0.0
                prec = float(sum(grades[:kk]))/kk if kk>0 else 0.0
                ndcg = ndcg_at_k(grades, kk)
                lblp = float(sum(1 for l in labels_top[:kk] if l==gt))/kk if kk>0 else 0.0
                fr.write(json.dumps({"text_id":tid,"dataset":ds,"true":gt,"k":kk,"hit":hit,"precision":prec,"ndcg":ndcg,"label_precision":lblp})+"\n")

    print(f"Saved {preds_path} and {ret_path}. Elapsed {time.perf_counter()-t0:.2f}s.")

if __name__ == "__main__":
    main()
