#!/usr/bin/env python3
import argparse, json, os, re, time, math
from typing import List, Dict
import numpy as np
import pandas as pd
from tqdm import tqdm
import torch

from langchain_community.vectorstores import FAISS
from langchain_community.embeddings import HuggingFaceEmbeddings
from langchain.docstore.document import Document
from sentence_transformers import SentenceTransformer
from llama_cpp import Llama

GENERAL_POLICY_TEXT = """
# General Content Moderation Policy Framework
## CRITERIA FOR FLAGGING (Immediate Human Intervention Required)
Content should be flagged if it contains indications of:
- Potential harm to self or others
- Severe mental health crises
- Highly sensitive personal information
- Illegal activities or dangerous behavior
- Discriminatory or hate speech
- Severe harassment or threats
- Crisis situations requiring professional intervention
## CRITERIA FOR NOT FLAGGING (Automated Processing Allowed)
Content should NOT be flagged if it contains:
- Normal professional communication
- Routine personal or professional concerns
- Appropriate personal experiences
- Constructive feedback or advice
- General inquiries
"""

SAFETY_FIRST = (
    "If any FLAGGED criteria are plausibly present, prefer FLAGGED. "
    "Only output NOT FLAGGED if clearly none apply. "
    "Return EXACTLY one token: FLAGGED or NOT FLAGGED."
)

def load_jsonl(path: str) -> List[Dict]:
    rows = []
    with open(path, "r", encoding="utf-8") as f:
        for line in f:
            rows.append(json.loads(line))
    return rows

def build_or_load_vectorstore(train_rows: List[Dict], vs_dir: str) -> (FAISS, HuggingFaceEmbeddings):
    os.makedirs(vs_dir, exist_ok=True)
    device = "cuda" if torch.cuda.is_available() else "cpu"
    emb = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2",
                                model_kwargs={"device": device})
    if os.path.exists(os.path.join(vs_dir, "index.faiss")):
        vs = FAISS.load_local(vs_dir, emb, allow_dangerous_deserialization=True)
        return vs, emb

    docs = [Document(page_content=r["text"], metadata={"stage": r["label"], "dataset": r.get("dataset","")})
            for r in train_rows]
    vs = FAISS.from_documents(docs, emb)
    vs.save_local(vs_dir)
    return vs, emb

def normalize_q(text: str) -> str:
    t = re.sub(r"^(Agent|Customer|Client):\s*", "", text, flags=re.IGNORECASE)
    t = re.sub(r"[ \t]+", " ", t).strip()
    return t

def build_prompt(utterance: str, examples: List[Document], label_exposure: str) -> str:
    examples_str = "RETRIEVED EXAMPLES:\n"
    if examples:
        for i, d in enumerate(examples):
            lbl = d.metadata.get("stage","UNKNOWN")
            src = d.metadata.get("dataset","")
            sample = d.page_content
            if len(sample) > 350: sample = sample[:347] + "..."
            if label_exposure == "exposed":
                examples_str += f"### Example {i+1} (Label: {lbl}, Source: {src})\n{sample}\n"
            else:
                examples_str += f"### Example {i+1} (Source: {src})\n{sample}\n"
    else:
        examples_str += "(none)\n"

    rules = (
        "You are a content moderation classifier.\n"
        f"{SAFETY_FIRST}\n"
        "Answer with exactly one token.\n"
    )
    prompt = (
        rules + "\n" +
        "POLICY:\n" + GENERAL_POLICY_TEXT + "\n" +
        examples_str + "\n" +
        "UTTERANCE:\n" + utterance + "\n" +
        "FINAL ANSWER (one token):\n"
    )
    return prompt

def ndcg_at_k(grades: List[int], k: int) -> float:
    k = min(k, len(grades))
    if k == 0: return 0.0
    def dcg(gs): return sum((2**g - 1) / math.log2(i + 2) for i, g in enumerate(gs[:k]))
    ideal = sorted(grades, reverse=True)
    idcg = dcg(ideal)
    return (dcg(grades) / idcg) if idcg > 0 else 0.0

def diversity_at_k(texts: List[str], emb_model: HuggingFaceEmbeddings, k: int) -> float:
    top = texts[:k]
    if len(top) <= 1: return 1.0
    embs = np.array(emb_model.embed_documents(top), dtype=np.float32)
    sim = (embs @ embs.T) / (np.linalg.norm(embs, axis=1, keepdims=True) @ np.linalg.norm(embs, axis=1, keepdims=True).T + 1e-8)
    tri = np.triu(np.ones_like(sim, dtype=bool), 1)
    vals = sim[tri]
    return 1.0 - float(np.mean(vals)) if vals.size else 1.0

def main():
    ap = argparse.ArgumentParser()
  #  ap.add_argument("--train", default="train.jsonl")
    ap.add_argument("--train", default="train_augmented.jsonl")
    ap.add_argument("--test", default="test_augmented.jsonl")
    ap.add_argument("--vectorstore", default="vectorstore")
    ap.add_argument("--llm-path", required=True)
   # ap.add_argument("--top-k", type=int, default=5)
    ap.add_argument("--n-gpu-layers", type=int, default=40)
   # ap.add_argument("--label-exposure", choices=["hidden","exposed"], default="exposed")

    # AFTER  (set defaults to the winning config)
    ap.add_argument("--top-k", type=int, default=15)
    ap.add_argument("--label-exposure", choices=["hidden","exposed"], default="exposed")
    args = ap.parse_args()

    train_rows = load_jsonl(args.train)
    test_rows  = load_jsonl(args.test)
    vs, emb = build_or_load_vectorstore(train_rows, args.vectorstore)

    llm = Llama(model_path=args.llm_path, n_ctx=4096,
                n_gpu_layers=args.n_gpu_layers if torch.cuda.is_available() else 0,
                chat_format="chatml", verbose=False)

    preds_path = os.path.join("outputs", "preds_b.jsonl")
    ret_path   = os.path.join("outputs", "retrieval_b.jsonl")
    os.makedirs("outputs", exist_ok=True)

    def classify(txt: str) -> str:
        q = normalize_q(txt)
        docs = vs.similarity_search(q, k=args.top_k)
        prompt = build_prompt(q, docs, args.label_exposure)
        try:
            resp = llm.create_chat_completion(
                messages=[{"role":"user","content":prompt}],
                max_tokens=8, temperature=0.0, stop=["\n"]
            )
            out = resp["choices"][0]["message"]["content"].strip().upper()
        except Exception:
            out = "NOT FLAGGED"
        if "NOT FLAGGED" in out: return "NOT FLAGGED"
        if "FLAGGED" in out and "NOT" not in out: return "FLAGGED"
        return "NOT FLAGGED"

    t0 = time.perf_counter()
    with open(preds_path,"w",encoding="utf-8") as fp, open(ret_path,"w",encoding="utf-8") as fr:
        for r in tqdm(test_rows, desc="Pipeline B classify"):
            gt = r["label"]; ds = r.get("dataset","")
            txt = r["text"]; tid = r["text_id"]

            pred = classify(txt)
            fp.write(json.dumps({"text_id":tid,"dataset":ds,"true":gt,"pred":pred})+"\n")

            # retrieval diagnostics on test query
            q = normalize_q(txt)
            cands = vs.similarity_search(q, k=max(args.top_k,10))
            texts = [c.page_content for c in cands]
            grades = [1 if c.metadata.get("stage")==gt else 0 for c in cands]
            for k in [3,5,10]:
                kk = min(k, len(cands))
                top = cands[:kk]
                hit = 1.0 if any(grades[:kk]) else 0.0
                prec = float(sum(grades[:kk]))/kk if kk>0 else 0.0
                ndcg = ndcg_at_k(grades, kk)
                lblp = float(sum(1 for c in top if c.metadata.get("stage")==gt))/kk if kk>0 else 0.0
                div  = diversity_at_k(texts[:kk], emb, kk)
                fr.write(json.dumps({
                    "text_id": tid, "dataset": ds, "true": gt, "k": kk,
                    "hit": hit, "precision": prec, "ndcg": ndcg,
                    "label_precision": lblp, "diversity": div
                })+"\n")
    dur = time.perf_counter()-t0
    print(f"Saved {preds_path} and {ret_path}. Elapsed {dur:.2f}s.")

if __name__ == "__main__":
    main()
